In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        (os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
# run if timm not present
!pip install timm --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m104.1 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import os
from pathlib import Path
import random
import math
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_curve, auc, precision_recall_curve
import matplotlib.pyplot as plt
import seaborn as sns

# CONFIG - edit paths / hyperparams
DATA_DIR = "/kaggle/input/breakhis-400x/train"   # change if needed
OUT_DIR = Path("./outputs")
OUT_DIR.mkdir(exist_ok=True)
IMG_SIZE = 224
BATCH_SIZE = 8
EPOCHS = 12
LR = 1e-4
RANDOM_SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Patch extraction / selection settings
PATCH_SCALES = [1/4, 1/8]   # fractions of image side -> coarse and fine (e.g., 1/4 -> 56x56 for 224)
TOP_K_PATCHES = 16          # number of patches selected per image
PATCH_RESIZE = 128          # each patch will be resized to this before feeding backbone (to balance detail)

# Progressive freezing schedule (in epochs)
FREEZE_STAGE_1 = 2  # freeze backbone, train GRU+classifier
FREEZE_STAGE_2 = 4  # unfreeze backbone last layer
# Stage 3: unfreeze all (remaining epochs)

# Mixup / CutMix
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0
MIXPROB = 0.5  # probability to apply mix augmentation

SEED = RANDOM_SEED
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)




<torch._C.Generator at 0x7c142bcf4830>

In [4]:
VALID_EXTS = {'.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'}

def safe_load_image(p, size=IMG_SIZE):
    try:
        img = Image.open(p).convert("RGB")
        img = img.resize((size, size), Image.BILINEAR)
        arr = np.array(img)
        if arr.dtype != np.uint8:
            arr = arr.astype(np.uint8)
        if arr.ndim != 3 or arr.shape[2] != 3:
            arr = np.zeros((size, size, 3), dtype=np.uint8)
        return arr
    except Exception:
        return np.zeros((size, size, 3), dtype=np.uint8)

# Build file list and classes (assumes structure DATA_DIR/class_x/images)
data_root = Path(DATA_DIR)
assert data_root.exists(), f"{DATA_DIR} not found"
classes = [d.name for d in sorted(data_root.iterdir()) if d.is_dir()]
print("Classes detected:", classes)
filepaths = []
labels = []
for i, cls in enumerate(classes):
    p = data_root/cls
    for f in sorted(p.rglob("*")):
        if f.suffix.lower() in VALID_EXTS and f.is_file():
            filepaths.append(str(f))
            labels.append(i)
print("Total samples:", len(filepaths))


Classes detected: ['benign', 'malignant']
Total samples: 1184


In [5]:
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    filepaths, labels, test_size=0.2, random_state=SEED, stratify=labels
)
train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_val_paths, train_val_labels, test_size=0.125, random_state=SEED, stratify=train_val_labels
)  # ~10% val of full

print("train", len(train_paths), "val", len(val_paths), "test", len(test_paths))


train 828 val 119 test 237


In [6]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, 1.0, None
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=CUTMIX_ALPHA):
    if alpha <= 0:
        return x, y, 1.0, None
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda according to pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam


In [7]:
class BalancedFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction='mean', eps=1e-9):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.eps = eps
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets, class_weights=None):
        """
        logits: [B, C], targets: [B]
        class_weights: tensor of shape [C] or None
        """
        logp = F.log_softmax(logits, dim=1)
        p = torch.exp(logp)
        ce_loss = F.nll_loss(logp, targets, reduction='none', weight=class_weights)
        pt = p.gather(1, targets.unsqueeze(1)).squeeze(1)  # p_t
        focal_factor = (1 - pt) ** self.gamma
        loss = focal_factor * ce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


In [8]:
# transforms for raw input (basic augmentations)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

class ImageFileDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        arr = safe_load_image(p, size=IMG_SIZE)
        img = Image.fromarray(arr)
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

