Swin Transformer

In [1]:
!pip install grad-cam

Collecting grad-cam
  Downloading grad-cam-1.5.5.tar.gz (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m76.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ttach (from grad-cam)
  Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7.1->grad-cam)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7.1->grad-cam)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.7.1->grad-cam)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collect

SimSiam ResNet PreTrain

In [14]:
%%writefile simsiam_model.py
import torch
import torch.nn as nn
from typing import Optional
import torchvision.models as tv_models


class MLPHead(nn.Module):
    """
    SimSiam projection MLP:
      Linear -> BN -> ReLU -> Linear -> BN(affine=False)
    """
    def __init__(self, in_dim: int = 2048, hidden_dim: int = 2048, out_dim: int = 2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim, affine=False),  # per SimSiam
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class PredictionHead(nn.Module):
    """
    SimSiam prediction MLP:
      Linear -> BN -> ReLU -> Linear
    """
    def __init__(self, in_dim: int = 2048, hidden_dim: int = 512, out_dim: int = 2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),  # bias=True default is fine
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def _build_resnet50_backbone() -> nn.Sequential:
    """
    Create a ResNet-50 backbone up to the global avgpool (drops the FC).
    Handles torchvision API differences around 'weights' vs 'pretrained'.
    """
    try:
        # torchvision >= 0.13
        resnet = tv_models.resnet50(weights=None)
    except TypeError:
        # older torchvision
        resnet = tv_models.resnet50(pretrained=False)
    backbone = nn.Sequential(*list(resnet.children())[:-1])  # up to avgpool
    return backbone


class SimSiam(nn.Module):
    """
    SimSiam with a ResNet-50 backbone, projection head, and prediction head.

    Args:
        fix_backbone_bn: If True, keep backbone BatchNorm layers in eval mode
                         and freeze their params; helpful for small batches.
    """
    def __init__(self, fix_backbone_bn: bool = True):
        super().__init__()
        self.backbone = _build_resnet50_backbone()
        self.projector = MLPHead(2048)
        self.predictor = PredictionHead()
        self.fix_backbone_bn = fix_backbone_bn

        if self.fix_backbone_bn:
            # Freeze BN params and set them to eval once.
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.requires_grad_(False)

    @torch.no_grad()
    def _forward_backbone(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)                # (N, 2048, 1, 1)
        x = torch.flatten(x, 1)             # (N, 2048)
        return x

    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        """
        Returns:
            p1, p2: predictions (gradients flow through)
            z1, z2: stop-grad targets (detached)
        """
        z1 = self.projector(self._forward_backbone(x1))
        z2 = self.projector(self._forward_backbone(x2))
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        return p1, p2, z1.detach(), z2.detach()

    def train(self, mode: bool = True):
        """
        Override to keep backbone BN layers in eval mode when requested.
        """
        super().train(mode)
        if self.fix_backbone_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
        return self

    def __repr__(self) -> str:
        return (
            f"SimSiam(backbone=ResNet50, "
            f"projector={self.projector}, predictor={self.predictor}, "
            f"fix_backbone_bn={self.fix_backbone_bn})"
        )

Overwriting simsiam_model.py


In [16]:
%%writefile simsiam_pretrain.py
import os
from typing import Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

from simsiam_model import SimSiam


