In [None]:
!pip install -q timm==0.9.2 transformers sentence-transformers torchmetrics scikit-learn statsmodels
# If you use BLIP captioning and sentence embeddings uncomment:
!pip install -q transformers sentence-transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!mkdir -p ~/.kaggle
# !cp '/content/drive/MyDrive/projectResearchPaper/plant/kaggle.json' ~/.kaggle/
!cp '/content/drive/MyDrive/test/kaggle.json' ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
#!/bin/bash
!kaggle datasets download warcoder/indian-medicinal-plant-image-dataset

In [None]:
import zipfile
with zipfile.ZipFile('/content/indian-medicinal-plant-image-dataset.zip', 'r') as zip_ref:
    zip_ref.extractall('/content')

In [None]:
# --- Imports ---
import os, random, math, time, json, zipfile
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# TensorFlow / Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model, applications

# PyTorch & timm
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm
from torchvision import transforms
from PIL import Image

# Sklearn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from statsmodels.stats.contingency_tables import mcnemar

In [None]:
# Repro
SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)
torch.manual_seed(SEED)

In [None]:
# Configuration (edit as needed)
# -------------------------
DATA_DIR = '/content/Medicinal plant dataset'  # change if different
OUT_ROOT = '/content/experiments'                      # where results will go
os.makedirs(OUT_ROOT, exist_ok=True)
PLOTS_DIR = os.path.join(OUT_ROOT, 'plots'); os.makedirs(PLOTS_DIR, exist_ok=True)
DATA_DIR = DATA_DIR
IMG_SIZE = (224,224)            # H,W
BATCH_SIZE_TF = 24
NUM_WORKERS = 4
EPOCHS_TF = 12
EPOCHS_PY = 12
FINE_TUNE_EPOCHS_TF = 10
LR_TF = 1e-3
LR_PY = 2e-5
FINE_TUNE_LR_TF = 1e-5
WEIGHT_DECAY = 1e-5
REPEATS = 1    # reduce during debugging; set >1 if you want repeats
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", DEVICE)
import pathlib, os, random, shutil
data_dir = '/content/Medicinal plant dataset'
data_dir_path = pathlib.Path(data_dir)
classes = [d.name for d in data_dir_path.iterdir() if d.is_dir()]

In [None]:
# -------------------------
# Dataset (Keras-style listing)
# -------------------------
class FolderDatasetListing:
    def __init__(self, root_dir):
        self.root = Path(root_dir)
        classes = sorted([p.name for p in self.root.iterdir() if p.is_dir()])
        self.class_to_idx = {c:i for i,c in enumerate(classes)}
        self.items = []
        for c in classes:
            for f in (self.root / c).glob('*'):
                if f.suffix.lower() in ('.jpg','.jpeg','.png','.webp'):
                    self.items.append((str(f), self.class_to_idx[c]))
        print(f"Found {len(self.items)} images across {len(classes)} classes.")
        self.classes = classes
        self.num_classes = len(classes)

full_dataset = FolderDatasetListing(DATA_DIR)
num_classes = full_dataset.num_classes
idx_to_class = {v:k for k,v in full_dataset.class_to_idx.items()}


In [None]:
# Create train/val/test split indices
n = len(full_dataset.items)
train_n = int(0.8 * n)
val_n = int(0.1 * n)
test_n = n - train_n - val_n
indices = list(range(n))
random.shuffle(indices)
train_idxs = indices[:train_n]
val_idxs   = indices[train_n:train_n+val_n]
test_idxs  = indices[train_n+val_n:]

