# Notebook 04: Comprehensive Machine Learning Classification

This notebook implements and compares **9+ machine learning models** including:

## Traditional ML
- Random Forest
- XGBoost
- SVM
- MLP

## Deep Learning
- Baseline CNN
- ResNet-1D

## **Advanced Models (Novel Contributions)**
- **Physics-Informed Neural Networks (PINNs)** ⭐
- **Transformer (Waveform Attention)** ⭐
- **Vision Transformer (ViT) for Waveforms** ⭐
- **Wavelet Scattering Networks** ⭐
- **CNN-Transformer Hybrid** ⭐

## Comprehensive Evaluation
- Accuracy, precision, recall, F1-score
- Inference speed
- Data efficiency
- Noise robustness
- Model interpretability

In [None]:
# Setup
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Local imports
from src.io import WaveformDataset
from src.pulse_analysis import PulseFeatureExtractor
from src.ml import *

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print("✓ Imports successful")

## Part 1: Data Preparation

Load data from Notebook 01 (real CAEN digitizer CSV files or synthetic data)
and prepare for both traditional ML (features) and deep learning (raw waveforms)

In [None]:
# Load processed data from Notebook 01
waveforms = np.load('../data/processed/waveforms.npy')
labels = np.load('../data/processed/labels.npy')

print(f"Dataset shape: {waveforms.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Class distribution: {np.bincount(labels)}")
print(f"Scintillators: LYSO(0), BGO(1), NaI(2), Plastic(3)")

In [None]:
# Extract features for traditional ML
print("Extracting pulse shape features...")

extractor = PulseFeatureExtractor(sampling_rate_MHz=250)  # CAEN DT5720D sampling rate

feature_list = []
for i, waveform in enumerate(waveforms):
    if i % 500 == 0:
        print(f"  Processing {i}/{len(waveforms)}...")
    
    features = extractor.extract_features(waveform)
    feature_list.append(features)

# Convert to DataFrame
features_df = pd.DataFrame(feature_list)
feature_columns = list(features_df.columns)

print(f"\n✓ Extracted {len(feature_columns)} features")
print(f"Features: {feature_columns}")

# Save features
features_df['label'] = labels
features_df.to_csv('../data/processed/pulse_features.csv', index=False)
print("✓ Features saved")

In [None]:
# Split data for ML
X_features = features_df[feature_columns].values
y = labels

# 70% train, 15% val, 15% test
X_temp, X_test_feat, y_temp, y_test = train_test_split(
    X_features, y, test_size=0.15, random_state=42, stratify=y
)
X_train_feat, X_val_feat, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.176, random_state=42, stratify=y_temp
)

print(f"Training set: {X_train_feat.shape}")
print(f"Validation set: {X_val_feat.shape}")
print(f"Test set: {X_test_feat.shape}")

## Part 2: Traditional Machine Learning Models

In [None]:
# Train traditional ML models
from src.ml.traditional_ml import TraditionalMLClassifier

traditional_models = {}

# Random Forest
print("Training Random Forest...")
rf = TraditionalMLClassifier('random_forest', n_estimators=100, max_depth=20)
rf.fit(X_train_feat, y_train)
rf_acc = rf.score(X_test_feat, y_test)
traditional_models['Random Forest'] = {'model': rf, 'accuracy': rf_acc * 100}
print(f"  Accuracy: {rf_acc*100:.2f}%")

# XGBoost
print("\nTraining XGBoost...")
xgb_model = TraditionalMLClassifier('xgboost', n_estimators=100, learning_rate=0.1)
xgb_model.fit(X_train_feat, y_train)
xgb_acc = xgb_model.score(X_test_feat, y_test)
traditional_models['XGBoost'] = {'model': xgb_model, 'accuracy': xgb_acc * 100}
print(f"  Accuracy: {xgb_acc*100:.2f}%")

# SVM
print("\nTraining SVM...")
svm = TraditionalMLClassifier('svm', C=10, kernel='rbf')
svm.fit(X_train_feat, y_train)
svm_acc = svm.score(X_test_feat, y_test)
traditional_models['SVM'] = {'model': svm, 'accuracy': svm_acc * 100}
print(f"  Accuracy: {svm_acc*100:.2f}%")