# ----------------------------
# Utilities
# ----------------------------
def seed_all(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ----------------------------
# Dataset
# ----------------------------
class UnlabeledDataset(Dataset):
    """
    Recursively loads images from a root directory and returns two augmented views.
    """
    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
        filepaths = []
        for dp, _, fns in os.walk(root_dir):
            for fn in fns:
                if fn.lower().endswith(exts):
                    filepaths.append(os.path.join(dp, fn))
        self.filepaths = sorted(filepaths)
        if len(self.filepaths) == 0:
            raise RuntimeError(f"No images found in {root_dir}")

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img = Image.open(self.filepaths[idx]).convert('RGB')
        if self.transform is None:
            raise RuntimeError("Transform must be provided for SimSiam pretraining.")
        v1 = self.transform(img)
        v2 = self.transform(img)
        return v1, v2


# ----------------------------
# Loss
# ----------------------------
def negative_cosine_similarity(p: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
    """
    SimSiam loss: -cosine(p, z.detach())
    """
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()


# ----------------------------
# Training
# ----------------------------
def pretrain(
    root_path: str = "/kaggle/input/minida/mini_output1/pretrain",
    checkpoint_dir: str = "/kaggle/working/simsiam_vanilla_resnet",
    epochs: int = 150,
    batch_size: int = 64,
    fix_backbone_bn: bool = True,   # flip to False to let BN use batch stats
    num_workers: int = 2,
    seed: int = 42,
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    seed_all(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "logs"))

    # SimSiam-style augmentations
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),  # odd kernel; 23 works well for 224 crops
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
    ])

    dataset = UnlabeledDataset(root_dir=root_path, transform=train_transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=True if num_workers > 0 else False,
        drop_last=True,
    )

    model = SimSiam(fix_backbone_bn=fix_backbone_bn).to(device)

    # Linear LR scaling with batch size (ImageNet folklore; works fine here)
    base_lr = 0.05 * batch_size / 256
    optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # AMP for faster training on GPU
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    checkpoint_path = os.path.join(checkpoint_dir, "simsiam_checkpoint.pth")
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print("Resuming from checkpoint...")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.backbone.load_state_dict(checkpoint['backbone'])
        model.projector.load_state_dict(checkpoint['projector'])
        model.predictor.load_state_dict(checkpoint['predictor'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        if 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed at epoch {start_epoch}")

    print(f"Starting SimSiam pretraining for {epochs} epochs (from epoch {start_epoch})...")

    for epoch in range(start_epoch, epochs):
        model.train()
        total_loss = 0.0

        for batch_idx, (x1, x2) in enumerate(dataloader):
            x1 = x1.to(device, non_blocking=True)
            x2 = x2.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                p1, p2, z1, z2 = model(x1, x2)
                loss = 0.5 * (
                    negative_cosine_similarity(p1, z2) +
                    negative_cosine_similarity(p2, z1)
                )

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}] Batch [{batch_idx}] Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        writer.add_scalar("Loss/train", avg_loss, epoch)
        scheduler.step()
        print(f"Epoch [{epoch+1}/{epochs}] Average Loss: {avg_loss:.4f}")

        # periodic checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'backbone': model.backbone.state_dict(),
                'projector': model.projector.state_dict(),
                'predictor': model.predictor.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict(),
            }, checkpoint_path)

    # final weights for downstream use (no optimizer/scheduler)
    final_path = os.path.join(checkpoint_dir, "simsiam_pretrained.pth")
    torch.save({
        'backbone': model.backbone.state_dict(),
        'projector': model.projector.state_dict(),
        'predictor': model.predictor.state_dict(),
    }, final_path)
    writer.close()
    print("Pretraining complete! Model saved to", final_path)


if __name__ == "__main__":
    pretrain()

Overwriting simsiam_pretrain.py


In [18]:
!python simsiam_pretrain.py

2025-08-28 02:11:13.291002: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756347073.311885    1906 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756347073.318152    1906 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Using device: cuda
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
Resuming from checkpoint...
Resumed at epoch 100
Starting SimSiam pretraining for 150 epochs (from epoch 100)...
  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
Epoch [101/150] Batch [0] Loss: -0.2153
Epoch [101/150] Average Loss: -0.2193
Epoch [102/150] Batch [0] Loss: -0.2458
Epoch [102/150] Average Loss: -0.2179
Epoch [103/150

FineTune SimSiam Resnet

In [19]:
%%writefile finetune_simsiam.py
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
from tqdm import tqdm
from PIL import Image

import torchvision.models as models
from torchvision import transforms
from torchvision.transforms import RandAugment
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt


# ----------------------------
# Reproducibility
# ----------------------------
def seed_all(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ----------------------------
# Mixup helpers
# ----------------------------
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.4, device=None):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=device)
    mixed_x = lam * x + (1.0 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam: float):
    return lam * criterion(pred, y_a) + (1.0 - lam) * criterion(pred, y_b)


