In [None]:
# %%writefile pretrain_simsiam_cbam.py
import os
import json
import time
import random
from dataclasses import dataclass, asdict
from typing import Iterable, Tuple, Optional

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.resnet import ResNet, Bottleneck


# ----------------------------
# Config
# ----------------------------
@dataclass
class PretrainConfig:
    root_path: str = "/kaggle/input/minida/mini_output1/pretrain"
    out_dir: str = "/kaggle/working/simsiam_cbam"
    epochs: int = 200
    batch_size: int = 64
    num_workers: int = 2
    accumulation_steps: int = 1

    # optimizer
    wd: float = 1e-4
    momentum: float = 0.9
    base_lr_ref: float = 0.05  # SimSiam style
    lr_ref_bs: int = 256

    # reproducibility
    seed: int = 42
    deterministic: bool = True

    # backbone BN behavior in pretrain
    # Keep default as your working run (True) to avoid changing accuracy behavior unexpectedly.
    fix_backbone_bn: bool = True

    # torch.compile safety
    enable_compile: bool = True

    # checkpoint every N epochs
    save_every: int = 10


# ----------------------------
# Reproducibility
# ----------------------------
def seed_all(seed: int = 42, deterministic: bool = True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


# ----------------------------
# CBAM-ResNet50 backbone
# ----------------------------
class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat))