train_ds = ImageFileDataset(train_paths, train_labels, transform=train_transform)
val_ds = ImageFileDataset(val_paths, val_labels, transform=eval_transform)
test_ds = ImageFileDataset(test_paths, test_labels, transform=eval_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)


In [9]:
class HybridSwinGRU(nn.Module):
    def __init__(self, swin_name='swin_tiny_patch4_window7_224', pretrained=True, 
                 patch_scales=PATCH_SCALES, top_k=TOP_K_PATCHES, patch_resize=PATCH_RESIZE,
                 gru_hidden=512, gru_layers=1, num_classes=2, bidirectional=False, dropout=0.3):
        super().__init__()
        self.patch_scales = patch_scales
        self.top_k = top_k
        self.patch_resize = patch_resize

        # backbone for patches & global image (use same weights)
        # use features_only=False but create a model exposing forward_features if available
        self.backbone = timm.create_model(swin_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        # some timm models expose forward_features; otherwise we call backbone(x) which gives pooled features
        # get feature dim by a dummy pass later
        # GRU and classifier will be initialized after we detect feat_dim
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        self.gru_hidden = gru_hidden
        self.gru_layers = gru_layers
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()

        # We'll create GRU & classifier lazily once we know the feature dim
        self.feat_dim = None
        self.gru = None
        self.classifier = None

        # Hybrid attention fusion: small MLP that takes [global_feat || gru_feat] -> gate
        self.fusion_gate = None

    def _extract_patches(self, imgs):
        """
        imgs: tensor [B, C, H, W] in range [0,1] normalized already
        returns: list of lists of patch tensors per image:
            patches_tensor: [B, num_patches_total, C, patchH, patchW]
        We'll produce patches per scale, non-overlapping grid, resize each patch to patch_resize.
        """
        B, C, H, W = imgs.shape
        device = imgs.device
        all_patches = []
        for b in range(B):
            img = imgs[b]  # C,H,W
            patches_for_image = []
            # convert to PIL for cropping easier (or use tensor ops). We'll use tensor ops for speed.
            # convert to [H,W,C] numpy for simple crop: but that would move to CPU.
            # Instead do tensor cropping.
            for scale in self.patch_scales:
                side = int(H * scale)
                if side < 8:
                    continue
                # number of patches per side
                num = max(1, H // side)
                stride = side
                for y in range(0, H - side + 1, stride):
                    for x in range(0, W - side + 1, stride):
                        p = img[:, y:y+side, x:x+side]  # C,side,side
                        # resize patch to patch_resize
                        p_resized = F.interpolate(p.unsqueeze(0), size=(self.patch_resize, self.patch_resize), mode='bilinear', align_corners=False).squeeze(0)
                        patches_for_image.append(p_resized)
            if len(patches_for_image) == 0:
                # fallback: whole image resized
                p_resized = F.interpolate(img.unsqueeze(0), size=(self.patch_resize, self.patch_resize), mode='bilinear', align_corners=False).squeeze(0)
                patches_for_image.append(p_resized)
            all_patches.append(torch.stack(patches_for_image, dim=0))  # [num_patches, C, patch_resize, patch_resize]
        # pad to same count across batch: convert to list of tensors (variable lengths)
        return all_patches  # list length B of tensors [Ni, C, ph, pw]

    @staticmethod
    def _score_patches(patches_tensor):
        """
        simple importance score per patch: per-patch variance over RGB channels (higher->more informative)
        patches_tensor: [N, C, h, w] (cpu or device)
        returns: scores numpy array shape [N]
        """
        # compute variance per patch (on CPU or GPU)
        with torch.no_grad():
            # flatten spatial and channel dims
            vals = patches_tensor.view(patches_tensor.size(0), -1)
            var = vals.var(dim=1)  # variance
            # also consider entropy like measure: approximate by normalized histogram? skip for performance
            return var  # tensor N

    def _ensure_feat_dim(self, device):
        if self.feat_dim is None:
            self.backbone.eval()
            dummy = torch.zeros(1, 3, self.patch_resize, self.patch_resize).to(device)
            with torch.no_grad():
                feat = self.backbone(dummy)  # returns pooled features [1, feat_dim]
            if feat.dim() == 4:
                # some backbones may return feature map, pool it
                feat = feat.mean(dim=[2,3])
            self.feat_dim = feat.shape[1]
            # now create GRU and classifier with correct dims
            self.gru = nn.GRU(input_size=self.feat_dim, hidden_size=self.gru_hidden, num_layers=self.gru_layers,
                              batch_first=True, bidirectional=self.bidirectional).to(device)
            self.classifier = nn.Sequential(
                nn.Linear(self.gru_hidden * self.num_directions, self.gru_hidden // 2),
                nn.GELU(),
                nn.Dropout(self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.3),
                nn.Linear(self.gru_hidden // 2, len(classes))
            ).to(device)
            # fusion gate MLP: takes concat [global_feat, gru_feat] -> sigmoid gate scalar (or vector)
            fusion_in = self.feat_dim + (self.gru_hidden * self.num_directions)
            self.fusion_gate = nn.Sequential(
                nn.Linear(fusion_in, fusion_in // 2),
                nn.GELU(),
                nn.Linear(fusion_in // 2, self.num_directions * self.gru_hidden if False else 1),
                nn.Sigmoid()
            ).to(device)

    def forward(self, x):
        """
        x: [B, C, H, W], normalized tensor
        returns logits [B, num_classes]
        """
        B = x.size(0)
        device = x.device
        # ensure feature dim structures created
        self._ensure_feat_dim(device)

        # 1) global feature: use backbone pooled features on whole image
        # For timm create_model(..., num_classes=0, global_pool='avg') returns pooled features when calling model(x)
        global_feat = self.backbone(x)  # [B, feat_dim] (pooled)
        if global_feat.dim() == 4:
            global_feat = global_feat.mean(dim=[2,3])  # pool

        # 2) multi-scale patches extraction
        all_patches = self._extract_patches(x)  # list B of [Ni, C, ph, pw]

        # 3) for each image: score patches and select top_k, then run backbone on selected patches to get embeddings
        seq_embeddings = []
        for b in range(B):
            patches_b = all_patches[b]  # [Ni, C, ph, pw]
            scores = self._score_patches(patches_b)  # tensor [Ni]
            # pick top_k indices
            topk = min(self.top_k, patches_b.size(0))
            vals, idxs = torch.topk(scores, k=topk, largest=True)
            selected = patches_b[idxs]  # [topk, C, ph, pw]
            # pass selected patches through backbone to get embeddings
            # process in batch
            with torch.no_grad():
                emb = self.backbone(selected.to(device))  # [topk, feat_dim] or maybe pooled output
            if emb.dim() == 4:
                emb = emb.mean(dim=[2,3])
            seq_embeddings.append(emb)  # list of [topk, feat_dim]

        # pad sequence to same length top_k if variable; stack to [B, T, feat_dim]
        T = max([s.size(0) for s in seq_embeddings])
        feat_dim = self.feat_dim
        seq_tensor = torch.zeros(B, T, feat_dim, device=device)
        for b in range(B):
            s = seq_embeddings[b]
            seq_tensor[b, :s.size(0), :] = s

        # 4) GRU over sequence of patch embeddings
        gru_out, h_n = self.gru(seq_tensor)  # gru_out [B, T, hidden*dir]
        # aggregate: take last time-step output
        last_out = gru_out[:, -1, :]  # [B, hidden*dir]

        # 5) attention fusion w/ global feature
        # concat
        concat = torch.cat([global_feat, last_out], dim=1)  # [B, feat_dim + hidden*dir]
        gate = self.fusion_gate(concat)  # [B, 1] or [B, hidden]
        if gate.size(1) == 1:
            fused = gate * last_out + (1 - gate) * global_feat[:, :last_out.size(1)]
        else:
            fused = gate * last_out + (1 - gate) * global_feat

        fused = self.dropout(fused)
        logits = self.classifier(fused)
        return logits


In [10]:
model = HybridSwinGRU(swin_name='swin_tiny_patch4_window7_224', pretrained=True,
                      patch_scales=PATCH_SCALES, top_k=TOP_K_PATCHES, patch_resize=PATCH_RESIZE,
                      gru_hidden=512, gru_layers=1, num_classes=len(classes),
                      bidirectional=False, dropout=0.3).to(DEVICE)

# compute class weights inverse frequency for balanced focal
from collections import Counter
cnt = Counter(train_labels)
freq = np.array([cnt[i] for i in range(len(classes))], dtype=float)
class_weights = torch.tensor((freq.sum() / (freq + 1e-8)), dtype=torch.float32).to(DEVICE)  # inverse freq
class_weights = class_weights / class_weights.sum()  # normalized

criterion = BalancedFocalLoss(gamma=2.0).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

In [11]:
def apply_mix_augment(x, y):
    # randomly choose mixup or cutmix or none
    r = random.random()
    if r < MIXPROB/2:
        # mixup
        mixed_x, y_a, y_b, lam = mixup_data(x, y, alpha=MIXUP_ALPHA)
        return mixed_x, (y_a, y_b, lam), 'mixup'
    elif r < MIXPROB:
        # cutmix
        mixed_x, y_a, y_b, lam = cutmix_data(x.clone(), y, alpha=CUTMIX_ALPHA)
        return mixed_x, (y_a, y_b, lam), 'cutmix'
    else:
        return x, (y, None, 1.0), 'none'

def compute_loss(logits, y_info, criterion, class_weights=None):
    # y_info: (y_a, y_b, lam) or (y, None, 1)
    y_a, y_b, lam = y_info
    if y_b is None:
        return criterion(logits, y_a, class_weights)
    else:
        # combine focal losses: lam*loss(a) + (1-lam)*loss(b)
        loss_a = criterion(logits, y_a, class_weights)
        loss_b = criterion(logits, y_b, class_weights)
        return lam * loss_a + (1 - lam) * loss_b


In [12]:
def set_requires_grad(module, flag):
    for p in module.parameters():
        p.requires_grad = flag

# freeze schedule:
# Stage 1: freeze backbone entirely for first FREEZE_STAGE_1 epochs
# Stage 2: unfreeze backbone last block (if possible) for next (FREEZE_STAGE_2 - FREEZE_STAGE_1) epochs
# Stage 3: unfreeze all remaining epochs

# Helper to unfreeze last layer blocks: for timm swin, try to unfreeze layers in model if named
def unfreeze_backbone_last(module):
    # try to unfreeze topmost parameters by name heuristics
    for name, param in module.named_parameters():
        if any(x in name for x in ['layers', 'block', 'stages', 'layer4', 'norm', 'head']):
            param.requires_grad = True
        else:
            param.requires_grad = False

# Stage 1
set_requires_grad(model.backbone, False)
set_requires_grad(model.gru, True)
set_requires_grad(model.classifier, True)
set_requires_grad(model.fusion_gate, True)

history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[], 'val_f1':[]}
best_val_f1 = -1.0
best_path = OUT_DIR/"best_hybrid_model.pth"

for epoch in range(1, EPOCHS+1):
    model.train()
    # progressive unfreeze logic
    if epoch == FREEZE_STAGE_1 + 1:
        print("Stage 2: Partially unfreezing backbone last blocks")
        # try to unfreeze last stage (heuristic)
        try:
            unfreeze_backbone_last(model.backbone)
        except:
            set_requires_grad(model.backbone, True)  # fallback
    if epoch == FREEZE_STAGE_2 + 1:
        print("Stage 3: Unfreezing entire backbone for fine-tuning")
        set_requires_grad(model.backbone, True)

    running_loss = 0.0
    preds_all = []
    targets_all = []
    for imgs, labels in tqdm(train_loader, leave=False):
        imgs = imgs.to(DEVICE)
        labels_t = torch.tensor(labels, dtype=torch.long).to(DEVICE)

        # apply mix augment
        imgs_aug, y_info, aug_type = apply_mix_augment(imgs, labels_t)

        optimizer.zero_grad()
        logits = model(imgs_aug)  # model handles patch extraction -> GRU -> fusion
        loss = compute_loss(logits, y_info, criterion, class_weights)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds_all.extend(torch.argmax(logits, dim=1).detach().cpu().numpy().tolist())
        targets_all.extend(labels_t.detach().cpu().numpy().tolist())

    train_loss = running_loss / len(train_loader)
    train_acc = accuracy_score(targets_all, preds_all)
    history['train_loss'].append(train_loss)
    history['train_acc'] = history.get('train_acc', []) + [train_acc]

    # validation
    model.eval()
    v_loss = 0.0
    v_preds = []
    v_targets = []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels_t = torch.tensor(labels, dtype=torch.long).to(DEVICE)
            logits = model(imgs)
            loss = criterion(logits, labels_t, class_weights)
            v_loss += loss.item()
            v_preds.extend(torch.argmax(logits, dim=1).cpu().numpy().tolist())
            v_targets.extend(labels_t.cpu().numpy().tolist())

    val_loss = v_loss / len(val_loader)
    val_acc = accuracy_score(v_targets, v_preds)
    val_prec, val_rec, val_f1, _ = precision_recall_fscore_support(v_targets, v_preds, average='weighted', zero_division=0)

    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)

    print(f"Epoch {epoch}/{EPOCHS} TrainLoss {train_loss:.4f} TrainAcc {train_acc:.4f} ValLoss {val_loss:.4f} ValAcc {val_acc:.4f} ValF1 {val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), best_path)
        print("Saved best model:", best_path)

    scheduler.step()
print("Training complete. Best val f1:", best_val_f1)


AttributeError: 'NoneType' object has no attribute 'parameters'

In [None]:
# load best model
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
model.eval()

y_true = []
y_pred = []
y_proba = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, leave=False):
        imgs = imgs.to(DEVICE)
        labels_t = torch.tensor(labels, dtype=torch.long).to(DEVICE)
        logits = model(imgs)
        probs = F.softmax(logits, dim=1)[:,1].cpu().numpy()  # prob of class 1
        preds = torch.argmax(logits, dim=1).cpu().numpy().tolist()
        y_true.extend(labels)
        y_pred.extend(preds)
        y_proba.extend(probs.tolist())

acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
cm = confusion_matrix(y_true, y_pred)
print("Test Acc:", acc, "F1:", f1)


In [None]:
# History arrays
epochs_range = list(range(1, len(history['train_loss'])+1))

plt.figure(figsize=(8,5))
plt.plot(epochs_range, history['train_loss'], marker='o', label='train_loss')
plt.plot(epochs_range, history['val_loss'], marker='o', label='val_loss')
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"loss_curve.png", dpi=300); plt.show()

plt.figure(figsize=(8,5))
plt.plot(epochs_range, history['train_acc'], marker='o', label='train_acc')
plt.plot(epochs_range, history['val_acc'], marker='o', label='val_acc')
plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"acc_curve.png", dpi=300); plt.show()

plt.figure(figsize=(8,5))
plt.plot(epochs_range, history['val_f1'], marker='o', label='val_f1')
plt.xlabel("Epoch"); plt.ylabel("F1"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"f1_curve.png", dpi=300); plt.show()

# Confusion matrix
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes, cmap='Blues')
plt.xlabel("Predicted"); plt.ylabel("Actual"); plt.title("Confusion Matrix")
plt.savefig(OUT_DIR/"confusion_matrix.png", dpi=300); plt.show()

# ROC + PR
fpr, tpr, _ = roc_curve(y_true, y_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6,5)); plt.plot(fpr, tpr, label=f"AUC={roc_auc:.4f}"); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.legend(); plt.grid(True)
plt.savefig(OUT_DIR/"roc_curve.png", dpi=300); plt.show()