In [None]:
# Helper to build tf.data from indices
def make_tf_dataset_from_indices(indices_list, batch_size=BATCH_SIZE_TF, transform=None, shuffle=False):
    paths = [full_dataset.items[i][0] for i in indices_list]
    labels = [full_dataset.items[i][1] for i in indices_list]
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    def _load(path, label):
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3, expand_animations=False)
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, IMG_SIZE)
        return image, label
    ds = ds.map(_load, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(1024)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds_tf = make_tf_dataset_from_indices(train_idxs, batch_size=BATCH_SIZE_TF, shuffle=True)
val_ds_tf   = make_tf_dataset_from_indices(val_idxs, batch_size=BATCH_SIZE_TF, shuffle=False)
test_ds_tf  = make_tf_dataset_from_indices(test_idxs, batch_size=BATCH_SIZE_TF, shuffle=False)


In [None]:
# Also keep lists for PyTorch inference
train_items = [full_dataset.items[i] for i in train_idxs]
val_items = [full_dataset.items[i] for i in val_idxs]
test_items = [full_dataset.items[i] for i in test_idxs]


In [None]:
# -------------------------
# TensorFlow / Keras Models (from your code) - fixed and ready
# -------------------------
from tensorflow.keras import applications, optimizers, callbacks, regularizers

def build_mobilenetv2(input_shape=(*IMG_SIZE,3)):
    base = applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable=False
    x = layers.GlobalAveragePooling2D()(base.output)
    #  add L2 and lower dropout to reduce underfitting
    x = layers.Dense(128, activation='relu',
                     kernel_regularizer=regularizers.l2(WEIGHT_DECAY))(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)
    model = Model(base.input, out)
    model.compile(optimizer=optimizers.Adam(LR_TF), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def build_resnet152(input_shape=(*IMG_SIZE,3)):
    base = applications.ResNet152(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable=False
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dense(128, activation='relu',
                     kernel_regularizer=regularizers.l2(WEIGHT_DECAY))(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)
    model = Model(base.input, out)
    model.compile(optimizer=optimizers.Adam(LR_TF), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def build_xception(input_shape=(*IMG_SIZE,3)):
    base = applications.Xception(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable=False
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dense(256, activation='relu',
                     kernel_regularizer=regularizers.l2(WEIGHT_DECAY))(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)
    model = Model(base.input, out)
    model.compile(optimizer=optimizers.Adam(LR_TF), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def build_mobilenetv3(input_shape=(*IMG_SIZE,3), model_type='large'):
    try:
        if model_type=='large':
            base = applications.MobileNetV3Large(weights='imagenet', include_top=False, input_shape=input_shape)
        else:
            base = applications.MobileNetV3Small(weights='imagenet', include_top=False, input_shape=input_shape)
    except Exception as e:
        print("MobileNetV3 not available; falling back to MobileNetV2. Error:", e)
        return build_mobilenetv2(input_shape)
    base.trainable=False
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dense(128, activation='relu',
                     kernel_regularizer=regularizers.l2(WEIGHT_DECAY))(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)
    model = Model(base.input, out)
    model.compile(optimizer=optimizers.Adam(LR_TF), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def build_resnet50_for_feats(input_shape=(*IMG_SIZE,3)):
    base = applications.ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable=False
    x = layers.GlobalAveragePooling2D()(base.output)
    model = Model(base.input, x)
    return model

def build_vit_custom_improved(input_shape=(*IMG_SIZE,3), num_classes=num_classes,
                              patch_size=16, projection_dim=128, transformer_layers=6,
                              num_heads=8, mlp_dim=256):
    inputs = layers.Input(shape=input_shape)
    patches = layers.Conv2D(filters=projection_dim, kernel_size=patch_size, strides=patch_size, padding="valid")(inputs)
    patches = layers.Reshape((-1, projection_dim))(patches)
    num_patches = patches.shape[1]
    positions = tf.range(start=0, limit=num_patches, delta=1)
    position_embeddings = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)(positions)
    encoded_patches = patches + position_embeddings
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=0.1)(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        mlp_output = keras.Sequential([
            layers.Dense(mlp_dim, activation='gelu'),
            layers.Dropout(0.1),
            layers.Dense(projection_dim),
            layers.Dropout(0.1)
        ])(x3)
        encoded_patches = layers.Add()([x2, mlp_output])
    x = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(mlp_dim, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = Model(inputs, outputs, name="VisionTransformer_Custom_Improved")
    optimizer = optimizers.Adam(learning_rate=3e-4)
    model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# -------------------------
# Training helper for Keras: safe compile & fit
# -------------------------
def compile_and_fit(model, name, epochs=EPOCHS_TF, train_ds=train_ds_tf, val_ds=val_ds_tf, outdir=OUT_ROOT):
    cb = [
        callbacks.ModelCheckpoint(os.path.join(outdir, f'{name}_best.h5'), save_best_only=True, monitor='val_accuracy', mode='max'),
        callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1),
        callbacks.EarlyStopping(monitor='val_accuracy', patience=6, restore_best_weights=True)
    ]
    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=cb, verbose=2)
    model.save(os.path.join(outdir, f'{name}_final.h5'))
    return history

In [None]:
# -------------------------
# Build models list (Keras)
# -------------------------
models_to_train = {}
models_to_train['MobileNetV2'] = build_mobilenetv2()
models_to_train['ResNet152'] = build_resnet152()
models_to_train['MobileNetV3'] = build_mobilenetv3(model_type='large')
models_to_train['ResNet50_feats'] = build_resnet50_for_feats()
# models_to_train['VisionTransformer_Custom'] = build_vit_custom_improved()

# Optional: don't instantiate Xception on small runtimes (it is heavy)
# models_to_train['Xception'] = build_xception()

In [None]:
# -------------------------
# Training Keras classifiers (with head training + fine-tune)
# -------------------------
trained_classifiers = {}
histories = {}
reports = {}
cms = {}
data_dir = OUT_ROOT

def plot_confusion(cm, class_names, save_path, title='Confusion Matrix'):
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
    plt.title(title)
    plt.ylabel('True'); plt.xlabel('Predicted')
    plt.tight_layout()
    plt.savefig(save_path); plt.close()

def save_plot_history(hist, save_path):
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(hist.history.get('loss',[]), label='train_loss'); plt.plot(hist.history.get('val_loss',[]), label='val_loss'); plt.legend()
    plt.subplot(1,2,2)
    plt.plot(hist.history.get('accuracy',[]), label='train_acc'); plt.plot(hist.history.get('val_accuracy',[]), label='val_acc'); plt.legend()
    plt.tight_layout(); plt.savefig(save_path); plt.close()

for name, model in list(models_to_train.items()):
    # If model is a feature extractor (ResNet50_feats), build a classifier head
    if name == 'ResNet50_feats':
        print(f"Building classifier for {name}")
        base = model
        inputs = base.input
        x = base.output
        x = layers.Dense(256, activation='relu',
                         kernel_regularizer=regularizers.l2(WEIGHT_DECAY))(x)   # 🔵 NEW LINE: add L2
        x = layers.Dropout(0.2)(x)  # 🟡 CHANGE THIS (lower dropout)
        out = layers.Dense(num_classes, activation='softmax')(x)
        clf = Model(inputs, out)
        clf.compile(optimizer=optimizers.Adam(LR_TF), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        # Train head then fine-tune:
        print(f"Training head for {name} ...")
        hist_head = compile_and_fit(clf, name, epochs=EPOCHS_TF)
        histories[name] = hist_head
        trained_classifiers[name] = clf
    else:
        print(f"Training {name} - head first, then fine-tune")
        # ---- Head training (base already frozen in builders) ----
        #  train initial head with existing compile in builder
        cb_head = [
            callbacks.ModelCheckpoint(os.path.join(OUT_ROOT, f'{name}_head_best.h5'), save_best_only=True, monitor='val_accuracy', mode='max'),
            callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1),
            callbacks.EarlyStopping(monitor='val_accuracy', patience=6, restore_best_weights=True)
        ]
        hist_head = model.fit(train_ds_tf, validation_data=val_ds_tf, epochs=EPOCHS_TF, callbacks=cb_head, verbose=2)
        histories[f"{name}_head"] = hist_head
        # 🔵 NEW LINE: Save head model
        model.save(os.path.join(OUT_ROOT, f'{name}_head.h5'))

        # ---- Fine-tune: unfreeze top layers ----
        # unfreeze last UNFREEZE_AT layers of the base (if model has .layers and includes a pretrained base)
        # how many layers to unfreeze (tune if needed)
        UNFREEZE_AT = 100   #  (number of layers from end to unfreeze)
        # set UNFREEZE_AT=None to unfreeze all layers
        # Try to detect base model layers: often layer names include 'mobilenet'/'resnet' etc.
        try:
            # find a base inside the model by checking for a layer with name 'input' and then base layers
            # We'll attempt to find the deepest pretrained layer group by inspecting layer names
            # If model was built with include_top=False, early layers will be part of model.layers
            if UNFREEZE_AT is None:
                for layer in model.layers:
                    layer.trainable = True
            else:
                # Unfreeze last UNFREEZE_AT trainable layers
                # NOTE: this is conservative — you can change UNFREEZE_AT as needed
                trainable_count = 0
                for layer in model.layers[::-1]:
                    if trainable_count < UNFREEZE_AT:
                        layer.trainable = True
                        trainable_count += 1
                    else:
                        layer.trainable = False
            # recompile with a lower lr for fine-tuning
            model.compile(optimizer=optimizers.Adam(FINE_TUNE_LR_TF),
                          loss='sparse_categorical_crossentropy',
                          metrics=['accuracy'])
            #  fine-tune
            hist_ft = model.fit(train_ds_tf, validation_data=val_ds_tf,
                                epochs=EPOCHS_TF + FINE_TUNE_EPOCHS_TF,
                                initial_epoch=hist_head.epoch[-1] + 1 if hasattr(hist_head, 'epoch') and len(hist_head.epoch)>0 else 0,
                                callbacks=[
                                    callbacks.ModelCheckpoint(os.path.join(OUT_ROOT, f'{name}_ft_best.h5'), save_best_only=True, monitor='val_accuracy', mode='max'),
                                    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1),
                                    callbacks.EarlyStopping(monitor='val_accuracy', patience=8, restore_best_weights=True)
                                ],
                                verbose=2)
            histories[name] = hist_ft
            trained_classifiers[name] = model
            #  save final finetuned model
            model.save(os.path.join(OUT_ROOT, f'{name}_finetuned.h5'))
        except Exception as e:
            print("Warning: fine-tuning step failed for", name, " — skipping fine-tune. Error:", e)
            trained_classifiers[name] = model
            histories[name] = hist_head

    # Evaluate on test set (create generator)
    test_paths_local = [full_dataset.items[i][0] for i in test_idxs]
    # use Keras predict on batched numpy arrays
    def keras_predict_from_paths(model, paths, batch=32, target_size=IMG_SIZE):
        arrs=[]
        preds=[]
        for i in range(0,len(paths),batch):
            batch_paths = paths[i:i+batch]
            imgs=[]
            for p in batch_paths:
                im = Image.open(p).resize(target_size)
                a = np.array(im).astype('float32')/255.0
                imgs.append(a)
            xb = np.stack(imgs,0)
            probs = model.predict(xb, verbose=0)
            preds.append(np.argmax(probs,axis=1))
        return np.concatenate(preds, axis=0)
    preds = keras_predict_from_paths(trained_classifiers[name], test_paths_local, batch=32)
    y_true = np.array([full_dataset.items[i][1] for i in test_idxs])
    rep = classification_report(y_true, preds, output_dict=True, zero_division=0)
    cm = confusion_matrix(y_true, preds)
    reports[name] = rep
    cms[name] = cm
    pd.DataFrame(rep).transpose().to_csv(os.path.join(data_dir, f"{name}_classification_report.csv"))
    np.save(os.path.join(data_dir, f"{name}_cm.npy"), cm)
    plot_confusion(cm, full_dataset.classes, os.path.join(data_dir, f"{name}_confusion.png"), title=f"{name} Confusion Matrix")
    # choose history to save: prefer finetune history if present
    save_plot_history(histories.get(name, histories.get(f"{name}_head")), os.path.join(data_dir, f"{name}_train_curve.png"))

print("Keras training/evaluation done. Artifacts in", data_dir)

In [None]:
# -------------------------
# PyTorch section: Swin, DeiT-small, CoAtNet0 (timm)
# -------------------------
# Dataset wrapper for PyTorch using same items lists
class ImageFolderDatasetFromList(Dataset):
    def __init__(self, items_list, transform=None):
        self.items = items_list
        self.transform = transform
    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        p, lbl = self.items[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, int(lbl), p

train_transform_pt = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE[0], scale=(0.6,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_transform_pt = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_ds_pt = ImageFolderDatasetFromList(train_items, transform=train_transform_pt)
val_ds_pt   = ImageFolderDatasetFromList(val_items, transform=val_transform_pt)
test_ds_pt  = ImageFolderDatasetFromList(test_items, transform=val_transform_pt)

train_loader_pt = DataLoader(train_ds_pt, batch_size=32, shuffle=True, num_workers=NUM_WORKERS)
val_loader_pt   = DataLoader(val_ds_pt, batch_size=32, shuffle=False, num_workers=NUM_WORKERS)
test_loader_pt  = DataLoader(test_ds_pt, batch_size=32, shuffle=False, num_workers=NUM_WORKERS)


In [None]:
# factories
def make_swin_tiny(num_classes, pretrained=True):
    return timm.create_model('swin_tiny_patch4_window7_224', pretrained=pretrained, num_classes=num_classes)
def make_deit_small_distilled(num_classes, pretrained=True):
    return timm.create_model('deit_small_distilled_patch16_224', pretrained=pretrained, num_classes=num_classes)
def make_coatnet0(num_classes, pretrained=True):
    return timm.create_model('coatnet_0', pretrained=pretrained, num_classes=num_classes)

pytorch_factories = {'swin_tiny': make_swin_tiny, 'deit_small_distilled': make_deit_small_distilled, 'coatnet_0': make_coatnet0}


In [None]:
### >>> NEW: Print Model Summary (Swin / DeiT / any timm model)

import torch
from torchsummary import summary  # pip install torchsummary if needed

def show_model_summary(factory, img_size=(3,224,224)):
    # Create model WITHOUT classifier — good for transfer learning
    model = factory(num_classes)
    model.eval()
    model = model.to(DEVICE)

    print("\n================ MODEL ARCHITECTURE ================")
    print(model)

    # Try torchsummary (may fail for some timm models — ignore errors)
    try:
        summary(model, img_size)
    except Exception as e:
        print("\n(torchsummary could not parse some models — safe to ignore)")


# ---- FIX: Use correct timm CoAtNet model name ----
pytorch_factories["coatnet_0"] = lambda num_classes: timm.create_model(
    "coatnet_0_rw_224.sw_in1k", pretrained=True, num_classes=num_classes
)


show_model_summary(pytorch_factories["swin_tiny"])
show_model_summary(pytorch_factories["deit_small_distilled"])
show_model_summary(pytorch_factories["coatnet_0"])


In [None]:
# Fixed PyTorch training loop cell — paste & run (replaces the broken block)
import os, math, traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch, torch.nn as nn, torch.optim as optim

# assume pytorch_factories, train_loader_pt, val_loader_pt, test_loader_pt, num_classes, OUT_ROOT, DEVICE, EPOCHS_PY, LR_PY, WEIGHT_DECAY exist

def freeze_model_fraction(model, fraction=0.65):
    params = list(model.parameters())
    total = len(params)
    freeze_count = int(total * fraction)
    freeze_count = min(freeze_count, max(0, total-1))
    for i,p in enumerate(params):
        p.requires_grad = False if i < freeze_count else True

def unfreeze_last_n_params(model, n=4):
    params = list(model.parameters())
    total = len(params)
    for i in range(max(0, total - n), total):
        params[i].requires_grad = True

def train_and_evaluate_pytorch(factory, name, epochs=EPOCHS_PY, lr=LR_PY, out_dir=OUT_ROOT):
    model = factory(num_classes).to(DEVICE)
    # safe freeze
    freeze_model_fraction(model, fraction=0.65)
    trainable = [p for p in model.parameters() if p.requires_grad]
    if len(trainable) == 0:
        unfreeze_last_n_params(model, n=8)
        trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(trainable, lr=lr, weight_decay=WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss()

    best_val = -1.0
    history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}
    last_ckpt = os.path.join(out_dir, f'{name}_last.pth')
    best_ckpt = os.path.join(out_dir, f'{name}_best.pth')

    try:
        for ep in range(epochs):
            model.train()
            run_loss=0.0; correct=0; total=0
            for imgs, labels, _p in train_loader_pt:
                imgs = imgs.to(DEVICE); labels = labels.to(DEVICE)
                optimizer.zero_grad()
                out = model(imgs)
                loss = criterion(out, labels)
                loss.backward(); optimizer.step()
                run_loss += float(loss.item()) * imgs.size(0)
                preds = out.argmax(dim=1)
                correct += (preds==labels).sum().item()
                total += labels.size(0)
            train_loss = run_loss/total if total>0 else 0.0
            train_acc = correct/total if total>0 else 0.0

            # validation
            model.eval()
            vloss=0.0; vcorrect=0; vtotal=0
            with torch.no_grad():
                for imgs, labels, _p in val_loader_pt:
                    imgs = imgs.to(DEVICE); labels = labels.to(DEVICE)
                    out = model(imgs)
                    loss = criterion(out, labels)
                    vloss += float(loss.item()) * imgs.size(0)
                    preds = out.argmax(dim=1)
                    vcorrect += (preds==labels).sum().item()
                    vtotal += labels.size(0)
            val_loss = vloss/vtotal if vtotal>0 else 0.0
            val_acc = vcorrect/vtotal if vtotal>0 else 0.0

            history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)

            print(f'{name} Ep{ep+1}/{epochs} train_acc={train_acc:.4f} val_acc={val_acc:.4f}')

            # checkpoint last
            torch.save(model.state_dict(), last_ckpt)
            # checkpoint best
            if val_acc > best_val:
                best_val = val_acc
                torch.save(model.state_dict(), best_ckpt)
        # end epochs
    except Exception as e:
        print(f"Error during training {name}: {e}")
        traceback.print_exc()
    # Ensure we have a checkpoint: prefer best, else last, else current state
    ckpt_to_load = best_ckpt if os.path.exists(best_ckpt) else (last_ckpt if os.path.exists(last_ckpt) else None)
    if ckpt_to_load is not None:
        try:
            model.load_state_dict(torch.load(ckpt_to_load, map_location=DEVICE))
        except Exception:
            # try non-strict load
            model.load_state_dict(torch.load(ckpt_to_load, map_location=DEVICE), strict=False)
    else:
        print(f"Warning: No checkpoint saved for {name}; using current in-memory model for testing.")

    # test
    y_true = []; y_pred = []
    model.eval()
    with torch.no_grad():
        for imgs, labels, _p in test_loader_pt:
            imgs = imgs.to(DEVICE); labels = labels.to(DEVICE)
            out = model(imgs)
            preds = out.argmax(dim=1).cpu().numpy()
            y_pred.extend(preds.tolist()); y_true.extend(labels.cpu().numpy().tolist())

    rep = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    # save artifacts
    pd.DataFrame(rep).transpose().to_csv(os.path.join(out_dir, f'{name}_classification_report.csv'))
    np.save(os.path.join(out_dir, f'{name}_cm.npy'), cm)
    # history plot (safe: may be empty)
    try:
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1)
        plt.plot(history['train_loss'], label='train_loss'); plt.plot(history['val_loss'], label='val_loss'); plt.legend()
        plt.subplot(1,2,2)
        plt.plot(history['train_acc'], label='train_acc'); plt.plot(history['val_acc'], label='val_acc'); plt.legend()
        plt.suptitle(name)
        plt.tight_layout(); plt.savefig(os.path.join(out_dir, f'{name}_history.png')); plt.close()
    except Exception:
        pass

    return history, rep, cm

# Initialize results dict (must be outside function)
pytorch_results = {}

# Train each pytorch factory and store results safely
for name, fac in pytorch_factories.items():
    try:
        print("Training PyTorch model:", name)
        hist, rep, cm = train_and_evaluate_pytorch(fac, name, epochs=EPOCHS_PY, lr=LR_PY, out_dir=OUT_ROOT)
        pytorch_results[name] = {'history': hist, 'report': rep, 'cm': cm}
    except Exception as e:
        # catch-all ensures loop continues for other models
        print("Error training", name, "->", e)
        traceback.print_exc()
        pytorch_results[name] = {'error': str(e)}

print("PyTorch training done. Artifacts in", OUT_ROOT)
print("Summary of PyTorch results keys:", list(pytorch_results.keys()))


In [None]:
# ============================================================
# FULL: INTEGRATED FEW-SHOT + HYBRID FUSION PIPELINE (READY-TO-PASTE)
# - Robust _make_embedding_extractor and _embed_batch (probe backbone outputs)
# - Handles varied backbone output shapes (spatial / token / pooled)
# - Prints inferred feature dims for debugging
# - Marks new/updated sections with comments
# ============================================================

### Imports
import os, random, math, time, warnings
import numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns
from collections import defaultdict, Counter
from tqdm import tqdm
from PIL import Image

# --- Configurable params (change if needed) ---
FEWSHOT_EPISODES = 200        # increase to 1000+ for final results
FUSION_EMBED_DIM = 256       # per-backbone embedding size
FUSION_BACKBONES = ['swin_tiny','deit_small_distilled','coatnet_0']  # must be keys in pytorch_factories
SAVE_OUTDIR = globals().get('OUT_ROOT', '/content/experiments')
PLOTS_DIR = globals().get('PLOTS_DIR', os.path.join(SAVE_OUTDIR, 'plots'))
os.makedirs(PLOTS_DIR, exist_ok=True)

# Required globals check
required_globals = ['pytorch_factories','train_items','val_items','test_items','num_classes','DEVICE']
missing = [g for g in required_globals if g not in globals()]
if missing:
    raise RuntimeError(f"Missing required globals in notebook environment: {missing}. Run earlier cells first.")

# --------------------------
# Robust fallback: _make_embedding_extractor (probes backbone to infer features)
# --------------------------
if 'make_embedding_extractor' in globals():
    _make_embedding_extractor = make_embedding_extractor
else:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torchvision import transforms

    def _make_embedding_extractor(factory, embed_dim=FUSION_EMBED_DIM, img_size=(224,224)):
        """
        Create (backbone, proj) where proj matches the backbone's actual output feature dimension.
        Probes the model with a dummy input to infer shape.
        """
        model = factory(num_classes)
        try:
            if hasattr(model, "reset_classifier"):
                model.reset_classifier(0)
        except Exception:
            pass

        model = model.to(DEVICE)
        model.eval()

        with torch.no_grad():
            dummy = torch.zeros(1, 3, img_size[0], img_size[1], device=DEVICE)
            try:
                feats = model.forward_features(dummy) if hasattr(model, 'forward_features') else model(dummy)
            except Exception:
                # last resort: try normal forward and accept result
                feats = model(dummy)

        # Deduce feature vector dim and create projection head accordingly
        if torch.is_tensor(feats):
            if feats.ndim == 4:
                feat_dim = feats.shape[1]
                proj = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(feat_dim, embed_dim))
            elif feats.ndim == 3:
                feat_dim = feats.shape[2]
                proj = nn.Sequential(nn.Flatten(), nn.Linear(feat_dim, embed_dim))
            elif feats.ndim == 2:
                feat_dim = feats.shape[1]
                proj = nn.Sequential(nn.Linear(feat_dim, embed_dim))
            else:
                feat_dim = embed_dim
                proj = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(feat_dim, embed_dim))
        else:
            # non-tensor fallback
            feat_dim = embed_dim
            proj = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(feat_dim, embed_dim))

        proj = proj.to(DEVICE)
        for p in model.parameters(): p.requires_grad = False
        for p in proj.parameters(): p.requires_grad = True

        # attach attribute for later inspection
        try:
            model._inferred_feat_dim = int(feat_dim)
        except Exception:
            model._inferred_feat_dim = None
        return model, proj

# --------------------------
# Robust fallback: _embed_batch (handles 4D/3D/2D features)
# --------------------------
if 'embed_batch' in globals():
    _embed_batch = embed_batch
else:
    import torch, torch.nn.functional as F
    from torchvision import transforms

    def _embed_batch(backbone, proj, paths, img_size=(224,224), batch_size=16):
        backbone.eval()
        if proj is not None:
            proj.eval()

        tfm = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
        device = DEVICE
        embeds = []
        with torch.no_grad():
            for i in range(0, len(paths), batch_size):
                batch_paths = paths[i:i+batch_size]
                imgs = []
                for p in batch_paths:
                    im = Image.open(p).convert('RGB')
                    imgs.append(tfm(im))
                x = torch.stack(imgs, dim=0).to(device)

                # get features
                if hasattr(backbone, 'forward_features'):
                    feats = backbone.forward_features(x)
                else:
                    try:
                        feats = backbone(x)
                    except Exception:
                        feats = backbone(x)

                # handle shapes
                if feats.ndim == 4:
                    # spatial map (B, C, H, W)
                    # If proj starts with Linear, pool first; else let proj handle pooling
                    try:
                        first_mod = next(iter(proj)) if isinstance(proj, torch.nn.Sequential) else None
                    except Exception:
                        first_mod = None

                    if isinstance(first_mod, nn.Linear):
                        vec = F.adaptive_avg_pool2d(feats, 1).view(feats.shape[0], -1)
                        emb = proj(vec)
                    else:
                        emb = proj(feats)
                elif feats.ndim == 3:
                    # token sequence (B, N, D) -> average tokens
                    vec = feats.mean(dim=1)
                    # try to use proj directly; if incompatible, do a quick linear mapping
                    try:
                        emb = proj(vec)
                    except Exception:
                        tmp_lin = nn.Linear(vec.shape[1], FUSION_EMBED_DIM).to(device)
                        emb = tmp_lin(vec)
                elif feats.ndim == 2:
                    vec = feats
                    emb = proj(vec) if proj is not None else vec
                else:
                    vec = feats.view(feats.shape[0], -1)
                    emb = proj(vec) if proj is not None else vec

                emb = F.normalize(emb, dim=1)
                embeds.append(emb.cpu().numpy())

        return np.vstack(embeds)

# --------------------------
# compute_prototypes fallback
# --------------------------
if 'compute_prototypes' in globals():
    _compute_prototypes = compute_prototypes
else:
    def _compute_prototypes(support_emb, support_labels, N_way):
        D = support_emb.shape[1]
        prot = np.zeros((N_way, D), dtype=np.float32)
        for c in range(N_way):
            idxs = [i for i,l in enumerate(support_labels) if l==c]
            if len(idxs) == 0:
                continue
            prot[c] = support_emb[idxs].mean(axis=0)
        prot /= (np.linalg.norm(prot, axis=1, keepdims=True) + 1e-9)
        return prot

# --------------------------
# FewShotDataset (robust): reduces N_way when insufficient classes; samples with replacement for small classes
# --------------------------
if 'FewShotDataset' in globals():
    _FewShotDataset = FewShotDataset
else:
    class _FewShotDataset:
        def __init__(self, items, img_size=(224,224)):
            self.items = items
            self.by_class = defaultdict(list)
            for p,l in items:
                self.by_class[l].append(p)
            self.classes = sorted(self.by_class.keys())

        def sample_episode(self, N_way=5, K_shot=5, Q_query=5):
            avail_classes = len(self.classes)
            if avail_classes == 0:
                raise ValueError("FewShotDataset has no classes (items list empty).")
            if avail_classes < N_way:
                warnings.warn(f"Requested N_way={N_way} but only {avail_classes} classes available. Reducing N_way -> {avail_classes}.", UserWarning)
                N_way = avail_classes
            chosen = random.sample(self.classes, N_way)
            support=[]; s_lbl=[]; query=[]; q_lbl=[]
            for i,cls in enumerate(chosen):
                imgs = self.by_class[cls]
                required = K_shot + Q_query
                if len(imgs) >= required:
                    sel = random.sample(imgs, required)
                else:
                    warnings.warn(f"Class {cls} has {len(imgs)} images but required {required}. Sampling with replacement.", UserWarning)
                    sel = [random.choice(imgs) for _ in range(required)]
                support += sel[:K_shot]; query += sel[K_shot:]
                s_lbl += [i]*K_shot; q_lbl += [i]*Q_query
            return support, s_lbl, query, q_lbl

# choose dataset object
if 'fs_val' in globals():
    _fs_val = fs_val
else:
    _fs_val = _FewShotDataset(val_items, img_size=globals().get('IMG_SIZE',(224,224)))

# quick stats
try:
    per_class_counts = {cls: len(imgs) for cls, imgs in _fs_val.by_class.items()}
    print("Validation set: total images =", len(val_items), "; classes =", len(per_class_counts))
    small = sorted(per_class_counts.items(), key=lambda x: x[1])[:6]
    print("Few smallest class counts (class:count):", small)
except Exception as e:
    print("Could not print fs_val stats:", e)

# --------------------------
# 1) Base few-shot evaluation (prototypical)
# --------------------------
if 'accs_5shot' not in globals() or 'accs_1shot' not in globals():
    print("Running base few-shot evaluation (prototypical) ...")

    def _evaluate_episode_base(fs_dataset, N_way=5, K_shot=5, Q_query=10):
        support, s_lbl, query, q_lbl = fs_dataset.sample_episode(N_way,K_shot,Q_query)
        # pick backbone
        base_backbone_name = FUSION_BACKBONES[0] if FUSION_BACKBONES[0] in pytorch_factories else list(pytorch_factories.keys())[0]
        bb_fac = pytorch_factories[base_backbone_name]
        bb, ph = _make_embedding_extractor(bb_fac, embed_dim=FUSION_EMBED_DIM)
        # print inferred feat dim
        try:
            print(f"Base backbone '{base_backbone_name}' inferred feat dim:", getattr(bb, '_inferred_feat_dim', None))
        except Exception:
            pass
        s_emb = _embed_batch(bb, ph, support)
        q_emb = _embed_batch(bb, ph, query)
        proto = _compute_prototypes(s_emb, s_lbl, N_way)
        logits = q_emb @ proto.T
        preds = np.argmax(logits, axis=1)
        return (preds == np.array(q_lbl)).mean()

    accs_5shot = []
    for _ in tqdm(range(FEWSHOT_EPISODES), desc='base 5-shot'):
        accs_5shot.append(_evaluate_episode_base(_fs_val, N_way=5, K_shot=5, Q_query=10))
    accs_1shot = []
    for _ in tqdm(range(FEWSHOT_EPISODES), desc='base 1-shot'):
        accs_1shot.append(_evaluate_episode_base(_fs_val, N_way=5, K_shot=1, Q_query=10))
    print(f"Base 5-shot mean: {np.mean(accs_5shot)*100:.2f}%  |  Base 1-shot mean: {np.mean(accs_1shot)*100:.2f}%")
else:
    print("Base few-shot results found; skipping base evaluation.")

# --------------------------
# 2) Build hybrid fusion extractors (HFPN)
# --------------------------
fusion_extractors = {}
for bk in FUSION_BACKBONES:
    if bk not in pytorch_factories:
        raise RuntimeError(f"Requested fusion backbone '{bk}' not found in pytorch_factories.")
    fac = pytorch_factories[bk]
    bb, ph = _make_embedding_extractor(fac, embed_dim=FUSION_EMBED_DIM)
    bb.eval(); ph.eval()
    fusion_extractors[bk] = (bb, ph)
    print(f"Backbone '{bk}' inferred feat dim:", getattr(bb, '_inferred_feat_dim', None))

print("Built hybrid-fusion extractors for:", list(fusion_extractors.keys()))

def _embed_fusion(paths):
    parts = []
    for bk in FUSION_BACKBONES:
        bb, ph = fusion_extractors[bk]
        emb = _embed_batch(bb, ph, paths)
        parts.append(emb)
    fused = np.concatenate(parts, axis=1)
    fused = fused / (np.linalg.norm(fused, axis=1, keepdims=True) + 1e-9)
    return fused

# --------------------------
# 3) Optional proj-head finetune (short)
# --------------------------
DO_FINETUNE_PROJ = False
if DO_FINETUNE_PROJ:
    print("Fine-tuning projection heads (small supervised head training on train_items).")
    import torch, torch.nn as nn, torch.optim as optim
    from torchvision import transforms
    tf = transforms.Compose([transforms.Resize(globals().get('IMG_SIZE',(224,224))), transforms.ToTensor(),
                             transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    small_train = train_items[:min(500, len(train_items))]
    X_paths = [p for p,_ in small_train]; Y = [lbl for _,lbl in small_train]
    class _SmallDS(torch.utils.data.Dataset):
        def __init__(self, paths, labels): self.paths=paths; self.labels=labels
        def __len__(self): return len(self.paths)
        def __getitem__(self,idx):
            im = Image.open(self.paths[idx]).convert('RGB'); return tf(im), int(self.labels[idx])
    ds = _SmallDS(X_paths, Y)
    dl = torch.utils.data.DataLoader(ds, batch_size=16, shuffle=True, num_workers=2)
    for bk,(bb,ph) in fusion_extractors.items():
        ph.train()
        opt = optim.Adam(filter(lambda p:p.requires_grad, ph.parameters()), lr=1e-3, weight_decay=1e-5)
        ce = nn.CrossEntropyLoss()
        for epoch in range(2):
            for imgs, lbls in dl:
                imgs = imgs.to(DEVICE); lbls = lbls.to(DEVICE)
                with torch.no_grad():
                    feats = bb.forward_features(imgs) if hasattr(bb,'forward_features') else bb(imgs)
                out = ph(feats)
                logits = nn.Linear(out.shape[1], num_classes).to(DEVICE)(out)
                loss = ce(logits, lbls)
                opt.zero_grad(); loss.backward(); opt.step()
        ph.eval()
    print("Proj-head fine-tune done.")

# --------------------------
# 4) Hybrid Fusion few-shot evaluation
# --------------------------
if 'accs_fusion_5shot' not in globals() or 'accs_fusion_1shot' not in globals():
    print("Running Hybrid Fusion few-shot evaluation (HFPN) ...")
    accs_fusion_5shot = []
    for _ in tqdm(range(FEWSHOT_EPISODES), desc='fusion 5-shot'):
        support, s_lbl, query, q_lbl = _fs_val.sample_episode(5,5,10)
        s_emb = _embed_fusion(support)
        q_emb = _embed_fusion(query)
        prot = _compute_prototypes(s_emb, s_lbl, 5)
        logits = q_emb @ prot.T
        preds = np.argmax(logits, axis=1)
        accs_fusion_5shot.append((preds == np.array(q_lbl)).mean())

    accs_fusion_1shot = []
    for _ in tqdm(range(FEWSHOT_EPISODES), desc='fusion 1-shot'):
        support, s_lbl, query, q_lbl = _fs_val.sample_episode(5,1,10)
        s_emb = _embed_fusion(support)
        q_emb = _embed_fusion(query)
        prot = _compute_prototypes(s_emb, s_lbl, 5)
        logits = q_emb @ prot.T
        preds = np.argmax(logits, axis=1)
        accs_fusion_1shot.append((preds == np.array(q_lbl)).mean())

    print(f"Fusion 5-shot mean: {np.mean(accs_fusion_5shot)*100:.2f}%  |  Fusion 1-shot mean: {np.mean(accs_fusion_1shot)*100:.2f}%")
else:
    print("Fusion few-shot results found; skipping fusion evaluation.")

# --------------------------
# 5) Plotting & Export
# --------------------------
accs_grouped = {}
accs_grouped['base_5shot'] = np.array(accs_5shot)
accs_grouped['base_1shot'] = np.array(accs_1shot)
if 'accs_fusion_5shot' in globals():
    accs_grouped['fusion_5shot'] = np.array(accs_fusion_5shot)
    accs_grouped['fusion_1shot'] = np.array(accs_fusion_1shot)

rows=[]
for k,arr in accs_grouped.items():
    rows.append({'setting':k, 'mean_acc':float(np.mean(arr)), 'std_acc':float(np.std(arr,ddof=1) if len(arr)>1 else np.std(arr)),
                 'median':float(np.median(arr)), 'n_episodes':int(len(arr))})
df_summary = pd.DataFrame(rows)
csv_out = os.path.join(SAVE_OUTDIR, 'fewshot_summary_integrated.csv')
df_summary.to_csv(csv_out, index=False)
print("Saved integrated few-shot summary CSV ->", csv_out)

sns.set(style="whitegrid")
plt.figure(figsize=(9,5))
if 'fusion_5shot' in accs_grouped:
    sns.histplot(accs_grouped['base_5shot'], label='base 5-shot', stat='density', kde=True, alpha=0.5)
    sns.histplot(accs_grouped['fusion_5shot'], label='fusion 5-shot', stat='density', kde=True, alpha=0.5)
    sns.histplot(accs_grouped['base_1shot'], label='base 1-shot', stat='density', kde=True, alpha=0.35)
    sns.histplot(accs_grouped['fusion_1shot'], label='fusion 1-shot', stat='density', kde=True, alpha=0.35)
else:
    sns.histplot(accs_grouped['base_5shot'], label='base 5-shot', stat='density', kde=True, alpha=0.6)
    sns.histplot(accs_grouped['base_1shot'], label='base 1-shot', stat='density', kde=True, alpha=0.6)
plt.xlabel('Episode Accuracy'); plt.title('Few-Shot Episode Accuracy Distribution (Integrated)')
plt.legend()
hist_p = os.path.join(PLOTS_DIR, 'fewshot_integrated_histogram.png')
plt.tight_layout(); plt.savefig(hist_p); plt.close()
print("Saved histogram ->", hist_p)

plt.figure(figsize=(7,5))
if 'fusion_5shot' in accs_grouped:
    labels=['base 5-shot','fusion 5-shot','base 1-shot','fusion 1-shot']
    data=[accs_grouped['base_5shot'], accs_grouped['fusion_5shot'], accs_grouped['base_1shot'], accs_grouped['fusion_1shot']]
else:
    labels=['base 5-shot','base 1-shot']; data=[accs_grouped['base_5shot'], accs_grouped['base_1shot']]
sns.boxplot(data=data); plt.xticks(range(len(labels)), labels, rotation=15)
plt.ylabel('Episode Accuracy'); plt.title('Few-Shot Accuracy Boxplot (Integrated)')
box_p = os.path.join(PLOTS_DIR, 'fewshot_integrated_boxplot.png')
plt.tight_layout(); plt.savefig(box_p); plt.close()
print("Saved boxplot ->", box_p)

plt.figure(figsize=(6,4))
if 'fusion_5shot' in accs_grouped:
    bars=[accs_grouped['base_5shot'].mean()*100, accs_grouped['fusion_5shot'].mean()*100,
          accs_grouped['base_1shot'].mean()*100, accs_grouped['fusion_1shot'].mean()*100]
    errs=[accs_grouped['base_5shot'].std(ddof=1)*100 if len(accs_grouped['base_5shot'])>1 else 0,
          accs_grouped['fusion_5shot'].std(ddof=1)*100 if len(accs_grouped['fusion_5shot'])>1 else 0,
          accs_grouped['base_1shot'].std(ddof=1)*100 if len(accs_grouped['base_1shot'])>1 else 0,
          accs_grouped['fusion_1shot'].std(ddof=1)*100 if len(accs_grouped['fusion_1shot'])>1 else 0]
    bl=['base 5-shot','fusion 5-shot','base 1-shot','fusion 1-shot']
else:
    bars=[accs_grouped['base_5shot'].mean()*100, accs_grouped['base_1shot'].mean()*100]
    errs=[accs_grouped['base_5shot'].std(ddof=1)*100 if len(accs_grouped['base_5shot'])>1 else 0,
          accs_grouped['base_1shot'].std(ddof=1)*100 if len(accs_grouped['base_1shot'])>1 else 0]
    bl=['base 5-shot','base 1-shot']
x=np.arange(len(bars)); plt.bar(x,bars,yerr=errs,capsize=5); plt.xticks(x,bl,rotation=12)
plt.ylabel('Mean Accuracy (%)'); plt.title('Few-Shot Mean Accuracy ± Std (Integrated)')
for i,v in enumerate(bars): plt.text(i, v + (errs[i]+0.5), f"{v:.2f}% ± {errs[i]:.2f}%", ha='center', fontsize=9)
bar_p = os.path.join(PLOTS_DIR, 'fewshot_integrated_meanstdbar.png')
plt.tight_layout(); plt.savefig(bar_p); plt.close()
print("Saved mean±std bar ->", bar_p)

plt.figure(figsize=(10,4))
if 'fusion_5shot' in accs_grouped:
    plt.plot(np.sort(accs_grouped['base_5shot'])[::-1], label='base 5-shot', alpha=0.8)
    plt.plot(np.sort(accs_grouped['fusion_5shot'])[::-1], label='fusion 5-shot', alpha=0.8)
    plt.plot(np.sort(accs_grouped['base_1shot'])[::-1], label='base 1-shot', alpha=0.6)
    plt.plot(np.sort(accs_grouped['fusion_1shot'])[::-1], label='fusion 1-shot', alpha=0.6)
else:
    plt.plot(np.sort(accs_grouped['base_5shot'])[::-1], label='base 5-shot', alpha=0.8)
    plt.plot(np.sort(accs_grouped['base_1shot'])[::-1], label='base 1-shot', alpha=0.8)
plt.xlabel('Episode index (sorted)'); plt.ylabel('Accuracy'); plt.legend(); plt.title('Sorted Episode Accuracies (Integrated)')
sorted_p = os.path.join(PLOTS_DIR, 'fewshot_integrated_sortedepisodes.png')
plt.tight_layout(); plt.savefig(sorted_p); plt.close()
print("Saved sorted-episodes plot ->", sorted_p)

print("\nIntegrated few-shot summary (printed):")
print(df_summary.to_string(index=False))

combined_csv = os.path.join(SAVE_OUTDIR, 'combined_summary.csv')
if os.path.exists(combined_csv):
    try:
        comb = pd.read_csv(combined_csv)
        to_add = []
        to_add.append({'model':'fewshot::base_proto','accuracy':float(df_summary.loc[df_summary['setting']=='base_5shot','mean_acc'].values[0]),
                       'macro_precision':np.nan,'macro_recall':np.nan,'macro_f1':np.nan})
        if 'fusion_5shot' in accs_grouped:
            to_add.append({'model':'fewshot::hybrid_fusion','accuracy':float(df_summary.loc[df_summary['setting']=='fusion_5shot','mean_acc'].values[0]),
                           'macro_precision':np.nan,'macro_recall':np.nan,'macro_f1':np.nan})
        comb = pd.concat([comb, pd.DataFrame(to_add)], ignore_index=True)
        comb.to_csv(os.path.join(SAVE_OUTDIR,'combined_summary_with_fewshot_integrated.csv'), index=False)
        print("Appended few-shot rows to combined_summary -> combined_summary_with_fewshot_integrated.csv")
    except Exception as e:
        print("Could not append to combined_summary:", e)

print("\nAll integrated few-shot artifacts saved in:", PLOTS_DIR)
# ============================================================
# END CELL
# ============================================================


In [None]:
# -------------------------
# Combined comparison: run both Keras and PyTorch models on the same test set and save combined metrics & plots
# -------------------------
# Build test_paths and y_test (we already have test_items)
test_paths = [p for p,_ in test_items]
y_test = np.array([lbl for _,lbl in test_items])
print("Test samples:", len(test_paths))

In [None]:
# Keras predictions (use trained_classifiers dict)
def keras_predict_on_paths(keras_model, paths, batch=32, target_size=IMG_SIZE):
    preds_list=[]; probs_list=[]
    for i in range(0,len(paths),batch):
        batch_paths = paths[i:i+batch]
        imgs=[]
        for p in batch_paths:
            im = Image.open(p).resize(target_size)
            a = np.array(im).astype('float32')/255.0
            imgs.append(a)
        xb = np.stack(imgs, axis=0)
        probs = keras_model.predict(xb, verbose=0)
        preds_list.append(np.argmax(probs, axis=1))
        probs_list.append(probs)
    return np.concatenate(probs_list, axis=0), np.concatenate(preds_list, axis=0)

keras_preds = {}
for name, model in trained_classifiers.items():
    print("Predict (Keras):", name)
    probs, preds = keras_predict_on_paths(model, test_paths, batch=32, target_size=IMG_SIZE)
    keras_preds[name] = {'probs': probs, 'preds': preds}


In [None]:
# PyTorch predictions (load best ckpts)
val_tf_pt = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
def pytorch_predict_ckpt(factory, ckpt, paths, batch=32):
    model = factory(num_classes).to(DEVICE)
    st = torch.load(ckpt, map_location=DEVICE)
    try:
        model.load_state_dict(st)
    except Exception:
        if isinstance(st, dict) and 'state_dict' in st:
            model.load_state_dict(st['state_dict'])
        else:
            model.load_state_dict(st, strict=False)
    model.eval()
    all_probs=[]; all_preds=[]
    for i in range(0,len(paths),batch):
        batch_paths = paths[i:i+batch]
        imgs=[]
        for p in batch_paths:
            im = Image.open(p).convert('RGB')
            t = val_tf_pt(im)
            imgs.append(t)
        xb = torch.stack(imgs).to(DEVICE)
        with torch.no_grad():
            logits = model(xb)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            preds = probs.argmax(axis=1)
        all_probs.append(probs); all_preds.append(preds)
    return np.vstack(all_probs), np.concatenate(all_preds)

pytorch_preds = {}
for name, fac in pytorch_factories.items():
    ckpt_path = os.path.join(OUT_ROOT, f'{name}_best.pth')
    if not os.path.exists(ckpt_path):
        print("PyTorch checkpoint missing for", name, "-> skipping predict")
        continue
    print("Predict (PyTorch):", name)
    probs, preds = pytorch_predict_ckpt(fac, ckpt_path, test_paths, batch=32)
    pytorch_preds[name] = {'probs': probs, 'preds': preds}


In [None]:
# Combine all preds
all_preds = {}
all_preds.update({f"keras::{k}": v for k,v in keras_preds.items()})
all_preds.update({f"torch::{k}": v for k,v in pytorch_preds.items()})
if len(all_preds)==0:
    raise RuntimeError("No models produced predictions. Check earlier steps.")


In [None]:
# Compute metrics and save
rows=[]
for key, v in all_preds.items():
    y_pred = v['preds']
    acc = accuracy_score(y_test, y_pred)
    mp = precision_score(y_test, y_pred, average='macro', zero_division=0)
    mr = recall_score(y_test, y_pred, average='macro', zero_division=0)
    mf1 = f1_score(y_test, y_pred, average='macro', zero_division=0)
    rows.append({'model': key, 'accuracy': acc, 'macro_precision': mp, 'macro_recall': mr, 'macro_f1': mf1})
summary_df = pd.DataFrame(rows).sort_values('macro_f1', ascending=False).reset_index(drop=True)
summary_df.to_csv(os.path.join(OUT_ROOT,'combined_summary.csv'), index=False)
print("Saved combined_summary.csv")

# Save per-model confusion and classification reports & plots
for key, v in all_preds.items():
    preds = v['preds']
    cm = confusion_matrix(y_test, preds)
    rep = classification_report(y_test, preds, output_dict=True, zero_division=0)
    pd.DataFrame(rep).transpose().to_csv(os.path.join(OUT_ROOT, f"{key.replace('::','_')}_classification_report.csv"))
    np.save(os.path.join(OUT_ROOT, f"{key.replace('::','_')}_cm.npy"), cm)
    # plot
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=full_dataset.classes, yticklabels=full_dataset.classes, cmap='Blues')
    plt.title(f'Confusion: {key}'); plt.tight_layout()
    plt.savefig(os.path.join(PLOTS_DIR, f"{key.replace('::','_')}_confusion.png")); plt.close()

# Plot overall ranking bars
plt.figure(figsize=(max(6, len(summary_df)*0.6),4))
sns.barplot(data=summary_df, x='model', y='accuracy'); plt.xticks(rotation=45, ha='right'); plt.tight_layout(); plt.savefig(os.path.join(PLOTS_DIR,'combined_accuracy.png')); plt.close()
plt.figure(figsize=(max(6, len(summary_df)*0.6),4))
sns.barplot(data=summary_df, x='model', y='macro_f1'); plt.xticks(rotation=45, ha='right'); plt.tight_layout(); plt.savefig(os.path.join(PLOTS_DIR,'combined_macro_f1.png')); plt.close()
print("Saved combined plots to", PLOTS_DIR)


In [None]:
# ---------------------------
# Add this cell: Summary report + bar plots for all models
# Paste & run after training + combined evaluation cells
# ---------------------------
import time, math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch

sns.set(style="whitegrid")

OUT_CSV = os.path.join(OUT_ROOT, 'models_comparison_summary.csv')
PLOTS_DIR = globals().get('PLOTS_DIR', os.path.join(OUT_ROOT, 'plots'))
os.makedirs(PLOTS_DIR, exist_ok=True)

# ---- Helpers to (re)compute predictions if needed ----
need_recompute = 'all_preds' not in globals() or not isinstance(all_preds, dict) or len(all_preds)==0

# Use existing all_preds if present
preds_map = {}
if not need_recompute:
    preds_map = dict(all_preds)  # keys like 'keras::Model', 'torch::name'
else:
    # fallback: try to compute using helpers if defined
    test_paths_local = globals().get('test_paths', [p for p,_ in test_items])
    y_test_local = np.array([lbl for _,lbl in test_items])
    print("Recomputing predictions (this may take a short while)...")
    # Keras
    if 'trained_classifiers' in globals() and isinstance(trained_classifiers, dict):
        for name, km in trained_classifiers.items():
            try:
                # reuse keras_predict_on_paths if available
                if 'keras_predict_on_paths' in globals():
                    probs, preds = keras_predict_on_paths(km, test_paths_local, batch=32, target_size=IMG_SIZE)
                else:
                    # simple fallback
                    probs_list=[]; preds_list=[]
                    for i in range(0,len(test_paths_local),32):
                        batch = test_paths_local[i:i+32]; imgs=[]
                        for p in batch:
                            im = Image.open(p).resize(IMG_SIZE)
                            a = np.array(im).astype('float32')/255.0; imgs.append(a)
                        xb = np.stack(imgs,0)
                        ps = km.predict(xb, verbose=0)
                        probs_list.append(ps); preds_list.append(np.argmax(ps,axis=1))
                    probs = np.vstack(probs_list); preds = np.concatenate(preds_list)
                preds_map[f"keras::{name}"] = {'probs':probs, 'preds':preds}
                print(" - Keras preds:", name)
            except Exception as e:
                print("  Failed Keras predict for", name, e)
    # PyTorch
    if 'pytorch_factories' in globals():
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        val_tf_pt = globals().get('val_tf_pt', None)
        if val_tf_pt is None:
            from torchvision import transforms
            val_tf_pt = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
        for name, fac in pytorch_factories.items():
            ckpt_path = os.path.join(OUT_ROOT, f'{name}_best.pth')
            if not os.path.exists(ckpt_path):
                print(" - skip torch:", name, " (checkpoint missing )")
                continue
            try:
                # reuse pytorch_predict_ckpt if present
                if 'pytorch_predict_ckpt' in globals():
                    probs, preds = pytorch_predict_ckpt(fac, ckpt_path, test_paths_local, batch=32)
                else:
                    # minimal local predict:
                    model = fac(num_classes).to(device)
                    st = torch.load(ckpt_path, map_location=device)
                    try: model.load_state_dict(st)
                    except:
                        if isinstance(st, dict) and 'state_dict' in st: model.load_state_dict(st['state_dict'])
                        else: model.load_state_dict(st, strict=False)
                    model.eval()
                    all_probs=[]; all_preds=[]
                    for i in range(0,len(test_paths_local),32):
                        batch = test_paths_local[i:i+32]; imgs=[]
                        for p in batch:
                            im = Image.open(p).convert('RGB'); t = val_tf_pt(im); imgs.append(t)
                        xb = torch.stack(imgs).to(device)
                        with torch.no_grad():
                            logits = model(xb); probs_b = torch.softmax(logits, dim=1).cpu().numpy(); preds_b = probs_b.argmax(axis=1)
                        all_probs.append(probs_b); all_preds.append(preds_b)
                    probs = np.vstack(all_probs); preds = np.concatenate(all_preds)
                preds_map[f"torch::{name}"] = {'probs':probs, 'preds':preds}
                print(" - Torch preds:", name)
            except Exception as e:
                print("  Failed Torch predict for", name, e)

# ---- Build summary metrics (accuracy, macro_f1) ----
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
summary_rows = []
y_true = np.array([lbl for _,lbl in test_items])

for model_key, d in preds_map.items():
    preds = np.array(d['preds'])
    acc = float(accuracy_score(y_true, preds))
    macro_f1 = float(f1_score(y_true, preds, average='macro', zero_division=0))
    macro_prec = float(precision_score(y_true, preds, average='macro', zero_division=0))
    macro_rec = float(recall_score(y_true, preds, average='macro', zero_division=0))
    summary_rows.append({'model': model_key, 'accuracy': acc, 'macro_precision': macro_prec, 'macro_recall': macro_rec, 'macro_f1': macro_f1})

summary_df = pd.DataFrame(summary_rows).sort_values('macro_f1', ascending=False).reset_index(drop=True)

# ---- Parameter counts ----
param_rows = []
# Keras params
if 'trained_classifiers' in globals():
    for name, km in trained_classifiers.items():
        try:
            params = km.count_params()
        except Exception:
            # fallback count trainable & non-trainable
            params = sum([np.prod(w.shape) for w in km.weights])
        param_rows.append({'model': f'keras::{name}', 'params': int(params)})
# PyTorch params
if 'pytorch_factories' in globals():
    for name, fac in pytorch_factories.items():
        try:
            m = fac(num_classes)
            pcount = sum([p.numel() for p in m.parameters()])
            param_rows.append({'model': f'torch::{name}', 'params': int(pcount)})
        except Exception as e:
            print("Failed param count for", name, e)

df_params = pd.DataFrame(param_rows)

# merge param info into summary
summary_df = summary_df.merge(df_params, how='left', left_on='model', right_on='model')
summary_df['params_M'] = (summary_df['params'] / 1e6).round(3)

# ---- Inference timing (approx avg ms per image) ----
def measure_inference_time_keras(km, paths, n_samples=50, batch=8):
    # sample n_samples images
    sel = paths[:n_samples]
    # warmup
    _ = None
    t0 = time.time()
    for i in range(0,len(sel), batch):
        batch_p = sel[i:i+batch]; imgs=[]
        for p in batch_p:
            im = Image.open(p).resize(IMG_SIZE); a = np.array(im).astype('float32')/255.0; imgs.append(a)
        xb = np.stack(imgs,0)
        _ = km.predict(xb, verbose=0)
    t1 = time.time()
    total = t1 - t0
    avg_ms = (total / len(sel)) * 1000.0
    return avg_ms

def measure_inference_time_torch(factory, ckpt, paths, n_samples=50, batch=8):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = factory(num_classes).to(device)
    st = torch.load(ckpt, map_location=device)
    try: model.load_state_dict(st)
    except:
        if isinstance(st, dict) and 'state_dict' in st: model.load_state_dict(st['state_dict'])
        else: model.load_state_dict(st, strict=False)
    model.eval()
    val_tf = globals().get('val_tf_pt', None)
    if val_tf is None:
        from torchvision import transforms
        val_tf = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    sel = paths[:n_samples]
    # warmup + measure
    t0 = time.time()
    with torch.no_grad():
        for i in range(0,len(sel), batch):
            batch_p = sel[i:i+batch]; imgs=[]
            for p in batch_p:
                im = Image.open(p).convert('RGB'); t = val_tf(im); imgs.append(t)
            xb = torch.stack(imgs).to(device)
            _ = model(xb)
    t1 = time.time()
    total = t1 - t0
    avg_ms = (total / len(sel)) * 1000.0
    return avg_ms

# Only measure for models where we can (limit to top N to save time)
measure_limit = 6   # number of models to time (set smaller if compute-limited)
timing_rows = []
model_list_for_timing = summary_df['model'].tolist()[:measure_limit]
print("Measuring inference time for (up to) {} models: {}".format(len(model_list_for_timing), model_list_for_timing))
for mk in model_list_for_timing:
    try:
        if mk.startswith('keras::') and 'trained_classifiers' in globals():
            kn = mk.split('::',1)[1]
            km = trained_classifiers.get(kn)
            if km is None:
                timing_rows.append({'model':mk, 'inf_ms':np.nan}); continue
            avg_ms = measure_inference_time_keras(km, [p for p,_ in test_items], n_samples=50, batch=8)
            timing_rows.append({'model':mk, 'inf_ms': round(avg_ms,2)})
        elif mk.startswith('torch::') and 'pytorch_factories' in globals():
            tn = mk.split('::',1)[1]
            fac = pytorch_factories.get(tn)
            ckpt = os.path.join(OUT_ROOT, f'{tn}_best.pth')
            if fac is None or not os.path.exists(ckpt):
                timing_rows.append({'model':mk, 'inf_ms':np.nan}); continue
            avg_ms = measure_inference_time_torch(fac, ckpt, [p for p,_ in test_items], n_samples=50, batch=8)
            timing_rows.append({'model':mk, 'inf_ms': round(avg_ms,2)})
        else:
            timing_rows.append({'model':mk, 'inf_ms':np.nan})
    except Exception as e:
        print("Timing failed for", mk, e)
        timing_rows.append({'model':mk, 'inf_ms':np.nan})

df_time = pd.DataFrame(timing_rows)
summary_df = summary_df.merge(df_time, on='model', how='left')

# ---- Save CSV and show table ----
summary_df = summary_df[['model','accuracy','macro_precision','macro_recall','macro_f1','params','params_M','inf_ms']]
summary_df.to_csv(OUT_CSV, index=False)
print("Saved summary CSV:", OUT_CSV)
display(summary_df.sort_values('macro_f1', ascending=False).reset_index(drop=True))

# ---- Plots: Accuracy, Macro-F1, Params (log) ----
plt.figure(figsize=(max(6, len(summary_df)*0.6),4))
sns.barplot(data=summary_df.sort_values('accuracy', ascending=False), x='model', y='accuracy')
plt.xticks(rotation=45, ha='right'); plt.title('Model test accuracy'); plt.tight_layout()
plt.savefig(os.path.join(PLOTS_DIR, 'summary_accuracy_bar.png')); plt.close()

plt.figure(figsize=(max(6, len(summary_df)*0.6),4))
sns.barplot(data=summary_df.sort_values('macro_f1', ascending=False), x='model', y='macro_f1')
plt.xticks(rotation=45, ha='right'); plt.title('Model macro F1'); plt.tight_layout()
plt.savefig(os.path.join(PLOTS_DIR, 'summary_macrof1_bar.png')); plt.close()

# Params plot (log scale)
plt.figure(figsize=(max(6, len(summary_df)*0.6),4))
sns.barplot(data=summary_df.sort_values('params_M', ascending=False), x='model', y='params_M')
plt.yscale('log')
plt.xticks(rotation=45, ha='right'); plt.title('Model Params (millions, log scale)'); plt.tight_layout()
plt.savefig(os.path.join(PLOTS_DIR, 'summary_params_bar.png')); plt.close()

# Inference time plot (if available)
if 'inf_ms' in summary_df.columns and summary_df['inf_ms'].notna().any():
    plt.figure(figsize=(max(6, len(summary_df)*0.6),4))
    sns.barplot(data=summary_df.sort_values('inf_ms', ascending=True), x='model', y='inf_ms')
    plt.xticks(rotation=45, ha='right'); plt.ylabel('Avg inference (ms/image)'); plt.title('Inference time (approx)'); plt.tight_layout()
    plt.savefig(os.path.join(PLOTS_DIR,'summary_inference_time.png')); plt.close()

print("Saved plots to", PLOTS_DIR)


In [None]:
# Pairwise McNemar tests
names = list(all_preds.keys()); n=len(names)
pvals = np.ones((n,n)); stats = np.zeros((n,n))
for i in range(n):
    for j in range(i+1, n):
        a = all_preds[names[i]]['preds']; b = all_preds[names[j]]['preds']
        a_corr = (a==y_test); b_corr = (b==y_test)
        n01 = int(np.logical_and(a_corr==True, b_corr==False).sum())
        n10 = int(np.logical_and(a_corr==False, b_corr==True).sum())
        table = [[int(np.logical_and(a_corr==True,b_corr==True).sum()), n01],[n10, int(np.logical_and(a_corr==False,b_corr==False).sum())]]
        res = mcnemar(table, exact=False)
        pvals[i,j] = pvals[j,i] = float(res.pvalue)
        stats[i,j] = stats[j,i] = float(res.statistic)
pd.DataFrame(pvals, index=names, columns=names).to_csv(os.path.join(OUT_ROOT,'mcnemar_pvalues.csv'))
pd.DataFrame(stats, index=names, columns=names).to_csv(os.path.join(OUT_ROOT,'mcnemar_stats.csv'))
print("Saved McNemar p-values and stats to", OUT_ROOT)


In [None]:
# -------------------------
# Robust Grad-CAM auto-run for top PyTorch models
# Paste & run this cell (replaces previous Grad-CAM cell)
# -------------------------
import os, math, traceback
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms

# ---------- Robust Grad-CAM function (supports conv outputs and square-token ViT-like outputs) ----------
def save_gradcam_grid_pytorch(factory, ckpt, layer_name, paths, save_path, max_images=12, img_size=(224,224), device='cuda'):
    """
    Compute Grad-CAM overlays for up to max_images and save as a grid.
    factory: callable(num_classes)->model
    ckpt: path to checkpoint .pth
    layer_name: exact module name from model.named_modules() to hook
    paths: list of image file paths
    """
    # load model and weights
    model = factory(num_classes).to(device)
    st = torch.load(ckpt, map_location=device)
    # handle common wrappers
    if isinstance(st, dict) and 'state_dict' in st:
        st = st['state_dict']
    try:
        model.load_state_dict(st)
    except Exception:
        # try non-strict
        try:
            model.load_state_dict(st, strict=False)
        except Exception:
            print("Failed to load checkpoint strictly or loosely for", ckpt)
    model.eval()

    modules = dict(model.named_modules())
    if layer_name not in modules:
        print(f"[GradCAM] Layer '{layer_name}' not found in model modules. Available tail modules: {list(modules.keys())[-40:]}")
        return

    target = modules[layer_name]

    # holders for hooks
    features_holder = {'feat': None}
    grads_holder = {'grad': None}

    def forward_hook(module, inp, out):
        # store first tensor output if tuple/list
        if isinstance(out, torch.Tensor):
            features_holder['feat'] = out.detach()
        elif isinstance(out, (list, tuple)) and len(out)>0 and isinstance(out[0], torch.Tensor):
            features_holder['feat'] = out[0].detach()
        else:
            features_holder['feat'] = None

    # use full backward hook when available to avoid partial-grad warning
    def backward_hook(module, grad_input, grad_output):
        if isinstance(grad_output, torch.Tensor):
            grads_holder['grad'] = grad_output.detach()
        elif isinstance(grad_output, (list, tuple)) and len(grad_output)>0 and isinstance(grad_output[0], torch.Tensor):
            grads_holder['grad'] = grad_output[0].detach()
        else:
            grads_holder['grad'] = None

    # register hooks
    try:
        fh = target.register_forward_hook(forward_hook)
        # prefer full backward hook API
        bh = target.register_full_backward_hook(backward_hook)
    except Exception:
        fh = target.register_forward_hook(forward_hook)
        bh = target.register_backward_hook(backward_hook)

    overlays = []
    tf_img = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

    def compute_cam(feat, grad, target_hw):
        """
        feat, grad: tensors without batch dim (C,H,W) or (C,L) or (C,)
        returns: 2D numpy array resized to target_hw or None
        """
        if feat is None or grad is None:
            return None

        # remove batch if present
        if feat.ndim == 4 and feat.shape[0] == 1: feat = feat[0]
        if grad.ndim == 4 and grad.shape[0] == 1: grad = grad[0]

        # case A: spatial (C,H,W)
        if feat.ndim == 3 and grad.ndim == 3:
            # channel weights: global avg pool of grads
            w = grad.mean(dim=(1,2), keepdim=True)    # (C,1,1)
            cam = (w * feat).sum(dim=0).cpu().numpy() # (H,W)
        # case B: tokens/sequence (C,L) -> try to reshape to square
        elif feat.ndim == 2 and grad.ndim == 2:
            L = feat.shape[1]
            s = int(np.round(np.sqrt(L)))
            if s*s != L:
                return None
            # token importance
            w = grad.mean(dim=1)                        # (C,)
            token_scores = (w.unsqueeze(1) * feat).sum(dim=0).cpu().numpy()  # (L,)
            cam = token_scores.reshape(s, s)
        else:
            return None

        cam = np.maximum(cam, 0)
        if cam.max() <= 1e-9:
            return None
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-9)
        cam_resized = cv2.resize(cam, (target_hw[1], target_hw[0]))  # cv2(width,height)
        return cam_resized

    count = 0
    for p in paths:
        if count >= max_images: break
        try:
            pil = Image.open(p).convert('RGB').resize(img_size)
        except Exception:
            continue

        x = tf_img(pil).unsqueeze(0).to(device)
        x.requires_grad_(True)
        features_holder['feat'] = None; grads_holder['grad'] = None

        out = model(x)
        # predicted class score
        pred = int(out.argmax(dim=1).item())
        score = out[0, pred]
        model.zero_grad()
        try:
            score.backward(retain_graph=True)
        except Exception:
            # sometimes backward fails for certain models; try without retain
            score.backward()

        feat = features_holder.get('feat', None)
        grad = grads_holder.get('grad', None)

        # remove batch dim if present
        if isinstance(feat, torch.Tensor) and feat.ndim == 4 and feat.shape[0] == 1:
            feat_proc = feat[0]
        else:
            feat_proc = feat
        if isinstance(grad, torch.Tensor) and grad.ndim == 4 and grad.shape[0] == 1:
            grad_proc = grad[0]
        else:
            grad_proc = grad

        cam = compute_cam(feat_proc, grad_proc, target_hw=img_size)
        if cam is None:
            # skip and continue
            print(f"[GradCAM] skipping image {p}: unsupported feat/grad shapes -> feat {None if feat is None else tuple(feat.shape)}, grad {None if grad is None else tuple(grad.shape)}")
            continue

        # colorize and overlay
        heat = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heat = cv2.cvtColor(heat, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
        base = np.array(pil).astype(np.float32)/255.0
        overlay = np.clip(0.5 * base + 0.5 * heat, 0, 1)
        overlays.append((overlay*255).astype(np.uint8))
        count += 1

    # remove hooks
    try:
        fh.remove(); bh.remove()
    except Exception:
        pass

    if len(overlays) == 0:
        print("[GradCAM] No overlays created for layer", layer_name)
        return

    cols = 4; rows = math.ceil(len(overlays)/cols)
    grid = Image.new('RGB', (cols*img_size[0], rows*img_size[1]))
    for i, arr in enumerate(overlays):
        r = i // cols; c = i % cols
        grid.paste(Image.fromarray(arr), (c*img_size[0], r*img_size[1]))
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    grid.save(save_path)
    print("[GradCAM] Saved grid to", save_path)


# ---------- Helper to automatically find a spatial module by testing forward pass ----------
def find_spatial_module_name(factory, sample_image_path, num_classes_local=num_classes, device='cuda', img_size=(224,224)):
    """
    Tries modules one-by-one: registers a forward hook, runs single forward and checks the feature shape.
    Returns the first module name whose forward output is a 4D tensor with H>1 and W>1.
    """
    model = factory(num_classes_local).to(device)
    model.eval()
    modules = list(model.named_modules())

    tf_img = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    pil = Image.open(sample_image_path).convert('RGB').resize(img_size)
    x = tf_img(pil).unsqueeze(0).to(device)

    found = None
    for name, module in modules:
        # skip trivial root module
        if name == '':
            continue
        feat_holder = {'out': None}
        def fh(m, inp, out):
            if isinstance(out, torch.Tensor):
                feat_holder['out'] = out
            elif isinstance(out, (list,tuple)) and len(out)>0 and isinstance(out[0], torch.Tensor):
                feat_holder['out'] = out[0]
            else:
                feat_holder['out'] = None
        h = module.register_forward_hook(fh)
        try:
            with torch.no_grad():
                _ = model(x)
            out_t = feat_holder['out']
            if isinstance(out_t, torch.Tensor):
                # remove batch if present
                t = out_t
                if t.ndim == 4 and t.shape[0] == 1:
                    t = t[0]
                if t.ndim == 3:
                    # (C,H,W) -> spatial found
                    _, H, W = t.shape
                    if H > 1 and W > 1:
                        found = name
                        h.remove()
                        break
        except Exception:
            # if forward failed for this module, continue
            pass
        h.remove()
    # clean up
    del model
    torch.cuda.empty_cache()
    return found

# ---------- Run automatic Grad-CAM for top torch models ----------
top3 = summary_df['model'].tolist()[:3]
for m in top3:
    if not m.startswith('torch::'):
        continue
    mn = m.split('::',1)[1]
    if mn not in pytorch_factories:
        print("[GradCAM] no factory for", mn); continue
    fac = pytorch_factories[mn]
    ckpt = os.path.join(OUT_ROOT, f'{mn}_best.pth')
    if not os.path.exists(ckpt):
        print("[GradCAM] checkpoint missing for", mn, ckpt); continue

    print(f"[GradCAM] Trying model {mn} with checkpoint {ckpt}")
    # choose a sample image to probe (first test image)
    sample_img = test_paths[0] if 'test_paths' in globals() and len(test_paths)>0 else test_items[0][0]
    try:
        chosen_layer = find_spatial_module_name(fac, sample_img, num_classes_local=num_classes, device=DEVICE, img_size=IMG_SIZE)
        if chosen_layer is None:
            print(f"[GradCAM] No spatial module automatically found for {mn}; you can choose a conv-like module name manually.")
            continue
        print(f"[GradCAM] Auto-chosen layer for {mn}: {chosen_layer}")
        save_path = os.path.join(PLOTS_DIR, f'{mn}_gradcam_grid.png')
        save_gradcam_grid_pytorch(fac, ckpt, chosen_layer, test_paths, save_path, max_images=12, img_size=IMG_SIZE, device=DEVICE)
    except Exception as e:
        print("[GradCAM] Error for model", mn, e)
        traceback.print_exc()

print("Grad-CAM auto-run complete. Check", PLOTS_DIR)


In [None]:
# ============================================
#   UNIVERSAL GRAD-CAM (CNN + ViT + Swin)
#   - Saves individual images
#   - Saves grid
# ============================================
import cv2, math, os
from PIL import Image
import numpy as np
import torch

def save_gradcam_grid_pytorch(factory, ckpt, layer_name, paths, save_path,
                              max_images=12, img_size=(224,224),
                              device=DEVICE):

    # -----------------------
    # Load model + checkpoint
    # -----------------------
    model = factory(num_classes).to(device)
    st = torch.load(ckpt, map_location=device)

    if isinstance(st, dict) and 'state_dict' in st:
        st = st['state_dict']

    try: model.load_state_dict(st)
    except: model.load_state_dict(st, strict=False)

    model.eval()

    # -----------------------------------
    # Find chosen layer
    # -----------------------------------
    modules = dict(model.named_modules())
    if layer_name not in modules:
        print("Layer", layer_name, "not found. Available:", list(modules.keys())[-30:])
        return

    target = modules[layer_name]

    # holders
    feat_holder = {'feat': None}
    grad_holder = {'grad': None}

    # forward hook
    def fh(m, inp, out):
        if isinstance(out, torch.Tensor):
            feat_holder['feat'] = out.detach()
        elif isinstance(out, (tuple,list)) and len(out)>0 and isinstance(out[0], torch.Tensor):
            feat_holder['feat'] = out[0].detach()

    # backward hook
    def bh(m, gin, gout):
        g = gout[0] if isinstance(gout, (tuple,list)) else gout
        if isinstance(g, torch.Tensor):
            grad_holder['grad'] = g.detach()

    # Try full backward hook (new) or fallback
    try:
        h1 = target.register_forward_hook(fh)
        h2 = target.register_full_backward_hook(bh)
    except:
        h1 = target.register_forward_hook(fh)
        h2 = target.register_backward_hook(bh)

    # -----------------------
    # Preprocessing
    # -----------------------
    tf_img = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    overlays = []
    individual_dir = save_path.replace(".png", "_individual")
    os.makedirs(individual_dir, exist_ok=True)

    # -----------------------
    # Grad-CAM utility
    # -----------------------
    def make_cam(feat, grad, target_hw):

        if feat is None or grad is None:
            return None

        # remove batch
        if feat.ndim == 4 and feat.shape[0] == 1: feat = feat[0]
        if grad.ndim == 4 and grad.shape[0] == 1: grad = grad[0]

        # -------------------------------
        # Case B: ViT-like (L,C) tokens
        # -------------------------------
        if feat.ndim == 3 and feat.shape[0] == 1 and feat.shape[2] > feat.shape[1]:
            # (1,L,C) → remove batch → (L,C)
            feat = feat[0]
            grad = grad[0] if grad is not None and grad.ndim==3 else grad

        # (L,C) → transpose → (C,L)
        if feat.ndim == 2 and feat.shape[1] > feat.shape[0]:
            feat = feat.permute(1,0)   # (C,L)
            if grad is not None:
                grad = grad.permute(1,0)

        # -----------------------------------
        # Case 1: CNN-style (C,H,W)
        # -----------------------------------
        if feat.ndim == 3 and feat.shape[1] > 1 and feat.shape[2] > 1:
            w = grad.mean(dim=(1,2), keepdim=True)
            cam = (w * feat).sum(dim=0).cpu().numpy()

        # -----------------------------------
        # Case 2: Tokens — (C,L)
        # -----------------------------------
        elif feat.ndim == 2:
            C,L = feat.shape
            s = int(np.sqrt(L))
            if s*s != L: return None
            w = grad.mean(dim=1)
            cam = (w.unsqueeze(1) * feat).sum(dim=0).cpu().numpy()
            cam = cam.reshape(s,s)
        else:
            return None

        # normalize
        cam = np.maximum(cam,0)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-9)
        cam = cv2.resize(cam, (target_hw[1], target_hw[0]))

        return cam

    # -----------------------
    # Process each test image
    # -----------------------
    idx = 0
    for p in paths[:max_images]:
        try:
            pil = Image.open(p).convert("RGB").resize(img_size)
        except:
            continue

        x = tf_img(pil).unsqueeze(0).to(device)
        x.requires_grad_(True)

        feat_holder['feat'] = None
        grad_holder['grad'] = None

        out = model(x)
        pred = int(out.argmax(1))
        score = out[0,pred]

        model.zero_grad()
        try: score.backward(retain_graph=True)
        except: score.backward()

        feat = feat_holder['feat']
        grad = grad_holder['grad']

        cam = make_cam(feat, grad, img_size)
        if cam is None:
            print("⚠️ CAM failed for:", p)
            continue

        # make overlay
        heat = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
        heat = cv2.cvtColor(heat, cv2.COLOR_BGR2RGB) / 255.0
        base = np.array(pil) / 255.0
        overlay = (0.5*base + 0.5*heat)
        overlay = np.clip(overlay*255,0,255).astype(np.uint8)

        overlays.append(overlay)

        # save individual
        Image.fromarray(overlay).save(os.path.join(individual_dir, f"img_{idx+1:02d}.png"))
        idx += 1

    # remove hooks
    try:
        h1.remove(); h2.remove()
    except:
        pass

    if len(overlays)==0:
        print("❌ No Grad-CAM images created for:", layer_name)
        return

    # -----------------------
    # Make grid
    # -----------------------
    cols = 4
    rows = math.ceil(len(overlays)/cols)
    grid = Image.new("RGB", (cols*img_size[0], rows*img_size[1]))

    for i,arr in enumerate(overlays):
        r = i//cols
        c = i%cols
        grid.paste(Image.fromarray(arr), (c*img_size[0], r*img_size[1]))

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    grid.save(save_path)
    print("✅ Saved Grad-CAM grid:", save_path)
    print("📂 Individual images saved to:", individual_dir)


In [None]:
# 🔵 NEW: Replacement Grad-CAM function that RETURNS overlays, paths, preds
# Paste this cell into your notebook to replace the old save_gradcam_grid_pytorch OR keep it as a new function.

import cv2, math, os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms

def save_gradcam_grid_pytorch_return(factory, ckpt, layer_name, paths, save_path,
                                     max_images=12, img_size=(224,224),
                                     device=DEVICE, num_classes=num_classes, idx_to_class=idx_to_class):
    """
    Similar to your original save_gradcam_grid_pytorch but returns:
       overlays -> list of numpy arrays (H,W,3) uint8
       paths_used -> list of source image paths
       preds -> list of predicted class indices (ints)
    Also saves individual images and a grid as before.
    """
    # -----------------------
    # Load model + checkpoint
    # -----------------------
    model = factory(num_classes).to(device)
    st = torch.load(ckpt, map_location=device)

    if isinstance(st, dict) and 'state_dict' in st:
        st = st['state_dict']

    try: model.load_state_dict(st)
    except: model.load_state_dict(st, strict=False)

    model.eval()

    # -----------------------------------
    # Find chosen layer
    # -----------------------------------
    modules = dict(model.named_modules())
    if layer_name not in modules:
        print("Layer", layer_name, "not found. Available (last 30):", list(modules.keys())[-30:])
        return [], [], []

    target = modules[layer_name]

    # holders
    feat_holder = {'feat': None}
    grad_holder = {'grad': None}

    # forward hook
    def fh(m, inp, out):
        if isinstance(out, torch.Tensor):
            feat_holder['feat'] = out.detach()
        elif isinstance(out, (tuple,list)) and len(out)>0 and isinstance(out[0], torch.Tensor):
            feat_holder['feat'] = out[0].detach()

    # backward hook
    def bh(m, gin, gout):
        g = gout[0] if isinstance(gout, (tuple,list)) else gout
        if isinstance(g, torch.Tensor):
            grad_holder['grad'] = g.detach()

    # Try full backward hook (new) or fallback
    try:
        h1 = target.register_forward_hook(fh)
        h2 = target.register_full_backward_hook(bh)
    except:
        h1 = target.register_forward_hook(fh)
        h2 = target.register_backward_hook(bh)

    # -----------------------
    # Preprocessing
    # -----------------------
    tf_img = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    overlays = []
    preds = []
    paths_used = []
    individual_dir = save_path.replace(".png", "_individual")
    os.makedirs(individual_dir, exist_ok=True)

    # -----------------------
    # Grad-CAM utility
    # -----------------------
    def make_cam(feat, grad, target_hw):

        if feat is None or grad is None:
            return None

        # remove batch
        if feat.ndim == 4 and feat.shape[0] == 1: feat = feat[0]
        if grad is not None and grad.ndim == 4 and grad.shape[0] == 1: grad = grad[0]

        # -------------------------------
        # Case B: ViT-like (L,C) tokens
        # -------------------------------
        if feat.ndim == 3 and feat.shape[0] == 1 and feat.shape[2] > feat.shape[1]:
            # (1,L,C) → remove batch → (L,C)
            feat = feat[0]
            grad = grad[0] if grad is not None and grad.ndim==3 else grad

        # (L,C) → transpose → (C,L)
        if feat.ndim == 2 and feat.shape[1] > feat.shape[0]:
            feat = feat.permute(1,0)   # (C,L)
            if grad is not None:
                grad = grad.permute(1,0)

        # -----------------------------------
        # Case 1: CNN-style (C,H,W)
        # -----------------------------------
        if feat.ndim == 3 and feat.shape[1] > 1 and feat.shape[2] > 1:
            w = grad.mean(dim=(1,2), keepdim=True)
            cam = (w * feat).sum(dim=0).cpu().numpy()

        # -----------------------------------
        # Case 2: Tokens — (C,L)
        # -----------------------------------
        elif feat.ndim == 2:
            C,L = feat.shape
            s = int(np.sqrt(L))
            if s*s != L: return None
            w = grad.mean(dim=1)
            cam = (w.unsqueeze(1) * feat).sum(dim=0).cpu().numpy()
            cam = cam.reshape(s,s)
        else:
            return None

        # normalize
        cam = np.maximum(cam,0)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-9)
        cam = cv2.resize(cam, (target_hw[1], target_hw[0]))

        return cam

    # -----------------------
    # Process each test image
    # -----------------------
    idx = 0
    for p in paths[:max_images]:
        try:
            pil = Image.open(p).convert("RGB").resize(img_size)
        except Exception as e:
            print("Skipping", p, " — read error:", e)
            continue

        x = tf_img(pil).unsqueeze(0).to(device)
        x.requires_grad_(True)

        feat_holder['feat'] = None
        grad_holder['grad'] = None

        out = model(x)
        pred = int(out.argmax(1))
        score = out[0,pred]

        model.zero_grad()
        try: score.backward(retain_graph=True)
        except: score.backward()

        feat = feat_holder['feat']
        grad = grad_holder['grad']

        cam = make_cam(feat, grad, img_size)
        if cam is None:
            print("⚠️ CAM failed for:", p)
            continue

        # make overlay
        heat = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
        heat = cv2.cvtColor(heat, cv2.COLOR_BGR2RGB) / 255.0
        base = np.array(pil) / 255.0
        overlay = (0.5*base + 0.5*heat)
        overlay = np.clip(overlay*255,0,255).astype(np.uint8)

        overlays.append(overlay)
        preds.append(pred)
        paths_used.append(p)

        # save individual
        Image.fromarray(overlay).save(os.path.join(individual_dir, f"img_{idx+1:02d}.png"))
        idx += 1

    # remove hooks
    try:
        h1.remove(); h2.remove()
    except:
        pass

    if len(overlays)==0:
        print("❌ No Grad-CAM images created for:", layer_name)
        return [], [], []

    # -----------------------
    # Make grid
    # -----------------------
    cols = 4
    rows = math.ceil(len(overlays)/cols)
    grid = Image.new("RGB", (cols*img_size[0], rows*img_size[1]))

    for i,arr in enumerate(overlays):
        r = i//cols
        c = i%cols
        grid.paste(Image.fromarray(arr), (c*img_size[0], r*img_size[1]))

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    grid.save(save_path)
    print("✅ Saved Grad-CAM grid:", save_path)
    print("📂 Individual images saved to:", individual_dir)

    # 🔵 NEW: return overlays, source paths and preds
    return overlays, paths_used, preds