# ----------------------------
# Dataloaders
# ----------------------------
def get_loaders(
    train_dir: str,
    val_dir: str,
    test_dir: str,
    batch_size: int = 32,
    num_workers: int = 2,
):
    pin = torch.cuda.is_available()

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    eval_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    train_ds = ImageFolder(train_dir, transform=train_transform)
    val_ds   = ImageFolder(val_dir,   transform=eval_transform)
    test_ds  = ImageFolder(test_dir,  transform=eval_transform)

    # Derive class names from the data to avoid ordering surprises
    class_names = train_ds.classes

    # Balanced sampling
    targets_np = np.array(train_ds.targets)
    classes = np.unique(targets_np)
    class_sample_count = np.array([(targets_np == t).sum() for t in classes], dtype=np.float64)
    # Avoid division by zero in pathological cases
    class_sample_count[class_sample_count == 0] = 1.0
    weights = 1.0 / class_sample_count
    samples_weights = weights[targets_np.astype(int)]
    sampler = WeightedRandomSampler(
        weights=torch.as_tensor(samples_weights, dtype=torch.double),
        num_samples=len(samples_weights),
        replacement=True
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=pin
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    return train_loader, val_loader, test_loader, class_names


# ----------------------------
# Model definition
# ----------------------------
class FineTuneModel(nn.Module):
    """
    Load a ResNet-50 backbone and initialize a small classifier head.
    backbone weights come from SimSiam pretraining ('backbone' key).
    """
    def __init__(self, pretrained_path: str, num_classes: int = 3):
        super().__init__()
        resnet = models.resnet50(pretrained=False)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # up to avgpool
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        # Load SimSiam backbone weights (robust to minor key diffs)
        ckpt = torch.load(pretrained_path, map_location="cpu")
        missing, unexpected = self.backbone.load_state_dict(ckpt.get("backbone", ckpt), strict=False)
        if missing or unexpected:
            print(f"[state_dict notice] missing: {missing} | unexpected: {unexpected}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x).flatten(1)
        return self.classifier(x)


# ----------------------------
# Train / Eval
# ----------------------------
def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    scaler: torch.cuda.amp.GradScaler,
    use_mixup: bool = True,
    mixup_alpha: float = 0.4
) -> Tuple[float, float]:
    model.train()
    total_loss, correct = 0.0, 0

    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            if use_mixup:
                imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha, device=device)
                outputs = model(imgs)
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
                preds = outputs.argmax(1)
                correct_batch = (lam * preds.eq(y_a).sum().item() +
                                 (1.0 - lam) * preds.eq(y_b).sum().item())
            else:
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                preds = outputs.argmax(1)
                correct_batch = preds.eq(labels).sum().item()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * imgs.size(0)
        correct += int(correct_batch)

    return total_loss / len(loader.dataset), correct / len(loader.dataset)