prec_vals, recall_vals, _ = precision_recall_curve(y_true, y_proba)
plt.figure(figsize=(6,5)); plt.plot(recall_vals, prec_vals); plt.xlabel("Recall"); plt.ylabel("Precision"); plt.grid(True)
plt.savefig(OUT_DIR/"pr_curve.png", dpi=300); plt.show()


In [None]:
from sklearn.manifold import TSNE

# collect embeddings from backbone global features for test set
embs = []
labels_list = []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(DEVICE)
        # get global feature (pooled) via backbone call inside our model: call model.backbone
        feat = model.backbone(imgs)  # pooled
        if feat.dim() == 4:
            feat = feat.mean(dim=[2,3])
        embs.append(feat.cpu().numpy())
        labels_list.extend(labels)
embs = np.concatenate(embs, axis=0)
tsne = TSNE(n_components=2, random_state=SEED)
tsne_feats = tsne.fit_transform(embs)

plt.figure(figsize=(8,6))
plt.scatter(tsne_feats[:,0], tsne_feats[:,1], c=labels_list, cmap='coolwarm', s=6)
plt.colorbar(ticks=range(len(classes))); plt.title("t-SNE of global features")
plt.savefig(OUT_DIR/"tsne.png", dpi=300); plt.show()


In [None]:
def grad_cam_on_image(model, img_tensor, class_idx=1):
    model.eval()
    # hook to capture features and gradients from backbone (attempt to access feature maps)
    gradients = []
    activations = []
    def forward_hook(module, inp, out):
        activations.append(out.detach())
    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0].detach())

    # try to register hook on backbone's layer before pooling if possible
    # heuristics: try layer named "layers" or last conv layer
    target_module = None
    for name, m in model.backbone.named_modules():
        if any(k in name for k in ['layer4', 'stages', 'layers', 'blocks']) :
            target_module = m
    if target_module is None:
        target_module = list(model.backbone.children())[-1]

    h_f = target_module.register_forward_hook(forward_hook)
    h_b = target_module.register_backward_hook(backward_hook)

    img = img_tensor.unsqueeze(0).to(DEVICE)
    img.requires_grad = True
    logits = model(img)
    pred_class = logits.argmax(dim=1).item()
    loss = logits[0, class_idx] if class_idx is not None else logits[0, pred_class]
    model.zero_grad()
    loss.backward(retain_graph=True)

    # get gradients and activations
    if len(gradients) == 0 or len(activations) == 0:
        print("Could not get activations/gradients from backbone for Grad-CAM")
        h_f.remove(); h_b.remove()
        return None

    grads = gradients[-1][0]  # [C, H, W]
    acts = activations[-1][0]  # [C, H, W]
    weights = grads.mean(dim=(1,2))  # [C]
    cam = (weights.view(-1,1,1) * acts).sum(dim=0).cpu().numpy()
    cam = np.maximum(cam, 0)
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    h_f.remove(); h_b.remove()
    return cam, pred_class

# Example usage:
sample_img_path = test_paths[0]
arr = safe_load_image(sample_img_path, size=IMG_SIZE)
pil = transforms.ToTensor()(Image.fromarray(arr))
result = grad_cam_on_image(model, pil, class_idx=None)
if result is not None:
    cam, pred = result
    plt.imshow(arr); plt.imshow(cam, cmap='jet', alpha=0.4); plt.title(f"Pred {pred}"); plt.axis('off')
    plt.savefig(OUT_DIR/"gradcam_example.png", dpi=300); plt.show()