# 🔵 NEW: Interactive slider viewer that shows filename + predicted class
def interactive_overlay_slider_with_labels(overlays, paths, preds, idx_to_class_map=None, start=0, figsize=(6,6)):
    """
    overlays: list of numpy arrays (H,W,3) uint8
    paths: list of corresponding source file paths
    preds: list of predicted class indices
    idx_to_class_map: dict mapping index -> class name (optional)
    """
    try:
        from ipywidgets import widgets, interact
        import matplotlib.pyplot as plt
        import numpy as np
    except Exception as e:
        print("ipywidgets not available. Install with: !pip install ipywidgets")
        print("Error:", e)
        return

    if len(overlays) == 0:
        print("No overlays to display.")
        return

    max_index = len(overlays) - 1

    def _show(i=0):
        plt.figure(figsize=figsize)
        img = overlays[int(i)]
        if img.dtype != np.uint8:
            img_show = (img * 255).astype('uint8')
        else:
            img_show = img
        title_parts = [f"{int(i)+1}/{len(overlays)}"]
        # filename
        try:
            fname = os.path.basename(paths[int(i)])
            title_parts.append(fname)
        except:
            pass
        # predicted label
        try:
            lab = idx_to_class_map[preds[int(i)]] if idx_to_class_map is not None else preds[int(i)]
            title_parts.append(f"pred: {lab}")
        except:
            title_parts.append(f"pred: {preds[int(i)]}")
        plt.imshow(img_show)
        plt.axis('off')
        plt.title("  |  ".join(title_parts))
        plt.show()

    interact(_show,
             i=widgets.IntSlider(min=0, max=max_index, step=1, value=start, description='Index'))