@torch.no_grad()
def eval_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    device: torch.device
) -> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray]:
    model.eval()
    total_loss, correct = 0.0, 0
    all_labels: List[int] = []
    all_preds:  List[int] = []
    all_probs:  List[np.ndarray] = []

    for imgs, labels in tqdm(loader, desc="Eval", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = F.softmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(1)
        correct += preds.eq(labels).sum().item()

        all_labels.extend(labels.cpu().numpy().tolist())
        all_preds.extend(preds.cpu().numpy().tolist())
        all_probs.extend(probs.cpu().numpy())

    return (total_loss / len(loader.dataset),
            correct / len(loader.dataset),
            np.array(all_labels),
            np.array(all_preds),
            np.array(all_probs))


# ----------------------------
# Early stopping
# ----------------------------
class EarlyStopping:
    def __init__(self, patience: int = 7, verbose: bool = False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_acc = None
        self.early_stop = False

    def __call__(self, val_acc: float, model: nn.Module, path: str):
        if self.best_acc is None or val_acc > self.best_acc:
            self.best_acc = val_acc
            self.counter = 0
            torch.save(model.state_dict(), path)
            if self.verbose:
                print("Validation accuracy improved, saving model.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True


# ----------------------------
# Reliability diagram
# ----------------------------
def plot_reliability(y_true: np.ndarray, y_prob: np.ndarray, class_names: List[str], n_bins: int = 10):
    from sklearn.calibration import calibration_curve
    plt.figure(figsize=(5, 5))
    for i, name in enumerate(class_names):
        prob_true, prob_pred = calibration_curve((y_true == i).astype(int), y_prob[:, i],
                                                 n_bins=n_bins, strategy='uniform')
        plt.plot(prob_pred, prob_true, marker='o', label=f"{name}")
    plt.plot([0, 1], [0, 1], '--')
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Fraction of Positives")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.tight_layout()
    plt.show()


# ----------------------------
# Main
# ----------------------------
def main(
    data_root: str = "/kaggle/input/minida/mini_output1",
    pretrained_path: str = "/kaggle/working/simsiam_vanilla_resnet/simsiam_pretrained.pth",
    epochs: int = 50,
    batch_size: int = 32,
    num_workers: int = 2,
    mixup_alpha: float = 0.4,
    patience: int = 7
):
    seed_all(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_dir = os.path.join(data_root, "train")
    val_dir   = os.path.join(data_root, "val")
    test_dir  = os.path.join(data_root, "test")
    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Pretrained weights not found at {pretrained_path}")

    train_loader, val_loader, test_loader, class_names = get_loaders(
        train_dir, val_dir, test_dir, batch_size=batch_size, num_workers=num_workers
    )

    # Build model
    model = FineTuneModel(pretrained_path, num_classes=len(class_names)).to(device)

    # Ensure BN uses batch stats during supervised finetune
    for m in model.backbone.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.train()
            m.requires_grad_(True)

    # Parameter groups: smaller LR for backbone
    backbone_params, classifier_params = [], []
    for n, p in model.named_parameters():
        (classifier_params if "classifier" in n else backbone_params).append(p)

    optimizer = torch.optim.AdamW(
        [
            {"params": backbone_params, "lr": 3e-5},
            {"params": classifier_params, "lr": 1e-4},
        ],
        weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.05)

    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
    early_stopper = EarlyStopping(patience=patience, verbose=True)
    best_model_path = "best_model.pth"

    # Train loop
    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, scaler,
            use_mixup=True, mixup_alpha=mixup_alpha
        )
        val_loss, val_acc, _, _, _ = eval_epoch(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

        early_stopper(val_acc, model, best_model_path)
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

    # Load best model
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    model.to(device).eval()

    # Final test eval
    print("\nTest set results:")
    test_loss, test_acc, test_labels, test_preds, test_probs = eval_epoch(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    print(classification_report(test_labels, test_preds, target_names=class_names))
    print("Confusion Matrix:\n", confusion_matrix(test_labels, test_preds))

    # ROC-AUC (macro)
    try:
        test_labels_onehot = np.eye(len(class_names))[test_labels]
        roc_macro = roc_auc_score(test_labels_onehot, test_probs, average='macro', multi_class='ovr')
        print(f"Test ROC-AUC (macro): {roc_macro:.4f}")
    except Exception as e:
        print(f"ROC-AUC calculation failed: {e}")

    # Reliability diagram
    plot_reliability(test_labels, test_probs, class_names, n_bins=10)

    # TTA evaluation
    print("\nTTA Evaluation:")
    tta_transforms = [
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(1.0),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomRotation(15),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ColorJitter(0.3, 0.3, 0.3),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(280), transforms.CenterCrop(224),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.GaussianBlur(3),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],
                                                                        [0.229, 0.224, 0.225])]),
    ]

    # Build a raw dataset we can retarget transforms onto
    tta_base_ds = ImageFolder(os.path.join(data_root, "test"))  # no transform here
    y_true = np.array(tta_base_ds.targets)
    y_agg = []

    with torch.no_grad():
        for t in tta_transforms:
            tta_base_ds.transform = t
            loader = DataLoader(tta_base_ds, batch_size=batch_size, shuffle=False,
                                num_workers=num_workers, pin_memory=torch.cuda.is_available())
            probs_chunks = []
            for imgs, _ in loader:
                imgs = imgs.to(device)
                logits = model(imgs)
                probs_chunks.append(F.softmax(logits, dim=1).cpu().numpy())
            y_agg.append(np.concatenate(probs_chunks, axis=0))

    final_probs = np.mean(y_agg, axis=0)
    final_preds = final_probs.argmax(axis=1)
    print(classification_report(y_true, final_preds, target_names=class_names))
    print("Confusion Matrix:\n", confusion_matrix(y_true, final_preds))


if __name__ == "__main__":
    main()

Overwriting finetune_simsiam.py


In [20]:
!python finetune_simsiam.py

Using device: cuda
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

Epoch 1/50
  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
Train Loss: 1.1041, Acc: 0.4038 | Val Loss: 1.1329, Acc: 0.3131                 
Validation accuracy improved, saving model.

Epoch 2/50
Train Loss: 1.0602, Acc: 0.4207 | Val Loss: 1.1042, Acc: 0.3131                 
EarlyStopping counter: 1 / 7

Epoch 3/50
Train Loss: 1.0708, Acc: 0.4017 | Val Loss: 1.0249, Acc: 0.3434                 
Validation accuracy improved, saving model.

Epoch 4/50
Train Loss: 1.0383, Acc: 0.4292 | Val Loss: 0.8619, Acc: 0.6566                 
Validation accuracy improved, saving model.

Epoch 5/50
Train Loss: 1.0468, Acc: 0.4482 | Val Loss: 0.9029, Acc: 0.6162                 
EarlyStopping counter: 1 / 7

Epoch 6/50
Train Loss: 1.0258, Acc: 0.4545 | Val Loss: 0.8405, Acc: 0.6667                 
Validation accuracy improved, saving model.

Epoch 7/50
Train Loss: 1.0044, Acc: 0.5159 | Val Loss

Hybrid loss: Supervised Contrastive + CrossEntropy

In [21]:
%%writefile finetune_hybrid.py
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
from tqdm import tqdm

import torchvision.models as models
from torchvision import transforms
from torchvision.transforms import RandAugment
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score


# ----------------------------
# Reproducibility
# ----------------------------
def seed_all(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ----------------------------
# Mixup helpers
# ----------------------------
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.4, device=None):
    if alpha and alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=device)
    mixed_x = lam * x + (1.0 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam: float):
    return lam * criterion(pred, y_a) + (1.0 - lam) * criterion(pred, y_b)