class CBAMBottleneck(Bottleneck):
    """Bottleneck + CBAM after bn3 and before residual add."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        c_out = self.conv3.out_channels
        self.ca = ChannelAttention(c_out)
        self.sa = SpatialAttention()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


def cbam_resnet50(*, norm_layer=None, **kwargs) -> ResNet:
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], norm_layer=norm_layer, **kwargs)
    # Zero-init last BN in residual branch for stability (makes block initially identity-like)
    for m in model.modules():
        if isinstance(m, Bottleneck) and hasattr(m, "bn3") and m.bn3.weight is not None:
            nn.init.constant_(m.bn3.weight, 0)
    return model


# ----------------------------
# Optim helper: param-wise WD
# ----------------------------
def exclude_from_wd(named_params: Iterable[Tuple[str, torch.nn.Parameter]], wd: float):
    wd_params, no_wd_params = [], []
    for n, p in named_params:
        if not p.requires_grad:
            continue
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower():
            no_wd_params.append(p)
        else:
            wd_params.append(p)
    return [{"params": wd_params, "weight_decay": wd},
            {"params": no_wd_params, "weight_decay": 0.0}]


# ----------------------------
# SimSiam heads
# ----------------------------
class MLPHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=2048, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim, affine=False),
        )

    def forward(self, x): return self.net(x)


class PredictionHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x): return self.net(x)


class SimSiam(nn.Module):
    def __init__(self, fix_backbone_bn: bool = True):
        super().__init__()
        resnet = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # to avgpool
        self.projector = MLPHead(2048)
        self.predictor = PredictionHead()
        self.fix_backbone_bn = fix_backbone_bn

        self._frozen_bn_modules = []
        if self.fix_backbone_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.requires_grad_(False)
                    self._frozen_bn_modules.append(m)

    def _forward_backbone(self, x):
        x = self.backbone(x)      # (B, 2048, 1, 1)
        return torch.flatten(x, 1)

    def forward(self, x1, x2):
        z1 = self.projector(self._forward_backbone(x1))
        z2 = self.projector(self._forward_backbone(x2))
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        return p1, p2, z1.detach(), z2.detach()

    def train(self, mode: bool = True):
        super().train(mode)
        if self.fix_backbone_bn:
            for m in self._frozen_bn_modules:
                m.eval()
        return self


def neg_cos_sim(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()


# ----------------------------
# Dataset
# ----------------------------
class UnlabeledDataset(Dataset):
    def __init__(self, root_dir: str, transform):
        self.transform = transform
        exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
        self.files = []
        for dp, _, fns in os.walk(root_dir):
            for fn in fns:
                if fn.lower().endswith(exts):
                    self.files.append(os.path.join(dp, fn))
        self.files.sort()
        if not self.files:
            raise RuntimeError(f"No images found under: {root_dir}")

    def __len__(self): return len(self.files)

    def __getitem__(self, idx: int):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img), self.transform(img)


# ----------------------------
# Training
# ----------------------------
def maybe_compile(model: nn.Module, enable: bool = True) -> nn.Module:
    if not enable: 
        return model
    if not hasattr(torch, "compile") or not torch.cuda.is_available():
        return model
    try:
        major, _ = torch.cuda.get_device_capability()
        if major < 7:
            print(f"[compile] Skipping torch.compile (CC {major}.x < 7.0)")
            return model
        return torch.compile(model)
    except Exception as e:
        print("[compile] failed -> eager:", e)
        return model


def pretrain(cfg: PretrainConfig):
    os.makedirs(cfg.out_dir, exist_ok=True)
    seed_all(cfg.seed, cfg.deterministic)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # augmentations (SimSiam strong augs)
    crop_size = 224
    blur_kernel = max(int(0.1 * crop_size) // 2 * 2 + 1, 3)
    tfm = transforms.Compose([
        transforms.RandomResizedCrop(crop_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=blur_kernel, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    ds = UnlabeledDataset(cfg.root_path, tfm)
    dl = DataLoader(
        ds, batch_size=cfg.batch_size, shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(cfg.num_workers > 0),
        drop_last=True,
    )

    model = SimSiam(fix_backbone_bn=cfg.fix_backbone_bn).to(device)
    model = maybe_compile(model, cfg.enable_compile)

    # LR scaling (keep your behavior)
    global_bs = cfg.batch_size
    lr = cfg.base_lr_ref * max(min(global_bs, 1024), 64) / cfg.lr_ref_bs

    param_groups = exclude_from_wd(model.named_parameters(), wd=cfg.wd)
    opt = torch.optim.SGD(param_groups, lr=lr, momentum=cfg.momentum)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

    # save config
    with open(os.path.join(cfg.out_dir, "pretrain_config.json"), "w") as f:
        json.dump(asdict(cfg), f, indent=2)

    ckpt_path = os.path.join(cfg.out_dir, "simsiam_cbam_checkpoint.pth")
    start_epoch = 0
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        target = model._orig_mod if hasattr(model, "_orig_mod") else model
        target.load_state_dict(ckpt["model"])
        opt.load_state_dict(ckpt["optimizer"])
        sched.load_state_dict(ckpt["scheduler"])
        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])
        start_epoch = int(ckpt.get("epoch", -1)) + 1
        print("Resumed from epoch:", start_epoch)

    print(f"Pretraining {cfg.epochs} epochs (from {start_epoch})... | lr={lr:.6f}")
    t0 = time.time()

    for epoch in range(start_epoch, cfg.epochs):
        target = model._orig_mod if hasattr(model, "_orig_mod") else model
        target.train(True)

        total = 0.0
        opt.zero_grad(set_to_none=True)
        micro = 0

        for x1, x2 in dl:
            x1 = x1.to(device, non_blocking=True)
            x2 = x2.to(device, non_blocking=True)

            with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
                p1, p2, z1, z2 = model(x1, x2)
                loss_full = 0.5 * (neg_cos_sim(p1, z2) + neg_cos_sim(p2, z1))
                loss = loss_full / max(cfg.accumulation_steps, 1)

            scaler.scale(loss).backward()
            micro += 1

            if micro == cfg.accumulation_steps:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                micro = 0

            total += float(loss_full.item())

        if micro > 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)

        avg_loss = total / len(dl)
        sched.step()

        if (epoch + 1) % cfg.save_every == 0:
            ckpt_target = model._orig_mod if hasattr(model, "_orig_mod") else model
            torch.save({
                "epoch": epoch,
                "model": ckpt_target.state_dict(),
                "optimizer": opt.state_dict(),
                "scheduler": sched.state_dict(),
                "scaler": scaler.state_dict(),
            }, ckpt_path)

        print(f"Epoch {epoch+1:03d}/{cfg.epochs} | loss={avg_loss:.4f}")

    final_path = os.path.join(cfg.out_dir, "simsiam_cbam_pretrained_final.pth")
    ckpt_target = model._orig_mod if hasattr(model, "_orig_mod") else model
    torch.save({
        "backbone": ckpt_target.backbone.state_dict(),
        "projector": ckpt_target.projector.state_dict(),
        "predictor": ckpt_target.predictor.state_dict(),
    }, final_path)

    print("Saved:", final_path)
    print("Total time (min):", (time.time() - t0) / 60.0)


if __name__ == "__main__":
    cfg = PretrainConfig()
    pretrain(cfg)

In [5]:
%%writefile finetune_cbam_eval.py
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

import warnings
warnings.filterwarnings(
    "ignore",
    message=r"adaptive_max_pool2d_backward_cuda does not have a deterministic implementation.*"
)

import json
import random
import argparse
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import RandAugment

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, accuracy_score, f1_score

# =========================================================
# AMP compatibility wrapper
# =========================================================
def make_amp(device: torch.device):
    use_cuda = (device.type == "cuda")
    if hasattr(torch, "amp") and hasattr(torch.amp, "autocast") and hasattr(torch.amp, "GradScaler"):
        autocast_fn = lambda: torch.amp.autocast(device_type="cuda", enabled=use_cuda)
        scaler = torch.amp.GradScaler("cuda", enabled=use_cuda)
        return autocast_fn, scaler
    autocast_fn = lambda: torch.cuda.amp.autocast(enabled=use_cuda)
    scaler = torch.cuda.amp.GradScaler(enabled=use_cuda)
    return autocast_fn, scaler

# =========================================================
# Determinism (CBAM-safe): warn_only avoids crash from AdaptiveMaxPool2d backward
# =========================================================
def seed_all(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # TF32 off for more repeatability
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass

    # IMPORTANT: CBAM uses AdaptiveMaxPool2d -> non-deterministic CUDA backward
    # So we enable deterministic algorithms but warn_only to avoid runtime error.
    try:
        torch.use_deterministic_algorithms(True, warn_only=True)
    except TypeError:
        # older torch
        try:
            torch.use_deterministic_algorithms(True)
        except Exception:
            pass

def seed_worker(worker_id: int):
    wseed = torch.initial_seed() % 2**32
    np.random.seed(wseed)
    random.seed(wseed)

# =========================================================
# Mixup
# =========================================================
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float, device: torch.device):
    if alpha and alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    bs = x.size(0)
    idx = torch.randperm(bs, device=device)
    mixed = lam * x + (1.0 - lam) * x[idx]
    return mixed, y, y[idx], lam

def mixup_ce_loss(ce, logits, y_a, y_b, lam: float):
    return lam * ce(logits, y_a) + (1.0 - lam) * ce(logits, y_b)

# =========================================================
# SupCon (multi-view, correct labels/mask)
# =========================================================
class SupConLoss(nn.Module):
    """
    Multi-view supervised contrastive loss.
    Expects feats of shape (B, V, D) where V>=2.
    """
    def __init__(self, temperature: float = 0.07, eps: float = 1e-8):
        super().__init__()
        self.t = float(temperature)
        self.eps = float(eps)

    def forward(self, feats: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        device = feats.device
        B, V, D = feats.shape

        feats = F.normalize(feats, dim=2).float()      # (B,V,D)
        feats = feats.view(B * V, D)                   # (BV,D)

        labels = labels.contiguous().view(B, 1)        # (B,1)
        labels = labels.repeat(1, V).view(B * V, 1)    # (BV,1)
        mask = torch.eq(labels, labels.T).float().to(device)  # (BV,BV)
        mask.fill_diagonal_(0)

        logits = (feats @ feats.T) / self.t
        logits = logits - logits.max(dim=1, keepdim=True).values.detach()

        eye = torch.eye(B * V, device=device)
        exp_logits = torch.exp(logits) * (1.0 - eye)
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + self.eps)

        pos_count = mask.sum(dim=1)
        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / (pos_count + self.eps)
        return -mean_log_prob_pos.mean()

# =========================================================
# Two-view wrapper for HYBRID training
# =========================================================
class TwoCropTransform:
    def __init__(self, base_transform):
        self.base = base_transform

    def __call__(self, img):
        return self.base(img), self.base(img)

class ImageFolderTwoView(ImageFolder):
    def __getitem__(self, index):
        path, y = self.samples[index]
        img = self.loader(path)
        x1, x2 = self.transform(img)
        return (x1, x2), y

# =========================================================
# CBAM-ResNet50 definition (must match pretrain architecture)
# =========================================================
from torchvision.models.resnet import ResNet, Bottleneck

class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat))

class CBAMBottleneck(Bottleneck):
    """Bottleneck + CBAM after bn3 and before residual add."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        c_out = self.conv3.out_channels
        self.ca = ChannelAttention(c_out)
        self.sa = SpatialAttention()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

def cbam_resnet50(**kwargs) -> ResNet:
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], **kwargs)
    # zero-init last BN (stability)
    for m in model.modules():
        if isinstance(m, Bottleneck) and hasattr(m, "bn3") and m.bn3.weight is not None:
            nn.init.constant_(m.bn3.weight, 0)
    return model