In [None]:
# # ============================================================
# # FEW-SHOT PLOTTING & EXPORT (READY TO PASTE)
# # Paste this right after the few-shot cell that produced accs_5shot & accs_1shot
# # ============================================================

# ### >>> NEW
# import os
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
# import numpy as np

# ### >>> UPDATED
# # use existing PLOTS_DIR and OUT_ROOT variables from your notebook
# PLOTS_DIR = globals().get('PLOTS_DIR', os.path.join(OUT_ROOT, 'plots'))
# os.makedirs(PLOTS_DIR, exist_ok=True)

# ### >>> NEW
# # Ensure accs_5shot and accs_1shot exist
# if 'accs_5shot' not in globals() or 'accs_1shot' not in globals():
#     raise RuntimeError("accs_5shot or accs_1shot not found. Run the few-shot cell first.")

# # Stats
# accs5 = np.array(accs_5shot)
# accs1 = np.array(accs_1shot)

# summary = {
#     'setting': ['5-way-5-shot', '5-way-1-shot'],
#     'mean_acc': [accs5.mean(), accs1.mean()],
#     'std_acc':  [accs5.std(ddof=1), accs1.std(ddof=1)],
#     'median':   [np.median(accs5), np.median(accs1)],
#     'n_episodes':[len(accs5), len(accs1)]
# }
# df_summary = pd.DataFrame(summary)
# ### >>> NEW
# # Save summary CSV
# csv_out = os.path.join(OUT_ROOT, 'fewshot_summary.csv')
# df_summary.to_csv(csv_out, index=False)
# print("Saved few-shot summary CSV ->", csv_out)

