In [25]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        (os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [26]:
# run only if timm missing
!pip install timm --quiet


In [27]:
# Cell 1
import os
from pathlib import Path
import random
import math
import numpy as np
from PIL import Image
from collections import Counter
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_curve, auc, precision_recall_curve
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

# CONFIG - edit these
DATA_DIR = "/kaggle/input/breakhis-400x"   # <-- change if necessary
OUT_DIR = Path("/kaggle/working/output")
OUT_DIR.mkdir(parents=True, exist_ok=True)
IMG_SIZE = 224
BATCH_SIZE = 8           # reduce if OOM
EPOCHS = 12
LR = 1e-4
RANDOM_SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Multi-scale + adaptive selection
SCALES = [1.0, 0.5]      # scales relative to IMG_SIZE; reduce if OOM
TOP_K = 16               # select top-K tokens per image
GRU_HIDDEN = 512
GRU_LAYERS = 1
BIDIRECTIONAL = False

# Mix/Cut params
MIXPROB = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# Progressive freezing schedule (epochs)
FREEZE_STAGE_1 = 2   # freeze backbone initially
FREEZE_STAGE_2 = 5   # partially unfreeze later

SEED = RANDOM_SEED
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print("Device:", DEVICE)


Device: cuda


In [28]:
# Cell 2
VALID_EXTS = {'.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'}

def safe_load_image(path, size=IMG_SIZE):
    try:
        img = Image.open(path).convert("RGB")
        img = img.resize((size, size), Image.BILINEAR)
        arr = np.array(img)
        if arr.dtype != np.uint8:
            arr = arr.astype(np.uint8)
        if arr.ndim != 3 or arr.shape[2] != 3:
            arr = np.zeros((size, size, 3), dtype=np.uint8)
        return arr
    except Exception:
        return np.zeros((size, size, 3), dtype=np.uint8)

data_root = Path(DATA_DIR)
assert data_root.exists(), f"{DATA_DIR} not found"

classes = [d.name for d in sorted(data_root.iterdir()) if d.is_dir()]
print("Detected classes:", classes)

filepaths = []
labels = []
for i, cls in enumerate(classes):
    folder = data_root / cls
    for p in sorted(folder.rglob("*")):
        if p.is_file() and p.suffix.lower() in VALID_EXTS:
            filepaths.append(str(p))
            labels.append(i)
print("Total samples:", len(filepaths))


Detected classes: ['test', 'train', 'valid']
Total samples: 1693


In [29]:
# Cell 3
trainval_paths, test_paths, trainval_labels, test_labels = train_test_split(
    filepaths, labels, test_size=0.2, random_state=SEED, stratify=labels
)
train_paths, val_paths, train_labels, val_labels = train_test_split(
    trainval_paths, trainval_labels, test_size=0.125, random_state=SEED, stratify=trainval_labels
)  # ~10% val overall

print("train:", len(train_paths), "val:", len(val_paths), "test:", len(test_paths))


train: 1184 val: 170 test: 339


In [30]:
# Cell 4
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

class ImageFileDataset(Dataset):
    def __init__(self, paths, labels, transform=None, img_size=IMG_SIZE):
        self.paths = paths
        self.labels = labels
        self.transform = transform
        self.img_size = img_size
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        arr = safe_load_image(p, size=self.img_size)
        img = Image.fromarray(arr)
        if self.transform:
            img = self.transform(img)
        label = int(self.labels[idx])
        return img, label

train_ds = ImageFileDataset(train_paths, train_labels, transform=train_transform)
val_ds = ImageFileDataset(val_paths, val_labels, transform=eval_transform)
test_ds = ImageFileDataset(test_paths, test_labels, transform=eval_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)


In [31]:
# Cell 5
def rand_bbox(W, H, lam):
    cut_rat = math.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(x.size(0)).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=CUTMIX_ALPHA):
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(x.size(0)).to(x.device)
    B, C, H, W = x.size()
    bbx1, bby1, bbx2, bby2 = rand_bbox(W, H, lam)
    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