# ----------------------------
# Dataloaders
# ----------------------------
def get_loaders(
    train_dir: str,
    val_dir: str,
    test_dir: str,
    batch_size: int = 32,
    num_workers: int = 2,
):
    pin = torch.cuda.is_available()

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    eval_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    train_ds = ImageFolder(train_dir, transform=train_transform)
    val_ds   = ImageFolder(val_dir,   transform=eval_transform)
    test_ds  = ImageFolder(test_dir,  transform=eval_transform)

    # Derive class names from data
    class_names = train_ds.classes

    # Balanced sampling
    targets_np = np.array(train_ds.targets, dtype=np.int64)
    classes = np.unique(targets_np)
    class_sample_count = np.array([(targets_np == t).sum() for t in classes], dtype=np.float64)
    class_sample_count[class_sample_count == 0] = 1.0
    weights = 1.0 / class_sample_count
    samples_weights = weights[targets_np]
    sampler = WeightedRandomSampler(
        weights=torch.as_tensor(samples_weights, dtype=torch.double),
        num_samples=len(samples_weights),
        replacement=True
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=pin
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    return train_loader, val_loader, test_loader, class_names, test_ds


# ----------------------------
# Supervised Contrastive Loss (corrected)
# ----------------------------
class SupConLoss(nn.Module):
    def __init__(self, temperature: float = 0.07, eps: float = 1e-8):
        super().__init__()
        self.temperature = temperature
        self.eps = eps

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        features: (B, D) float
        labels:   (B,)   int
        """
        device = features.device
        # Compute in float32 for stability (even under autocast)
        f = F.normalize(features, dim=1).float()
        B = f.size(0)
        labels = labels.contiguous().view(-1, 1)  # (B,1)
        mask = torch.eq(labels, labels.T).float().to(device)  # (B,B)
        mask.fill_diagonal_(0)  # remove self-positives

        logits = (f @ f.T) / self.temperature  # (B,B)
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        # Exclude self from denominator
        exp_logits = torch.exp(logits) * (1 - torch.eye(B, device=device))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + self.eps)

        pos_count = mask.sum(1)  # (B,)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (pos_count + self.eps)
        loss = -mean_log_prob_pos.mean()
        return loss


# ----------------------------
# Model definition
# ----------------------------
class FineTuneModel(nn.Module):
    """
    ResNet-50 backbone (loaded from SimSiam 'backbone') +
    - feature_layer (D=2048 -> 128) for SupCon
    - classifier head for CE
    """
    def __init__(self, pretrained_path: str, num_classes: int = 3):
        super().__init__()
        resnet = models.resnet50(pretrained=False)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # up to avgpool
        self.feature_layer = nn.Linear(2048, 128)
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

        # Load SimSiam backbone weights (robust to minor key diffs)
        ckpt = torch.load(pretrained_path, map_location="cpu")
        sd = ckpt.get("backbone", ckpt)
        missing, unexpected = self.backbone.load_state_dict(sd, strict=False)
        if missing or unexpected:
            print(f"[state_dict notice] missing: {missing} | unexpected: {unexpected}")

    def forward(self, x: torch.Tensor, return_features: bool = False):
        feats = self.backbone(x).flatten(1)           # (B, 2048)
        logits = self.classifier(feats)               # (B, C)
        proj = F.normalize(self.feature_layer(feats), dim=1)  # (B, 128)
        return (logits, proj) if return_features else logits


# ----------------------------
# Train / Eval
# ----------------------------
def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    ce_loss_fn,
    supcon_loss_fn: SupConLoss,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    scaler: torch.cuda.amp.GradScaler,
    use_mixup: bool = True,
    mixup_alpha: float = 0.4,
    supcon_weight: float = 0.5
) -> Tuple[float, float, float, float]:
    """
    Hybrid loss:
      - SupCon on CLEAN images (features from a clean forward)
      - CE on MIXED images (logits from a second forward)
    """
    model.train()
    total_loss = total_ce = total_sup = 0.0
    correct = 0

    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs = imgs.to(device); labels = labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            # Clean forward for SupCon features
            logits_clean, feats_clean = model(imgs, return_features=True)
            loss_sup = supcon_loss_fn(feats_clean, labels)

            # CE branch: mixup (optional)
            if use_mixup and mixup_alpha > 0.0:
                mixed, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha, device=device)
                logits_mixed = model(mixed)
                loss_ce = mixup_criterion(ce_loss_fn, logits_mixed, y_a, y_b, lam)
                preds = logits_mixed.argmax(1)
                correct += (lam * preds.eq(y_a).sum().item()
                            + (1.0 - lam) * preds.eq(y_b).sum().item())
            else:
                loss_ce = ce_loss_fn(logits_clean, labels)
                preds = logits_clean.argmax(1)
                correct += preds.eq(labels).sum().item()

            loss = (1.0 - supcon_weight) * loss_ce + supcon_weight * loss_sup

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bs = imgs.size(0)
        total_loss += float(loss.item()) * bs
        total_ce   += float(loss_ce.item()) * bs
        total_sup  += float(loss_sup.item()) * bs

    n = len(loader.dataset)
    return total_loss / n, correct / n, total_ce / n, total_sup / n


@torch.no_grad()
def eval_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    device: torch.device
) -> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray]:
    model.eval()
    total_loss, correct = 0.0, 0
    all_labels: List[int] = []
    all_preds:  List[int] = []
    all_probs:  List[np.ndarray] = []

    for imgs, labels in tqdm(loader, desc="Eval", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = F.softmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        total_loss += float(loss.item()) * imgs.size(0)
        preds = outputs.argmax(1)
        correct += preds.eq(labels).sum().item()

        all_labels.extend(labels.cpu().numpy().tolist())
        all_preds.extend(preds.cpu().numpy().tolist())
        all_probs.extend(probs.cpu().numpy())

    return (total_loss / len(loader.dataset),
            correct / len(loader.dataset),
            np.array(all_labels),
            np.array(all_preds),
            np.array(all_probs))


# ----------------------------
# Early stopping
# ----------------------------
class EarlyStopping:
    def __init__(self, patience: int = 7, verbose: bool = True):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_acc = None
        self.early_stop = False

    def __call__(self, val_acc: float, model: nn.Module, path: str):
        if self.best_acc is None or val_acc > self.best_acc:
            self.best_acc = val_acc
            self.counter = 0
            torch.save(model.state_dict(), path)
            if self.verbose:
                print("Validation accuracy improved, saving model.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True


# ----------------------------
# Grad-CAM (robust, optional)
# ----------------------------
def save_gradcams(model: nn.Module, test_ds: ImageFolder, class_names: List[str], device: torch.device):
    try:
        from pytorch_grad_cam import GradCAM
        from pytorch_grad_cam.utils.image import show_cam_on_image
        from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    except Exception as e:
        print(f"[grad-cam] Skipping (package not available): {e}")
        return

    # Target the last conv inside layer4
    # backbone = [.., layer4, avgpool]; pick last Bottleneck in layer4
    try:
        last_block = model.backbone[-2][-1]  # layer4[-1]
        target_layer = getattr(last_block, "conv3", last_block)  # conv3 if present
    except Exception:
        target_layer = model.backbone[-2]  # fallback: whole layer4

    cam = GradCAM(model=model, target_layers=[target_layer])
    model.eval()

    from PIL import Image as PILImage
    # Build a raw (untransformed) dataset view to re-transform
    base_ds = ImageFolder(test_ds.root)

    # Use the same eval transform as loaders
    eval_transform = transforms.Compose([
        transforms.Resize(256), transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    for i, cls in enumerate(class_names):
        idxs = [j for j, t in enumerate(base_ds.targets) if t == i]
        if not idxs:
            print(f"[grad-cam] No samples for class {cls}")
            continue
        idx = random.choice(idxs)
        img, _ = base_ds[idx]
        img_tensor = eval_transform(PILImage.open(base_ds.samples[idx][0]).convert("RGB")).unsqueeze(0).to(device)
        grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(i)])[0]
        # De-normalize for visualization
        img_np = np.transpose(img_tensor[0].cpu().numpy(), (1, 2, 0))
        img_np = (img_np * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)
        cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

        import matplotlib.pyplot as plt
        plt.imshow(cam_image)
        plt.title(f"Grad-CAM: {cls} (hybrid loss)")
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f"gradcam_{cls}_hybrid_loss.png")
        plt.close()


# ----------------------------
# Main
# ----------------------------
def main(
    data_root: str = "/kaggle/input/minida/mini_output1",
    pretrained_path: str = "/kaggle/working/simsiam_vanilla_resnet/simsiam_pretrained.pth",
    epochs: int = 50,
    batch_size: int = 32,
    num_workers: int = 2,
    mixup_alpha: float = 0.4,
    supcon_weight: float = 0.5
):
    seed_all(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_dir = os.path.join(data_root, "train")
    val_dir   = os.path.join(data_root, "val")
    test_dir  = os.path.join(data_root, "test")
    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Pretrained weights not found at {pretrained_path}")

    train_loader, val_loader, test_loader, class_names, test_ds = get_loaders(
        train_dir, val_dir, test_dir, batch_size=batch_size, num_workers=num_workers
    )

    # Build model
    model = FineTuneModel(pretrained_path, num_classes=len(class_names)).to(device)

    # Ensure BN uses batch stats during supervised finetune
    for m in model.backbone.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.train()
            m.requires_grad_(True)

    # Parameter groups: smaller LR for backbone
    backbone_params, head_params = [], []
    for n, p in model.named_parameters():
        (head_params if ("classifier" in n or "feature_layer" in n) else backbone_params).append(p)

    optimizer = torch.optim.AdamW(
        [
            {"params": backbone_params, "lr": 3e-5},
            {"params": head_params,     "lr": 1e-4},
        ],
        weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    ce_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
    supcon_loss_fn = SupConLoss(temperature=0.07)

    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
    early_stopper = EarlyStopping(patience=7, verbose=True)
    best_model_path = "best_model_hybrid_loss.pth"

    # Train loop
    train_losses, train_accs, train_ce_losses, train_sup_losses = [], [], [], []
    val_losses, val_accs = [], []

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        tr_loss, tr_acc, tr_ce, tr_sup = train_epoch(
            model, train_loader, ce_loss_fn, supcon_loss_fn, optimizer, device, scaler,
            use_mixup=True, mixup_alpha=mixup_alpha, supcon_weight=supcon_weight
        )
        va_loss, va_acc, _, _, _ = eval_epoch(model, val_loader, ce_loss_fn, device)
        scheduler.step()

        print(f"Train Loss: {tr_loss:.4f}, Acc: {tr_acc:.4f} | "
              f"Val Loss: {va_loss:.4f}, Acc: {va_acc:.4f}")

        train_losses.append(tr_loss);  train_accs.append(tr_acc)
        train_ce_losses.append(tr_ce); train_sup_losses.append(tr_sup)
        val_losses.append(va_loss);    val_accs.append(va_acc)

        early_stopper(va_acc, model, best_model_path)
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

    # Evaluate best model
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    model.to(device).eval()

    print("\nTest set results:")
    te_loss, te_acc, te_labels, te_preds, te_probs = eval_epoch(model, test_loader, ce_loss_fn, device)
    print(f"Test Loss: {te_loss:.4f}, Test Acc: {te_acc:.4f}")
    print(classification_report(te_labels, te_preds, target_names=class_names))
    try:
        te_onehot = np.eye(len(class_names))[te_labels]
        roc_macro = roc_auc_score(te_onehot, te_probs, average='macro', multi_class='ovr')
        print(f"Test ROC-AUC (macro): {roc_macro:.4f}")
    except Exception as e:
        print(f"ROC-AUC calculation failed: {e}")

    # Optional Grad-CAM visualizations
    print("\nRunning Grad-CAMs (if package available)...")
    save_gradcams(model, test_ds, class_names, device)


if __name__ == "__main__":
    main()

Overwriting finetune_hybrid.py


In [22]:
!python finetune_hybrid.py

Using device: cuda
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

Epoch 1/50
  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
Train Loss: 2.2731, Acc: 0.4010 | Val Loss: 1.1937, Acc: 0.3131                 
Validation accuracy improved, saving model.

Epoch 2/50
Train Loss: 2.2401, Acc: 0.4201 | Val Loss: 1.0096, Acc: 0.5758                 
Validation accuracy improved, saving model.

Epoch 3/50
Train Loss: 2.2215, Acc: 0.4805 | Val Loss: 0.9161, Acc: 0.5556                 
EarlyStopping counter: 1 / 7

Epoch 4/50
Train Loss: 2.2052, Acc: 0.4573 | Val Loss: 0.8967, Acc: 0.6667                 
Validation accuracy improved, saving model.

Epoch 5/50
Train Loss: 2.1989, Acc: 0.4714 | Val Loss: 0.8267, Acc: 0.6162                 
EarlyStopping counter: 1 / 7

Epoch 6/50
Train Loss: 2.2027, Acc: 0.4704 | Val Loss: 0.8486, Acc: 0.6263                 
EarlyStopping counter: 2 / 7

Epoch 7/50
Train Loss: 2.1861, Acc: 0.5162 | Val Loss: 0.8167, Acc: 