# # ---------- Plot 1: Histogram + KDE overlay ----------
# plt.figure(figsize=(8,5))
# sns.histplot(accs5, label='5-shot', stat='density', kde=True, alpha=0.6)
# sns.histplot(accs1, label='1-shot', stat='density', kde=True, alpha=0.6)
# plt.xlabel('Episode Accuracy')
# plt.title('Few-Shot Episode Accuracy Distribution')
# plt.legend()
# hist_path = os.path.join(PLOTS_DIR, 'fewshot_accuracy_histogram.png')
# plt.tight_layout(); plt.savefig(hist_path); plt.close()
# print("Saved histogram ->", hist_path)

# # ---------- Plot 2: Boxplot comparison ----------
# plt.figure(figsize=(6,5))
# sns.boxplot(data=[accs5, accs1])
# plt.xticks([0,1], ['5-way-5-shot', '5-way-1-shot'])
# plt.ylabel('Episode Accuracy')
# plt.title('Few-Shot Accuracy Boxplot')
# box_path = os.path.join(PLOTS_DIR, 'fewshot_accuracy_boxplot.png')
# plt.tight_layout(); plt.savefig(box_path); plt.close()
# print("Saved boxplot ->", box_path)

# # ---------- Plot 3: Mean ± STD bar plot ----------
# plt.figure(figsize=(6,4))
# means = [accs5.mean()*100, accs1.mean()*100]
# stds  = [accs5.std(ddof=1)*100, accs1.std(ddof=1)*100]
# sns.barplot(x=['5-shot','1-shot'], y=means, yerr=stds, capsize=0.15)
# plt.ylabel('Mean Accuracy (%)')
# plt.title('Few-Shot Mean Accuracy ± Std')
# for i, v in enumerate(means):
#     plt.text(i, v + stds[i] + 0.5, f"{v:.2f}% ± {stds[i]:.2f}%", ha='center')
# bar_path = os.path.join(PLOTS_DIR, 'fewshot_mean_std_bar.png')
# plt.tight_layout(); plt.savefig(bar_path); plt.close()
# print("Saved mean±std bar ->", bar_path)

