<a href="https://colab.research.google.com/github/arnavsinghal09/GSoC-QMAML/blob/main/Quark_Gluon_Classification_MAML_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# 1. Imports

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import random

In [3]:
# 2. Hyperparameters

DATA_DIR = "/content/drive/MyDrive/quark-gluon-dataset"
FILES = [
    "quark-gluon_train-set_n793900.hdf5",
    "quark-gluon_test-set_n139306.hdf5",
    "quark-gluon_test-set_n10000.hdf5"
]
FILE_PATHS = [f"{DATA_DIR}/{fname}" for fname in FILES]
FILE_LABELS = ["Train", "Test1", "Test2"]

BATCH_SIZE = 128
LEARNING_RATE = 2e-4
EPOCHS = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SHAPE = (3, 125, 125)
N_WAY = 2
K_SHOT = 64
K_QUERY = 128
META_BATCH_SIZE = 8
INNER_STEPS = 5
INNER_LR = 1e-2
NUM_WORKERS = 4

In [4]:
# 3. Dataset and Task Sampler

class JetImageDataset(Dataset):
    def __init__(self, X, y, pt=None, pt_bins=None):
        self.X = X
        self.y = y
        self.pt = pt
        self.pt_bins = pt_bins
        if self.pt is not None and self.pt_bins is not None:
            self.bin_indices = []
            for i in range(len(pt_bins) - 1):
                idx = np.where((pt >= pt_bins[i]) & (pt < pt_bins[i+1]))[0]
                self.bin_indices.append(idx)
        else:
            self.bin_indices = [np.arange(len(y))]
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        x = self.X[idx]
        if x.shape != IMG_SHAPE:
            x = np.transpose(x, (2, 0, 1))  # (3, 125, 125)
        label = int(self.y[idx])
        return torch.tensor(x, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

def sample_task(dataset, bin_idx, k_shot, k_query):
    idxs = dataset.bin_indices[bin_idx]
    idxs = np.random.permutation(idxs)
    support_idxs = idxs[:k_shot]
    query_idxs = idxs[k_shot:k_shot + k_query]
    X_s, y_s = zip(*[dataset[i] for i in support_idxs])
    X_q, y_q = zip(*[dataset[i] for i in query_idxs])
    return (torch.stack(X_s), torch.tensor(y_s)), (torch.stack(X_q), torch.tensor(y_q))

In [5]:
# 4. CNN Module

class JetCNN(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 5, stride=2, padding=2), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(256, 512, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1024), nn.ReLU(),
            nn.Linear(1024, 256), nn.ReLU(),
            nn.Linear(256, n_classes)
        )
    def forward(self, x):
        return self.net(x)

In [6]:
# 5. Data Preparation and Meta-Task Definition

with h5py.File(FILE_PATHS[0], "r") as f:
    X = np.array(f["X_jets"][:100000])  # Use more data for better GPU usage
    y = np.array(f["y"][:100000])
    pt = np.array(f["pt"][:100000])

pt_bins = np.percentile(pt, np.linspace(0, 100, 6))
print("pT bins:", pt_bins)

jet_dataset = JetImageDataset(X, y, pt=pt, pt_bins=pt_bins)