def apply_mix_augment(images, labels):
    r = random.random()
    if r < MIXPROB/2:
        return mixup_data(images, labels, alpha=MIXUP_ALPHA)
    elif r < MIXPROB:
        return cutmix_data(images, labels, alpha=CUTMIX_ALPHA)
    else:
        return images, labels, None, 1.0


In [32]:
# Cell 6
class BalancedFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets, class_weights=None):
        # logits: [B, C]; targets: [B]
        logp = F.log_softmax(logits, dim=1)
        p = torch.exp(logp)
        # cross entropy per-sample
        if class_weights is not None:
            ce = F.nll_loss(logp, targets, reduction='none', weight=class_weights)
        else:
            ce = F.nll_loss(logp, targets, reduction='none')
        pt = p.gather(1, targets.unsqueeze(1)).squeeze(1)
        focal = (1 - pt) ** self.gamma
        loss = focal * ce
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


In [33]:
# Cell 7 (FIXED)
class MultiScaleHybridModel(nn.Module):
    def __init__(self, swin_name='swin_tiny_patch4_window7_224', pretrained=True,
                 scales=SCALES, top_k=TOP_K, gru_hidden=GRU_HIDDEN, gru_layers=GRU_LAYERS,
                 bidirectional=BIDIRECTIONAL, num_classes=None, device=DEVICE):
        super().__init__()
        self.scales = scales
        self.top_k = top_k
        self.device = device
        self.gru_hidden = gru_hidden
        self.num_directions = 2 if bidirectional else 1

        # Backbone (feature map)
        self.backbone_feats = timm.create_model(
            swin_name, pretrained=pretrained, features_only=True, out_indices=[-1]
        ).to(device)

        # Backbone (global pooled)
        self.backbone_pool = timm.create_model(
            swin_name, pretrained=pretrained, num_classes=0, global_pool="avg"
        ).to(device)

        # Detect feature map channels
        dummy = torch.zeros(1, 3, IMG_SIZE, IMG_SIZE).to(device)
        with torch.no_grad():
            feat_out = self.backbone_feats(dummy)
            feat_out = feat_out[0] if isinstance(feat_out, list) else feat_out
            self.feat_c = feat_out.shape[1]   # feature map channels

            pooled_out = self.backbone_pool(dummy)
            if pooled_out.dim() == 4:
                pooled_out = pooled_out.mean([2, 3])
            self.pooled_dim = pooled_out.shape[1]   # pooled feature dim

        # GRU
        self.gru = nn.GRU(
            input_size=self.feat_c,
            hidden_size=gru_hidden,
            num_layers=gru_layers,
            batch_first=True,
            bidirectional=bidirectional
        ).to(device)

        # adapter for pooled â†’ GRU dimension
        self.pooled_adapter = nn.Linear(self.pooled_dim, gru_hidden*self.num_directions).to(device)

        # NOW fusion_in is correct
        fusion_in = gru_hidden*self.num_directions + gru_hidden*self.num_directions

        self.fusion_gate = nn.Sequential(
            nn.Linear(fusion_in, fusion_in // 2),
            nn.GELU(),
            nn.Linear(fusion_in // 2, gru_hidden * self.num_directions),
            nn.Sigmoid()
        ).to(device)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(gru_hidden*self.num_directions, gru_hidden),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(gru_hidden, num_classes)
        ).to(device)

    def _get_feature_map(self, imgs, scale):
        if scale != 1.0:
            size = int(IMG_SIZE * scale)
            imgs = F.interpolate(imgs, size=(size, size), mode='bilinear', align_corners=False)
        out = self.backbone_feats(imgs)
        return out[0] if isinstance(out, list) else out

    def forward(self, x):
        B = x.size(0)

        if x.size(2) != IMG_SIZE or x.size(3) != IMG_SIZE:
            x = F.interpolate(x, size=(IMG_SIZE, IMG_SIZE), mode='bilinear')

        pooled = self.backbone_pool(x)
        if pooled.dim() == 4:
            pooled = pooled.mean([2, 3])

        pooled_proj = self.pooled_adapter(pooled)

        # Multi-scale patches
        tokens_multi = []
        for scale in self.scales:
            fmap = self._get_feature_map(x, scale)
            Bf, Cf, Hf, Wf = fmap.shape
            tokens = fmap.view(Bf, Cf, -1).permute(0, 2, 1)
            tokens_multi.append(tokens)

        tokens_all = torch.cat(tokens_multi, dim=1)

        # Adaptive slice-selection
        var = tokens_all.var(dim=2)
        k = min(self.top_k, tokens_all.size(1))
        _, idx = torch.topk(var, k, dim=1)
        batch = torch.arange(B).unsqueeze(1).to(x.device)
        selected = tokens_all[batch, idx]

        gru_out, _ = self.gru(selected)
        last = gru_out[:, -1, :]

        # Hybrid Attention Fusion
        fusion_input = torch.cat([pooled_proj, last], dim=1)
        gate = self.fusion_gate(fusion_input)
        fused = gate * last + (1 - gate) * pooled_proj

        return self.classifier(fused)


In [34]:
# Cell 8
num_classes = len(classes)
model = MultiScaleHybridModel(
    swin_name='swin_tiny_patch4_window7_224',
    pretrained=True,
    scales=SCALES,
    top_k=TOP_K,
    gru_hidden=GRU_HIDDEN,
    gru_layers=GRU_LAYERS,
    bidirectional=BIDIRECTIONAL,
    num_classes=num_classes,
    device=DEVICE
)

# class weights (inverse freq)
cnt = Counter(train_labels)
freq = np.array([cnt[i] for i in range(num_classes)], dtype=float)
inv = (freq.sum() / (freq + 1e-8))
class_weights = torch.tensor(inv / inv.sum(), dtype=torch.float32).to(DEVICE)

criterion = BalancedFocalLoss(gamma=2.0).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

print("Model created. Device:", DEVICE)


Model created. Device: cuda


In [35]:
# Cell 9
def compute_loss(logits, y_info, criterion, class_weights=None):
    # y_info: (y_a, y_b, lam) or (labels, None, 1.0)
    y_a, y_b, lam = y_info
    if y_b is None:
        return criterion(logits, y_a, class_weights)
    else:
        la = criterion(logits, y_a, class_weights)
        lb = criterion(logits, y_b, class_weights)
        return lam * la + (1 - lam) * lb

def set_requires_grad(module, flag):
    if module is None:
        return
    for p in module.parameters():
        p.requires_grad = flag


In [36]:
# Cell 10
best_val_f1 = -1.0
best_path = OUT_DIR/"best_hybrid.pth"

history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[], 'val_f1':[]}

