<a href="https://colab.research.google.com/github/arnavsinghal09/GSoC-QMAML/blob/main/Quark_Gluon_Classification_MAML_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
# 1. Imports and Hyperparameters

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import random

DATA_DIR = "/content/drive/MyDrive/quark-gluon-dataset"
FILES = [
    "quark-gluon_train-set_n793900.hdf5",
    "quark-gluon_test-set_n139306.hdf5",
    "quark-gluon_test-set_n10000.hdf5"
]
FILE_PATHS = [f"{DATA_DIR}/{fname}" for fname in FILES]
FILE_LABELS = ["Train", "Test1", "Test2"]

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SHAPE = (3, 125, 125)  # For jet images
N_WAY = 2
K_SHOT = 16
K_QUERY = 32
META_BATCH_SIZE = 4
INNER_STEPS = 1
INNER_LR = 1e-2

In [14]:
# 2.1. Jet Image Dataset

class JetImageDataset(Dataset):
    def __init__(self, X, y, pt=None, pt_bins=None):
        self.X = X
        self.y = y
        self.pt = pt
        self.pt_bins = pt_bins
        if self.pt is not None and self.pt_bins is not None:
            self.bin_indices = []
            for i in range(len(pt_bins) - 1):
                idx = np.where((pt >= pt_bins[i]) & (pt < pt_bins[i+1]))[0]
                self.bin_indices.append(idx)
        else:
            self.bin_indices = [np.arange(len(y))]
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
      x = self.X[idx]
      if x.shape != IMG_SHAPE:
          x = np.transpose(x, (2, 0, 1))  # (3, 125, 125)
      label = int(self.y[idx])  # <-- Cast to int here!
      return torch.tensor(x, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

In [15]:
# 2.2. Task Sampler

def sample_task(dataset, bin_idx, k_shot, k_query):
    idxs = dataset.bin_indices[bin_idx]
    idxs = np.random.permutation(idxs)
    support_idxs = idxs[:k_shot]
    query_idxs = idxs[k_shot:k_shot + k_query]
    X_s, y_s = zip(*[dataset[i] for i in support_idxs])
    X_q, y_q = zip(*[dataset[i] for i in query_idxs])
    return (torch.stack(X_s), torch.tensor(y_s)), (torch.stack(X_q), torch.tensor(y_q))

In [16]:
# 3.1. CNN for Jet Images

class JetCNN(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 5, stride=2, padding=2), nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128), nn.ReLU(),
            nn.Linear(128, n_classes)
        )
    def forward(self, x):
        return self.net(x)

In [17]:
# 4. Classical Supervised Baseline

def train_baseline(model, train_loader, val_loader, epochs=EPOCHS):
    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                preds = model(X).argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        print(f"Epoch {epoch+1}: Val Acc = {correct/total:.4f}")
    return model

In [18]:
# 5. MAML Meta-Learning Loop

def maml_train(model, dataset, pt_bins, meta_batch_size=META_BATCH_SIZE, epochs=EPOCHS):
    model = model.to(DEVICE)
    meta_optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        meta_loss = 0
        for _ in range(meta_batch_size):
            # Sample a meta-task (random pt bin)
            bin_idx = random.randint(0, len(pt_bins)-2)
            (X_s, y_s), (X_q, y_q) = sample_task(dataset, bin_idx, K_SHOT, K_QUERY)
            X_s, y_s, X_q, y_q = X_s.to(DEVICE), y_s.to(DEVICE), X_q.to(DEVICE), y_q.to(DEVICE)
            # Clone model for inner loop
            fast_weights = [p.clone().detach().requires_grad_(True) for p in model.parameters()]
            # Inner loop
            for _ in range(INNER_STEPS):
                logits = model(X_s)
                loss = loss_fn(logits, y_s)
                grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
                fast_weights = [w - INNER_LR * g for w, g in zip(fast_weights, grads)]
            # Outer loop: evaluate on query set
            def forward_with_weights(x, weights):
                x = nn.functional.conv2d(x, weights[0], weights[1], stride=2, padding=2)
                x = nn.functional.relu(x)
                x = nn.functional.conv2d(x, weights[2], weights[3], stride=2, padding=1)
                x = nn.functional.relu(x)
                x = nn.functional.conv2d(x, weights[4], weights[5], stride=2, padding=1)
                x = nn.functional.relu(x)
                x = x.view(x.size(0), -1)
                x = nn.functional.linear(x, weights[6], weights[7])
                x = nn.functional.relu(x)
                x = nn.functional.linear(x, weights[8], weights[9])
                return x
            logits_q = forward_with_weights(X_q, fast_weights)
            loss_q = loss_fn(logits_q, y_q)
            meta_loss += loss_q
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()
        print(f"Epoch {epoch+1}: Meta Loss = {meta_loss.item()/meta_batch_size:.4f}")
    return model

In [19]:
# 6. Data Preparation and Meta-Task Definition

# Load data (example for train file)
with h5py.File(FILE_PATHS[0], "r") as f:
    X = np.array(f["X_jets"][:10000])  # Use a subset for speed
    y = np.array(f["y"][:10000])
    pt = np.array(f["pt"][:10000])