# =========================================================
# Fine-tune model (CBAM backbone + classifier + optional SupCon head)
# =========================================================
class FineTuneCBAM(nn.Module):
    def __init__(self, pretrained_path: str, num_classes: int, hybrid: bool):
        super().__init__()
        resnet = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # to avgpool

        self.classifier = nn.Sequential(
            nn.Linear(2048, 512, bias=True),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.30),
            nn.Linear(512, num_classes, bias=True),
        )

        self.hybrid = bool(hybrid)
        self.supcon_proj = nn.Linear(2048, 128, bias=True) if self.hybrid else None

        ckpt = torch.load(pretrained_path, map_location="cpu")
        # expected: {"backbone": ..., "projector": ..., "predictor": ...}
        if "backbone" in ckpt:
            missing, unexpected = self.backbone.load_state_dict(ckpt["backbone"], strict=False)
        else:
            # in case user saved backbone-only
            missing, unexpected = self.backbone.load_state_dict(ckpt, strict=False)

        if missing or unexpected:
            print(f"[state_dict notice] missing: {missing} | unexpected: {unexpected}")

    def forward_feats(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x).flatten(1)  # (B,2048)

    def logits_from_feats(self, feats: torch.Tensor) -> torch.Tensor:
        return self.classifier(feats)

    def supcon_from_feats(self, feats: torch.Tensor) -> torch.Tensor:
        if not self.hybrid:
            raise RuntimeError("supcon_from_feats called while hybrid=False")
        return F.normalize(self.supcon_proj(feats), dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.forward_feats(x)
        return self.logits_from_feats(feats)

def count_params_m(model: nn.Module) -> float:
    return sum(p.numel() for p in model.parameters()) / 1e6

# =========================================================
# Data
# =========================================================
def build_datasets(data_root: str):
    train_dir = os.path.join(data_root, "train")
    val_dir   = os.path.join(data_root, "val")
    test_dir  = os.path.join(data_root, "test")

    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.9, 1.1)),
        transforms.RandomHorizontalFlip(p=0.5),
        RandAugment(num_ops=2, magnitude=7),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

    eval_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

    train_ds_ce = ImageFolder(train_dir, transform=train_tf)
    train_ds_hybrid = ImageFolderTwoView(train_dir, transform=TwoCropTransform(train_tf))

    val_ds  = ImageFolder(val_dir,  transform=eval_tf)
    test_ds = ImageFolder(test_dir, transform=eval_tf)

    class_names = train_ds_ce.classes
    assert val_ds.classes == class_names and test_ds.classes == class_names, "Split class mismatch."
    assert train_ds_hybrid.classes == class_names, "HYBRID train classes mismatch."

    return train_ds_ce, train_ds_hybrid, val_ds, test_ds, class_names, train_dir, val_dir, test_dir

def build_loaders(train_ds, val_ds, test_ds, batch_size, num_workers, seed, use_weighted_sampler: bool):
    g = torch.Generator()
    g.manual_seed(seed)
    pin = torch.cuda.is_available()

    if use_weighted_sampler:
        targets = np.array(train_ds.targets, dtype=np.int64)
        counts = np.bincount(targets)
        counts[counts == 0] = 1
        class_w = 1.0 / counts
        sample_w = class_w[targets]
        sampler = WeightedRandomSampler(
            weights=torch.as_tensor(sample_w, dtype=torch.double),
            num_samples=len(sample_w),
            replacement=True
        )
        shuffle = False
    else:
        sampler = None
        shuffle = True

    train_loader = DataLoader(
        train_ds, batch_size=batch_size,
        shuffle=shuffle, sampler=sampler,
        num_workers=num_workers, pin_memory=pin,
        worker_init_fn=seed_worker, generator=g,
        persistent_workers=True if num_workers > 0 else False
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin,
        worker_init_fn=seed_worker, generator=g,
        persistent_workers=True if num_workers > 0 else False
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin,
        worker_init_fn=seed_worker, generator=g,
        persistent_workers=True if num_workers > 0 else False
    )
    return train_loader, val_loader, test_loader

# =========================================================
# Eval helpers
# =========================================================
def macro_roc_auc(y_true: np.ndarray, probs: np.ndarray, n_classes: int) -> float:
    onehot = np.eye(n_classes)[y_true]
    return float(roc_auc_score(onehot, probs, average="macro", multi_class="ovr"))

@torch.no_grad()
def eval_loader(model, loader, device) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
    model.eval()
    ce_plain = nn.CrossEntropyLoss()
    ys, preds, probs = [], [], []
    total_loss = 0.0
    n = 0

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        pr = F.softmax(logits, dim=1)
        loss = ce_plain(logits, y)

        bs = x.size(0)
        total_loss += float(loss.item()) * bs
        n += bs

        ys.append(y.cpu().numpy())
        preds.append(logits.argmax(1).cpu().numpy())
        probs.append(pr.cpu().numpy())

    y = np.concatenate(ys)
    p = np.concatenate(preds)
    pr = np.concatenate(probs)
    return y, p, pr, total_loss / max(n, 1)

# =========================================================
# TTA (val selection only, then apply to test)
# =========================================================
def _tta_views_transforms(n_views: int):
    norm = transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    v1 = [
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), norm]),
    ]
    v2 = [
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), norm]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),
                            transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), norm]),
    ]
    v4 = [
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), norm]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),
                            transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), norm]),
        transforms.Compose([transforms.Resize(288), transforms.CenterCrop(224), transforms.ToTensor(), norm]),
        transforms.Compose([transforms.Resize(288), transforms.CenterCrop(224),
                            transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), norm]),
    ]
    if n_views == 1: return v1, "V1: center@256"
    if n_views == 2: return v2, "V2: center@256 + hflip@256"
    if n_views == 4: return v4, "V4: center/hflip @256 and @288"
    raise ValueError("n_views must be one of {1,2,4}")

@torch.no_grad()
def tta_probs(model, dataset_dir: str, class_names: List[str], device, batch_size, num_workers, seed, n_views: int):
    g = torch.Generator()
    g.manual_seed(seed)

    base_ds = ImageFolder(dataset_dir)
    assert base_ds.classes == class_names, "TTA dataset class order mismatch."

    tfms, policy = _tta_views_transforms(n_views)
    y_true = np.array(base_ds.targets, dtype=np.int64)

    model.eval()
    probs_all = []
    for tfm in tfms:
        base_ds.transform = tfm
        loader = DataLoader(
            base_ds, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=torch.cuda.is_available(),
            worker_init_fn=seed_worker, generator=g,
            persistent_workers=True if num_workers > 0 else False
        )
        chunks = []
        for x, _ in loader:
            x = x.to(device)
            logits = model(x)
            chunks.append(F.softmax(logits, dim=1).cpu().numpy())
        probs_all.append(np.concatenate(chunks, axis=0))

    mean_probs = np.mean(probs_all, axis=0)
    return y_true, mean_probs, policy