# Stage 1 freeze backbone
set_requires_grad(model.backbone_feats, False)
set_requires_grad(model.backbone_pool, False)
# pooled_adapter may be Identity currently

for epoch in range(1, EPOCHS+1):
    model.train()
    # progressive unfreeze
    if epoch == FREEZE_STAGE_1 + 1:
        set_requires_grad(model.backbone_pool, True)
    if epoch == FREEZE_STAGE_2 + 1:
        set_requires_grad(model.backbone_feats, True)
        set_requires_grad(model.backbone_pool, True)
        if not isinstance(model.pooled_adapter, nn.Identity):
            set_requires_grad(model.pooled_adapter, True)

    running_loss = 0.0
    preds_all = []
    targets_all = []

    for imgs, labels in tqdm(train_loader, leave=False):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)

        imgs_aug, y_a, y_b, lam = apply_mix_augment(imgs, labels)
        # imgs_aug already device-matched
        # ensure correct size
        if imgs_aug.size(2) != IMG_SIZE or imgs_aug.size(3) != IMG_SIZE:
            imgs_aug = F.interpolate(imgs_aug, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)

        optimizer.zero_grad()
        logits = model(imgs_aug)
        y_info = (y_a, y_b, lam) if y_b is not None else (y_a, None, 1.0)
        loss = compute_loss(logits, y_info, criterion, class_weights)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds_all.extend(torch.argmax(logits, dim=1).detach().cpu().numpy().tolist())
        targets_all.extend(labels.detach().cpu().numpy().tolist())

    train_loss = running_loss / len(train_loader)
    train_acc = accuracy_score(targets_all, preds_all)

    # validation
    model.eval()
    v_loss = 0.0
    v_preds = []
    v_targets = []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            if imgs.size(2) != IMG_SIZE or imgs.size(3) != IMG_SIZE:
                imgs = F.interpolate(imgs, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
            logits = model(imgs)
            loss = criterion(logits, labels, class_weights)
            v_loss += loss.item()
            v_preds.extend(torch.argmax(logits, dim=1).cpu().numpy().tolist())
            v_targets.extend(labels.cpu().numpy().tolist())

    val_loss = v_loss / len(val_loader)
    val_acc = accuracy_score(v_targets, v_preds)
    val_prec, val_rec, val_f1, _ = precision_recall_fscore_support(v_targets, v_preds, average='weighted', zero_division=0)

    history['train_loss'].append(train_loss)
    history['train_acc'] = history.get('train_acc', []) + [train_acc]
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)

    print(f"Epoch {epoch}/{EPOCHS} TrainLoss {train_loss:.4f} TrainAcc {train_acc:.4f} ValLoss {val_loss:.4f} ValAcc {val_acc:.4f} ValF1 {val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), best_path)
        print("Saved best model:", best_path)

    scheduler.step()

print("Training finished. Best val f1:", best_val_f1)


  0%|          | 0/148 [00:00<?, ?it/s]

AssertionError: Input height (112) doesn't match model (224).

In [None]:
# Cell 11
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
model.eval()

y_true = []
y_pred = []
y_proba = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, leave=False):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        if imgs.size(2) != IMG_SIZE or imgs.size(3) != IMG_SIZE:
            imgs = F.interpolate(imgs, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
        logits = model(imgs)
        probs = F.softmax(logits, dim=1)[:,1].cpu().numpy()  # prob class1
        preds = torch.argmax(logits, dim=1).cpu().numpy().tolist()
        y_true.extend(labels.cpu().numpy().tolist())
        y_pred.extend(preds)
        y_proba.extend(probs.tolist())

acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
cm = confusion_matrix(y_true, y_pred)
print("Test Acc:", acc, "F1:", f1)
print("Confusion Matrix:\n", cm)


In [None]:
# Cell 12
epochs = list(range(1, len(history['train_loss'])+1))

plt.figure(figsize=(8,5))
plt.plot(epochs, history['train_loss'], marker='o', label='train_loss')
plt.plot(epochs, history['val_loss'], marker='o', label='val_loss')
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"loss_curve.png", dpi=300); plt.show()