indices = np.arange(len(y))
np.random.shuffle(indices)
split = int(0.8 * len(indices))
train_idx, val_idx = indices[:split], indices[split:]
train_loader = DataLoader(
    Subset(jet_dataset, train_idx),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=NUM_WORKERS
)
val_loader = DataLoader(
    Subset(jet_dataset, val_idx),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

pT bins: [ 70.23306274  95.22527161 105.78971405 117.68426208 135.88970642
 323.42160034]


In [7]:
# 6. Training Loop for Classical Baseline (with All Stats)

def compute_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    return acc, f1, prec, rec

def train_baseline(model, train_loader, val_loader, epochs=EPOCHS):
    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        y_true_train, y_pred_train = [], []
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X.size(0)
            preds = logits.argmax(dim=1)
            y_true_train.append(y)
            y_pred_train.append(preds)
        y_true_train = torch.cat(y_true_train)
        y_pred_train = torch.cat(y_pred_train)
        train_loss /= len(train_loader.dataset)
        train_acc, train_f1, train_prec, train_rec = compute_metrics(y_true_train, y_pred_train)

        # Validation
        model.eval()
        val_loss = 0
        y_true_val, y_pred_val = [], []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                logits = model(X)
                loss = loss_fn(logits, y)
                val_loss += loss.item() * X.size(0)
                preds = logits.argmax(dim=1)
                y_true_val.append(y)
                y_pred_val.append(preds)
        y_true_val = torch.cat(y_true_val)
        y_pred_val = torch.cat(y_pred_val)
        val_loss /= len(val_loader.dataset)
        val_acc, val_f1, val_prec, val_rec = compute_metrics(y_true_val, y_pred_val)

        print(f"Epoch {epoch+1:2d} | "
              f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | "
              f"Prec: {train_prec:.4f} | Rec: {train_rec:.4f} || "
              f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | "
              f"Prec: {val_prec:.4f} | Rec: {val_rec:.4f}")
    return model

In [8]:
# 7. MAML Meta-Learning Loop (with All Stats)

def maml_train(model, dataset, pt_bins, meta_batch_size=META_BATCH_SIZE, epochs=EPOCHS):
    model = model.to(DEVICE)
    meta_optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        meta_loss = 0
        all_y_true, all_y_pred = [], []
        for _ in range(meta_batch_size):
            bin_idx = random.randint(0, len(pt_bins)-2)
            (X_s, y_s), (X_q, y_q) = sample_task(dataset, bin_idx, K_SHOT, K_QUERY)
            X_s, y_s, X_q, y_q = X_s.to(DEVICE), y_s.to(DEVICE), X_q.to(DEVICE), y_q.to(DEVICE)
            # Clone model for inner loop
            fast_weights = [p.clone().detach().requires_grad_(True) for p in model.parameters()]
            # Inner loop
            for _ in range(INNER_STEPS):
                logits = model(X_s)
                loss = loss_fn(logits, y_s)
                grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
                fast_weights = [w - INNER_LR * g for w, g in zip(fast_weights, grads)]
            # Outer loop: evaluate on query set
            def forward_with_weights(x, weights):
              # Conv1
              x = nn.functional.conv2d(x, weights[0], weights[1], stride=2, padding=2)
              x = nn.functional.relu(x)
              # Conv2
              x = nn.functional.conv2d(x, weights[2], weights[3], stride=2, padding=1)
              x = nn.functional.relu(x)
              # Conv3
              x = nn.functional.conv2d(x, weights[4], weights[5], stride=2, padding=1)
              x = nn.functional.relu(x)
              # Conv4
              x = nn.functional.conv2d(x, weights[6], weights[7], stride=2, padding=1)
              x = nn.functional.relu(x)
              # Conv5
              x = nn.functional.conv2d(x, weights[8], weights[9], stride=2, padding=1)
              x = nn.functional.relu(x)
              # Flatten
              x = x.view(x.size(0), -1)
              # FC1
              x = nn.functional.linear(x, weights[10], weights[11])
              x = nn.functional.relu(x)
              # FC2
              x = nn.functional.linear(x, weights[12], weights[13])
              x = nn.functional.relu(x)
              # FC3
              x = nn.functional.linear(x, weights[14], weights[15])
              return x
            logits_q = forward_with_weights(X_q, fast_weights)
            loss_q = loss_fn(logits_q, y_q)
            meta_loss += loss_q
            preds = logits_q.argmax(dim=1)
            all_y_true.append(y_q)
            all_y_pred.append(preds)
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()
        all_y_true = torch.cat(all_y_true)
        all_y_pred = torch.cat(all_y_pred)
        acc, f1, prec, rec = compute_metrics(all_y_true, all_y_pred)
        print(f"Epoch {epoch+1:2d} | Meta Loss: {meta_loss.item()/meta_batch_size:.4f} | "
              f"Acc: {acc:.4f} | F1: {f1:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f}")
    return model

In [9]:
# 8. Train and Evaluate Baseline

baseline_model = JetCNN()
baseline_model = train_baseline(baseline_model, train_loader, val_loader)

Epoch  1 | Train Loss: 0.6932 | Acc: 0.4986 | F1: 0.4677 | Prec: 0.4988 | Rec: 0.4402 || Val Loss: 0.6931 | Acc: 0.5034 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch  2 | Train Loss: 0.6932 | Acc: 0.4990 | F1: 0.4445 | Prec: 0.4992 | Rec: 0.4006 || Val Loss: 0.6931 | Acc: 0.5034 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch  3 | Train Loss: 0.6932 | Acc: 0.4993 | F1: 0.5368 | Prec: 0.4997 | Rec: 0.5799 || Val Loss: 0.6931 | Acc: 0.5034 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch  4 | Train Loss: 0.6931 | Acc: 0.5008 | F1: 0.4562 | Prec: 0.5014 | Rec: 0.4184 || Val Loss: 0.6933 | Acc: 0.4965 | F1: 0.6635 | Prec: 0.4965 | Rec: 0.9996
Epoch  5 | Train Loss: 0.6931 | Acc: 0.4999 | F1: 0.5991 | Prec: 0.5002 | Rec: 0.7468 || Val Loss: 0.6937 | Acc: 0.5034 | F1: 0.0006 | Prec: 0.5000 | Rec: 0.0003
Epoch  6 | Train Loss: 0.6931 | Acc: 0.4996 | F1: 0.3457 | Prec: 0.5000 | Rec: 0.2641 || Val Loss: 0.6935 | Acc: 0.4964 | F1: 0.6634 | Prec: 0.4965 | Rec: 0.9995
Epoch  7 | Train Loss: 0.6777 | Acc: 0.5566 | F1: 0.5879 | Prec: 0.5495 | Rec: 0.6321 || Val Loss: 0.6000 | Acc: 0.6896 | F1: 0.6899 | Prec: 0.6845 | Rec: 0.6954
Epoch  8 | Train Loss: 0.5947 | Acc: 0.6906 | F1: 0.6974 | Prec: 0.6829 | Rec: 0.7126 || Val Loss: 0.5873 | Acc: 0.6991 | F1: 0.7068 | Prec: 0.6847 | Rec: 0.7305
Epoch  9 | Train Loss: 0.5878 | Acc: 0.6956 | F1: 0.7043 | Prec: 0.6854 | Rec: 0.7243 || Val Loss: 0.5865 | Acc: 0.6995 | F1: 0.7132 | Prec: 0.6779 | Rec: 0.7524
Epoch 10 | Train Loss: 0.584

In [10]:
# 9. Train and Evaluate MAML

maml_model = JetCNN()
maml_model = maml_train(maml_model, jet_dataset, pt_bins)

Epoch  1 | Meta Loss: 0.6951 | Acc: 0.4639 | F1: 0.6338 | Prec: 0.4639 | Rec: 1.0000
Epoch  2 | Meta Loss: 0.6909 | Acc: 0.5371 | F1: 0.6989 | Prec: 0.5371 | Rec: 1.0000
Epoch  3 | Meta Loss: 0.6909 | Acc: 0.5332 | F1: 0.6955 | Prec: 0.5332 | Rec: 1.0000
Epoch  4 | Meta Loss: 0.6932 | Acc: 0.5059 | F1: 0.6719 | Prec: 0.5059 | Rec: 1.0000
Epoch  5 | Meta Loss: 0.6969 | Acc: 0.4727 | F1: 0.6419 | Prec: 0.4727 | Rec: 1.0000
Epoch  6 | Meta Loss: 0.6967 | Acc: 0.4834 | F1: 0.6517 | Prec: 0.4834 | Rec: 1.0000
Epoch  7 | Meta Loss: 0.6892 | Acc: 0.5430 | F1: 0.7038 | Prec: 0.5430 | Rec: 1.0000
Epoch  8 | Meta Loss: 0.7045 | Acc: 0.4727 | F1: 0.6419 | Prec: 0.4727 | Rec: 1.0000
Epoch  9 | Meta Loss: 0.7171 | Acc: 0.4453 | F1: 0.6162 | Prec: 0.4453 | Rec: 1.0000
Epoch 10 | Meta Loss: 0.7156 | Acc: 0.4922 | F1: 0.6597 | Prec: 0.4922 | Rec: 1.0000
Epoch 11 | Meta Loss: 0.7313 | Acc: 0.5039 | F1: 0.6701 | Prec: 0.5039 | Rec: 1.0000
Epoch 12 | Meta Loss: 0.7595 | Acc: 0.4795 | F1: 0.6482 | Prec: 0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 16 | Meta Loss: 0.6951 | Acc: 0.5166 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 17 | Meta Loss: 0.6889 | Acc: 0.5527 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 18 | Meta Loss: 0.7111 | Acc: 0.4893 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 19 | Meta Loss: 0.7042 | Acc: 0.5068 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 20 | Meta Loss: 0.7007 | Acc: 0.5049 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 21 | Meta Loss: 0.7063 | Acc: 0.4600 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 22 | Meta Loss: 0.6932 | Acc: 0.5254 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 23 | Meta Loss: 0.6949 | Acc: 0.4951 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 24 | Meta Loss: 0.6930 | Acc: 0.5088 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 25 | Meta Loss: 0.6924 | Acc: 0.5107 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 26 | Meta Loss: 0.6932 | Acc: 0.4883 | F1: 0.0000 | Prec: 0.0000 | Rec: 0.0000
Epoch 27 | Meta Loss: 0.6927 | Acc: 0.5234 | F1: 0.2254 | Prec: 0.5547 | Rec: 0.1414
Epoch 28 | Meta Loss: 0.6929 | Acc: 0.5381 | F1: 0.4679 | Prec: 0.5417 | Rec: 0.4119
Epoch 29 | Meta Loss: 0.6925 | Acc: 0.5225 | F1: 0.3311 | Prec: 0.4727 | Rec: 0.2547
Epoch 30 | Meta Loss: 0.6929 | Acc: 0.5312 | F1: 0.5229 | Prec: 0.5137 | Rec: 0.5324


In [11]:
# 10. Few-Shot Evaluation Function

def few_shot_eval(model, dataset, pt_bins, n_tasks=None, k_shot=K_SHOT, k_query=K_QUERY, inner_steps=INNER_STEPS, inner_lr=INNER_LR):
    model.eval()
    all_acc, all_f1, all_prec, all_rec = [], [], [], []
    n_bins = len(pt_bins) - 1 if n_tasks is None else n_tasks
    for bin_idx in range(n_bins):
        (X_s, y_s), (X_q, y_q) = sample_task(dataset, bin_idx, k_shot, k_query)
        X_s, y_s, X_q, y_q = X_s.to(DEVICE), y_s.to(DEVICE), X_q.to(DEVICE), y_q.to(DEVICE)
        # Fast adaptation (inner loop)
        fast_weights = [p.clone().detach().requires_grad_(True) for p in model.parameters()]
        loss_fn = nn.CrossEntropyLoss()
        for _ in range(inner_steps):
            logits = model(X_s)
            loss = loss_fn(logits, y_s)
            grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            fast_weights = [w - inner_lr * g for w, g in zip(fast_weights, grads)]
        # Evaluate on query set
        def forward_with_weights(x, weights):
          # Conv1
          x = nn.functional.conv2d(x, weights[0], weights[1], stride=2, padding=2)
          x = nn.functional.relu(x)
          # Conv2
          x = nn.functional.conv2d(x, weights[2], weights[3], stride=2, padding=1)
          x = nn.functional.relu(x)
          # Conv3
          x = nn.functional.conv2d(x, weights[4], weights[5], stride=2, padding=1)
          x = nn.functional.relu(x)
          # Conv4
          x = nn.functional.conv2d(x, weights[6], weights[7], stride=2, padding=1)
          x = nn.functional.relu(x)
          # Conv5
          x = nn.functional.conv2d(x, weights[8], weights[9], stride=2, padding=1)
          x = nn.functional.relu(x)
          # Flatten
          x = x.view(x.size(0), -1)
          # FC1
          x = nn.functional.linear(x, weights[10], weights[11])
          x = nn.functional.relu(x)
          # FC2
          x = nn.functional.linear(x, weights[12], weights[13])
          x = nn.functional.relu(x)
          # FC3
          x = nn.functional.linear(x, weights[14], weights[15])
          return x
        logits_q = forward_with_weights(X_q, fast_weights)
        preds = logits_q.argmax(dim=1)
        acc = accuracy_score(y_q.cpu().numpy(), preds.cpu().numpy())
        f1 = f1_score(y_q.cpu().numpy(), preds.cpu().numpy())
        prec = precision_score(y_q.cpu().numpy(), preds.cpu().numpy())
        rec = recall_score(y_q.cpu().numpy(), preds.cpu().numpy())
        all_acc.append(acc)
        all_f1.append(f1)
        all_prec.append(prec)
        all_rec.append(rec)
        print(f"Bin {bin_idx}: Acc={acc:.4f} | F1={f1:.4f} | Prec={prec:.4f} | Rec={rec:.4f}")
    print(f"\nMean Few-Shot: Acc={np.mean(all_acc):.4f} | F1={np.mean(all_f1):.4f} | Prec={np.mean(all_prec):.4f} | Rec={np.mean(all_rec):.4f}")
    return all_acc, all_f1, all_prec, all_rec

In [12]:
# 11. Few-Shot Evaluation for MAML and Baseline

print("MAML Few-Shot Evaluation:")
few_shot_eval(maml_model, jet_dataset, pt_bins)

print("\nClassical Baseline Few-Shot Evaluation (no adaptation):")
# For baseline, skip adaptation: just forward pass
def baseline_few_shot_eval(model, dataset, pt_bins, n_tasks=None, k_query=K_QUERY):
    model.eval()
    all_acc, all_f1, all_prec, all_rec = [], [], [], []
    n_bins = len(pt_bins) - 1 if n_tasks is None else n_tasks
    for bin_idx in range(n_bins):
        idxs = dataset.bin_indices[bin_idx]
        idxs = np.random.permutation(idxs)
        query_idxs = idxs[:k_query]
        X_q, y_q = zip(*[dataset[i] for i in query_idxs])
        X_q = torch.stack(X_q).to(DEVICE)
        y_q = torch.tensor(y_q).to(DEVICE)
        with torch.no_grad():
            logits_q = model(X_q)
            preds = logits_q.argmax(dim=1)
        acc = accuracy_score(y_q.cpu().numpy(), preds.cpu().numpy())
        f1 = f1_score(y_q.cpu().numpy(), preds.cpu().numpy())
        prec = precision_score(y_q.cpu().numpy(), preds.cpu().numpy())
        rec = recall_score(y_q.cpu().numpy(), preds.cpu().numpy())
        all_acc.append(acc)
        all_f1.append(f1)
        all_prec.append(prec)
        all_rec.append(rec)
        print(f"Bin {bin_idx}: Acc={acc:.4f} | F1={f1:.4f} | Prec={prec:.4f} | Rec={rec:.4f}")
    print(f"\nMean Few-Shot: Acc={np.mean(all_acc):.4f} | F1={np.mean(all_f1):.4f} | Prec={np.mean(all_prec):.4f} | Rec={np.mean(all_rec):.4f}")
    return all_acc, all_f1, all_prec, all_rec

baseline_few_shot_eval(baseline_model, jet_dataset, pt_bins)

MAML Few-Shot Evaluation:
Bin 0: Acc=0.5625 | F1=0.0000 | Prec=0.0000 | Rec=0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Bin 1: Acc=0.4688 | F1=0.0000 | Prec=0.0000 | Rec=0.0000
Bin 2: Acc=0.4531 | F1=0.0000 | Prec=0.0000 | Rec=0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Bin 3: Acc=0.5391 | F1=0.7005 | Prec=0.5391 | Rec=1.0000
Bin 4: Acc=0.5312 | F1=0.0000 | Prec=0.0000 | Rec=0.0000

Mean Few-Shot: Acc=0.5109 | F1=0.1401 | Prec=0.1078 | Rec=0.2000

Classical Baseline Few-Shot Evaluation (no adaptation):
Bin 0: Acc=0.7656 | F1=0.7222 | Prec=0.8125 | Rec=0.6500
Bin 1: Acc=0.6875 | F1=0.6774 | Prec=0.7119 | Rec=0.6462
Bin 2: Acc=0.7891 | F1=0.7970 | Prec=0.8281 | Rec=0.7681
Bin 3: Acc=0.7891 | F1=0.7874 | Prec=0.7937 | Rec=0.7812
Bin 4: Acc=0.8125 | F1=0.8310 | Prec=0.8551 | Rec=0.8082

Mean Few-Shot: Acc=0.7688 | F1=0.7630 | Prec=0.8002 | Rec=0.7307


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


([0.765625, 0.6875, 0.7890625, 0.7890625, 0.8125],
 [0.7222222222222222,
  0.6774193548387096,
  0.7969924812030075,
  0.7874015748031497,
  0.8309859154929577],
 [0.8125, 0.711864406779661, 0.828125, 0.7936507936507936, 0.855072463768116],
 [0.65, 0.6461538461538462, 0.7681159420289855, 0.78125, 0.8082191780821918])