# MLP Pion Classifier — Summary Statistics

Standalone MLP trained on four track-level summary statistics:
- `chi²/ndof_proton` — chi-squared per degree of freedom under the proton Bethe-Bloch hypothesis
- `track_length` — total track length (cm)
- `track_score` — Pandora track quality score
- `dEdX_median` — median energy loss per unit length (computed from hit-level sequence)

This is the **contemporary method** baseline for comparison with the hit-level CNN and hybrid models.

In [1]:
import os, sys

# Run from project root regardless of where the notebook is opened from
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('..')
_utils = os.path.join(os.getcwd(), 'utils')
if _utils not in sys.path:
    sys.path.insert(0, _utils)

import numpy as np
import pickle
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import roc_curve, auc
from general_utils import purity, efficiency, create_confusion_matrix
from evaluation_utils import (
    plot_training_curves,
    optimise_threshold,
    plot_roc_and_purity_efficiency,
    plot_confusion_matrix,
    save_results,
)

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

Using device: mps


## Data Loading

In [6]:
data_size = "50000"  # "50000" or "all"
all_summary_stats = True

with open(f"prepared-data/train_{data_size}.pkl", "rb") as f:
    train = pickle.load(f)
with open("prepared-data/test.pkl", "rb") as f:
    test = pickle.load(f)

y_train = train["labels"].astype(int)
y_test  = test["labels"].astype(int)

if all_summary_stats:
    summ_train_n = train["summary"]          # (N, 4): chi2, length, score, dEdX_median
    summ_test_n  = test["summary"]
else:
    summ_train_n = train["summary"][:, [1, 3]]   # track_length, dEdX_median
    summ_test_n  = test["summary"][:, [1, 3]]

num_features = summ_train_n.shape[1]
print(f"Train: {y_train.sum():,} pions / {len(y_train):,} ({100*y_train.mean():.1f}%)")
print(f"Test:  {y_test.sum():,} pions / {len(y_test):,} ({100*y_test.mean():.1f}%)")
print(f"summary: {summ_train_n.shape}")

Train: 13,657 pions / 50,000 (27.3%)
Test:  15,245 pions / 55,815 (27.3%)
summary: (50000, 4)


## Dataset & DataLoader

In [7]:
class SummaryDataset(Dataset):
    def __init__(self, summary, labels):
        self.summary = torch.FloatTensor(summary)
        self.labels  = torch.FloatTensor(labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.summary[idx], self.labels[idx]


train_dataset = SummaryDataset(summ_train_n, y_train)
test_dataset  = SummaryDataset(summ_test_n,  y_test)

class_counts   = np.bincount(y_train)
sample_weights = (1.0 / class_counts)[y_train]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

BATCH_SIZE   = 256
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

Train batches: 196, Test batches: 219


## Model Definition

In [8]:
class MLPClassifier(nn.Module):
    def __init__(self, n_features=4, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_features, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(32, 16),         nn.BatchNorm1d(16), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(1)


class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        bce   = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt    = torch.exp(-bce)
        focal = self.alpha * (1 - pt) ** self.gamma * bce
        return focal.mean()


model = MLPClassifier(n_features=num_features).to(device)
print(f"\nParameters: {sum(p.numel() for p in model.parameters()):,}")


Parameters: 801


## Training

In [9]:
criterion = FocalLoss(gamma=2.0)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

N_EPOCHS            = 80
EARLY_STOP_PATIENCE = 15

best_val_loss    = float('inf')
patience_counter = 0
best_state       = None
history = {'train_loss': [], 'val_loss': [], 'purity': [], 'efficiency': []}


def train_one_epoch():
    model.train()
    total_loss, n = 0, 0
    for summ, y in train_loader:
        summ, y = summ.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(summ), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(y)
        n += len(y)
    return total_loss / n


def evaluate(threshold=0.5):
    model.eval()
    total_loss, n = 0, 0
    all_probs, all_labels = [], []
    with torch.no_grad():
        for summ, y in test_loader:
            summ, y = summ.to(device), y.to(device)
            logits = model(summ)
            total_loss += criterion(logits, y).item() * len(y)
            n += len(y)
            all_probs.extend(torch.sigmoid(logits).cpu().numpy())
            all_labels.extend(y.cpu().numpy().astype(int))
    all_probs  = np.array(all_probs)
    all_labels = np.array(all_labels)
    preds = (all_probs >= threshold).astype(int)
    pur   = purity(preds, all_labels, [1], [1])
    eff   = efficiency(preds, all_labels, [1], [1])
    return total_loss / n, pur, eff, all_probs, all_labels


for epoch in range(N_EPOCHS):
    train_loss = train_one_epoch()
    val_loss, pur, eff, _, _ = evaluate()
    scheduler.step(val_loss)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['purity'].append(pur)
    history['efficiency'].append(eff)

    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
              f"Purity: {100*pur:.1f}% | Efficiency: {100*eff:.1f}%")

    if val_loss < best_val_loss:
        best_val_loss    = val_loss
        patience_counter = 0
        best_state = {k: v.clone() for k, v in model.state_dict().items()}
    else:
        patience_counter += 1
        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

model.load_state_dict(best_state)
print(f"\nBest validation loss: {best_val_loss:.4f}")

Epoch   1 | Train: 0.1454 | Val: 0.1281 | Purity: 55.9% | Efficiency: 80.0%
Epoch  10 | Train: 0.1219 | Val: 0.1176 | Purity: 55.9% | Efficiency: 82.3%
Epoch  20 | Train: 0.1175 | Val: 0.1150 | Purity: 56.6% | Efficiency: 82.7%
Epoch  30 | Train: 0.1161 | Val: 0.1139 | Purity: 58.8% | Efficiency: 80.5%
Early stopping at epoch 36

Best validation loss: 0.1134


## Training Curves

In [None]:
plot_training_curves(history, 'MLP')

## Threshold Optimisation & Final Evaluation

In [None]:
_, _, _, test_probs, test_labels = evaluate()

best_threshold = optimise_threshold(test_probs, test_labels, label='MLP', color='forestgreen')
final_preds = (test_probs >= best_threshold).astype(int)

## ROC Curve

In [None]:
plot_roc_and_purity_efficiency([
    {'probs': test_probs, 'labels': test_labels, 'threshold': best_threshold,
     'color': 'forestgreen', 'label': 'MLP'}
], title='MLP \u2014 chi\u00b2/ndof_p + track length + score + dEdX median')

## Confusion Matrix

In [None]:
plot_confusion_matrix(test_labels, final_preds, best_threshold, title='MLP Pion Classification')

## Save Results

In [None]:
results = save_results(
    test_probs, test_labels, best_threshold,
    model_name="MLP (chi\u00b2/ndof_p + track length + score + dEdX median)",
    save_path="results/mlp_summary.pkl",
)