plt.figure(figsize=(8,5))
plt.plot(epochs, history['train_acc'], marker='o', label='train_acc')
plt.plot(epochs, history['val_acc'], marker='o', label='val_acc')
plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"acc_curve.png", dpi=300); plt.show()

plt.figure(figsize=(8,5))
plt.plot(epochs, history['val_f1'], marker='o', label='val_f1')
plt.xlabel("Epoch"); plt.ylabel("F1"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"f1_curve.png", dpi=300); plt.show()


In [None]:
# Cell 13
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes, cmap='Blues')
plt.xlabel("Predicted"); plt.ylabel("Actual"); plt.title("Confusion Matrix")
plt.savefig(OUT_DIR/"confusion_matrix.png", dpi=300); plt.show()

fpr, tpr, _ = roc_curve(y_true, y_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6,5)); plt.plot(fpr, tpr, label=f"AUC={roc_auc:.4f}"); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"roc_curve.png", dpi=300); plt.show()

prec_vals, recall_vals, _ = precision_recall_curve(y_true, y_proba)
plt.figure(figsize=(6,5)); plt.plot(recall_vals, prec_vals); plt.xlabel("Recall"); plt.ylabel("Precision"); plt.grid(True)
plt.savefig(OUT_DIR/"pr_curve.png", dpi=300); plt.show()