# # ---------- Plot 4: Episode-level scatter (optional insight) ----------
# plt.figure(figsize=(10,4))
# plt.plot(np.arange(len(accs5)), np.sort(accs5)[::-1], label='5-shot (sorted)', alpha=0.8)
# plt.plot(np.arange(len(accs1)), np.sort(accs1)[::-1], label='1-shot (sorted)', alpha=0.8)
# plt.xlabel('Episode index (sorted)')
# plt.ylabel('Accuracy')
# plt.title('Sorted Episode Accuracies (descending)')
# plt.legend()
# scatter_path = os.path.join(PLOTS_DIR, 'fewshot_sorted_episodes.png')
# plt.tight_layout(); plt.savefig(scatter_path); plt.close()
# print("Saved sorted-episodes plot ->", scatter_path)

# # ---------- Print quick summary to notebook output ----------
# print("\nFew-shot summary (printed):")
# print(df_summary.to_string(index=False))

# ### >>> NEW
# # Optional: add results to your combined summary CSV if it exists
# combined_csv = os.path.join(OUT_ROOT, 'combined_summary.csv')
# if os.path.exists(combined_csv):
#     try:
#         combined = pd.read_csv(combined_csv)
#         fewshot_row = {
#             'model': 'fewshot::swin_proto' if BACKBONE_NAME.startswith('swin') else f'fewshot::{BACKBONE_NAME}_proto',
#             'accuracy': df_summary.loc[df_summary['setting']=='5-way-5-shot','mean_acc'].values[0],
#             'macro_precision': np.nan, 'macro_recall': np.nan, 'macro_f1': np.nan
#         }
#         combined = pd.concat([combined, pd.DataFrame([fewshot_row])], ignore_index=True)
#         combined.to_csv(os.path.join(OUT_ROOT,'combined_summary_with_fewshot.csv'), index=False)
#         print("Appended few-shot summary to combined_summary -> combined_summary_with_fewshot.csv")
#     except Exception as e:
#         print("Could not append to combined_summary:", e)