print("\n✓ Traditional ML models trained")

## Part 3: Deep Learning Models

### 3.1 Data Preparation for Deep Learning

In [None]:
# Prepare waveform data for PyTorch
X_train_wave, X_test_wave = train_test_split(
    waveforms, test_size=0.15, random_state=42, stratify=labels
)
y_train_wave = labels[~np.isin(np.arange(len(labels)), 
                               np.where(np.isin(waveforms, X_test_wave))[0])]
y_test_wave = labels[np.isin(np.arange(len(labels)), 
                             np.where(np.isin(waveforms, X_test_wave))[0])]

# Convert to tensors
X_train_tensor = torch.FloatTensor(X_train_wave)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test_wave)
y_test_tensor = torch.LongTensor(y_test)

# Create data loaders
from torch.utils.data import TensorDataset

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("✓ Data loaders created")

### 3.2 Baseline CNN

In [None]:
from src.ml.cnn_models import SimpleCNN
from src.ml.training import ModelTrainer

print("Training Baseline CNN...")
cnn_model = SimpleCNN(input_length=1024, num_classes=4, dropout=0.3)
cnn_trainer = ModelTrainer(cnn_model, device=device, learning_rate=0.001)

history_cnn = cnn_trainer.train(
    train_loader, test_loader, 
    epochs=50, 
    early_stopping_patience=10,
    verbose=False
)

print(f"✓ CNN trained - Best Val Acc: {max(history_cnn['val_acc']):.2f}%")

### 3.3 Physics-Informed Neural Network (PINN) ⭐

**Novel contribution**: Incorporates physics constraints into loss function

In [None]:
from src.ml.physics_informed import PhysicsInformedCNN

print("Training Physics-Informed CNN...")
pinn_model = PhysicsInformedCNN(
    input_length=1024,
    num_classes=4,
    alpha=0.7,  # Classification loss weight
    beta=0.2,   # Decay time loss weight
    gamma=0.1   # Energy conservation weight
)

pinn_trainer = ModelTrainer(pinn_model, device=device, learning_rate=0.001)

history_pinn = pinn_trainer.train(
    train_loader, test_loader,
    epochs=50,
    early_stopping_patience=10,
    is_physics_informed=True,
    verbose=False
)

print(f"✓ PINN trained - Best Val Acc: {max(history_pinn['val_acc']):.2f}%")
print("  Physics constraints: Decay time, Energy conservation, Rise time")

### 3.4 Transformer Models ⭐

In [None]:
from src.ml.transformer_models import WaveformTransformer, VisionTransformerWaveform

# Standard Transformer
print("Training Waveform Transformer...")
transformer = WaveformTransformer(
    waveform_length=1024,
    d_model=64,
    nhead=8,
    num_layers=4,
    num_classes=4
)
transformer_trainer = ModelTrainer(transformer, device=device, learning_rate=0.0005)
history_transformer = transformer_trainer.train(
    train_loader, test_loader,
    epochs=50,
    verbose=False
)
print(f"✓ Transformer trained - Best Val Acc: {max(history_transformer['val_acc']):.2f}%")

# Vision Transformer
print("\nTraining Vision Transformer (ViT)...")
vit = VisionTransformerWaveform(
    waveform_length=1024,
    patch_size=16,
    d_model=64,
    nhead=8,
    num_classes=4
)
vit_trainer = ModelTrainer(vit, device=device, learning_rate=0.0005)
history_vit = vit_trainer.train(
    train_loader, test_loader,
    epochs=50,
    verbose=False
)
print(f"✓ ViT trained - Best Val Acc: {max(history_vit['val_acc']):.2f}%")

### 3.5 Hybrid CNN-Transformer ⭐

In [None]:
from src.ml.hybrid_models import CNNTransformerHybrid

print("Training CNN-Transformer Hybrid...")
hybrid = CNNTransformerHybrid(
    input_length=1024,
    num_classes=4,
    d_model=64
)
hybrid_trainer = ModelTrainer(hybrid, device=device, learning_rate=0.001)
history_hybrid = hybrid_trainer.train(
    train_loader, test_loader,
    epochs=50,
    verbose=False
)
print(f"✓ Hybrid trained - Best Val Acc: {max(history_hybrid['val_acc']):.2f}%")