def pick_tta_policy_on_val(model, val_dir: str, class_names: List[str], device, batch_size, num_workers, seed,
                           candidate_views=(1, 2, 4)):
    best = {"views": 1, "acc": -1.0, "policy": ""}
    per_policy = {}

    for k in candidate_views:
        y, probs, policy = tta_probs(
            model=model, dataset_dir=val_dir, class_names=class_names,
            device=device, batch_size=batch_size, num_workers=num_workers, seed=seed, n_views=k
        )
        pred = probs.argmax(axis=1)
        acc = float(accuracy_score(y, pred))
        per_policy[str(k)] = {"val_acc": acc, "policy": policy}
        if acc > best["acc"]:
            best = {"views": int(k), "acc": acc, "policy": policy}

    return best, per_policy

# =========================================================
# Train one seed
# =========================================================
def train_one_seed(
    mode: str,
    seed: int,
    pretrained_path: str,
    train_ds, val_ds, test_ds,
    class_names: List[str],
    val_dir: str,
    test_dir: str,
    device: torch.device,
    out_dir: str,
    epochs: int,
    patience: int,
    batch_size: int,
    num_workers: int,
    use_weighted_sampler: bool,
    lr_backbone: float,
    lr_head: float,
    weight_decay: float,
    label_smoothing: float,
    mixup_alpha: float,
    supcon_weight: float,
    temperature: float,
    tta_candidates: List[int],
) -> Dict:
    seed_all(seed)
    os.makedirs(out_dir, exist_ok=True)

    hybrid = (mode.upper() == "HYBRID")
    model = FineTuneCBAM(pretrained_path, num_classes=len(class_names), hybrid=hybrid).to(device)

    # BN trainable
    model.train()
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            m.train()
            for p in m.parameters():
                p.requires_grad = True

    train_loader, val_loader, test_loader = build_loaders(
        train_ds, val_ds, test_ds,
        batch_size=batch_size, num_workers=num_workers, seed=seed,
        use_weighted_sampler=use_weighted_sampler
    )

    # sampler => unweighted CE (avoid double-compensation)
    if use_weighted_sampler:
        ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    else:
        targets = np.array(train_ds.targets, dtype=np.int64)
        counts = np.bincount(targets)
        counts[counts == 0] = 1
        w = (1.0 / counts)
        w = w / w.mean()
        ce = nn.CrossEntropyLoss(
            weight=torch.tensor(w, dtype=torch.float32, device=device),
            label_smoothing=label_smoothing
        )

    supcon = SupConLoss(temperature=temperature) if hybrid else None

    # param groups: backbone vs head (+ supcon head is included in head group)
    backbone_params, head_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone"):
            backbone_params.append(p)
        else:
            head_params.append(p)

    optimizer = torch.optim.AdamW(
        [{"params": backbone_params, "lr": lr_backbone},
         {"params": head_params, "lr": lr_head}],
        weight_decay=weight_decay
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    autocast_ctx, scaler = make_amp(device)

    best_val_acc = -1.0
    bad = 0
    best_path = os.path.join(out_dir, f"best_{mode.lower()}_seed{seed}.pth")

    for ep in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        correct = 0.0
        nseen = 0

        for batch in train_loader:
            optimizer.zero_grad(set_to_none=True)

            if hybrid:
                (x1, x2), y = batch
                x1 = x1.to(device, non_blocking=True)
                x2 = x2.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                with autocast_ctx():
                    # CE on view-1 (optional mixup on view-1 only)
                    if mixup_alpha and mixup_alpha > 0:
                        xmix, ya, yb, lam = mixup_data(x1, y, mixup_alpha, device)
                        feats_mix = model.forward_feats(xmix)
                        logits_mix = model.logits_from_feats(feats_mix)
                        loss_ce = mixup_ce_loss(ce, logits_mix, ya, yb, lam)

                        pred = logits_mix.argmax(1)
                        correct += lam * pred.eq(ya).sum().item() + (1.0 - lam) * pred.eq(yb).sum().item()
                    else:
                        feats1 = model.forward_feats(x1)
                        logits1 = model.logits_from_feats(feats1)
                        loss_ce = ce(logits1, y)

                        pred = logits1.argmax(1)
                        correct += pred.eq(y).sum().item()

                    # SupCon on clean views
                    feats1_clean = model.forward_feats(x1)
                    feats2_clean = model.forward_feats(x2)
                    z1 = model.supcon_from_feats(feats1_clean)
                    z2 = model.supcon_from_feats(feats2_clean)

                    feats = torch.stack([z1, z2], dim=1)  # (B,2,128)
                    loss_sup = supcon(feats, y)

                    loss = loss_ce + supcon_weight * loss_sup

            else:
                x, y = batch
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                with autocast_ctx():
                    if mixup_alpha and mixup_alpha > 0:
                        xmix, ya, yb, lam = mixup_data(x, y, mixup_alpha, device)
                        feats_mix = model.forward_feats(xmix)
                        logits = model.logits_from_feats(feats_mix)
                        loss = mixup_ce_loss(ce, logits, ya, yb, lam)

                        pred = logits.argmax(1)
                        correct += lam * pred.eq(ya).sum().item() + (1.0 - lam) * pred.eq(yb).sum().item()
                    else:
                        feats = model.forward_feats(x)
                        logits = model.logits_from_feats(feats)
                        loss = ce(logits, y)

                        pred = logits.argmax(1)
                        correct += pred.eq(y).sum().item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            bs = y.size(0)
            total_loss += float(loss.item()) * bs
            nseen += bs

        scheduler.step()

        train_loss = total_loss / max(nseen, 1)
        train_acc = float(correct) / max(nseen, 1)

        yv, pv, _, vloss = eval_loader(model, val_loader, device)
        val_acc = float(accuracy_score(yv, pv))
        val_f1 = float(f1_score(yv, pv, average="macro"))

        saved = ""
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            bad = 0
            torch.save(model.state_dict(), best_path)
            saved = "(saved)"
        else:
            bad += 1

        print(f"[{mode}][seed={seed}] Epoch {ep:02d}/{epochs} | "
              f"train loss {train_loss:.4f} acc {train_acc:.4f} || "
              f"val loss {vloss:.4f} acc {val_acc:.4f} macroF1 {val_f1:.4f} {saved}")

        if bad >= patience:
            print(f"[{mode}][seed={seed}] Early stopping.")
            break

    model.load_state_dict(torch.load(best_path, map_location=device))
    model.eval()

    yt, pt, prt, tloss = eval_loader(model, test_loader, device)
    test_acc = float(accuracy_score(yt, pt))
    auc = float(macro_roc_auc(yt, prt, len(class_names)))

    best_tta, per_policy = pick_tta_policy_on_val(
        model=model,
        val_dir=val_dir,
        class_names=class_names,
        device=device,
        batch_size=batch_size,
        num_workers=num_workers,
        seed=seed,
        candidate_views=tuple(tta_candidates)
    )

    ytta, prob_tta, chosen_policy = tta_probs(
        model=model,
        dataset_dir=test_dir,
        class_names=class_names,
        device=device,
        batch_size=batch_size,
        num_workers=num_workers,
        seed=seed,
        n_views=best_tta["views"]
    )
    pred_tta = prob_tta.argmax(axis=1)
    tta_acc = float(accuracy_score(ytta, pred_tta))
    auc_tta = float(macro_roc_auc(ytta, prob_tta, len(class_names)))

    rep = classification_report(yt, pt, target_names=class_names, digits=4)
    cm = confusion_matrix(yt, pt).tolist()

    return {
        "mode": mode,
        "seed": int(seed),
        "epochs_ran": int(ep),
        "best_val_acc": float(best_val_acc),
        "test_loss": float(tloss),
        "test_acc": float(test_acc),
        "tta_acc": float(tta_acc),
        "macro_auc": float(auc),
        "macro_auc_tta": float(auc_tta),
        "params_m": float(round(count_params_m(model), 4)),
        "tta_candidates": list(map(int, tta_candidates)),
        "tta_val_selection": {
            "picked_views": int(best_tta["views"]),
            "picked_val_acc": float(best_tta["acc"]),
            "picked_policy": best_tta["policy"],
            "all_candidates": per_policy
        },
        "tta_policy_applied_to_test": chosen_policy,
        "confusion_matrix": cm,
        "classification_report": rep,
        "checkpoint_path": best_path
    }

# =========================================================
# Summary
# =========================================================
@dataclass
class SummaryStats:
    mean: float
    std: float

def summarize(values: List[float]) -> SummaryStats:
    a = np.array(values, dtype=np.float64)
    if len(a) <= 1:
        return SummaryStats(mean=float(a.mean()) if len(a) else 0.0, std=0.0)
    return SummaryStats(mean=float(a.mean()), std=float(a.std(ddof=1)))

# =========================================================
# Main
# =========================================================
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", type=str, required=True)
    ap.add_argument("--pretrained_path", type=str, required=True)
    ap.add_argument("--out_json", type=str, default="cbam_results_3seeds.json")
    ap.add_argument("--out_dir", type=str, default="./runs_cbam_reviewproof")

    ap.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2])
    ap.add_argument("--epochs", type=int, default=60)
    ap.add_argument("--patience", type=int, default=10)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--num_workers", type=int, default=2)

    ap.add_argument("--use_weighted_sampler", action="store_true")

    ap.add_argument("--label_smoothing", type=float, default=0.0)
    ap.add_argument("--weight_decay", type=float, default=1e-4)

    ap.add_argument("--ce_lr_backbone", type=float, default=3e-5)
    ap.add_argument("--ce_lr_head", type=float, default=2e-4)
    ap.add_argument("--ce_mixup_alpha", type=float, default=0.0)

    ap.add_argument("--hy_lr_backbone", type=float, default=3e-5)
    ap.add_argument("--hy_lr_head", type=float, default=1e-4)
    ap.add_argument("--hy_mixup_alpha", type=float, default=0.05)
    ap.add_argument("--supcon_weight", type=float, default=0.04)
    ap.add_argument("--temperature", type=float, default=0.07)

    ap.add_argument("--tta_candidates", type=int, nargs="+", default=[1, 2, 4],
                    help="TTA candidates on VAL (subset of {1,2,4}).")

    args = ap.parse_args()

    allowed = {1, 2, 4}
    cand = [int(x) for x in args.tta_candidates]
    if any(x not in allowed for x in cand) or len(cand) == 0:
        raise ValueError("--tta_candidates must be a non-empty subset of {1,2,4} (e.g., 1 2 4).")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    if not os.path.exists(args.pretrained_path):
        raise FileNotFoundError(f"Pretrained weights not found: {args.pretrained_path}")

    train_ds_ce, train_ds_hybrid, val_ds, test_ds, class_names, train_dir, val_dir, test_dir = build_datasets(args.data_root)
    print("Classes:", class_names)
    print("Split dirs:", {"train": train_dir, "val": val_dir, "test": test_dir})

    all_results = {"config": vars(args), "class_names": class_names, "per_seed": []}

    for mode in ["CE", "HYBRID"]:
        train_ds = train_ds_ce if mode == "CE" else train_ds_hybrid
        for seed in args.seeds:
            r = train_one_seed(
                mode=mode,
                seed=seed,
                pretrained_path=args.pretrained_path,
                train_ds=train_ds,
                val_ds=val_ds,
                test_ds=test_ds,
                class_names=class_names,
                val_dir=val_dir,
                test_dir=test_dir,
                device=device,
                out_dir=args.out_dir,
                epochs=args.epochs,
                patience=args.patience,
                batch_size=args.batch_size,
                num_workers=args.num_workers,
                use_weighted_sampler=args.use_weighted_sampler,
                lr_backbone=(args.ce_lr_backbone if mode == "CE" else args.hy_lr_backbone),
                lr_head=(args.ce_lr_head if mode == "CE" else args.hy_lr_head),
                weight_decay=args.weight_decay,
                label_smoothing=args.label_smoothing,
                mixup_alpha=(args.ce_mixup_alpha if mode == "CE" else args.hy_mixup_alpha),
                supcon_weight=args.supcon_weight,
                temperature=args.temperature,
                tta_candidates=cand
            )
            all_results["per_seed"].append(r)

    def collect(mode: str, key: str) -> List[float]:
        return [x[key] for x in all_results["per_seed"] if x["mode"] == mode]

    summary = {}
    for mode in ["CE", "HYBRID"]:
        s = {
            "test_acc_%": asdict(summarize([v * 100.0 for v in collect(mode, "test_acc")])),
            "tta_acc_%": asdict(summarize([v * 100.0 for v in collect(mode, "tta_acc")])),
            "macro_auc": asdict(summarize(collect(mode, "macro_auc"))),
            "macro_auc_tta": asdict(summarize(collect(mode, "macro_auc_tta"))),
            "params_m": asdict(summarize(collect(mode, "params_m"))),
        }
        summary[mode] = s

    all_results["summary"] = summary

    print("\n================ SUMMARY (mean ± std over seeds) ================\n")
    for mode in ["CE", "HYBRID"]:
        s = summary[mode]
        print(f"MODE: {mode}")
        print(f"Test Acc (%): {s['test_acc_%']['mean']:.2f} ± {s['test_acc_%']['std']:.2f}")
        print(f"TTA  Acc (%): {s['tta_acc_%']['mean']:.2f} ± {s['tta_acc_%']['std']:.2f}")
        print(f"Macro ROC-AUC (single): {s['macro_auc']['mean']:.4f} ± {s['macro_auc']['std']:.4f}")
        print(f"Macro ROC-AUC (TTA):    {s['macro_auc_tta']['mean']:.4f} ± {s['macro_auc_tta']['std']:.4f}")
        print(f"# Params (M): {s['params_m']['mean']:.2f} ± {s['params_m']['std']:.2f}")
        print()

    with open(args.out_json, "w") as f:
        json.dump(all_results, f, indent=2)
    print(f"Saved per-seed metrics: {args.out_json}")