# print("\nAll few-shot plots saved in:", PLOTS_DIR)


In [None]:
# # ============================================================
# #  FEW-SHOT PROTOTYPE NETWORK (READY TO PASTE)
# #  - integrates with your existing PyTorch pipeline
# #  - few-shot 5-way-5-shot & 5-way-1-shot classification
# #  - uses your existing timm backbones (Swin/DeiT/CoAtNet)
# # ============================================================

# ### >>> NEW
# import random, math
# import numpy as np
# import torch, torch.nn as nn, torch.optim as optim
# from torchvision import transforms
# from PIL import Image
# from collections import defaultdict

# device = DEVICE  # uses your existing device

# ### >>> NEW
# # ------------------------------------------------------------
# # 1. BUILD EMBEDDING EXTRACTOR (BACKBONE WITHOUT CLASSIFIER)
# # ------------------------------------------------------------
# def make_embedding_extractor(factory, embed_dim=512):
#     """
#     Converts any timm model into an embedding extractor by removing the classifier.
#     """
#     model = factory(num_classes)  # your factory already expects num_classes
#     model.reset_classifier(0) if hasattr(model, "reset_classifier") else None

#     # Projection head to get fixed-dim embeddings
#     feat_dim = getattr(model, "num_features", embed_dim)
#     proj = nn.Sequential(
#         nn.AdaptiveAvgPool2d(1),
#         nn.Flatten(),
#         nn.Linear(feat_dim, embed_dim)
#     )

