**CBAM SIMSIAM RESNET**

In [2]:
%%writefile cbam_resnet.py
import torch
import torch.nn as nn
from typing import Optional, Callable
from torchvision.models.resnet import ResNet, Bottleneck


class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(x_cat)
        return self.sigmoid(out)


class CBAMBottleneck(Bottleneck):
    """
    ResNet Bottleneck with CBAM after the third BN and before the residual add.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        planes_out = self.conv3.out_channels
        self.ca = ChannelAttention(planes_out)
        self.sa = SpatialAttention()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


def cbam_resnet50(*, norm_layer: Optional[Callable] = None, **kwargs) -> ResNet:
    """
    Return a ResNet-50 that uses CBAMBottleneck blocks.
    kwargs forwarded to torchvision.models.resnet.ResNet (e.g., num_classes).
    """
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], norm_layer=norm_layer, **kwargs)
    # Zero-init last BN in each residual branch (improves training stability).
    for m in model.modules():
        if isinstance(m, Bottleneck):
            nn.init.constant_(m.bn3.weight, 0)
    return model

Overwriting cbam_resnet.py


In [6]:
%%writefile simsiam_cbam_pretrain.py
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from typing import Iterable, Tuple

from cbam_resnet import cbam_resnet50


# ----------------------------
# Reproducibility
# ----------------------------
def seed_all(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_all(42)


# ----------------------------
# Optim helper (param-wise WD)
# ----------------------------
def exclude_from_wd(named_params: Iterable[Tuple[str, torch.nn.Parameter]], wd: float = 1e-4):
    """Create two param groups: with and without weight decay (no wd for BN/bias/1D)."""
    wd_params, no_wd_params = [], []
    for n, p in named_params:
        if not p.requires_grad:
            continue
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower():
            no_wd_params.append(p)
        else:
            wd_params.append(p)
    return [{"params": wd_params, "weight_decay": wd},
            {"params": no_wd_params, "weight_decay": 0.0}]


# ----------------------------
# SimSiam heads
# ----------------------------
class MLPHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=2048, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim, affine=False),
        )

    def forward(self, x):
        return self.net(x)


class PredictionHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class SimSiam(nn.Module):
    """
    SimSiam with CBAM-ResNet50 backbone.
    """
    def __init__(self, fix_backbone_bn: bool = True):
        super().__init__()
        resnet = cbam_resnet50(num_classes=1000)
        # children: conv1, bn1, relu, maxpool, layer1..4, avgpool, fc
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # up to avgpool
        self.projector = MLPHead(2048)
        self.predictor = PredictionHead()
        self.fix_backbone_bn = fix_backbone_bn

        # Freeze BN params and set them eval; keep handles to re-freeze in .train()
        self._frozen_bn_modules = []
        if self.fix_backbone_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.requires_grad_(False)
                    self._frozen_bn_modules.append(m)

    def _forward_backbone(self, x):
        x = self.backbone(x)           # (B, 2048, 1, 1)
        x = torch.flatten(x, 1)        # (B, 2048)
        return x

    def forward(self, x1, x2):
        z1 = self.projector(self._forward_backbone(x1))
        z2 = self.projector(self._forward_backbone(x2))
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        # stop-grad on targets
        return p1, p2, z1.detach(), z2.detach()

    def train(self, mode: bool = True):
        super().train(mode)
        if self.fix_backbone_bn:
            for m in self._frozen_bn_modules:
                m.eval()  # keep BN frozen regardless of mode
        return self


# ----------------------------
# Dataset (recursive)
# ----------------------------
class UnlabeledDataset(Dataset):
    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
        filepaths = []
        for dp, _, fns in os.walk(root_dir):
            for fn in fns:
                if fn.lower().endswith(exts):
                    filepaths.append(os.path.join(dp, fn))
        self.filepaths = sorted(filepaths)
        if len(self.filepaths) == 0:
            raise RuntimeError(f"No images found under {root_dir}")

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx: int):
        img = Image.open(self.filepaths[idx]).convert("RGB")
        if self.transform is None:
            raise RuntimeError("Transform must be provided for SimSiam pretraining.")
        v1 = self.transform(img)
        v2 = self.transform(img)
        return v1, v2


# ----------------------------
# Loss
# ----------------------------
def negative_cosine_similarity(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()


# ----------------------------
# Training
# ----------------------------
def pretrain(
    root_path: str = "/kaggle/input/minida/mini_output1/pretrain",
    checkpoint_dir: str = "/kaggle/working/simsiam_cbam",
    epochs: int = 200,
    batch_size: int = 64,
    num_workers: int = 2,
    accumulation_steps: int = 1,
    fix_backbone_bn: bool = True,
    seed: int = 42,
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    seed_all(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "logs_cbam"))

    # SimSiam-style strong augs (GaussianBlur is important)
    crop_size = 224
    blur_kernel = max(int(0.1 * crop_size) // 2 * 2 + 1, 3)  # odd, >=3
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(crop_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=blur_kernel, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    dataset = UnlabeledDataset(root_dir=root_path, transform=train_transform)
    pin = torch.cuda.is_available()
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin,
        persistent_workers=True if num_workers > 0 else False,
        drop_last=True,
    )

    model = SimSiam(fix_backbone_bn=fix_backbone_bn).to(device)

    # --- safe torch.compile (skip on older GPUs like P100 CC 6.0) ---
    use_compile = False
    if hasattr(torch, "compile") and torch.cuda.is_available():
        try:
            major, _ = torch.cuda.get_device_capability()
            if major >= 7:
                use_compile = True
            else:
                print(f"Skipping torch.compile: compute capability {major}.x < 7.0")
        except Exception as e:
            print(f"Could not query device capability, skipping compile: {e}")

    if use_compile:
        try:
            model = torch.compile(model)
        except Exception as e:
            print(f"torch.compile failed, falling back to eager: {e}")

    # Linear LR scaling with batch size (clamped)
    global_bs = batch_size  # adjust if using DDP
    base_lr = 0.05 * max(min(global_bs, 1024), 64) / 256.0

    # Param-wise weight decay
    param_groups = exclude_from_wd(model.named_parameters(), wd=1e-4)
    optimizer = torch.optim.SGD(param_groups, lr=base_lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # New AMP API
    scaler = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))

    ckpt_path = os.path.join(checkpoint_dir, "simsiam_cbam_checkpoint.pth")
    start_epoch = 0
    if os.path.exists(ckpt_path):
        print("Resuming from checkpoint...")
        ckpt = torch.load(ckpt_path, map_location=device)
        target = model._orig_mod if hasattr(model, "_orig_mod") else model
        target.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])
        start_epoch = int(ckpt.get("epoch", -1)) + 1
        print(f"Resumed at epoch {start_epoch}")

    print(f"Starting SimSiam+CBAM pretraining for {epochs} epochs (from {start_epoch})...")

    for epoch in range(start_epoch, epochs):
        target = model._orig_mod if hasattr(model, "_orig_mod") else model
        target.train(True)

        total_loss = 0.0
        optimizer.zero_grad(set_to_none=True)
        micro = 0

        for step, (x1, x2) in enumerate(dataloader, start=1):
            x1 = x1.to(device, non_blocking=True)
            x2 = x2.to(device, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
                p1, p2, z1, z2 = model(x1, x2)
                loss_full = 0.5 * (negative_cosine_similarity(p1, z2) +
                                   negative_cosine_similarity(p2, z1))
                loss = loss_full / max(accumulation_steps, 1)

            scaler.scale(loss).backward()
            micro += 1

            if micro == accumulation_steps:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                micro = 0

            total_loss += float(loss_full.item())

        # Finalize residue grads if loop ended mid-accumulation
        if micro > 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        avg_loss = total_loss / len(dataloader)
        writer.add_scalar("Loss/train", avg_loss, epoch)
        scheduler.step()

        if (epoch + 1) % 10 == 0:
            ckpt_target = model._orig_mod if hasattr(model, "_orig_mod") else model
            torch.save({
                "epoch": epoch,
                "model": ckpt_target.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "scaler": scaler.state_dict(),
            }, ckpt_path)

        print(f"Epoch [{epoch + 1}/{epochs}] Avg Loss: {avg_loss:.4f}")

    final_path = os.path.join(checkpoint_dir, "simsiam_cbam_pretrained_final.pth")
    ckpt_target = model._orig_mod if hasattr(model, "_orig_mod") else model
    torch.save({
        "backbone": ckpt_target.backbone.state_dict(),
        "projector": ckpt_target.projector.state_dict(),
        "predictor": ckpt_target.predictor.state_dict(),
    }, final_path)
    writer.close()
    print(f"Pretraining complete! Model saved to {final_path}")


if __name__ == "__main__":
    pretrain()

Overwriting simsiam_cbam_pretrain.py


In [1]:
!python simsiam_cbam_pretrain.py

2025-08-29 03:23:40.088698: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756437820.289227      74 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756437820.347267      74 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Using device: cuda
Skipping torch.compile: compute capability 6.x < 7.0
Resuming from checkpoint...
Resumed at epoch 190
Starting SimSiam+CBAM pretraining for 200 epochs (from 190)...
Epoch [191/200] Avg Loss: -0.4395
Epoch [192/200] Avg Loss: -0.4541
Epoch [193/200] Avg Loss: -0.4229
Epoch [194/200] Avg Loss: -0.4344
Epoch [195/200] Avg Loss: -0.4363
Epoch [196/200] Avg Loss: -0.4456
Epoch [197/200] Avg Loss: -0.4710
Epoch [198/200]

**CBAM-ResNet fine-tuning**

In [13]:
# finetune_cbam.py
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Iterable
from tqdm import tqdm

from torchvision import transforms
from torchvision.transforms import RandAugment
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, ConfusionMatrixDisplay, roc_curve, auc
import matplotlib.pyplot as plt

from cbam_resnet import cbam_resnet50  # ensure cbam_resnet.py is available


# ----------------------------
# Reproducibility
# ----------------------------
def seed_all(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ----------------------------
# Mixup helpers
# ----------------------------
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.3, device=None):
    lam = np.random.beta(alpha, alpha) if alpha and alpha > 0 else 1.0
    idx = torch.randperm(x.size(0), device=device)
    mixed_x = lam * x + (1.0 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam: float):
    return lam * criterion(pred, y_a) + (1.0 - lam) * criterion(pred, y_b)


# ----------------------------
# Optim helper (param-wise WD)
# ----------------------------
def exclude_from_wd(named_params: Iterable, wd: float = 1e-4):
    wd_params, no_wd_params = [], []
    for n, p in named_params:
        if not p.requires_grad:
            continue
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower():
            no_wd_params.append(p)
        else:
            wd_params.append(p)
    return [{"params": wd_params, "weight_decay": wd},
            {"params": no_wd_params, "weight_decay": 0.0}]


# ----------------------------
# Data
# ----------------------------
def make_train_transform(img_size: int) -> transforms.Compose:
    return transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.15), ratio=(0.3, 3.3), value='random'),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])


def make_eval_transform(img_size: int) -> transforms.Compose:
    resize_size = int(round(img_size * 1.14))  # common 224->256 style scaling
    return transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])


def get_loaders(
    train_dir: str,
    val_dir: str,
    test_dir: str,
    batch_size: int = 32,
    num_workers: int = 2,
    img_size: int = 224,
):
    pin = torch.cuda.is_available()

    train_transform = make_train_transform(img_size)
    eval_transform  = make_eval_transform(img_size)

    train_ds = ImageFolder(train_dir, transform=train_transform)
    val_ds   = ImageFolder(val_dir,   transform=eval_transform)
    test_ds  = ImageFolder(test_dir,  transform=eval_transform)

    class_names = train_ds.classes

    # Class-balanced sampling
    targets_np = np.array(train_ds.targets, dtype=np.int64)
    classes = np.unique(targets_np)
    class_counts = np.array([(targets_np == c).sum() for c in classes], dtype=np.float64)
    class_counts[class_counts == 0] = 1.0
    weights_per_class = 1.0 / class_counts
    sample_weights = weights_per_class[targets_np]
    sampler = WeightedRandomSampler(
        weights=torch.as_tensor(sample_weights, dtype=torch.double),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=pin, drop_last=True  # stabilizes BN + MixUp
    )
    val_loader   = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    test_loader  = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    return train_loader, val_loader, test_loader, test_ds, class_names


# ----------------------------
# Model
# ----------------------------
class FineTuneCBAM(nn.Module):
    """
    Loads a CBAM-ResNet50 backbone and attaches a small classifier head.
    Expects SimSiam pretrain checkpoint with key 'backbone' (or raw state_dict).
    """
    def __init__(self, pretrained_path: str, num_classes: int):
        super().__init__()
        backbone = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # up to avgpool
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes),
        )
        ckpt = torch.load(pretrained_path, map_location="cpu")
        sd = ckpt.get("backbone", ckpt)
        loaded = self.backbone.load_state_dict(sd, strict=False)
        missing, unexpected = loaded.missing_keys, loaded.unexpected_keys
        if missing and any(k.startswith(("layer1", "layer2", "layer3", "layer4")) for k in missing):
            raise RuntimeError(f"Backbone weights incompatible, missing critical keys: {missing[:10]}")
        if unexpected:
            print(f"[state_dict notice] unexpected: {unexpected[:10]}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x).flatten(1)
        return self.classifier(x)


# ----------------------------
# Scheduler (warmup + cosine)
# ----------------------------
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

def get_scheduler(optimizer, total_epochs: int, warmup_epochs: int = 5):
    warmup = LinearLR(optimizer, start_factor=0.2, total_iters=warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=max(total_epochs - warmup_epochs, 1))
    return SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs])


# ----------------------------
# Train / Eval
# ----------------------------
def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    scaler: torch.amp.GradScaler,
    use_mixup: bool = True,
    mixup_alpha: float = 0.3
) -> Tuple[float, float]:
    model.train()
    total_loss, correct = 0.0, 0

    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
            if use_mixup and mixup_alpha > 0:
                imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha, device=device)
                logits = model(imgs)
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
                preds = logits.argmax(1)
                correct += int(lam * preds.eq(y_a).sum().item()
                               + (1.0 - lam) * preds.eq(y_b).sum().item())
            else:
                logits = model(imgs)
                loss = criterion(logits, labels)
                preds = logits.argmax(1)
                correct += int(preds.eq(labels).sum().item())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item()) * imgs.size(0)

    n = len(loader.dataset)
    return total_loss / n, correct / n


@torch.no_grad()
def eval_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    device: torch.device
) -> Tuple[float, float, np.ndarray, np.ndarray]:
    model.eval()
    total_loss, correct = 0.0, 0
    all_labels: List[int] = []
    all_probs:  List[np.ndarray] = []

    for imgs, labels in tqdm(loader, desc="Eval", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        probs = F.softmax(logits, dim=1)
        loss = criterion(logits, labels)

        total_loss += float(loss.item()) * imgs.size(0)
        preds = logits.argmax(1)
        correct += int(preds.eq(labels).sum().item())

        all_labels.extend(labels.cpu().numpy().tolist())
        all_probs.extend(probs.cpu().numpy())

    n = len(loader.dataset)
    return total_loss / n, correct / n, np.array(all_labels), np.array(all_probs)


# ----------------------------
# Early stopping
# ----------------------------
class EarlyStopping:
    def __init__(self, patience: int = 7, verbose: bool = True):
        self.patience = patience
        self.counter = 0
        self.best_acc = None
        self.best_state = None
        self.best_epoch = -1
        self.verbose = verbose

    def __call__(self, val_acc: float, model: nn.Module, epoch: int):
        if self.best_acc is None or val_acc > self.best_acc:
            self.best_acc = val_acc
            self.counter = 0
            self.best_epoch = epoch
            self.best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            if self.verbose:
                print("Validation accuracy improved, saving best state.")
            return False
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            return self.counter >= self.patience


# ----------------------------
# Plots
# ----------------------------
def plot_confusion_matrix(y_true, y_pred, class_names, save_path=None):
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    plt.title('Normalized Confusion Matrix (CBAM)')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.close()


def plot_roc_per_class(y_true, y_score, class_names, save_path=None):
    plt.figure(figsize=(8, 6))
    for i, name in enumerate(class_names):
        if np.sum(y_true == i) == 0:
            continue
        try:
            fpr, tpr, _ = roc_curve((y_true == i).astype(int), y_score[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.2f})')
        except Exception as e:
            print(f"ROC error for class {name}: {e}")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Per-class ROC Curves (CBAM)')
    plt.legend()
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.close()


def plot_reliability(y_true, y_prob, class_names, n_bins=10, save_path="reliability_diagram_cbam.png"):
    from sklearn.calibration import calibration_curve
    y_prob = np.clip(y_prob, 1e-6, 1-1e-6)
    plt.figure(figsize=(5, 5))
    for i, name in enumerate(class_names):
        try:
            prob_true, prob_pred = calibration_curve((y_true == i).astype(int),
                                                     y_prob[:, i],
                                                     n_bins=n_bins,
                                                     strategy='uniform')
            plt.plot(prob_pred, prob_true, marker='o', label=name)
        except Exception as e:
            print(f"Reliability curve failed for {name}: {e}")
    plt.plot([0, 1], [0, 1], '--', color='gray')
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Fraction of Positives")
    plt.title("Reliability Diagram (CBAM)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()


def plot_loss_acc_curves(train_losses, val_losses, train_accs, val_accs, save_path="loss_acc_curves_cbam.png"):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curves (CBAM)")
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.title("Accuracy Curves (CBAM)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()


# ----------------------------
# TTA
# ----------------------------
def tta_eval(
    model: nn.Module,
    test_ds: ImageFolder,
    batch_size: int,
    class_names: List[str],
    device: torch.device,
    img_size: int
):
    resize_base = int(round(img_size * 1.14))
    resize_up   = int(round(img_size * 1.25))

    tta_transforms = [
        transforms.Compose([transforms.Resize(resize_base), transforms.CenterCrop(img_size),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(resize_base), transforms.CenterCrop(img_size),
                            transforms.RandomHorizontalFlip(1.0),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(resize_up), transforms.CenterCrop(img_size),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(resize_base), transforms.CenterCrop(img_size),
                            transforms.GaussianBlur(3),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
    ]
    pin = torch.cuda.is_available()
    model.eval()
    probs_all = []

    for t in tta_transforms:
        test_ds.transform = t
        loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=pin)
        probs_list = []
        with torch.no_grad():
            for imgs, _ in loader:
                imgs = imgs.to(device)
                logits = model(imgs)
                probs_list.append(F.softmax(logits, dim=1).cpu().numpy())
        probs_all.append(np.concatenate(probs_list, axis=0))

    tta_probs = np.array(probs_all)
    mean_probs = np.mean(tta_probs, axis=0)
    final_preds = mean_probs.argmax(axis=1)
    return mean_probs, final_preds


# ----------------------------
# Main
# ----------------------------
def main(
    data_root: str = "/kaggle/input/minida/mini_output1",
    pretrained_path: str = "/kaggle/working/simsiam_cbam/simsiam_cbam_pretrained_final.pth",
    epochs: int = 60,
    batch_size: int = 24,
    num_workers: int = 2,
    patience: int = 10,
    mixup_alpha: float = 0.3,
    head_only_warmup_epochs: int = 5,
    img_size: int = 256,
    save_dir: str = "."  # NEW: where to write model files
):
    seed_all(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device} | IMG_SIZE={img_size}")
    print(f"Loading pretrained backbone from: {pretrained_path}")

    # NEW: ensure output directory exists
    os.makedirs(save_dir, exist_ok=True)

    train_dir = os.path.join(data_root, "train")
    val_dir   = os.path.join(data_root, "val")
    test_dir  = os.path.join(data_root, "test")

    train_loader, val_loader, test_loader, test_ds, class_names = get_loaders(
        train_dir, val_dir, test_dir, batch_size=batch_size, num_workers=num_workers, img_size=img_size
    )

    # Build model
    model = FineTuneCBAM(pretrained_path, num_classes=len(class_names)).to(device)

    # Ensure BN uses batch stats during supervised finetune
    for m in model.backbone.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.train()
            m.requires_grad_(True)

    # Parameter groups: smaller LR for backbone, param-wise WD
    backbone_named = [(n, p) for n, p in model.named_parameters() if "classifier" not in n]
    head_named     = [(n, p) for n, p in model.named_parameters() if "classifier" in n]
    groups_backbone = exclude_from_wd(backbone_named, wd=1e-4)
    groups_head     = exclude_from_wd(head_named, wd=1e-4)
    for g in groups_backbone: g["lr"] = 3e-5
    for g in groups_head:     g["lr"] = 1e-4

    optimizer = torch.optim.AdamW(groups_backbone + groups_head, betas=(0.9, 0.999))
    scheduler = get_scheduler(optimizer, total_epochs=epochs, warmup_epochs=5)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.02)

    scaler = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))

    # Optional: head-only warmup (freeze backbone for first few epochs)
    for p in model.backbone.parameters():
        p.requires_grad = False

    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    # --- safe torch.compile (skip on CC < 7.0) ---
    use_compile = False
    if hasattr(torch, "compile") and torch.cuda.is_available():
        try:
            major, _ = torch.cuda.get_device_capability()
            if major >= 7:
                use_compile = True
            else:
                print(f"Skipping torch.compile: compute capability {major}.x < 7.0")
        except Exception as e:
            print(f"Could not query device capability, skipping compile: {e}")

    if use_compile:
        try:
            model = torch.compile(model)
        except Exception as e:
            print(f"torch.compile failed, falling back to eager: {e}")

    for epoch in range(epochs):
        # Unfreeze backbone after warmup
        if epoch == head_only_warmup_epochs:
            target = model._orig_mod if hasattr(model, "_orig_mod") else model
            for p in target.backbone.parameters():
                p.requires_grad = True
            print(f"Unfroze backbone at epoch {epoch}")

        # MixUp annealing
        decay_start = int(epochs * 2 / 3)
        if epoch < decay_start:
            cur_alpha = mixup_alpha
        else:
            remaining = max(epochs - epoch - 1, 0)
            span = max(epochs - decay_start, 1)
            cur_alpha = mixup_alpha * (remaining / span)
        use_mixup_flag = cur_alpha > 1e-6

        print(f"\nEpoch {epoch + 1}/{epochs} | mixup_alpha={cur_alpha:.4f}")
        tr_loss, tr_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, scaler,
            use_mixup=use_mixup_flag, mixup_alpha=cur_alpha
        )
        va_loss, va_acc, va_labels, va_probs = eval_epoch(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Train Loss: {tr_loss:.4f}, Acc: {tr_acc:.4f} | "
              f"Val Loss: {va_loss:.4f}, Acc: {va_acc:.4f}")

        train_losses.append(tr_loss); train_accs.append(tr_acc)
        val_losses.append(va_loss);   val_accs.append(va_acc)

        # Early stopping
        target = model._orig_mod if hasattr(model, "_orig_mod") else model
        if not hasattr(main, "_early"):
            main._early = EarlyStopping(patience=patience, verbose=True)
        if main._early(va_acc, target, epoch):
            print("Early stopping triggered.")
            break

    # === Load best state and SAVE to disk (NEW) ===
    target = model._orig_mod if hasattr(model, "_orig_mod") else model
    target.load_state_dict(main._early.best_state)
    target.to(device).eval()
    print(f"Best epoch (val acc): {main._early.best_epoch + 1}")

    # NEW: save-only-weights file (for inference)
    best_weights_path = os.path.join(save_dir, "finetuned_cbam_best.pth")
    torch.save(target.state_dict(), best_weights_path)
    print(f"Best model weights saved to {best_weights_path}")

    # NEW: full checkpoint (resume training capability)
    best_ckpt_path = os.path.join(save_dir, "finetuned_cbam_checkpoint.pth")
    torch.save({
        "epoch": main._early.best_epoch,
        "model_state_dict": target.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "best_val_acc": main._early.best_acc,
        "class_names": class_names,
        "img_size": img_size,
    }, best_ckpt_path)
    print(f"Best model checkpoint saved to {best_ckpt_path}")

    # Curves
    plot_loss_acc_curves(train_losses, val_losses, train_accs, val_accs)

    # Final Test Evaluation
    print("\nTest set results (CBAM):")
    te_loss, te_acc, te_labels, te_probs = eval_epoch(target, test_loader, criterion, device)
    te_preds = te_probs.argmax(axis=1)
    print(f"Test Loss: {te_loss:.4f}, Test Acc: {te_acc:.4f}")
    print(classification_report(te_labels, te_preds, target_names=class_names))
    plot_confusion_matrix(te_labels, te_preds, class_names, save_path="cbam_norm_confmat.png")

    try:
        te_onehot = np.eye(len(class_names))[te_labels]
        roc_macro = roc_auc_score(te_onehot, te_probs, average='macro', multi_class='ovr')
        print(f"Test ROC-AUC (macro): {roc_macro:.4f}")
        plot_roc_per_class(te_labels, te_probs, class_names, save_path="cbam_perclass_roc.png")
    except Exception as e:
        print(f"ROC-AUC calculation failed: {e}")

    plot_reliability(te_labels, te_probs, class_names)

    # TTA Evaluation
    print("\nTest-Time Augmentation (TTA) Evaluation (CBAM):")
    mean_probs, final_preds = tta_eval(target, test_ds, batch_size, class_names, device, img_size)
    print(classification_report(test_ds.targets, final_preds, target_names=class_names))
    plot_confusion_matrix(test_ds.targets, final_preds, class_names, save_path="cbam_tta_confmat.png")


if __name__ == "__main__":
    main()

Using device: cuda | IMG_SIZE=256
Loading pretrained backbone from: /kaggle/working/simsiam_cbam/simsiam_cbam_pretrained_final.pth
Skipping torch.compile: compute capability 6.x < 7.0

Epoch 1/60 | mixup_alpha=0.3000


                                                      

Train Loss: 1.0524, Acc: 0.3700 | Val Loss: 1.0759, Acc: 0.6263
Validation accuracy improved, saving best state.

Epoch 2/60 | mixup_alpha=0.3000


                                                      

Train Loss: 1.0348, Acc: 0.3932 | Val Loss: 1.0009, Acc: 0.4747
EarlyStopping counter: 1 / 10

Epoch 3/60 | mixup_alpha=0.3000


                                                      

Train Loss: 1.0196, Acc: 0.3679 | Val Loss: 0.9624, Acc: 0.6162
EarlyStopping counter: 2 / 10

Epoch 4/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.9877, Acc: 0.4440 | Val Loss: 0.8948, Acc: 0.6566
Validation accuracy improved, saving best state.

Epoch 5/60 | mixup_alpha=0.3000




Train Loss: 0.9629, Acc: 0.4440 | Val Loss: 0.8513, Acc: 0.7475
Validation accuracy improved, saving best state.
Unfroze backbone at epoch 5

Epoch 6/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.9339, Acc: 0.5159 | Val Loss: 0.7557, Acc: 0.7475
EarlyStopping counter: 1 / 10

Epoch 7/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.9267, Acc: 0.4820 | Val Loss: 0.7080, Acc: 0.8081
Validation accuracy improved, saving best state.

Epoch 8/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8988, Acc: 0.4968 | Val Loss: 0.6711, Acc: 0.7071
EarlyStopping counter: 1 / 10

Epoch 9/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8959, Acc: 0.5032 | Val Loss: 0.6340, Acc: 0.7879
EarlyStopping counter: 2 / 10

Epoch 10/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8764, Acc: 0.5243 | Val Loss: 0.5914, Acc: 0.8283
Validation accuracy improved, saving best state.

Epoch 11/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7812, Acc: 0.6321 | Val Loss: 0.4867, Acc: 0.8788
Validation accuracy improved, saving best state.

Epoch 12/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8259, Acc: 0.5814 | Val Loss: 0.4969, Acc: 0.8283
EarlyStopping counter: 1 / 10

Epoch 13/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8112, Acc: 0.5708 | Val Loss: 0.4747, Acc: 0.8788
EarlyStopping counter: 2 / 10

Epoch 14/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8214, Acc: 0.5793 | Val Loss: 0.4739, Acc: 0.8687
EarlyStopping counter: 3 / 10

Epoch 15/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7960, Acc: 0.5835 | Val Loss: 0.5007, Acc: 0.8182
EarlyStopping counter: 4 / 10

Epoch 16/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7463, Acc: 0.6321 | Val Loss: 0.4040, Acc: 0.8889
Validation accuracy improved, saving best state.

Epoch 17/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7655, Acc: 0.5920 | Val Loss: 0.4335, Acc: 0.8384
EarlyStopping counter: 1 / 10

Epoch 18/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7939, Acc: 0.5856 | Val Loss: 0.4085, Acc: 0.8687
EarlyStopping counter: 2 / 10

Epoch 19/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7677, Acc: 0.5983 | Val Loss: 0.4048, Acc: 0.8687
EarlyStopping counter: 3 / 10

Epoch 20/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8363, Acc: 0.5603 | Val Loss: 0.4403, Acc: 0.8586
EarlyStopping counter: 4 / 10

Epoch 21/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7535, Acc: 0.6025 | Val Loss: 0.4203, Acc: 0.8990
Validation accuracy improved, saving best state.

Epoch 22/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7789, Acc: 0.6152 | Val Loss: 0.3978, Acc: 0.8889
EarlyStopping counter: 1 / 10

Epoch 23/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7862, Acc: 0.6216 | Val Loss: 0.4498, Acc: 0.8081
EarlyStopping counter: 2 / 10

Epoch 24/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7886, Acc: 0.5814 | Val Loss: 0.3914, Acc: 0.8586
EarlyStopping counter: 3 / 10

Epoch 25/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.6754, Acc: 0.6744 | Val Loss: 0.3707, Acc: 0.8990
EarlyStopping counter: 4 / 10

Epoch 26/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7498, Acc: 0.6195 | Val Loss: 0.3908, Acc: 0.8384
EarlyStopping counter: 5 / 10

Epoch 27/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7536, Acc: 0.6321 | Val Loss: 0.3866, Acc: 0.8485
EarlyStopping counter: 6 / 10

Epoch 28/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7702, Acc: 0.5856 | Val Loss: 0.3676, Acc: 0.8687
EarlyStopping counter: 7 / 10

Epoch 29/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7065, Acc: 0.6575 | Val Loss: 0.3862, Acc: 0.8485
EarlyStopping counter: 8 / 10

Epoch 30/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7560, Acc: 0.6195 | Val Loss: 0.3580, Acc: 0.9091
Validation accuracy improved, saving best state.

Epoch 31/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7527, Acc: 0.5983 | Val Loss: 0.3863, Acc: 0.8586
EarlyStopping counter: 1 / 10

Epoch 32/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7455, Acc: 0.6216 | Val Loss: 0.3674, Acc: 0.8788
EarlyStopping counter: 2 / 10

Epoch 33/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7306, Acc: 0.6237 | Val Loss: 0.3599, Acc: 0.8889
EarlyStopping counter: 3 / 10

Epoch 34/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7185, Acc: 0.6575 | Val Loss: 0.3537, Acc: 0.8990
EarlyStopping counter: 4 / 10

Epoch 35/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7114, Acc: 0.6364 | Val Loss: 0.3881, Acc: 0.8384
EarlyStopping counter: 5 / 10

Epoch 36/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.8030, Acc: 0.5708 | Val Loss: 0.3780, Acc: 0.8990
EarlyStopping counter: 6 / 10

Epoch 37/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.6783, Acc: 0.6638 | Val Loss: 0.3497, Acc: 0.8990
EarlyStopping counter: 7 / 10

Epoch 38/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7579, Acc: 0.6131 | Val Loss: 0.3784, Acc: 0.8586
EarlyStopping counter: 8 / 10

Epoch 39/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.6906, Acc: 0.6512 | Val Loss: 0.3494, Acc: 0.8990
EarlyStopping counter: 9 / 10

Epoch 40/60 | mixup_alpha=0.3000


                                                      

Train Loss: 0.7295, Acc: 0.6406 | Val Loss: 0.3486, Acc: 0.8788
EarlyStopping counter: 10 / 10
Early stopping triggered.
Best epoch (val acc): 30
Best model weights saved to ./finetuned_cbam_best.pth
Best model checkpoint saved to ./finetuned_cbam_checkpoint.pth

Test set results (CBAM):


                                                   

Test Loss: 0.2868, Test Acc: 0.9596
              precision    recall  f1-score   support

  Alternaria       1.00      0.89      0.94        37
Healthy Leaf       0.89      1.00      0.94        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.96        99
   macro avg       0.96      0.96      0.96        99
weighted avg       0.96      0.96      0.96        99

Test ROC-AUC (macro): 0.9991

Test-Time Augmentation (TTA) Evaluation (CBAM):
              precision    recall  f1-score   support

  Alternaria       1.00      0.95      0.97        37
Healthy Leaf       0.94      1.00      0.97        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.98        99
   macro avg       0.98      0.98      0.98        99
weighted avg       0.98      0.98      0.98        99