if __name__ == "__main__":
    main()

Overwriting finetune_cbam_eval.py


In [6]:
!python finetune_cbam_eval.py \
  --data_root /kaggle/input/minida/mini_output1 \
  --pretrained_path /kaggle/working/simsiam_cbam/simsiam_cbam_pretrained_final.pth \
  --seeds 0 1 2 \
  --epochs 60 --patience 10 --batch_size 32 --num_workers 2 \
  --label_smoothing 0.0 --weight_decay 1e-4 \
  --ce_lr_backbone 3e-5 --ce_lr_head 2e-4 --ce_mixup_alpha 0.0 \
  --hy_lr_backbone 3e-5 --hy_lr_head 1e-4 --hy_mixup_alpha 0.05 \
  --supcon_weight 0.04 --temperature 0.07 \
  --tta_candidates 1 2 4

Device: cuda
Classes: ['Alternaria', 'Healthy Leaf', 'straw_mite']
Split dirs: {'train': '/kaggle/input/minida/mini_output1/train', 'val': '/kaggle/input/minida/mini_output1/val', 'test': '/kaggle/input/minida/mini_output1/test'}
[CE][seed=0] Epoch 01/60 | train loss 0.6994 acc 0.6850 || val loss 0.9797 acc 0.3232 macroF1 0.1803 (saved)
[CE][seed=0] Epoch 02/60 | train loss 0.4927 acc 0.7970 || val loss 0.4994 acc 0.8687 macroF1 0.8637 (saved)
[CE][seed=0] Epoch 03/60 | train loss 0.4752 acc 0.8055 || val loss 0.2702 acc 0.9091 macroF1 0.9106 (saved)
[CE][seed=0] Epoch 04/60 | train loss 0.4743 acc 0.8055 || val loss 0.3512 acc 0.8687 macroF1 0.8637 
[CE][seed=0] Epoch 05/60 | train loss 0.4388 acc 0.8161 || val loss 0.2270 acc 0.9192 macroF1 0.9210 (saved)
[CE][seed=0] Epoch 06/60 | train loss 0.3820 acc 0.8224 || val loss 0.2605 acc 0.8990 macroF1 0.8962 
[CE][seed=0] Epoch 07/60 | train loss 0.3761 acc 0.8266 || val loss 0.2641 acc 0.8990 macroF1 0.8962 
[CE][seed=0] Epoch 08/60 | t