In [None]:
# Cell 14
embs = []
labels_list = []
with torch.no_grad():
    for imgs, labels in tqdm(test_loader, leave=False):
        imgs = imgs.to(DEVICE)
        if imgs.size(2) != IMG_SIZE or imgs.size(3) != IMG_SIZE:
            imgs = F.interpolate(imgs, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
        feat = model.backbone_pool(imgs)
        if feat.dim() == 4:
            feat = feat.mean(dim=[2,3])
        embs.append(feat.cpu().numpy())
        labels_list.extend(labels)
embs = np.concatenate(embs, axis=0)
tsne = TSNE(n_components=2, random_state=SEED)
tsne_feats = tsne.fit_transform(embs)

plt.figure(figsize=(8,6))
plt.scatter(tsne_feats[:,0], tsne_feats[:,1], c=labels_list, cmap='coolwarm', s=6)
plt.colorbar(ticks=range(len(classes))); plt.title("t-SNE of global features")
plt.savefig(OUT_DIR/"tsne.png", dpi=300); plt.show()


In [None]:
# Cell 15
def grad_cam_simple(model, img_tensor, target_class=None):
    model.eval()
    activations = []
    gradients = []
    def forward_hook(module, inp, out):
        activations.append(out.detach())
    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0].detach())

    # Hook heuristic: last module of backbone_feats
    target_module = None
    for name, m in model.backbone_feats.named_modules():
        target_module = m
    if target_module is None:
        print("No suitable module for hook")
        return None

    h_f = target_module.register_forward_hook(forward_hook)
    h_b = target_module.register_backward_hook(backward_hook)

    img = img_tensor.unsqueeze(0).to(DEVICE)
    img.requires_grad = True
    logits = model(img)
    pred = logits.argmax(dim=1).item()
    if target_class is None:
        target_class = pred
    loss = logits[0, target_class]
    model.zero_grad()
    loss.backward(retain_graph=True)

    if len(activations) == 0 or len(gradients) == 0:
        h_f.remove(); h_b.remove()
        return None

    act = activations[-1][0]  # [C,H,W]
    grad = gradients[-1][0]   # [C,H,W]
    weights = grad.mean(dim=(1,2))  # [C]
    cam = (weights.view(-1,1,1) * act).sum(dim=0).cpu().numpy()
    cam = np.maximum(cam, 0)
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    h_f.remove(); h_b.remove()
    return cam, pred

# example use:
sample_path = test_paths[0]
arr = safe_load_image(sample_path, IMG_SIZE)
pil = transforms.ToTensor()(Image.fromarray(arr))
pil = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])(pil)
res = grad_cam_simple(model, pil, target_class=None)
if res is not None:
    cam, pred = res
    plt.imshow(arr); plt.imshow(cam, cmap='jet', alpha=0.4); plt.title(f"Pred {classes[pred]}"); plt.axis('off')
    plt.savefig(OUT_DIR/"gradcam_example.png", dpi=300); plt.show()
else:
    print("Grad-CAM failed to produce map")


In [None]:
# Cell 16
torch.save(model.state_dict(), OUT_DIR/"final_model.pth")

import pandas as pd
df_hist = pd.DataFrame({
    'epoch': list(range(1, len(history['train_loss'])+1)),
    'train_loss': history['train_loss'],
    'val_loss': history['val_loss'],
    'train_acc': history['train_acc'],
    'val_acc': history['val_acc'],
    'val_f1': history['val_f1']
})
df_hist.to_csv(OUT_DIR/"training_history.csv", index=False)
print("Saved model and history to", OUT_DIR)


In [None]:
# Cell 17
torch.save(model.state_dict(), OUT_DIR/"final_model.pth")
import pandas as pd
df_hist = pd.DataFrame({
    'epoch': list(range(1, len(history['train_loss'])+1)),
    'train_loss': history['train_loss'],
    'val_loss': history['val_loss'],
    'train_acc': history['train_acc'],
    'val_acc': history['val_acc'],
    'val_f1': history['val_f1']
})
df_hist.to_csv(OUT_DIR/"training_history.csv", index=False)
print("Saved model and history to", OUT_DIR)