# Define pt bins for meta-tasks (e.g., 5 bins)
pt_bins = np.percentile(pt, np.linspace(0, 100, 6))
print("pT bins:", pt_bins)

# Prepare dataset
jet_dataset = JetImageDataset(X, y, pt=pt, pt_bins=pt_bins)

# For baseline: use all data, random split
indices = np.arange(len(y))
np.random.shuffle(indices)
split = int(0.8 * len(indices))
train_idx, val_idx = indices[:split], indices[split:]
train_loader = DataLoader(Subset(jet_dataset, train_idx), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(Subset(jet_dataset, val_idx), batch_size=BATCH_SIZE)

pT bins: [ 70.55924225  95.12071228 106.02068024 117.48006592 135.62369995
 302.24349976]


In [20]:
# 7. Train and Evaluate Baseline

baseline_model = JetCNN()
baseline_model = train_baseline(baseline_model, train_loader, val_loader)

Epoch 1: Val Acc = 0.6350
Epoch 2: Val Acc = 0.6760
Epoch 3: Val Acc = 0.6770
Epoch 4: Val Acc = 0.6915
Epoch 5: Val Acc = 0.6875
Epoch 6: Val Acc = 0.6910
Epoch 7: Val Acc = 0.6805
Epoch 8: Val Acc = 0.6900
Epoch 9: Val Acc = 0.6880
Epoch 10: Val Acc = 0.6850


In [21]:
# 8. Train and Evaluate MAML

maml_model = JetCNN()
maml_model = maml_train(maml_model, jet_dataset, pt_bins)

Epoch 1: Meta Loss = 0.6928
Epoch 2: Meta Loss = 0.7013
Epoch 3: Meta Loss = 0.7049
Epoch 4: Meta Loss = 0.8919
Epoch 5: Meta Loss = 0.8584
Epoch 6: Meta Loss = 1.0428
Epoch 7: Meta Loss = 1.1687
Epoch 8: Meta Loss = 1.1612
Epoch 9: Meta Loss = 0.9326
Epoch 10: Meta Loss = 0.7353


In [22]:
# 9. Few-Shot Evaluation on Held-Out Jet Conditions

def few_shot_eval(model, dataset, pt_bins, n_tasks=5):
    model.eval()
    accs = []
    for bin_idx in range(len(pt_bins)-1):
        (X_s, y_s), (X_q, y_q) = sample_task(dataset, bin_idx, K_SHOT, K_QUERY)
        X_s, y_s, X_q, y_q = X_s.to(DEVICE), y_s.to(DEVICE), X_q.to(DEVICE), y_q.to(DEVICE)
        # Fast adaptation
        fast_weights = [p.clone().detach().requires_grad_(True) for p in model.parameters()]
        loss_fn = nn.CrossEntropyLoss()
        for _ in range(INNER_STEPS):
            logits = model(X_s)
            loss = loss_fn(logits, y_s)
            grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            fast_weights = [w - INNER_LR * g for w, g in zip(fast_weights, grads)]
        # Evaluate on query set
        def forward_with_weights(x, weights):
            x = nn.functional.conv2d(x, weights[0], weights[1], stride=2, padding=2)
            x = nn.functional.relu(x)
            x = nn.functional.conv2d(x, weights[2], weights[3], stride=2, padding=1)
            x = nn.functional.relu(x)
            x = nn.functional.conv2d(x, weights[4], weights[5], stride=2, padding=1)
            x = nn.functional.relu(x)
            x = x.view(x.size(0), -1)
            x = nn.functional.linear(x, weights[6], weights[7])
            x = nn.functional.relu(x)
            x = nn.functional.linear(x, weights[8], weights[9])
            return x
        logits_q = forward_with_weights(X_q, fast_weights)
        preds = logits_q.argmax(dim=1)
        acc = (preds == y_q).float().mean().item()
        accs.append(acc)
        print(f"Bin {bin_idx}: Few-shot accuracy = {acc:.4f}")
    print(f"Mean few-shot accuracy: {np.mean(accs):.4f}")
    return accs

# Evaluate MAML
few_shot_eval(maml_model, jet_dataset, pt_bins)
# Evaluate Baseline (no adaptation, just forward pass)
few_shot_eval(baseline_model, jet_dataset, pt_bins)

Bin 0: Few-shot accuracy = 0.6250
Bin 1: Few-shot accuracy = 0.5000
Bin 2: Few-shot accuracy = 0.3750
Bin 3: Few-shot accuracy = 0.5000
Bin 4: Few-shot accuracy = 0.3438
Mean few-shot accuracy: 0.4688
Bin 0: Few-shot accuracy = 0.5938
Bin 1: Few-shot accuracy = 0.6562
Bin 2: Few-shot accuracy = 0.7812
Bin 3: Few-shot accuracy = 0.6250
Bin 4: Few-shot accuracy = 0.7188
Mean few-shot accuracy: 0.6750


[0.59375, 0.65625, 0.78125, 0.625, 0.71875]