<a href="https://colab.research.google.com/github/arnavsinghal09/GSoC-QMAML/blob/main/Quark_Gluon_Classification_QMAML_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# 1. Imports

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import random

In [3]:
# 2. Hyperparameters

DATA_DIR = "/content/drive/MyDrive/quark-gluon-dataset"
FILES = [
    "quark-gluon_train-set_n793900.hdf5",
    "quark-gluon_test-set_n139306.hdf5",
    "quark-gluon_test-set_n10000.hdf5"
]
FILE_PATHS = [f"{DATA_DIR}/{fname}" for fname in FILES]
FILE_LABELS = ["Train", "Test1", "Test2"]

BATCH_SIZE = 128
LEARNING_RATE = 5e-4
EPOCHS = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SHAPE = (3, 125, 125)
N_WAY = 2
K_SHOT = 32
K_QUERY = 64
META_BATCH_SIZE = 8
INNER_STEPS = 3
INNER_LR = 5e-3
NUM_WORKERS = 2

In [4]:
# 3. Dataset and Task Sampler

class JetImageDataset(Dataset):
    def __init__(self, X, y, pt=None, pt_bins=None):
        self.X = X
        self.y = y
        self.pt = pt
        self.pt_bins = pt_bins
        if self.pt is not None and self.pt_bins is not None:
            self.bin_indices = []
            for i in range(len(pt_bins) - 1):
                idx = np.where((pt >= pt_bins[i]) & (pt < pt_bins[i+1]))[0]
                self.bin_indices.append(idx)
        else:
            self.bin_indices = [np.arange(len(y))]
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        x = self.X[idx]
        if x.shape != IMG_SHAPE:
            x = np.transpose(x, (2, 0, 1))  # (3, 125, 125)
        label = int(self.y[idx])
        return torch.tensor(x, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

def sample_task(dataset, bin_idx, k_shot, k_query):
    idxs = dataset.bin_indices[bin_idx]
    idxs = np.random.permutation(idxs)
    support_idxs = idxs[:k_shot]
    query_idxs = idxs[k_shot:k_shot + k_query]
    X_s, y_s = zip(*[dataset[i] for i in support_idxs])
    X_q, y_q = zip(*[dataset[i] for i in query_idxs])
    return (torch.stack(X_s), torch.tensor(y_s)), (torch.stack(X_q), torch.tensor(y_q))

In [5]:
# 4. Model (Deeper CNN for Jet Images)

class JetCNN(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 5, stride=2, padding=2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 512), nn.ReLU(),
            nn.Linear(512, n_classes)
        )
    def forward(self, x):
        return self.net(x)

In [6]:
# 5. Data Preparation and Meta-Task Definition

with h5py.File(FILE_PATHS[0], "r") as f:
    X = np.array(f["X_jets"][:50000])  # Use more data for better GPU usage
    y = np.array(f["y"][:50000])
    pt = np.array(f["pt"][:50000])

pt_bins = np.percentile(pt, np.linspace(0, 100, 6))
print("pT bins:", pt_bins)

jet_dataset = JetImageDataset(X, y, pt=pt, pt_bins=pt_bins)