#     model = model.to(device)
#     proj = proj.to(device)

#     # Freeze backbone, train projection only
#     for p in model.parameters():
#         p.requires_grad = False
#     for p in proj.parameters():
#         p.requires_grad = True

#     return model, proj


# ### >>> NEW
# # Choose backbone (you can switch between swin_tiny, deit_small, coatnet)
# BACKBONE_NAME = "swin_tiny"   # CHANGE if you want deit_small_distilled / coatnet_0
# backbone_factory = pytorch_factories[BACKBONE_NAME]

# backbone, proj_head = make_embedding_extractor(backbone_factory, embed_dim=512)


# # ------------------------------------------------------------
# # 2. FEW-SHOT EPISODE SAMPLER
# # ------------------------------------------------------------
# ### >>> NEW
# class FewShotDataset:
#     def __init__(self, items, img_size=(224,224)):
#         self.items = items
#         self.img_size = img_size

#         self.by_class = defaultdict(list)
#         for p, lbl in items:
#             self.by_class[lbl].append(p)

#         self.classes = sorted(self.by_class.keys())

#         self.transform = transforms.Compose([
#             transforms.Resize(img_size),
#             transforms.RandomResizedCrop(img_size[0], scale=(0.7,1.0)),
#             transforms.RandomHorizontalFlip(),
#             transforms.ColorJitter(0.3,0.3,0.2,0.05),
#             transforms.ToTensor(),
#             transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
#         ])

#     # sample N-way/K-shot episode
#     def sample_episode(self, N_way=5, K_shot=5, Q_query=5):
#         chosen = random.sample(self.classes, N_way)
#         support, support_labels = [], []
#         query, query_labels = [], []

#         for i, cls in enumerate(chosen):
#             imgs = random.sample(self.by_class[cls], K_shot + Q_query)
#             support.extend(imgs[:K_shot])
#             query.extend(imgs[K_shot:])
#             support_labels.extend([i]*K_shot)
#             query_labels.extend([i]*Q_query)

#         return support, support_labels, query, query_labels


# fs_train = FewShotDataset(train_items, img_size=IMG_SIZE)
# fs_val   = FewShotDataset(val_items, img_size=IMG_SIZE)


# # ------------------------------------------------------------
# # 3. EMBEDDING FUNCTION
# # ------------------------------------------------------------
# ### >>> NEW
# def embed_batch(backbone, proj, paths):
#     backbone.eval(); proj.eval()
#     embeds = []
#     tf = transforms.Compose([
#         transforms.Resize(IMG_SIZE),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
#     ])

#     with torch.no_grad():
#         for p in paths:
#             im = Image.open(p).convert("RGB")
#             x = tf(im).unsqueeze(0).to(device)

#             feats = backbone.forward_features(x) if hasattr(backbone,"forward_features") else backbone(x)
#             embedding = proj(feats)
#             embedding = nn.functional.normalize(embedding, dim=1)

#             embeds.append(embedding.cpu().numpy())

#     return np.vstack(embeds)


# # ------------------------------------------------------------
# # 4. PROTOTYPE CREATION
# # ------------------------------------------------------------
# ### >>> NEW
# def compute_prototypes(support_emb, support_labels, N_way):
#     D = support_emb.shape[1]
#     prototypes = np.zeros((N_way, D), dtype=np.float32)

#     for c in range(N_way):
#         idxs = [i for i,l in enumerate(support_labels) if l == c]
#         prototypes[c] = support_emb[idxs].mean(axis=0)

#     # normalize
#     prototypes /= (np.linalg.norm(prototypes, axis=1, keepdims=True) + 1e-9)
#     return prototypes


# # ------------------------------------------------------------
# # 5. EVALUATE SINGLE EPISODE
# # ------------------------------------------------------------
# ### >>> NEW
# def evaluate_episode(backbone, proj, fs_dataset, N_way=5, K_shot=5, Q_query=5):
#     support, s_lbl, query, q_lbl = fs_dataset.sample_episode(N_way, K_shot, Q_query)

#     s_emb = embed_batch(backbone, proj, support)
#     q_emb = embed_batch(backbone, proj, query)
#     prototypes = compute_prototypes(s_emb, s_lbl, N_way)

#     # distances
#     logits = q_emb @ prototypes.T
#     preds = np.argmax(logits, axis=1)
#     acc = (preds == np.array(q_lbl)).mean()

#     return acc


# # ------------------------------------------------------------
# # 6. FULL FEW-SHOT EVALUATION (RUN THIS CELL)
# # ------------------------------------------------------------
# ### >>> NEW
# print("\n==================================================")
# print(" FEW-SHOT TESTING (5-way-5-shot and 5-way-1-shot)")
# print("==================================================\n")

# episodes = 200  # you can increase to 1000

# # 5-way-5-shot
# accs_5shot = []
# for _ in range(episodes):
#     accs_5shot.append(evaluate_episode(backbone, proj_head, fs_val, N_way=5, K_shot=5, Q_query=10))

# # 5-way-1-shot
# accs_1shot = []
# for _ in range(episodes):
#     accs_1shot.append(evaluate_episode(backbone, proj_head, fs_val, N_way=5, K_shot=1, Q_query=10))


# print(f"5-way 5-shot accuracy: {np.mean(accs_5shot)*100:.2f}%")
# print(f"5-way 1-shot accuracy: {np.mean(accs_1shot)*100:.2f}%")


In [None]:
# # ============================================================
# # HYBRID FUSION FEW-SHOT PROTOTYPE NETWORK
# # - Fuse embeddings from Swin, DeiT, CoAtNet
# # - Evaluate 5-way 5-shot and 5-way 1-shot
# # - Produces plots and appends to combined_summary if present
# # Paste this cell AFTER your existing FEW-SHOT block
# # ============================================================

# import numpy as np, os
# from tqdm import tqdm

# # --- 1) Build embedding extractors (frozen backbones + small proj heads) ---
# # use embed_dim 256 per backbone to keep fused dim manageable
# EMBED_DIM = 256

# # create three independent embedding extractors (backbone + proj)
# backbone_swin, proj_swin = make_embedding_extractor(pytorch_factories['swin_tiny'], embed_dim=EMBED_DIM)
# backbone_deit, proj_deit = make_embedding_extractor(pytorch_factories['deit_small_distilled'], embed_dim=EMBED_DIM)
# backbone_coat, proj_coat = make_embedding_extractor(pytorch_factories['coatnet_0'], embed_dim=EMBED_DIM)

# # Put them in eval mode (proj heads are trainable by default in your extractor but here we assume evaluation-only)
# backbone_swin.eval(); proj_swin.eval()
# backbone_deit.eval(); proj_deit.eval()
# backbone_coat.eval(); proj_coat.eval()

# # --- 2) fused-embedding helper that calls your existing embed_batch() ---
# def embed_fusion(paths):
#     """
#     Given list of image file paths -> returns L2-normalized fused embeddings (N x (3*EMBED_DIM))
#     Uses embed_batch(backbone, proj, paths) already defined in notebook.
#     """
#     # compute each embedding (these return numpy arrays N x EMBED_DIM)
#     emb_s = embed_batch(backbone_swin, proj_swin, paths)
#     emb_d = embed_batch(backbone_deit, proj_deit, paths)
#     emb_c = embed_batch(backbone_coat, proj_coat, paths)

#     # concat
#     fused = np.concatenate([emb_s, emb_d, emb_c], axis=1)
#     # L2 normalize per-sample
#     fused = fused / (np.linalg.norm(fused, axis=1, keepdims=True) + 1e-9)
#     return fused

# # --- 3) reuse compute_prototypes (already in notebook) ---
# # compute_prototypes(support_emb, support_labels, N_way) -> prototypes normalized

# # --- 4) evaluate a single fusion episode ---
# def evaluate_episode_fusion(fs_dataset, N_way=5, K_shot=5, Q_query=10):
#     support, s_lbl, query, q_lbl = fs_dataset.sample_episode(N_way=N_way, K_shot=K_shot, Q_query=Q_query)
#     s_emb = embed_fusion(support)
#     q_emb = embed_fusion(query)

#     prototypes = compute_prototypes(s_emb, s_lbl, N_way)
#     logits = q_emb @ prototypes.T   # cosine-like scores because embeddings normalized
#     preds = np.argmax(logits, axis=1)
#     acc = (preds == np.array(q_lbl)).mean()
#     return acc

# # --- 5) Run evaluation (episodes) ---
# episodes = 200   # increase to 1000 for final runs if you have time

# print("\nRunning Hybrid-Fusion few-shot evaluation (this may take a while)...")
# accs_fusion_5shot = []
# for _ in tqdm(range(episodes), desc='fusion 5-shot'):
#     accs_fusion_5shot.append(evaluate_episode_fusion(fs_val, N_way=5, K_shot=5, Q_query=10))

# accs_fusion_1shot = []
# for _ in tqdm(range(episodes), desc='fusion 1-shot'):
#     accs_fusion_1shot.append(evaluate_episode_fusion(fs_val, N_way=5, K_shot=1, Q_query=10))

# # --- 6) Summarize & save results + plots (integrates with your PLOTS_DIR + OUT_ROOT) ---
# import matplotlib.pyplot as plt
# import seaborn as sns
# import pandas as pd

# PLOTS_DIR = globals().get('PLOTS_DIR', os.path.join(OUT_ROOT, 'plots'))
# os.makedirs(PLOTS_DIR, exist_ok=True)

# acc5 = np.array(accs_fusion_5shot)
# acc1 = np.array(accs_fusion_1shot)

# summary = {
#     'setting': ['5-way-5-shot', '5-way-1-shot'],
#     'mean_acc': [acc5.mean(), acc1.mean()],
#     'std_acc':  [acc5.std(ddof=1), acc1.std(ddof=1)],
#     'median':   [np.median(acc5), np.median(acc1)],
#     'n_episodes':[len(acc5), len(acc1)]
# }
# df_summary = pd.DataFrame(summary)
# csv_out = os.path.join(OUT_ROOT, 'fewshot_hybrid_fusion_summary.csv')
# df_summary.to_csv(csv_out, index=False)
# print("Saved fusion summary CSV ->", csv_out)

# # Histogram
# plt.figure(figsize=(8,5))
# sns.histplot(acc5, label='fusion 5-shot', stat='density', kde=True, alpha=0.6)
# sns.histplot(acc1, label='fusion 1-shot', stat='density', kde=True, alpha=0.6)
# plt.xlabel('Episode Accuracy'); plt.title('Fusion Few-Shot Episode Accuracy')
# plt.legend()
# hist_path = os.path.join(PLOTS_DIR, 'fusion_fewshot_accuracy_histogram.png')
# plt.tight_layout(); plt.savefig(hist_path); plt.close()
# print("Saved histogram ->", hist_path)

# # Boxplot
# plt.figure(figsize=(6,5))
# sns.boxplot(data=[acc5, acc1])
# plt.xticks([0,1], ['5-way-5-shot', '5-way-1-shot'])
# plt.ylabel('Episode Accuracy')
# plt.title('Fusion Few-Shot Accuracy Boxplot')
# box_path = os.path.join(PLOTS_DIR, 'fusion_fewshot_boxplot.png')
# plt.tight_layout(); plt.savefig(box_path); plt.close()
# print("Saved boxplot ->", box_path)

# # Mean ± std bar
# plt.figure(figsize=(6,4))
# means = [acc5.mean()*100, acc1.mean()*100]
# stds  = [acc5.std(ddof=1)*100, acc1.std(ddof=1)*100]
# sns.barplot(x=['5-shot','1-shot'], y=means, yerr=stds, capsize=0.15)
# plt.ylabel('Mean Accuracy (%)'); plt.title('Fusion Few-Shot Mean Accuracy ± Std')
# for i, v in enumerate(means):
#     plt.text(i, v + stds[i] + 0.5, f"{v:.2f}% ± {stds[i]:.2f}%", ha='center')
# bar_path = os.path.join(PLOTS_DIR, 'fusion_fewshot_mean_std_bar.png')
# plt.tight_layout(); plt.savefig(bar_path); plt.close()
# print("Saved mean±std bar ->", bar_path)

# # Print summary
# print("\nHybrid Fusion few-shot summary:")
# print(df_summary.to_string(index=False))

# # --- 7) Append to combined_summary.csv if present (safe) ---
# combined_csv = os.path.join(OUT_ROOT, 'combined_summary.csv')
# try:
#     if os.path.exists(combined_csv):
#         combined = pd.read_csv(combined_csv)
#         fewshot_row = {
#             'model': 'fewshot::hybrid_fusion',
#             'accuracy': float(df_summary.loc[df_summary['setting']=='5-way-5-shot','mean_acc'].values[0]),
#             'macro_precision': np.nan, 'macro_recall': np.nan, 'macro_f1': np.nan
#         }
#         combined = pd.concat([combined, pd.DataFrame([fewshot_row])], ignore_index=True)
#         combined.to_csv(os.path.join(OUT_ROOT,'combined_summary_with_fewshot_hybrid.csv'), index=False)
#         print("Appended fusion summary to combined_summary -> combined_summary_with_fewshot_hybrid.csv")
# except Exception as e:
#     print("Could not append to combined_summary:", e)

# print("\nFusion few-shot artifacts saved in:", PLOTS_DIR)