In [1]:
# --- Single cell: generate Confusion Matrix + ROC (CBAM CE & HYBRID) from saved checkpoints ---

import os, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize

# -----------------------------
# USER SETTINGS (edit if needed)
# -----------------------------
DATA_ROOT = "/kaggle/input/minida/mini_output1"          # has train/val/test folders
RESULTS_JSON = "/kaggle/working/cbam_results_3seeds.json"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
NUM_WORKERS = 2

OUT_DIR = "/kaggle/working"
os.makedirs(OUT_DIR, exist_ok=True)

# -----------------------------
# CBAM-ResNet50 definition (MUST match finetune_cbam_eval.py)
# -----------------------------
from torchvision.models.resnet import ResNet, Bottleneck

class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat))

class CBAMBottleneck(Bottleneck):
    """Bottleneck + CBAM after bn3 and before residual add."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        c_out = self.conv3.out_channels
        self.ca = ChannelAttention(c_out)
        self.sa = SpatialAttention()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

def cbam_resnet50(**kwargs) -> ResNet:
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], **kwargs)
    # zero-init last BN (stability)
    for m in model.modules():
        if isinstance(m, Bottleneck) and hasattr(m, "bn3") and m.bn3.weight is not None:
            nn.init.constant_(m.bn3.weight, 0)
    return model


# -----------------------------
# Model definition (must match finetune_cbam_eval.py)
# -----------------------------
class FineTuneCBAM(nn.Module):
    def __init__(self, num_classes: int, hybrid: bool):
        super().__init__()
        resnet = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # to avgpool

        self.classifier = nn.Sequential(
            nn.Linear(2048, 512, bias=True),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.30),
            nn.Linear(512, num_classes, bias=True),
        )

        self.hybrid = bool(hybrid)
        self.supcon_proj = nn.Linear(2048, 128, bias=True) if self.hybrid else None

    def forward(self, x):
        feats = self.backbone(x).flatten(1)
        return self.classifier(feats)


# -----------------------------
# Data
# -----------------------------
eval_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

test_dir = os.path.join(DATA_ROOT, "test")
test_ds = ImageFolder(test_dir, transform=eval_tf)
class_names = test_ds.classes
n_classes = len(class_names)

test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available(),
    persistent_workers=True if NUM_WORKERS > 0 else False
)

# -----------------------------
# Helper: pick best checkpoint per mode from cbam_results_3seeds.json
# -----------------------------
with open(RESULTS_JSON, "r") as f:
    R = json.load(f)

def pick_best(mode: str):
    candidates = [x for x in R["per_seed"] if x["mode"].upper() == mode.upper()]
    if len(candidates) == 0:
        raise RuntimeError(f"No entries found for mode={mode} in {RESULTS_JSON}")
    # best is highest best_val_acc; tie-breaker: higher test_acc
    candidates.sort(key=lambda x: (x.get("best_val_acc", -1), x.get("test_acc", -1)), reverse=True)
    return candidates[0]

best_ce = pick_best("CE")
best_hy = pick_best("HYBRID")

print("Picked CE checkpoint:", best_ce["checkpoint_path"], "| seed:", best_ce["seed"], "| best_val_acc:", best_ce["best_val_acc"])
print("Picked HYBRID checkpoint:", best_hy["checkpoint_path"], "| seed:", best_hy["seed"], "| best_val_acc:", best_hy["best_val_acc"])


# -----------------------------
# Inference: get y_true, probs, preds
# -----------------------------
@torch.no_grad()
def infer(checkpoint_path: str, hybrid: bool):
    model = FineTuneCBAM(num_classes=n_classes, hybrid=hybrid).to(DEVICE)
    sd = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(sd, strict=True)
    model.eval()

    y_true = []
    probs = []
    preds = []

    for x, y in test_loader:
        x = x.to(DEVICE)
        logits = model(x)
        pr = F.softmax(logits, dim=1).cpu().numpy()
        pd = np.argmax(pr, axis=1)

        probs.append(pr)
        preds.append(pd)
        y_true.append(y.numpy())

    y_true = np.concatenate(y_true)
    probs = np.concatenate(probs)
    preds = np.concatenate(preds)
    return y_true, probs, preds


# -----------------------------
# Plot: Confusion Matrix
# -----------------------------
def plot_confusion(y_true, y_pred, title, out_path):
    cm = confusion_matrix(y_true, y_pred)
    fig = plt.figure(figsize=(6.2, 5.4))
    ax = fig.add_subplot(111)
    im = ax.imshow(cm, interpolation="nearest")
    fig.colorbar(im)

    ax.set_title(title)
    ax.set_xlabel("Predicted label")
    ax.set_ylabel("True label")
    ax.set_xticks(np.arange(n_classes))
    ax.set_yticks(np.arange(n_classes))
    ax.set_xticklabels(class_names, rotation=45, ha="right")
    ax.set_yticklabels(class_names)

    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(
                j, i, format(cm[i, j], "d"),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black",
                fontsize=11
            )

    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print("Saved:", out_path)


# -----------------------------
# Plot: ROC curves (per-class + micro + macro)
# -----------------------------
def plot_roc(y_true, probs, title, out_path):
    y_bin = label_binarize(y_true, classes=list(range(n_classes)))  # (N,C)

    fpr, tpr, roc_auc = {}, {}, {}

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    fpr["micro"], tpr["micro"], _ = roc_curve(y_bin.ravel(), probs.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= n_classes
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    fig = plt.figure(figsize=(6.2, 5.4))
    ax = fig.add_subplot(111)

    ax.plot(fpr["micro"], tpr["micro"], linewidth=2.5, label=f"micro-average (AUC = {roc_auc['micro']:.4f})")
    ax.plot(fpr["macro"], tpr["macro"], linewidth=2.5, label=f"macro-average (AUC = {roc_auc['macro']:.4f})")

    for i, name in enumerate(class_names):
        ax.plot(fpr[i], tpr[i], linewidth=1.8, label=f"{name} (AUC = {roc_auc[i]:.4f})")

    ax.plot([0, 1], [0, 1], linestyle="--", linewidth=1.5)
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title(title)
    ax.legend(loc="lower right", fontsize=9)
    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print("Saved:", out_path)


# -----------------------------
# Generate for CE
# -----------------------------
y_ce, pr_ce, pd_ce = infer(best_ce["checkpoint_path"], hybrid=False)
plot_confusion(
    y_ce, pd_ce,
    "Confusion Matrix (SimSiam-CBAM-ResNet-50, CE)",
    os.path.join(OUT_DIR, "CBAM_SSM_CM_CE.png")
)
plot_roc(
    y_ce, pr_ce,
    "ROC Curve (SimSiam-CBAM-ResNet-50, CE)",
    os.path.join(OUT_DIR, "CBAM_SSM_ROC_CE.png")
)

# -----------------------------
# Generate for HYBRID
# -----------------------------
y_hy, pr_hy, pd_hy = infer(best_hy["checkpoint_path"], hybrid=True)
plot_confusion(
    y_hy, pd_hy,
    "Confusion Matrix (SimSiam-CBAM-ResNet-50, Hybrid)",
    os.path.join(OUT_DIR, "CBAM_SSM_CM_HYBRID.png")
)
plot_roc(
    y_hy, pr_hy,
    "ROC Curve (SimSiam-CBAM-ResNet-50, Hybrid)",
    os.path.join(OUT_DIR, "CBAM_SSM_ROC_HYBRID.png")
)

Picked CE checkpoint: ./runs_cbam_reviewproof/best_ce_seed0.pth | seed: 0 | best_val_acc: 0.9595959595959596
Picked HYBRID checkpoint: ./runs_cbam_reviewproof/best_hybrid_seed2.pth | seed: 2 | best_val_acc: 0.9595959595959596
Saved: /kaggle/working/CBAM_SSM_CM_CE.png
Saved: /kaggle/working/CBAM_SSM_ROC_CE.png
Saved: /kaggle/working/CBAM_SSM_CM_HYBRID.png
Saved: /kaggle/working/CBAM_SSM_ROC_HYBRID.png


In [16]:
%%writefile cbam_temp_scaling_eval_from_ablation_json.py
import os, json, argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings(
    "ignore",
    message="adaptive_max_pool2d_backward_cuda does not have a deterministic implementation*",
)
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from sklearn.metrics import accuracy_score, roc_auc_score

# -----------------------------
# CBAM-ResNet50 (same as your finetune_cbam_eval.py)
# -----------------------------
from torchvision.models.resnet import ResNet, Bottleneck

class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x))

class CBAMBottleneck(Bottleneck):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        c_out = self.conv3.out_channels
        self.ca = ChannelAttention(c_out)
        self.sa = SpatialAttention()

    def forward(self, x):
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)   # <-- FIXED HERE
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

def cbam_resnet50(**kwargs) -> ResNet:
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], **kwargs)
    for m in model.modules():
        if isinstance(m, Bottleneck) and hasattr(m, "bn3") and m.bn3.weight is not None:
            nn.init.constant_(m.bn3.weight, 0)
    return model

class FineTuneCBAM(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        resnet = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512, bias=True),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.30),
            nn.Linear(512, num_classes, bias=True),
        )

    def forward(self, x):
        feats = self.backbone(x).flatten(1)
        return self.classifier(feats)

# -----------------------------
# Metrics / helpers
# -----------------------------
def macro_roc_auc(y_true: np.ndarray, probs: np.ndarray, n_classes: int) -> float:
    onehot = np.eye(n_classes)[y_true]
    return float(roc_auc_score(onehot, probs, average="macro", multi_class="ovr"))

def ece_from_probs(probs: np.ndarray, y_true: np.ndarray, n_bins: int = 15) -> float:
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == y_true).astype(np.float64)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    N = len(y_true)

    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        mask = (conf > lo) & (conf <= hi) if i > 0 else (conf >= lo) & (conf <= hi)
        if not np.any(mask):
            continue
        bin_acc = acc[mask].mean()
        bin_conf = conf[mask].mean()
        ece += (mask.sum() / N) * abs(bin_acc - bin_conf)
    return float(ece)

@torch.no_grad()
def collect_logits(model, loader, device):
    model.eval()
    all_logits, all_y = [], []
    for x, y in loader:
        x = x.to(device)
        all_logits.append(model(x).detach().cpu())
        all_y.append(y.detach().cpu())
    return torch.cat(all_logits, dim=0), torch.cat(all_y, dim=0)

def fit_temperature(val_logits: torch.Tensor, val_y: torch.Tensor, device, max_iter: int = 200) -> float:
    val_logits = val_logits.to(device)
    val_y = val_y.to(device)

    log_T = torch.zeros((), device=device, requires_grad=True)
    optimizer = torch.optim.LBFGS([log_T], lr=0.5, max_iter=max_iter, line_search_fn="strong_wolfe")
    nll = nn.CrossEntropyLoss()

    def closure():
        optimizer.zero_grad(set_to_none=True)
        T = torch.exp(log_T).clamp_min(1e-3)
        loss = nll(val_logits / T, val_y)
        loss.backward()
        return loss

    optimizer.step(closure)
    return float(torch.exp(log_T).clamp_min(1e-3).detach().cpu().item())

# -----------------------------
# Extract checkpoints from your ablation JSON (rows[])
# -----------------------------
def extract_from_rows(ab_json: dict):
    rows = ab_json.get("rows", None)
    if not isinstance(rows, list) or len(rows) == 0:
        raise KeyError("JSON must contain a non-empty list at key 'rows'.")

    out = []
    for r in rows:
        if not isinstance(r, dict):
            continue
        pretty = r.get("pretty_name") or r.get("noise_condition")
        ckpt = r.get("checkpoint") or r.get("checkpoint_path") or r.get("ckpt")
        if not pretty or not ckpt:
            continue
        out.append((pretty, ckpt, r))
    if not out:
        raise KeyError("No usable entries found in rows[]. Expected each row to have 'pretty_name' and 'checkpoint'.")
    return out

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", type=str, required=True)
    ap.add_argument("--ablation_json", type=str, required=True)
    ap.add_argument("--out_json", type=str, default="cbam_noise_bestseed_temp_scaled.json")
    ap.add_argument("--batch_size", type=int, default=64)
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--ece_bins", type=int, default=15)
    ap.add_argument("--max_iter", type=int, default=200)
    args = ap.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    eval_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

    val_ds = ImageFolder(os.path.join(args.data_root, "val"), transform=eval_tf)
    test_ds = ImageFolder(os.path.join(args.data_root, "test"), transform=eval_tf)
    assert val_ds.classes == test_ds.classes, "val/test class order mismatch"
    class_names = val_ds.classes
    n_classes = len(class_names)
    print("Classes:", class_names)

    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=torch.cuda.is_available(),
                            persistent_workers=(args.num_workers > 0))
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=torch.cuda.is_available(),
                             persistent_workers=(args.num_workers > 0))

    with open(args.ablation_json, "r") as f:
        ab = json.load(f)

    entries = extract_from_rows(ab)

    out = {
        "source_ablation_json": args.ablation_json,
        "class_names": class_names,
        "ece_bins": int(args.ece_bins),
        "rows": []
    }

    for pretty_name, ckpt_path, original_row in entries:
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"{pretty_name}: checkpoint not found: {ckpt_path}")

        model = FineTuneCBAM(num_classes=n_classes).to(device)
        sd = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(sd, strict=True)

        # fit T on VAL
        val_logits, val_y = collect_logits(model, val_loader, device)
        T = fit_temperature(val_logits, val_y, device=device, max_iter=args.max_iter)

        # evaluate on TEST with temp-scaled probs
        test_logits, test_y = collect_logits(model, test_loader, device)
        probs = F.softmax((test_logits.to(device) / T), dim=1).detach().cpu().numpy()
        y_true = test_y.numpy()
        y_pred = probs.argmax(axis=1)

        acc = float(accuracy_score(y_true, y_pred))
        auc = float(macro_roc_auc(y_true, probs, n_classes))
        ece = float(ece_from_probs(probs, y_true, n_bins=args.ece_bins))

        row_out = dict(original_row)
        row_out.update({
            "temperature_T": float(T),
            "temp_scaled_test_acc_%": float(acc * 100.0),
            "temp_scaled_macro_roc_auc": float(auc),
            "ece_temp_scaled_%": float(ece * 100.0),
        })
        out["rows"].append(row_out)

        print(f"{pretty_name:<22} | T={T:.3f} | "
              f"Acc={acc*100:.2f}% | AUC={auc:.4f} | ECE(TS)={ece*100:.2f}%")

    with open(args.out_json, "w") as f:
        json.dump(out, f, indent=2)
    print("Saved:", args.out_json)

if __name__ == "__main__":
    main()


Overwriting cbam_temp_scaling_eval_from_ablation_json.py


In [17]:
!python cbam_temp_scaling_eval_from_ablation_json.py \
  --data_root /kaggle/input/minida/mini_output1 \
  --ablation_json /kaggle/working/cbam_noise_bestseed_results.json \
  --out_json /kaggle/working/cbam_noise_bestseed_temp_scaled.json \
  --batch_size 64 --num_workers 2 --ece_bins 15

Device: cuda
Classes: ['Alternaria', 'Healthy Leaf', 'straw_mite']
No Noise               | T=0.702 | Acc=97.98% | AUC=0.9986 | ECE(TS)=5.02%
Gaussian Only          | T=0.628 | Acc=95.96% | AUC=0.9948 | ECE(TS)=3.31%
Salt-and-Pepper Only   | T=0.714 | Acc=95.96% | AUC=0.9977 | ECE(TS)=5.97%
Both Noises            | T=0.382 | Acc=93.94% | AUC=0.9920 | ECE(TS)=6.36%
Saved: /kaggle/working/cbam_noise_bestseed_temp_scaled.json