indices = np.arange(len(y))
np.random.shuffle(indices)
split = int(0.8 * len(indices))
train_idx, val_idx = indices[:split], indices[split:]
train_loader = DataLoader(
    Subset(jet_dataset, train_idx),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=NUM_WORKERS
)
val_loader = DataLoader(
    Subset(jet_dataset, val_idx),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

pT bins: [ 70.23306274  95.28910828 105.83418427 117.69684753 135.86664734
 308.84353638]


In [7]:
# 6. Training Loop for Classical Baseline (with All Stats)

def compute_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    return acc, f1, prec, rec

def train_baseline(model, train_loader, val_loader, epochs=EPOCHS):
    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        y_true_train, y_pred_train = [], []
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X.size(0)
            preds = logits.argmax(dim=1)
            y_true_train.append(y)
            y_pred_train.append(preds)
        y_true_train = torch.cat(y_true_train)
        y_pred_train = torch.cat(y_pred_train)
        train_loss /= len(train_loader.dataset)
        train_acc, train_f1, train_prec, train_rec = compute_metrics(y_true_train, y_pred_train)

        # Validation
        model.eval()
        val_loss = 0
        y_true_val, y_pred_val = [], []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                logits = model(X)
                loss = loss_fn(logits, y)
                val_loss += loss.item() * X.size(0)
                preds = logits.argmax(dim=1)
                y_true_val.append(y)
                y_pred_val.append(preds)
        y_true_val = torch.cat(y_true_val)
        y_pred_val = torch.cat(y_pred_val)
        val_loss /= len(val_loader.dataset)
        val_acc, val_f1, val_prec, val_rec = compute_metrics(y_true_val, y_pred_val)

        print(f"Epoch {epoch+1:2d} | "
              f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | "
              f"Prec: {train_prec:.4f} | Rec: {train_rec:.4f} || "
              f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | "
              f"Prec: {val_prec:.4f} | Rec: {val_rec:.4f}")
    return model

In [8]:
# 7. MAML Meta-Learning Loop (with All Stats)

def maml_train(model, dataset, pt_bins, meta_batch_size=META_BATCH_SIZE, epochs=EPOCHS):
    model = model.to(DEVICE)
    meta_optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        meta_loss = 0
        all_y_true, all_y_pred = [], []
        for _ in range(meta_batch_size):
            bin_idx = random.randint(0, len(pt_bins)-2)
            (X_s, y_s), (X_q, y_q) = sample_task(dataset, bin_idx, K_SHOT, K_QUERY)
            X_s, y_s, X_q, y_q = X_s.to(DEVICE), y_s.to(DEVICE), X_q.to(DEVICE), y_q.to(DEVICE)
            # Clone model for inner loop
            fast_weights = [p.clone().detach().requires_grad_(True) for p in model.parameters()]
            # Inner loop
            for _ in range(INNER_STEPS):
                logits = model(X_s)
                loss = loss_fn(logits, y_s)
                grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
                fast_weights = [w - INNER_LR * g for w, g in zip(fast_weights, grads)]
            # Outer loop: evaluate on query set
            def forward_with_weights(x, weights):
                x = nn.functional.conv2d(x, weights[0], weights[1], stride=2, padding=2)
                x = nn.functional.relu(x)
                x = nn.functional.conv2d(x, weights[2], weights[3], stride=2, padding=1)
                x = nn.functional.relu(x)
                x = nn.functional.conv2d(x, weights[4], weights[5], stride=2, padding=1)
                x = nn.functional.relu(x)
                x = nn.functional.conv2d(x, weights[6], weights[7], stride=2, padding=1)
                x = nn.functional.relu(x)
                x = x.view(x.size(0), -1)
                x = nn.functional.linear(x, weights[8], weights[9])
                x = nn.functional.relu(x)
                x = nn.functional.linear(x, weights[10], weights[11])
                return x
            logits_q = forward_with_weights(X_q, fast_weights)
            loss_q = loss_fn(logits_q, y_q)
            meta_loss += loss_q
            preds = logits_q.argmax(dim=1)
            all_y_true.append(y_q)
            all_y_pred.append(preds)
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()
        all_y_true = torch.cat(all_y_true)
        all_y_pred = torch.cat(all_y_pred)
        acc, f1, prec, rec = compute_metrics(all_y_true, all_y_pred)
        print(f"Epoch {epoch+1:2d} | Meta Loss: {meta_loss.item()/meta_batch_size:.4f} | "
              f"Acc: {acc:.4f} | F1: {f1:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f}")
    return model

In [9]:
# 8. Train and Evaluate Baseline

baseline_model = JetCNN()
baseline_model = train_baseline(baseline_model, train_loader, val_loader)

Epoch  1 | Train Loss: 0.6935 | Acc: 0.4951 | F1: 0.2211 | Prec: 0.4867 | Rec: 0.1430 || Val Loss: 0.6931 | Acc: 0.5041 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch  2 | Train Loss: 0.6345 | Acc: 0.6364 | F1: 0.6496 | Prec: 0.6280 | Rec: 0.6729 || Val Loss: 0.5966 | Acc: 0.6897 | F1: 0.6776 | Prec: 0.6989 | Rec: 0.6576
Epoch  3 | Train Loss: 0.5943 | Acc: 0.6928 | F1: 0.7059 | Prec: 0.6784 | Rec: 0.7358 || Val Loss: 0.5942 | Acc: 0.6950 | F1: 0.6962 | Prec: 0.6879 | Rec: 0.7046
Epoch  4 | Train Loss: 0.5887 | Acc: 0.6973 | F1: 0.7073 | Prec: 0.6860 | Rec: 0.7300 || Val Loss: 0.5912 | Acc: 0.6972 | F1: 0.7110 | Prec: 0.6749 | Rec: 0.7512
Epoch  5 | Train Loss: 0.5862 | Acc: 0.6977 | F1: 0.7079 | Prec: 0.6861 | Rec: 0.7311 || Val Loss: 0.5955 | Acc: 0.6940 | F1: 0.7145 | Prec: 0.6649 | Rec: 0.7721
Epoch  6 | Train Loss: 0.5837 | Acc: 0.7007 | F1: 0.7108 | Prec: 0.6888 | Rec: 0.7342 || Val Loss: 0.5898 | Acc: 0.6984 | F1: 0.7097 | Prec: 0.6789 | Rec: 0.7433
Epoch  7 | Train Loss: 0.5814 | Acc: 0.7023 | F1: 0.7125 | Prec: 0.6903 | Rec: 0.7361 || Val Loss: 0.5867 | Acc: 0.6973 | F1: 0.7010 | Prec: 0.6870 | Rec: 0.7157
Epoch  8 | Train Loss: 0.577

In [10]:
# 9. Train and Evaluate MAML

maml_model = JetCNN()
maml_model = maml_train(maml_model, jet_dataset, pt_bins)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch  1 | Meta Loss: 0.6927 | Acc: 0.5117 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000
Epoch  2 | Meta Loss: 0.6983 | Acc: 0.5000 | F1: 0.6667 | Prec: 0.5000 | Rec: 1.0000
Epoch  3 | Meta Loss: 0.7421 | Acc: 0.5215 | F1: 0.6855 | Prec: 0.5215 | Rec: 1.0000
Epoch  4 | Meta Loss: 0.8504 | Acc: 0.4746 | F1: 0.6437 | Prec: 0.4746 | Rec: 1.0000
Epoch  5 | Meta Loss: 0.9558 | Acc: 0.4844 | F1: 0.6526 | Prec: 0.4844 | Rec: 1.0000
Epoch  6 | Meta Loss: 0.7888 | Acc: 0.5039 | F1: 0.6701 | Prec: 0.5039 | Rec: 1.0000
Epoch  7 | Meta Loss: 0.7572 | Acc: 0.4531 | F1: 0.6237 | Prec: 0.4531 | Rec: 1.0000
Epoch  8 | Meta Loss: 0.7116 | Acc: 0.5078 | F1: 0.2174 | Prec: 0.5469 | Rec: 0.1357


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch  9 | Meta Loss: 0.7959 | Acc: 0.4785 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 10 | Meta Loss: 0.7976 | Acc: 0.5312 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 11 | Meta Loss: 0.8223 | Acc: 0.4785 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 12 | Meta Loss: 0.7412 | Acc: 0.5371 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000
Epoch 13 | Meta Loss: 0.7148 | Acc: 0.5254 | F1: 0.2085 | Prec: 0.5000 | Rec: 0.1317
Epoch 14 | Meta Loss: 0.7019 | Acc: 0.5371 | F1: 0.4552 | Prec: 0.5156 | Rec: 0.4074
Epoch 15 | Meta Loss: 0.6993 | Acc: 0.5195 | F1: 0.4252 | Prec: 0.4740 | Rec: 0.3856
Epoch 16 | Meta Loss: 0.6981 | Acc: 0.4844 | F1: 0.4027 | Prec: 0.4635 | Rec: 0.3560
Epoch 17 | Meta Loss: 0.6880 | Acc: 0.5645 | F1: 0.5584 | Prec: 0.5508 | Rec: 0.5663
Epoch 18 | Meta Loss: 0.6932 | Acc: 0.5234 | F1: 0.6176 | Prec: 0.5130 | Rec: 0.7756
Epoch 19 | Meta Loss: 0.7414 | Acc: 0.4668 | F1: 0.4699 | Prec: 0.4727 | Rec: 0.4672
Epoch 20 | Meta Loss: 0.7072 | Acc: 0.5078 | F1: 0.6135 | Prec: 0.5208 | Rec: 0.7463


In [11]:
# 10. Few-Shot Evaluation Function

def few_shot_eval(model, dataset, pt_bins, n_tasks=None, k_shot=K_SHOT, k_query=K_QUERY, inner_steps=INNER_STEPS, inner_lr=INNER_LR):
    model.eval()
    all_acc, all_f1, all_prec, all_rec = [], [], [], []
    n_bins = len(pt_bins) - 1 if n_tasks is None else n_tasks
    for bin_idx in range(n_bins):
        (X_s, y_s), (X_q, y_q) = sample_task(dataset, bin_idx, k_shot, k_query)
        X_s, y_s, X_q, y_q = X_s.to(DEVICE), y_s.to(DEVICE), X_q.to(DEVICE), y_q.to(DEVICE)
        # Fast adaptation (inner loop)
        fast_weights = [p.clone().detach().requires_grad_(True) for p in model.parameters()]
        loss_fn = nn.CrossEntropyLoss()
        for _ in range(inner_steps):
            logits = model(X_s)
            loss = loss_fn(logits, y_s)
            grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            fast_weights = [w - inner_lr * g for w, g in zip(fast_weights, grads)]
        # Evaluate on query set
        def forward_with_weights(x, weights):
            x = nn.functional.conv2d(x, weights[0], weights[1], stride=2, padding=2)
            x = nn.functional.relu(x)
            x = nn.functional.conv2d(x, weights[2], weights[3], stride=2, padding=1)
            x = nn.functional.relu(x)
            x = nn.functional.conv2d(x, weights[4], weights[5], stride=2, padding=1)
            x = nn.functional.relu(x)
            x = nn.functional.conv2d(x, weights[6], weights[7], stride=2, padding=1)
            x = nn.functional.relu(x)
            x = x.view(x.size(0), -1)
            x = nn.functional.linear(x, weights[8], weights[9])
            x = nn.functional.relu(x)
            x = nn.functional.linear(x, weights[10], weights[11])
            return x
        logits_q = forward_with_weights(X_q, fast_weights)
        preds = logits_q.argmax(dim=1)
        acc = accuracy_score(y_q.cpu().numpy(), preds.cpu().numpy())
        f1 = f1_score(y_q.cpu().numpy(), preds.cpu().numpy())
        prec = precision_score(y_q.cpu().numpy(), preds.cpu().numpy())
        rec = recall_score(y_q.cpu().numpy(), preds.cpu().numpy())
        all_acc.append(acc)
        all_f1.append(f1)
        all_prec.append(prec)
        all_rec.append(rec)
        print(f"Bin {bin_idx}: Acc={acc:.4f} | F1={f1:.4f} | Prec={prec:.4f} | Rec={rec:.4f}")
    print(f"\nMean Few-Shot: Acc={np.mean(all_acc):.4f} | F1={np.mean(all_f1):.4f} | Prec={np.mean(all_prec):.4f} | Rec={np.mean(all_rec):.4f}")
    return all_acc, all_f1, all_prec, all_rec

In [12]:
# 11. Few-Shot Evaluation for MAML and Baseline

print("MAML Few-Shot Evaluation:")
few_shot_eval(maml_model, jet_dataset, pt_bins)

print("\nClassical Baseline Few-Shot Evaluation (no adaptation):")
# For baseline, skip adaptation: just forward pass
def baseline_few_shot_eval(model, dataset, pt_bins, n_tasks=None, k_query=K_QUERY):
    model.eval()
    all_acc, all_f1, all_prec, all_rec = [], [], [], []
    n_bins = len(pt_bins) - 1 if n_tasks is None else n_tasks
    for bin_idx in range(n_bins):
        idxs = dataset.bin_indices[bin_idx]
        idxs = np.random.permutation(idxs)
        query_idxs = idxs[:k_query]
        X_q, y_q = zip(*[dataset[i] for i in query_idxs])
        X_q = torch.stack(X_q).to(DEVICE)
        y_q = torch.tensor(y_q).to(DEVICE)
        with torch.no_grad():
            logits_q = model(X_q)
            preds = logits_q.argmax(dim=1)
        acc = accuracy_score(y_q.cpu().numpy(), preds.cpu().numpy())
        f1 = f1_score(y_q.cpu().numpy(), preds.cpu().numpy())
        prec = precision_score(y_q.cpu().numpy(), preds.cpu().numpy())
        rec = recall_score(y_q.cpu().numpy(), preds.cpu().numpy())
        all_acc.append(acc)
        all_f1.append(f1)
        all_prec.append(prec)
        all_rec.append(rec)
        print(f"Bin {bin_idx}: Acc={acc:.4f} | F1={f1:.4f} | Prec={prec:.4f} | Rec={rec:.4f}")
    print(f"\nMean Few-Shot: Acc={np.mean(all_acc):.4f} | F1={np.mean(all_f1):.4f} | Prec={np.mean(all_prec):.4f} | Rec={np.mean(all_rec):.4f}")
    return all_acc, all_f1, all_prec, all_rec

baseline_few_shot_eval(baseline_model, jet_dataset, pt_bins)

MAML Few-Shot Evaluation:
Bin 0: Acc=0.5625 | F1=0.0000 | Prec=0.0000 | Rec=0.0000
Bin 1: Acc=0.6094 | F1=0.7573 | Prec=0.6094 | Rec=1.0000
Bin 2: Acc=0.5000 | F1=0.6667 | Prec=0.5000 | Rec=1.0000
Bin 3: Acc=0.3906 | F1=0.0000 | Prec=0.0000 | Rec=0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Bin 4: Acc=0.4844 | F1=0.6526 | Prec=0.4844 | Rec=1.0000

Mean Few-Shot: Acc=0.5094 | F1=0.4153 | Prec=0.3187 | Rec=0.6000

Classical Baseline Few-Shot Evaluation (no adaptation):
Bin 0: Acc=0.7031 | F1=0.6275 | Prec=0.6154 | Rec=0.6400
Bin 1: Acc=0.7031 | F1=0.7164 | Prec=0.8571 | Rec=0.6154
Bin 2: Acc=0.7656 | F1=0.7887 | Prec=0.7778 | Rec=0.8000
Bin 3: Acc=0.7969 | F1=0.8116 | Prec=0.8750 | Rec=0.7568
Bin 4: Acc=0.7188 | F1=0.7188 | Prec=0.6571 | Rec=0.7931

Mean Few-Shot: Acc=0.7375 | F1=0.7326 | Prec=0.7565 | Rec=0.7210


([0.703125, 0.703125, 0.765625, 0.796875, 0.71875],
 [0.6274509803921569,
  0.7164179104477612,
  0.7887323943661971,
  0.8115942028985508,
  0.71875],
 [0.6153846153846154,
  0.8571428571428571,
  0.7777777777777778,
  0.875,
  0.6571428571428571],
 [0.64, 0.6153846153846154, 0.8, 0.7567567567567568, 0.7931034482758621])