## Part 4: Comprehensive Model Comparison

In [None]:
from src.ml.evaluation import ModelComparison

# Collect all deep learning models
dl_models = {
    'CNN': cnn_model,
    'PINN': pinn_model,
    'Transformer': transformer,
    'ViT': vit,
    'CNN-Transformer': hybrid
}

print("Evaluating all models...\n")
comparison = ModelComparison(dl_models, device=device)
results_df = comparison.evaluate_all(test_loader)

display(results_df)

# Save results
results_df.to_csv('../results/tables/model_comparison.csv')
print("\n✓ Results saved")

In [None]:
# Visualize comparison
fig = comparison.plot_comparison(results_df, title="Deep Learning Model Comparison")
plt.savefig('../results/figures/ml_model_comparison.pdf', dpi=300, bbox_inches='tight')
plt.show()

## Part 5: Model Interpretability

### 5.1 Physics Validation (PINN)

In [None]:
from src.ml.interpretability import ModelInterpretability

interp = ModelInterpretability(pinn_model, device=device)

# Validate learned physics
physics_results = interp.validate_physics_learning(
    pinn_model,
    X_test_wave,
    y_test
)

print("Physics-Informed Model Validation:")
print("="*60)
for class_key, metrics in physics_results.items():
    scint_names = ['LYSO', 'BGO', 'NaI', 'Plastic']
    idx = int(class_key.split('_')[1])
    print(f"\n{scint_names[idx]}:")
    print(f"  Known τ: {metrics['known_tau']:.1f} ns")
    print(f"  Learned τ: {metrics['learned_tau_mean']:.1f} ± {metrics['learned_tau_std']:.1f} ns")
    print(f"  Relative Error: {metrics['relative_error']:.1f}%")

# Plot
fig = interp.plot_physics_validation(physics_results)
plt.savefig('../results/figures/pinn_physics_validation.pdf', dpi=300, bbox_inches='tight')
plt.show()

### 5.2 Saliency Maps (CNN)

In [None]:
# Compute saliency for example waveforms
interp_cnn = ModelInterpretability(cnn_model, device=device)

# Select one waveform from each class
for class_idx in range(4):
    # Find first test sample of this class
    idx = np.where(y_test == class_idx)[0][0]
    waveform = torch.FloatTensor(X_test_wave[idx])
    
    saliency = interp_cnn.compute_saliency_map(waveform)
    
    fig = interp_cnn.plot_saliency_map(
        X_test_wave[idx],
        saliency,
        title=f"Saliency Map - {['LYSO', 'BGO', 'NaI', 'Plastic'][class_idx]}"
    )
    plt.savefig(f'../results/figures/saliency_class_{class_idx}.pdf', dpi=300, bbox_inches='tight')
    plt.show()

## Part 6: Summary and Recommendations

In [None]:
print("="*80)
print("COMPREHENSIVE ML CLASSIFICATION SUMMARY")
print("="*80)

print("\n🏆 TOP PERFORMING MODELS:\n")
top_3 = results_df.nlargest(3, 'Accuracy (%)')
for idx, (model_name, row) in enumerate(top_3.iterrows(), 1):
    print(f"{idx}. {model_name}")
    print(f"   Accuracy: {row['Accuracy (%)']:.2f}%")
    print(f"   Inference: {row['Inference Time (ms)']:.2f} ms")
    print(f"   Size: {row['Model Size (MB)']:.2f} MB\n")

print("\n⚡ FASTEST MODEL (Real-time applications):")
fastest = results_df.nsmallest(1, 'Inference Time (ms)')
print(f"   {fastest.index[0]} - {fastest['Inference Time (ms)'].values[0]:.2f} ms")

print("\n🔬 MOST INTERPRETABLE:")
print("   Physics-Informed CNN (PINN)")
print("   - Incorporates known physics")
print("   - Learned decay times match theory")

print("\n💡 RECOMMENDATIONS:")
print("   • Best overall: CNN-Transformer Hybrid")
print("   • Best interpretability: Physics-Informed CNN")
print("   • Best speed-accuracy: Standard CNN")
print("   • Novel research: Transformer models")

print("\n✓ Analysis complete!")
print("="*80)