Swin Transformer

In [1]:
!pip install grad-cam

Collecting grad-cam
  Downloading grad-cam-1.5.5.tar.gz (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m76.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ttach (from grad-cam)
  Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7.1->grad-cam)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7.1->grad-cam)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.7.1->grad-cam)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collect

In [6]:
# === First, in a separate cell if needed ===
# !pip install pytorch-grad-cam timm

import os
import gc
import random
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score, roc_auc_score,
    accuracy_score, balanced_accuracy_score, roc_curve
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import timm
# Use new torch.amp API to avoid deprecation warnings
from torch.amp import autocast, GradScaler
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
import datetime
import logging
from sklearn.preprocessing import label_binarize
import cv2
from PIL import Image  # <-- needed for Grad-CAM step
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# Try to import Swin reshape; if unavailable, define a fallback
try:
    from pytorch_grad_cam.utils.reshape_transforms import swin_reshape_transform as _swin_reshape_transform
    _HAVE_SWN = True
except Exception:
    _HAVE_SWN = False

def _timm_swin_fallback_reshape(tensor):
    """
    Fallback reshape for timm Swin/SwinV2.
    Accepts [B, L, C] and returns [B, C, h, w] by inferring h*w=L.
    If tensor is already [B, C, H, W], returns as-is.
    Raises a clear error if L is not factorizable into a near-square grid.
    """
    if tensor.dim() == 4:
        return tensor
    B, L, C = tensor.shape
    h = int(round(L ** 0.5))
    w = L // h if h > 0 else 0
    if h * w != L:
        raise ValueError(
            f"Grad-CAM reshape fallback cannot form HxW from sequence length L={L}. "
            "Provide a proper reshape_transform for this backbone/IMG_SIZE."
        )
    return tensor.transpose(1, 2).reshape(B, C, h, w)

# Unified reshape_transform to use with GradCAM
reshape_transform = _swin_reshape_transform if _HAVE_SWN else _timm_swin_fallback_reshape

# ---------------------------- Configuration ----------------------------
class Config:
    DATA_PATHS = {
        "train": "/kaggle/input/minida/mini_output1/train",
        "val":   "/kaggle/input/minida/mini_output1/val",
        "test":  "/kaggle/input/minida/mini_output1/test"
    }
    CLASS_NAMES = sorted(os.listdir(DATA_PATHS["train"]))
    NUM_CLASSES = len(CLASS_NAMES)

    MODEL_NAME = "swinv2_small_window16_256"
    IMG_SIZE = 256

    DROP_RATE = 0.2
    DROP_PATH_RATE = 0.2

    USE_MIXUP = True
    MIXUP_ALPHA = 0.3
    CUTMIX_ALPHA = 1.0

    USE_TTA = True

    ACCUM_STEPS = 2
    TRAIN_BATCH_SIZE = 32 // ACCUM_STEPS
    VAL_BATCH_SIZE = 64

    EPOCHS = 40
    LR = 1e-4
    WEIGHT_DECAY = 0.05

    LABEL_SMOOTHING = 0.1  # will be set to 0.0 automatically if USE_MIXUP

    CONTRASTIVE_LOSS_WEIGHT = 0.3

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 2
    MIXED_PRECISION = True

    OUTPUT_DIR = "./output"
    MODEL_SAVE = f"./output/best_swinv2.pth"

    EARLY_STOP_PATIENCE = 7
    GRAD_CLIP = 1.0

    LOG_FILE = "training.log"

    # Determinism vs speed
    DETERMINISTIC = True  # set False for speed (enables cudnn.benchmark)

    def __init__(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)
        self._set_seed()
        self._set_timestamp()
        self._setup_logging()

    def _set_seed(self):
        torch.manual_seed(42)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(42)
        np.random.seed(42)
        random.seed(42)
        torch.backends.cudnn.deterministic = self.DETERMINISTIC
        torch.backends.cudnn.benchmark = not self.DETERMINISTIC
        os.environ["PYTHONHASHSEED"] = "42"

    def _set_timestamp(self):
        self.TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    def _setup_logging(self):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        # avoid duplicate handlers across notebook reruns
        if not self.logger.handlers:
            file_handler = logging.FileHandler(os.path.join(self.OUTPUT_DIR, self.LOG_FILE))
            file_handler.setFormatter(formatter)
            self.logger.addHandler(file_handler)
            console_handler = logging.StreamHandler()
            console_handler.setFormatter(formatter)
            self.logger.addHandler(console_handler)

cfg = Config()
VIS_BATCH_SIZE = min(16, cfg.VAL_BATCH_SIZE)

# If mixup is enabled, avoid double-smoothing
if cfg.USE_MIXUP:
    cfg.LABEL_SMOOTHING = 0.0

def config_to_serializable_dict(cfg):
    skip_types = (logging.Logger,)
    out = {}
    for k, v in cfg.__dict__.items():
        if k.startswith("_") or isinstance(v, skip_types) or callable(v):
            continue
        try:
            json.dumps(v)
            out[k] = v
        except Exception:
            continue
    return out

with open(os.path.join(cfg.OUTPUT_DIR, f'run_config_{cfg.TIMESTAMP}.json'), 'w') as f:
    json.dump(config_to_serializable_dict(cfg), f, indent=4)

# --------- Supervised Contrastive Loss ----------
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    def forward(self, features, labels):
        features = nn.functional.normalize(features, dim=1)
        similarity_matrix = torch.div(torch.matmul(features, features.T), self.temperature)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(labels.shape[0], device=features.device))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1).clamp(min=1)
        loss = -mean_log_prob_pos.mean()
        return loss

# --------- Augmentations ----------
class PathologyAugment:
    @staticmethod
    def get_train_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(cfg.IMG_SIZE, scale=(0.6, 1.0), ratio=(0.8, 1.2)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
            transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
            transforms.RandomApply([transforms.RandomRotation(15)], p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3))
        ])

    @staticmethod
    def get_test_transform():
        return transforms.Compose([
            transforms.Resize(int(cfg.IMG_SIZE * 1.15)),
            transforms.CenterCrop(cfg.IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    @staticmethod
    def get_tta_transforms():
        base = PathologyAugment.get_test_transform()
        hflip = transforms.Compose([
            transforms.Resize(int(cfg.IMG_SIZE * 1.15)),
            transforms.CenterCrop(cfg.IMG_SIZE),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        vflip = transforms.Compose([
            transforms.Resize(int(cfg.IMG_SIZE * 1.15)),
            transforms.CenterCrop(cfg.IMG_SIZE),
            transforms.RandomVerticalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        return [base, hflip, vflip]

# --------- Metrics Tracker ----------
class MetricsTracker:
    def __init__(self):
        self.reset()
    def reset(self):
        self.losses, self.preds, self.targets, self.probs = [], [], [], []
    def update(self, loss, outputs, targets):
        self.losses.append(loss)
        probs = outputs.float().softmax(1).detach().cpu()
        self.probs.append(probs)
        self.preds.append(probs.argmax(1))
        self.targets.append(targets.detach().cpu())
    def compute(self):
        probs = torch.cat(self.probs).numpy()
        preds = torch.cat(self.preds).numpy()
        targets = torch.cat(self.targets).numpy()
        probs = probs / (probs.sum(axis=1, keepdims=True) + 1e-10)
        try:
            auc_score = roc_auc_score(targets, probs, multi_class='ovo')
        except Exception:
            auc_score = float('nan')
        return {
            "loss": np.mean(self.losses),
            "accuracy": accuracy_score(targets, preds),
            "balanced_accuracy": balanced_accuracy_score(targets, preds),
            "f1_macro": f1_score(targets, preds, average='macro'),
            "auc": auc_score,
            "targets": targets,
            "preds": preds,
            "probs": probs
        }

# --------- Data helpers ----------

def create_weighted_sampler(dataset):
    indices = [label for _, label in dataset.samples]
    class_counts = np.bincount(indices, minlength=len(dataset.classes))
    class_counts[class_counts == 0] = 1  # avoid div-by-zero
    weights_per_class = 1.0 / class_counts
    sample_weights = np.array([weights_per_class[label] for label in indices], dtype=np.float64)
    sample_weights = torch.as_tensor(sample_weights, dtype=torch.double)
    return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# --------- Model ----------

def create_model():
    model = timm.create_model(
        cfg.MODEL_NAME, pretrained=True, num_classes=cfg.NUM_CLASSES,
        drop_rate=cfg.DROP_RATE, drop_path_rate=cfg.DROP_PATH_RATE, img_size=cfg.IMG_SIZE
    ).to(cfg.DEVICE)

    # Torch 2.x compile: skip on older GPUs (e.g., Tesla P100, CC 6.0) and fall back on eager safely
    if hasattr(torch, "compile"):
        try:
            if cfg.DEVICE.type == 'cuda':
                major, minor = torch.cuda.get_device_capability()
                if major < 7:
                    cfg.logger.info("Skipping torch.compile: CUDA capability < 7.0; falling back to eager.")
                    return model
            model = torch.compile(model)
        except Exception as e:
            cfg.logger.info(f"torch.compile failed; falling back to eager. Reason: {e}")
            # keep eager model
            pass
    return model

# --------- Eval / Plots ----------

def evaluate(model, loader, criterion=None):
    model.eval()
    tracker = MetricsTracker()
    with torch.inference_mode(), autocast(device_type='cuda' if cfg.DEVICE.type=='cuda' else 'cpu', enabled=cfg.MIXED_PRECISION):
        for inputs, targets in tqdm(loader, desc="Evaluating", leave=False):
            inputs, targets = inputs.to(cfg.DEVICE), targets.to(cfg.DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, targets).item() if criterion else 0
            tracker.update(loss, outputs, targets)
    return tracker.compute()

def save_confusion_matrix(metrics, phase):
    cm = confusion_matrix(metrics['targets'], metrics['preds'])
    cm_norm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-9)
    plt.figure(figsize=(8,7))
    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="viridis",
                xticklabels=cfg.CLASS_NAMES, yticklabels=cfg.CLASS_NAMES)
    plt.xlabel("Predicted", fontsize=14)
    plt.ylabel("True", fontsize=14)
    plt.title(f"{phase.capitalize()} Confusion Matrix (Normalized)", fontsize=16)
    plt.tight_layout()
    save_path = f"{cfg.OUTPUT_DIR}/{phase}_confusion_matrix_{cfg.TIMESTAMP}.png"
    plt.savefig(save_path)
    plt.close()
    # also save raw counts as CSV
    np.savetxt(f"{cfg.OUTPUT_DIR}/{phase}_confusion_matrix_raw_{cfg.TIMESTAMP}.csv", cm, fmt="%d", delimiter=",")


def save_classification_report(metrics, phase):
    report = classification_report(
        metrics['targets'], metrics['preds'], target_names=cfg.CLASS_NAMES, digits=4
    )
    save_path = f"{cfg.OUTPUT_DIR}/{phase}_report_{cfg.TIMESTAMP}.txt"
    with open(save_path, "w") as f:
        f.write(f"{phase} Classification Report:\n")
        f.write(report)
        f.write("\nBalanced Accuracy: {:.4f}\n".format(metrics['balanced_accuracy']))
        f.write("AUC: {:.4f}\n".format(metrics.get('auc', float('nan'))))
    print(f"\n{phase.capitalize()} Classification Report:\n{report}")


def plot_curves(history, timestamp, output_dir):
    plt.figure()
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/loss_curve_{timestamp}.png")
    plt.close()

    plt.figure()
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/acc_curve_{timestamp}.png")
    plt.close()


def plot_roc_auc(targets, probs, phase):
    targets_bin = label_binarize(targets, classes=list(range(cfg.NUM_CLASSES)))
    plt.figure(figsize=(8, 6))
    for i, class_name in enumerate(cfg.CLASS_NAMES):
        try:
            fpr, tpr, _ = roc_curve(targets_bin[:, i], probs[:, i])
            auc_ = roc_auc_score(targets_bin[:, i], probs[:, i])
            plt.plot(fpr, tpr, label=f"{class_name} (AUC={auc_:.2f})")
        except Exception as e:
            print(f"ROC Curve failed for class {class_name}: {e}")
    plt.plot([0, 1], [0, 1], "k--", lw=1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve - {phase.capitalize()}")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{cfg.OUTPUT_DIR}/{phase}_roc_{cfg.TIMESTAMP}.png")
    plt.close()


def tta_distribution_plot(model, dataset):
    tta_transforms = PathologyAugment.get_tta_transforms()
    loader = DataLoader(dataset, batch_size=VIS_BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS)
    fig, axs = plt.subplots(len(tta_transforms), 1, figsize=(8, 4*len(tta_transforms)))
    with torch.inference_mode():
        for idx, tta_tf in enumerate(tta_transforms):
            dataset.transform = tta_tf
            all_probs = []
            for inputs, _ in loader:
                inputs = inputs.to(cfg.DEVICE)
                outputs = model(inputs)
                probs = outputs.softmax(dim=1).detach().cpu().numpy()
                all_probs.append(probs)
            all_probs = np.concatenate(all_probs, axis=0)
            mean_probs = all_probs.mean(axis=0)
            axs[idx].bar(cfg.CLASS_NAMES, mean_probs)
            axs[idx].set_title(f"TTA Transform {idx+1}: Mean Class Probs")
            axs[idx].set_ylabel("Probability")
            axs[idx].set_ylim([0, 1])
    plt.tight_layout()
    plt.savefig(f"{cfg.OUTPUT_DIR}/tta_mean_probs_{cfg.TIMESTAMP}.png")
    plt.close()


def plot_tta_variance(model, dataset):
    tta_transforms = PathologyAugment.get_tta_transforms()
    loader = DataLoader(dataset, batch_size=VIS_BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS)
    all_probs = []
    with torch.inference_mode():
        for tta_tf in tta_transforms:
            dataset.transform = tta_tf
            probs_run = []
            for inputs, _ in loader:
                inputs = inputs.to(cfg.DEVICE)
                outputs = model(inputs)
                probs = outputs.softmax(dim=1).detach().cpu().numpy()
                probs_run.append(probs)
            all_probs.append(np.concatenate(probs_run, axis=0))
    all_probs = np.stack(all_probs, axis=0)
    max_class_idx = all_probs.mean(axis=0).argmax(axis=1)
    tta_var = []
    for i, idx in enumerate(max_class_idx):
        tta_var.append(np.var(all_probs[:, i, idx]))
    plt.figure(figsize=(8,4))
    plt.hist(tta_var, bins=30)
    plt.title('Variance in Predicted Probability (Most Confident Class) Across TTA')
    plt.xlabel('Variance')
    plt.ylabel('Num Samples')
    plt.tight_layout()
    plt.savefig(f"{cfg.OUTPUT_DIR}/tta_variance_hist_{cfg.TIMESTAMP}.png")
    plt.close()


def show_tta_flip_examples(model, dataset, num_examples=6):
    tta_transforms = PathologyAugment.get_tta_transforms()
    loader = DataLoader(dataset, batch_size=VIS_BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS)

    with torch.inference_mode():
        # Base
        dataset.transform = tta_transforms[0]
        base_probs = []
        for inputs, _ in loader:
            inputs = inputs.to(cfg.DEVICE)
            outputs = model(inputs)
            probs = outputs.softmax(dim=1).detach().cpu().numpy()
            base_probs.append(probs)
        base_probs = np.concatenate(base_probs, axis=0)
        base_preds = np.argmax(base_probs, axis=1)

        # All TTA runs
        all_probs = []
        for tta_tf in tta_transforms:
            dataset.transform = tta_tf
            tta_probs = []
            for inputs, _ in loader:
                inputs = inputs.to(cfg.DEVICE)
                outputs = model(inputs)
                probs = outputs.softmax(dim=1).detach().cpu().numpy()
                tta_probs.append(probs)
            all_probs.append(np.concatenate(tta_probs, axis=0))

    mean_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
    tta_preds = np.argmax(mean_probs, axis=1)
    changed = np.where(base_preds != tta_preds)[0]
    if len(changed) == 0:
        print("No TTA flip examples found!")
        return
    sample_idxs = np.random.choice(changed, size=min(num_examples, len(changed)), replace=False)
    plt.figure(figsize=(15, 3 * len(sample_idxs)))
    for i, idx in enumerate(sample_idxs):
        img_path, true_label = dataset.samples[idx]
        img = plt.imread(img_path)
        plt.subplot(len(sample_idxs), 1, i+1)
        plt.imshow(img)
        plt.title(f"True: {cfg.CLASS_NAMES[true_label]}, Base Pred: {cfg.CLASS_NAMES[base_preds[idx]]}, TTA Pred: {cfg.CLASS_NAMES[tta_preds[idx]]}")
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"{cfg.OUTPUT_DIR}/tta_flip_examples_{cfg.TIMESTAMP}.png")
    plt.close()


# --------- Grad-CAM (Swin-safe) ----------

def gradcam_for_all_classes(model, dataset, output_dir, timestamp):
    class2idx = {cls: [] for cls in range(cfg.NUM_CLASSES)}
    for idx, (_, label) in enumerate(dataset.samples):
        class2idx[label].append(idx)

    # choose robust target layer for timm SwinV2
    try:
        target_layers = [model.stages[-1].blocks[-1].norm2]
    except Exception:
        # fallback
        target_layers = [getattr(model, 'norm', None) or list(model.modules())[-1]]

    model.eval()
    # Grad-CAM needs gradients; enable grad context (do NOT use inference_mode here)
    with torch.enable_grad():
        for cls in range(cfg.NUM_CLASSES):
            if len(class2idx[cls]) == 0:
                continue
            idx = random.choice(class2idx[cls])
            img_path, _ = dataset.samples[idx]
            img = cv2.imread(img_path)
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_disp = cv2.resize(img_rgb, (cfg.IMG_SIZE, cfg.IMG_SIZE)) / 255.0
            img_pil = Image.fromarray(img_rgb)

            transform = transforms.Compose([
                transforms.Resize(int(cfg.IMG_SIZE*1.15)),
                transforms.CenterCrop(cfg.IMG_SIZE),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
            img_tensor = transform(img_pil).unsqueeze(0).to(cfg.DEVICE)
            img_tensor.requires_grad_(True)
            model.zero_grad(set_to_none=True)

            with autocast(device_type='cuda' if cfg.DEVICE.type=='cuda' else 'cpu', enabled=False):
                cam = GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform)
                grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(cls)])
            del cam
            cam_image = show_cam_on_image(img_disp.astype(np.float32), grayscale_cam[0], use_rgb=True)
            plt.imshow(cam_image)
            plt.title(f'Grad-CAM for class: {cfg.CLASS_NAMES[cls]}')
            plt.axis('off')
            plt.tight_layout()
            fname = f"{output_dir}/gradcam_{cfg.CLASS_NAMES[cls]}_{timestamp}.png"
            plt.savefig(fname, bbox_inches='tight', pad_inches=0)
            plt.close()
            print(f"Saved Grad-CAM for class {cfg.CLASS_NAMES[cls]}: {fname}")


# --------- Train/Eval loops ----------

def train_epoch(model, loader, optimizer, criterion, scaler, mixup_fn, contrastive_loss, epoch, use_contrastive):
    model.train()
    tracker = MetricsTracker()
    optimizer.zero_grad(set_to_none=True)

    for step, (inputs, targets) in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}/{cfg.EPOCHS}", dynamic_ncols=True)):
        inputs = inputs.to(cfg.DEVICE)
        orig_targets = targets.to(cfg.DEVICE)

        if mixup_fn is not None:
            inputs, mixed_targets = mixup_fn(inputs, orig_targets)
        else:
            mixed_targets = orig_targets

        with autocast(device_type='cuda' if cfg.DEVICE.type=='cuda' else 'cpu', enabled=cfg.MIXED_PRECISION):
            outputs = model(inputs)
            ce_loss = criterion(outputs, mixed_targets)

            # features for contrastive (no pooling => pool if needed)
            if use_contrastive:
                features = model.forward_features(inputs)
                if isinstance(features, (tuple, list)):
                    features = features[0]
                if features.dim() > 2:
                    features = features.mean(dim=(2, 3))
                con_loss = contrastive_loss(features, orig_targets)
            else:
                con_loss = torch.tensor(0.0, device=cfg.DEVICE)

            loss = (1 - cfg.CONTRASTIVE_LOSS_WEIGHT) * ce_loss + cfg.CONTRASTIVE_LOSS_WEIGHT * con_loss
            # correct scaling for grad accumulation
            loss = loss / cfg.ACCUM_STEPS

        if scaler is not None:
            scaler.scale(loss).backward()
            if (step + 1) % cfg.ACCUM_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.GRAD_CLIP)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
        else:
            loss.backward()
            if (step + 1) % cfg.ACCUM_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.GRAD_CLIP)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

        # Track using ORIGINAL targets (hard labels)
        tracker.update((loss.item() * cfg.ACCUM_STEPS), outputs, orig_targets)

    return tracker.compute()


def tta_predict(model, dataset):
    tta_transforms = PathologyAugment.get_tta_transforms()
    loader = DataLoader(dataset, batch_size=VIS_BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS)
    all_probs = []
    model.eval()
    with torch.inference_mode():
        for tta_tf in tta_transforms:
            dataset.transform = tta_tf
            probs_run = []
            for inputs, _ in loader:
                inputs = inputs.to(cfg.DEVICE)
                outputs = model(inputs)
                probs = outputs.softmax(dim=1).detach().cpu().numpy()
                probs_run.append(probs)
            all_probs.append(np.concatenate(probs_run, axis=0))
    avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
    targets = np.array([label for _, label in dataset.samples])
    preds = np.argmax(avg_probs, axis=1)
    return avg_probs, preds, targets


# === OOM-SAFE MAIN ===

def main():
    cfg.logger.info(f"Starting Swin Transformer training at {cfg.TIMESTAMP}")

    # Datasets
    train_ds = datasets.ImageFolder(cfg.DATA_PATHS['train'], PathologyAugment.get_train_transform())
    val_ds   = datasets.ImageFolder(cfg.DATA_PATHS['val'],   PathologyAugment.get_test_transform())
    test_ds  = datasets.ImageFolder(cfg.DATA_PATHS['test'],  PathologyAugment.get_test_transform())

    # Log dataset sizes for reproducibility
    cfg.logger.info(f"Dataset sizes | train: {len(train_ds)} | val: {len(val_ds)} | test: {len(test_ds)}")

    # Sampler and Loaders
    sampler = create_weighted_sampler(train_ds)
    train_loader = DataLoader(train_ds, batch_size=cfg.TRAIN_BATCH_SIZE, sampler=sampler,
                              num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_loader   = DataLoader(val_ds, batch_size=cfg.VAL_BATCH_SIZE, shuffle=False,
                              num_workers=cfg.NUM_WORKERS, pin_memory=True)
    test_loader  = DataLoader(test_ds, batch_size=cfg.VAL_BATCH_SIZE, shuffle=False,
                              num_workers=cfg.NUM_WORKERS, pin_memory=True)

    # Model & Optimizer
    model = create_model()
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=cfg.WEIGHT_DECAY)

    # Losses
    if cfg.USE_MIXUP:
        criterion = SoftTargetCrossEntropy()
        mixup_fn = Mixup(
            mixup_alpha=cfg.MIXUP_ALPHA, cutmix_alpha=cfg.CUTMIX_ALPHA,
            label_smoothing=0.0, num_classes=cfg.NUM_CLASSES
        )
    else:
        criterion = nn.CrossEntropyLoss(label_smoothing=cfg.LABEL_SMOOTHING)
        mixup_fn = None

    val_criterion = nn.CrossEntropyLoss()  # hard labels for val/test
    contrastive_loss = SupConLoss(temperature=0.07)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.EPOCHS)
    scaler = GradScaler(device='cuda', enabled=(cfg.MIXED_PRECISION and cfg.DEVICE.type == 'cuda'))

    # Disable contrastive if mixup is on (consistency)
    use_contrastive = (cfg.CONTRASTIVE_LOSS_WEIGHT > 0) and (not cfg.USE_MIXUP)

    # History & logging helpers
    history = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]}
    best_auc = -np.inf
    best_epoch = -1
    metrics_csv = os.path.join(cfg.OUTPUT_DIR, f"metrics_{cfg.TIMESTAMP}.csv")
    with open(metrics_csv, 'w') as f:
        f.write("epoch,train_loss,train_acc,val_loss,val_acc,val_auc\n")

    # Train loop
    for epoch in range(cfg.EPOCHS):
        train_metrics = train_epoch(
            model, train_loader, optimizer, criterion, scaler, mixup_fn, contrastive_loss, epoch, use_contrastive
        )
        val_metrics = evaluate(model, val_loader, val_criterion)

        cfg.logger.info(
            f"Epoch {epoch+1} Train | Loss: {train_metrics['loss']:.4f} | Acc: {train_metrics['accuracy']:.4f} | AUC: {train_metrics['auc']:.4f}"
        )
        cfg.logger.info(
            f"Epoch {epoch+1} Val   | Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | AUC: {val_metrics['auc']:.4f}"
        )

        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['train_acc'].append(train_metrics['accuracy'])
        history['val_acc'].append(val_metrics['accuracy'])

        with open(metrics_csv, 'a') as f:
            f.write(f"{epoch+1},{train_metrics['loss']:.6f},{train_metrics['accuracy']:.6f},"
                    f"{val_metrics['loss']:.6f},{val_metrics['accuracy']:.6f},{val_metrics['auc']:.6f}\n")

        scheduler.step()

        if val_metrics['auc'] > best_auc:
            best_auc = val_metrics['auc']
            best_epoch = epoch + 1
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'best_auc': best_auc,
                'config': config_to_serializable_dict(cfg)
            }, cfg.MODEL_SAVE)
            # Also save a pure state_dict for safe loading with weights_only=True
            torch.save(model.state_dict(), cfg.MODEL_SAVE.replace('.pth', '_weights.pth'))
            cfg.logger.info(f"Saved best model with AUC: {best_auc:.4f} at epoch {best_epoch}")
            no_improve = 0
        else:
            no_improve = (no_improve + 1) if 'no_improve' in locals() else 1
            if no_improve >= cfg.EARLY_STOP_PATIENCE:
                cfg.logger.info(f"Early stopping at epoch {epoch+1}")
                break

    plot_curves(history, cfg.TIMESTAMP, cfg.OUTPUT_DIR)
    cfg.logger.info(f"Best Val AUC: {best_auc:.4f} (epoch {best_epoch}) | checkpoint: {cfg.MODEL_SAVE}")

    # Load best weights
    state = torch.load(cfg.MODEL_SAVE, map_location=cfg.DEVICE, weights_only=False)
    model.load_state_dict(state['model'])

    # ----- Test & TTA -----
    if cfg.USE_TTA:
        avg_probs, preds, targets = tta_predict(model, test_ds)
        f1 = f1_score(targets, preds, average='macro')
        try:
            auc_ = roc_auc_score(targets, avg_probs, multi_class='ovo')
        except Exception:
            auc_ = float('nan')
        print(f"\nTest (TTA) F1 Macro: {f1:.4f} | AUC: {auc_:.4f}")
        print(classification_report(targets, preds, target_names=cfg.CLASS_NAMES))
        save_confusion_matrix({'targets': targets, 'preds': preds}, "test_tta")
        save_classification_report({'targets': targets, 'preds': preds, 'balanced_accuracy': balanced_accuracy_score(targets, preds), 'auc': auc_}, "test_tta")
        plot_roc_auc(targets, avg_probs, phase="test_tta")

        subset_size = min(200, len(test_ds))
        subset_ds = torch.utils.data.Subset(test_ds, range(subset_size))
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()
        tta_distribution_plot(model, subset_ds)
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()
        plot_tta_variance(model, subset_ds)
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()
        show_tta_flip_examples(model, subset_ds)
    else:
        test_metrics = evaluate(model, test_loader, val_criterion)
        save_confusion_matrix(test_metrics, "test")
        save_classification_report(test_metrics, "test")
        plot_roc_auc(test_metrics['targets'], test_metrics['probs'], phase="test")
        print(f"\nFinal Test Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"F1 Macro: {test_metrics['f1_macro']:.4f}")
        print(f"AUC: {test_metrics['auc']:.4f}")

    # Cleanup & Grad-CAM visualizations
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    gradcam_for_all_classes(model, test_ds, cfg.OUTPUT_DIR, cfg.TIMESTAMP)

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()

2025-08-27 04:12:09,428 - INFO - Starting Swin Transformer training at 20250827_041209
2025-08-27 04:12:09,728 - INFO - Dataset sizes | train: 473 | val: 99 | test: 99
2025-08-27 04:12:10,805 - INFO - Skipping torch.compile: CUDA capability < 7.0; falling back to eager.
Epoch 1/40: 100%|██████████| 29/29 [00:40<00:00,  1.39s/it]
2025-08-27 04:12:59,714 - INFO - Epoch 1 Train | Loss: 0.6463 | Acc: 0.4957 | AUC: 0.6961
2025-08-27 04:12:59,714 - INFO - Epoch 1 Val   | Loss: 0.1719 | Acc: 1.0000 | AUC: 1.0000
2025-08-27 04:13:00,871 - INFO - Saved best model with AUC: 1.0000 at epoch 1
Epoch 2/40: 100%|██████████| 29/29 [00:41<00:00,  1.44s/it]
2025-08-27 04:13:51,435 - INFO - Epoch 2 Train | Loss: 0.3261 | Acc: 0.6573 | AUC: 0.7866
2025-08-27 04:13:51,436 - INFO - Epoch 2 Val   | Loss: 0.1508 | Acc: 0.9596 | AUC: 0.9973
Epoch 3/40: 100%|██████████| 29/29 [00:38<00:00,  1.33s/it]
2025-08-27 04:14:38,439 - INFO - Epoch 3 Train | Loss: 0.3140 | Acc: 0.6853 | AUC: 0.8274
2025-08-27 04:14:38,4


Test (TTA) F1 Macro: 0.9705 | AUC: 1.0000
              precision    recall  f1-score   support

  Alternaria       1.00      0.92      0.96        37
Healthy Leaf       0.91      1.00      0.95        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.97        99
   macro avg       0.97      0.97      0.97        99
weighted avg       0.97      0.97      0.97        99


Test_tta Classification Report:
              precision    recall  f1-score   support

  Alternaria     1.0000    0.9189    0.9577        37
Healthy Leaf     0.9118    1.0000    0.9538        31
  straw_mite     1.0000    1.0000    1.0000        31

    accuracy                         0.9697        99
   macro avg     0.9706    0.9730    0.9705        99
weighted avg     0.9724    0.9697    0.9698        99

No TTA flip examples found!
Saved Grad-CAM for class Alternaria: ./output/gradcam_Alternaria_20250827_041209.png
Saved Grad-CAM for class Healthy Leaf: ./output/g

In [7]:
import os, random, numpy as np, torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from timm import create_model
from sklearn.metrics import f1_score, roc_auc_score, classification_report, confusion_matrix, roc_curve
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm

# Modern AMP API
from torch.amp import autocast, GradScaler

# ----------- Reproducibility ----------
def seed_all(seed=42, deterministic=True):
    torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = not deterministic
    os.environ["PYTHONHASHSEED"] = str(seed)

# ----------- Custom Noise -------------
class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=0.05): self.mean = float(mean); self.std = float(std)
    def __call__(self, tensor): return tensor + torch.randn_like(tensor) * self.std + self.mean

class AddSaltPepperNoise(object):
    def __init__(self, salt_prob=0.01, pepper_prob=0.01):
        self.salt_prob = float(salt_prob); self.pepper_prob = float(pepper_prob)
    def __call__(self, tensor):
        c, h, w = tensor.shape
        mask = torch.rand((h, w), device=tensor.device)
        salt = (mask < self.salt_prob).float(); pepper = (mask > 1 - self.pepper_prob).float()
        for i in range(c):
            tensor[i] = tensor[i] * (1 - salt - pepper) + salt + 0.0 * pepper
        return tensor

# ----------- Mixup --------------------
def mixup_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha) if alpha and alpha > 0 else 1.0
    b = x.size(0)
    index = torch.randperm(b, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, float(lam)

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ----------- Supervised Contrastive Loss -------------
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07): super().__init__(); self.temperature = temperature
    def forward(self, features, labels):
        z = nn.functional.normalize(features, dim=1)
        sim = torch.matmul(z, z.T) / self.temperature
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(z.device)
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(labels.shape[0], device=z.device))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1).clamp(min=1)
        return -mean_log_prob_pos.mean()

# ----------- Model with Progressive Unfreezing --------
class FineTuneSWIN(nn.Module):
    def __init__(self, model_name="swinv2_small_window16_256", num_classes=3, pretrained=True):
        super().__init__()
        self.backbone = create_model(model_name, pretrained=pretrained, num_classes=0)
        in_features = self.backbone.num_features
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(512, num_classes)
        )
    def forward(self, x):
        feats = self.backbone(x)
        return self.classifier(feats)
    def forward_features(self, x):
        return self.backbone(x)

# ----------- Data Loader -------------
def get_loaders(batch_size=32, noise_cfg=None, img_size=256, data_root="/kaggle/input/minida/mini_output1"):
    train_dir, val_dir, test_dir = [os.path.join(data_root, x) for x in ["train", "val", "test"]]

    tfms = [
        transforms.RandomResizedCrop(img_size, scale=(0.6, 1.0), ratio=(0.8, 1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
    if noise_cfg:
        if noise_cfg.get('use_gaussian'): tfms.append(AddGaussianNoise(std=float(noise_cfg.get('gaussian_std', 0.05))))
        if noise_cfg.get('use_saltpepper'):
            tfms.append(AddSaltPepperNoise(salt_prob=float(noise_cfg.get('salt_prob', 0.01)), pepper_prob=float(noise_cfg.get('pepper_prob', 0.01))))
    tfms.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
    train_transform = transforms.Compose(tfms)

    val_transform = transforms.Compose([
        transforms.Resize(int(img_size * 1.15)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_ds = ImageFolder(train_dir, train_transform)
    val_ds   = ImageFolder(val_dir,   val_transform)
    test_ds  = ImageFolder(test_dir,  val_transform)

    counts = np.bincount(np.array(train_ds.targets), minlength=len(train_ds.classes))
    counts[counts == 0] = 1
    weights_per_class = 1.0 / counts
    weights = weights_per_class[np.array(train_ds.targets)]
    sampler = WeightedRandomSampler(torch.as_tensor(weights, dtype=torch.double), num_samples=len(train_ds), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=2, pin_memory=True, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader, test_loader, train_ds.classes, train_ds, val_ds, test_ds

# ----------- Temperature Scaling -----------
class _TempScaleModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.T = nn.Parameter(torch.ones(()))
    def forward(self, logits):
        return logits / self.T.clamp_min(1e-3)

def _gather_logits_labels(model, loader, device):
    model.eval(); all_logits=[]; all_labels=[]
    with torch.inference_mode(), autocast(device_type='cuda' if device.type=='cuda' else 'cpu', enabled=(device.type=='cuda')):
        for x, y in loader:
            x = x.to(device); y = y.to(device)
            logits = model(x)
            all_logits.append(logits.detach().float().cpu())
            all_labels.append(y.detach().cpu())
    return torch.cat(all_logits), torch.cat(all_labels)

def fit_temperature(model, val_loader, device):
    logits, labels = _gather_logits_labels(model, val_loader, device)
    Tmod = _TempScaleModule()
    Tmod.to(device)
    # LBFGS tends to work well for 1D; fall back to Adam if needed
    optimizer = torch.optim.LBFGS(Tmod.parameters(), lr=0.1, max_iter=50)
    nll = nn.CrossEntropyLoss()

    logits = logits.to(device)
    labels = labels.to(device)

    def closure():
        optimizer.zero_grad()
        loss = nll(Tmod(logits), labels)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception:
        opt = torch.optim.Adam(Tmod.parameters(), lr=0.01)
        for _ in range(200):
            opt.zero_grad(); loss = nll(Tmod(logits), labels); loss.backward(); opt.step()

    T_value = float(Tmod.T.detach().cpu().item())
    return T_value

# ----------- Training & Eval Loop (Hybrid Loss, Mixup, AMP) -------------
def train_epoch(model, loader, ce_criterion, con_criterion, optimizer, scaler, use_mixup, supcon_weight=0.3, device=torch.device("cpu")):
    use_contrastive = (supcon_weight > 0) and (not use_mixup)
    model.train(); total_loss=0.0; correct=0.0
    for imgs, labels in tqdm(loader, desc='Train', leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        if use_mixup:
            imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=0.2)
        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type='cuda' if device.type=='cuda' else 'cpu', enabled=(device.type=='cuda')):
            outputs = model(imgs)
            ce_loss = mixup_criterion(ce_criterion, outputs, y_a, y_b, lam) if use_mixup else ce_criterion(outputs, labels)
            if use_contrastive:
                feats = model.forward_features(imgs)
                if isinstance(feats, (tuple, list)): feats = feats[0]
                con_loss = con_criterion(feats, labels)
            else:
                con_loss = torch.tensor(0.0, device=device)
            loss = (1 - supcon_weight) * ce_loss + supcon_weight * con_loss
        if device.type == 'cuda':
            scaler.scale(loss).backward(); scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(1)
        if use_mixup:
            correct += (lam * preds.eq(y_a).sum().item() + (1 - lam) * preds.eq(y_b).sum().item())
        else:
            correct += preds.eq(labels).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

def eval_epoch(model, loader, ce_criterion, device=torch.device("cpu"), temperature: float | None = None):
    model.eval(); total_loss=0.0; correct=0.0; all_labels=[]; all_probs=[]
    with torch.inference_mode(), autocast(device_type='cuda' if device.type=='cuda' else 'cpu', enabled=(device.type=='cuda')):
        for imgs, labels in tqdm(loader, desc='Eval', leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            if temperature is not None:
                logits = logits / max(temperature, 1e-3)
            probs = torch.softmax(logits.float(), dim=1)
            loss = ce_criterion(logits, labels)
            total_loss += loss.item() * imgs.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
            all_labels.extend(labels.detach().cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())
    acc = correct / len(loader.dataset)
    return total_loss / len(loader.dataset), acc, np.array(all_labels), np.array(all_probs)

# ----------- Test-Time Augmentation (optional) -------------
def tta_predict(model, dataset, transforms_list, batch_size=32, device=torch.device("cpu"), temperature: float | None = None):
    model.eval(); all_probs = []
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    with torch.inference_mode(), autocast(device_type='cuda' if device.type=='cuda' else 'cpu', enabled=(device.type=='cuda')):
        for tfm in transforms_list:
            dataset.transform = tfm
            batch_probs = []
            for imgs, _ in loader:
                imgs = imgs.to(device)
                logits = model(imgs)
                if temperature is not None:
                    logits = logits / max(temperature, 1e-3)
                probs = torch.softmax(logits.float(), dim=1).cpu().numpy()
                batch_probs.append(probs)
            all_probs.append(np.concatenate(batch_probs, axis=0))
    avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
    return avg_probs

# ----------- ECE and Reliability Diagram ----------
def compute_ece(labels, probs, n_bins=10):
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    confidences = np.max(probs, axis=1); predictions = np.argmax(probs, axis=1); accuracies = predictions == labels
    ece = 0.0
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if np.any(mask):
            bin_acc = np.mean(accuracies[mask]); bin_conf = np.mean(confidences[mask])
            ece += np.abs(bin_acc - bin_conf) * np.sum(mask) / len(probs)
    return 100 * ece

def plot_reliability_diagram(labels, probs, phase, out_dir=".", n_bins=10, dpi=300):
    confidences = np.max(probs, axis=1); predictions = np.argmax(probs, axis=1); accuracies = predictions == labels
    bins = np.linspace(0.0, 1.0, n_bins + 1); bin_accs, bin_confs = [], []
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i + 1])
        if np.any(mask):
            bin_accs.append(np.mean(accuracies[mask])); bin_confs.append(np.mean(confidences[mask]))
        else:
            bin_accs.append(0.0); bin_confs.append((bins[i] + bins[i+1]) / 2)
    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], 'k--')
    plt.bar(bin_confs, bin_accs, width=0.08, align='center', alpha=0.7, edgecolor='black')
    plt.xlabel("Confidence"); plt.ylabel("Accuracy"); plt.title(f"Reliability Diagram - {phase}")
    plt.tight_layout(); save_path = os.path.join(out_dir, f"{phase}_reliability_diagram.png")
    plt.savefig(save_path, dpi=dpi); plt.close()

# ----------- Confusion, ROC -------------
def plot_confusion_matrix(labels, preds, phase, class_names, out_dir=".", dpi=300):
    cm = confusion_matrix(labels, preds)
    cm_norm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-9)
    plt.figure(figsize=(7, 6))
    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="viridis", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted"); plt.ylabel("True"); plt.title(f"{phase} Confusion Matrix (Normalized)")
    plt.tight_layout(); save_path = os.path.join(out_dir, f"{phase}_confusion_matrix.png")
    plt.savefig(save_path, dpi=dpi); plt.close()

def plot_roc_curve(labels, probs, phase, class_names, out_dir=".", dpi=300):
    labels_onehot = np.eye(len(class_names))[labels]
    plt.figure(figsize=(8, 6))
    for i, cls in enumerate(class_names):
        try:
            fpr, tpr, _ = roc_curve(labels_onehot[:, i], probs[:, i])
            auc_ = roc_auc_score(labels_onehot[:, i], probs[:, i])
            plt.plot(fpr, tpr, label=f"{cls} (AUC={auc_:.2f})")
        except Exception:
            continue
    plt.plot([0, 1], [0, 1], "k--"); plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve - {phase}"); plt.legend(); plt.tight_layout(); save_path = os.path.join(out_dir, f"{phase}_roc_curve.png")
    plt.savefig(save_path, dpi=dpi); plt.close()

# ----------- Early Stopping ---------------
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = int(patience); self.counter = 0; self.best_acc = None; self.best_state = None
    def __call__(self, val_acc, model):
        if (self.best_acc is None) or (val_acc > self.best_acc):
            self.best_acc = float(val_acc)
            self.best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            self.counter = 0; return False
        self.counter += 1; return self.counter >= self.patience

# ----------- Main Ablation Runner -------------
def run_ablation_experiments(ablation_configs, epochs=50, batch_size=32, patience=10, supcon_weight=0.3,
                             data_root="/kaggle/input/minida/mini_output1", img_size=256, use_tempscale=True):
    results = []
    os.makedirs("plots", exist_ok=True)
    seed_all()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    for cfg in ablation_configs:
        print(f"\n==== Running: {cfg.get('name','exp')} ====")
        train_loader, val_loader, test_loader, class_names, train_ds, val_ds, test_ds = \
            get_loaders(batch_size, noise_cfg=cfg, img_size=img_size, data_root=data_root)
        print(f"Dataset sizes | train: {len(train_ds)} | val: {len(val_ds)} | test: {len(test_ds)}")

        model = FineTuneSWIN(num_classes=len(class_names)).to(device)
        for p in model.backbone.parameters(): p.requires_grad = False
        for p in model.classifier.parameters(): p.requires_grad = True

        ce_criterion  = nn.CrossEntropyLoss(label_smoothing=0.05)
        con_criterion = SupConLoss(temperature=0.07)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
        scaler = GradScaler(device='cuda', enabled=(device.type=='cuda'))
        early_stopper = EarlyStopping(patience=patience)

        for epoch in range(epochs):
            if epoch == 5:
                for p in model.backbone.parameters(): p.require_grad = True if hasattr(p, 'require_grad') else None
                for p in model.backbone.parameters(): p.requires_grad = True

            train_loss, train_acc = train_epoch(
                model, train_loader, ce_criterion, con_criterion, optimizer, scaler,
                use_mixup=True, supcon_weight=supcon_weight, device=device
            )
            val_loss, val_acc, _, _ = eval_epoch(model, val_loader, ce_criterion, device=device)
            print(f"Epoch {epoch+1:02d}: Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f}")
            if early_stopper(val_acc, model):
                print(f"Early stopping at epoch {epoch+1}")
                break

        model.load_state_dict(early_stopper.best_state)

        # Temperature scaling on val (optional)
        T_value = None
        if use_tempscale:
            try:
                T_value = fit_temperature(model, val_loader, device)
                print(f"Fitted temperature: T = {T_value:.4f}")
            except Exception as e:
                print(f"Temperature scaling failed: {e}")
                T_value = None

        # Test
        test_loss, test_acc, labels, probs = eval_epoch(model, test_loader, ce_criterion, device=device, temperature=T_value)
        preds = np.argmax(probs, axis=1)
        try:
            labels_onehot = np.eye(len(class_names))[labels]
            roc_macro = roc_auc_score(labels_onehot, probs, average='macro', multi_class='ovr')
        except Exception:
            roc_macro = None
        f1_macro = f1_score(labels, preds, average='macro')
        ece = compute_ece(labels, probs)

        print(f"Test Accuracy: {test_acc:.4f}")
        print(f"F1 Macro: {f1_macro:.4f}")
        print(f"ROC-AUC (macro): {roc_macro:.4f}" if roc_macro is not None else "ROC-AUC N/A")
        print(f"ECE: {ece:.2f}%")
        print(classification_report(labels, preds, target_names=class_names))

        phase = cfg.get('name','exp')
        plot_confusion_matrix(labels, preds, phase=phase, class_names=class_names, out_dir="plots")
        plot_roc_curve(labels, probs, phase=phase, class_names=class_names, out_dir="plots")
        plot_reliability_diagram(labels, probs, phase=phase, out_dir="plots")

        row = dict(cfg)
        row.update({"test_acc": test_acc, "f1_macro": f1_macro, "roc_auc_macro": roc_macro, "ece": ece, "T": T_value})
        results.append(row)

        # --- NaN-safe printing/saving block ---
        df = pd.DataFrame(results)
        # Keep metrics numeric; convert config columns to object to avoid NaN float formatting warnings
        for col in ["use_gaussian", "use_saltpepper", "gaussian_std", "salt_prob", "pepper_prob", "T"]:
            if col in df.columns:
                df[col] = df[col].astype("object")
        with pd.option_context("display.float_format", "{:.6f}".format):
            print("\nIntermediate results:\n", df.fillna(""))
        df.to_csv("ablation_results.csv", index=False)
        print("Intermediate results saved to ablation_results.csv")

    print("\nAll ablation runs complete.")
    df = pd.DataFrame(results)
    for col in ["use_gaussian", "use_saltpepper", "gaussian_std", "salt_prob", "pepper_prob", "T"]:
        if col in df.columns:
            df[col] = df[col].astype("object")
    with pd.option_context("display.float_format", "{:.6f}".format):
        print(df.fillna(""))
    df.to_csv("ablation_results.csv", index=False)
    return results

# ----------- Example configs -------------
ablation_configs = [
    {"name": "NoNoise"},
    {"name": "GaussianOnly", "use_gaussian": True, "gaussian_std": 0.05},
    {"name": "SaltPepperOnly", "use_saltpepper": True, "salt_prob": 0.01, "pepper_prob": 0.01},
    {"name": "BothNoises", "use_gaussian": True, "gaussian_std": 0.05, "use_saltpepper": True, "salt_prob": 0.01, "pepper_prob": 0.01},
]

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _ = run_ablation_experiments(ablation_configs, epochs=50, batch_size=32, patience=10, supcon_weight=0.3)

Device: cuda

==== Running: NoNoise ====
Dataset sizes | train: 473 | val: 99 | test: 99


                                                      

Epoch 01: Train Acc 0.5660 | Val Acc 0.7778


                                                      

Epoch 02: Train Acc 0.8098 | Val Acc 0.8990


                                                      

Epoch 03: Train Acc 0.7879 | Val Acc 0.9394


                                                      

Epoch 04: Train Acc 0.8190 | Val Acc 0.9495


                                                      

Epoch 05: Train Acc 0.8151 | Val Acc 0.9394


                                                      

Epoch 06: Train Acc 0.7755 | Val Acc 0.9697


                                                      

Epoch 07: Train Acc 0.8630 | Val Acc 0.9798


                                                      

Epoch 08: Train Acc 0.8404 | Val Acc 0.9899


                                                      

Epoch 09: Train Acc 0.8777 | Val Acc 0.9798


                                                      

Epoch 10: Train Acc 0.8187 | Val Acc 0.9899


                                                      

Epoch 11: Train Acc 0.8741 | Val Acc 0.9697


                                                      

Epoch 12: Train Acc 0.8710 | Val Acc 0.9798


                                                      

Epoch 13: Train Acc 0.8945 | Val Acc 0.9697


                                                      

Epoch 14: Train Acc 0.8876 | Val Acc 0.9798


                                                      

Epoch 15: Train Acc 0.8677 | Val Acc 0.9798


                                                      

Epoch 16: Train Acc 0.9251 | Val Acc 0.9697


                                                      

Epoch 17: Train Acc 0.8602 | Val Acc 0.9697


                                                      

Epoch 18: Train Acc 0.9197 | Val Acc 0.9697
Early stopping at epoch 18


                                                   

Test Accuracy: 1.0000
F1 Macro: 1.0000
ROC-AUC (macro): 1.0000
ECE: 6.29%
              precision    recall  f1-score   support

  Alternaria       1.00      1.00      1.00        37
Healthy Leaf       1.00      1.00      1.00        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           1.00        99
   macro avg       1.00      1.00      1.00        99
weighted avg       1.00      1.00      1.00        99

Intermediate results saved to ablation_results.csv

==== Running: GaussianOnly ====
Dataset sizes | train: 473 | val: 99 | test: 99


                                                      

Epoch 01: Train Acc 0.5723 | Val Acc 0.7879


                                                      

Epoch 02: Train Acc 0.7400 | Val Acc 0.8485


                                                      

Epoch 03: Train Acc 0.7705 | Val Acc 0.8485


                                                      

Epoch 04: Train Acc 0.7840 | Val Acc 0.8687


                                                      

Epoch 05: Train Acc 0.8374 | Val Acc 0.8485


                                                      

Epoch 06: Train Acc 0.8745 | Val Acc 0.9495


                                                      

Epoch 07: Train Acc 0.8705 | Val Acc 0.9596


                                                      

Epoch 08: Train Acc 0.8707 | Val Acc 0.9394


                                                      

Epoch 09: Train Acc 0.8703 | Val Acc 0.9596


                                                      

Epoch 10: Train Acc 0.8378 | Val Acc 0.9697


                                                      

Epoch 11: Train Acc 0.8634 | Val Acc 0.9899


                                                      

Epoch 12: Train Acc 0.9036 | Val Acc 0.9495


                                                      

Epoch 13: Train Acc 0.8973 | Val Acc 0.9899


                                                      

Epoch 14: Train Acc 0.8840 | Val Acc 0.9899


                                                      

Epoch 15: Train Acc 0.8869 | Val Acc 0.9495


                                                      

Epoch 16: Train Acc 0.8753 | Val Acc 0.9596


                                                      

Epoch 17: Train Acc 0.8944 | Val Acc 0.9899


                                                      

Epoch 18: Train Acc 0.8944 | Val Acc 0.9798


                                                      

Epoch 19: Train Acc 0.8996 | Val Acc 0.9495


                                                      

Epoch 20: Train Acc 0.8721 | Val Acc 0.9495


                                                      

Epoch 21: Train Acc 0.8409 | Val Acc 0.9899
Early stopping at epoch 21


                                                   

Test Accuracy: 1.0000
F1 Macro: 1.0000
ROC-AUC (macro): 1.0000
ECE: 9.77%
              precision    recall  f1-score   support

  Alternaria       1.00      1.00      1.00        37
Healthy Leaf       1.00      1.00      1.00        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           1.00        99
   macro avg       1.00      1.00      1.00        99
weighted avg       1.00      1.00      1.00        99

Intermediate results saved to ablation_results.csv

==== Running: SaltPepperOnly ====
Dataset sizes | train: 473 | val: 99 | test: 99


                                                      

Epoch 01: Train Acc 0.5999 | Val Acc 0.8990


                                                      

Epoch 02: Train Acc 0.7921 | Val Acc 0.9192


                                                      

Epoch 03: Train Acc 0.7635 | Val Acc 0.8485


                                                      

Epoch 04: Train Acc 0.7956 | Val Acc 0.9293


                                                      

Epoch 05: Train Acc 0.8132 | Val Acc 0.9495


                                                      

Epoch 06: Train Acc 0.7281 | Val Acc 0.9091


                                                      

Epoch 07: Train Acc 0.8716 | Val Acc 0.9495


                                                      

Epoch 08: Train Acc 0.8808 | Val Acc 0.9697


                                                      

Epoch 09: Train Acc 0.8485 | Val Acc 0.9596


                                                      

Epoch 10: Train Acc 0.8868 | Val Acc 0.9798


                                                      

Epoch 11: Train Acc 0.8314 | Val Acc 0.9697


                                                      

Epoch 12: Train Acc 0.8180 | Val Acc 0.9596


                                                      

Epoch 13: Train Acc 0.8302 | Val Acc 0.9697


                                                      

Epoch 14: Train Acc 0.8861 | Val Acc 0.9798


                                                      

Epoch 15: Train Acc 0.8782 | Val Acc 0.9798


                                                      

Epoch 16: Train Acc 0.8901 | Val Acc 0.9798


                                                      

Epoch 17: Train Acc 0.8914 | Val Acc 0.9697


                                                      

Epoch 18: Train Acc 0.8756 | Val Acc 0.9697


                                                      

Epoch 19: Train Acc 0.9091 | Val Acc 0.9798


                                                      

Epoch 20: Train Acc 0.8815 | Val Acc 0.9798
Early stopping at epoch 20


                                                   

Test Accuracy: 0.9899
F1 Macro: 0.9901
ROC-AUC (macro): 1.0000
ECE: 4.45%
              precision    recall  f1-score   support

  Alternaria       1.00      0.97      0.99        37
Healthy Leaf       0.97      1.00      0.98        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.99        99
   macro avg       0.99      0.99      0.99        99
weighted avg       0.99      0.99      0.99        99

Intermediate results saved to ablation_results.csv

==== Running: BothNoises ====
Dataset sizes | train: 473 | val: 99 | test: 99


                                                      

Epoch 01: Train Acc 0.4968 | Val Acc 0.5960


                                                      

Epoch 02: Train Acc 0.7163 | Val Acc 0.6768


                                                      

Epoch 03: Train Acc 0.7636 | Val Acc 0.7475


                                                      

Epoch 04: Train Acc 0.7415 | Val Acc 0.8081


                                                      

Epoch 05: Train Acc 0.8149 | Val Acc 0.8384


                                                      

Epoch 06: Train Acc 0.8243 | Val Acc 0.9192


                                                      

Epoch 07: Train Acc 0.8070 | Val Acc 0.9697


                                                      

Epoch 08: Train Acc 0.8738 | Val Acc 0.9495


                                                      

Epoch 09: Train Acc 0.8362 | Val Acc 0.9192


                                                      

Epoch 10: Train Acc 0.9234 | Val Acc 0.9596


                                                      

Epoch 11: Train Acc 0.8202 | Val Acc 0.9596


                                                      

Epoch 12: Train Acc 0.8946 | Val Acc 0.9596


                                                      

Epoch 13: Train Acc 0.9090 | Val Acc 0.9495


                                                      

Epoch 14: Train Acc 0.8688 | Val Acc 0.9899


                                                      

Epoch 15: Train Acc 0.8992 | Val Acc 0.9596


                                                      

Epoch 16: Train Acc 0.8755 | Val Acc 0.9495


                                                      

Epoch 17: Train Acc 0.8657 | Val Acc 0.9091


                                                      

Epoch 18: Train Acc 0.8520 | Val Acc 0.9596


                                                      

Epoch 19: Train Acc 0.9096 | Val Acc 0.9697


                                                      

Epoch 20: Train Acc 0.8760 | Val Acc 0.9899


                                                      

Epoch 21: Train Acc 0.8925 | Val Acc 0.9596


                                                      

Epoch 22: Train Acc 0.8948 | Val Acc 0.9394


                                                      

Epoch 23: Train Acc 0.8693 | Val Acc 0.9596


                                                      

Epoch 24: Train Acc 0.9141 | Val Acc 0.9596
Early stopping at epoch 24


                                                   

Test Accuracy: 0.9798
F1 Macro: 0.9803
ROC-AUC (macro): 1.0000
ECE: 6.69%
              precision    recall  f1-score   support

  Alternaria       1.00      0.95      0.97        37
Healthy Leaf       0.94      1.00      0.97        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.98        99
   macro avg       0.98      0.98      0.98        99
weighted avg       0.98      0.98      0.98        99

Intermediate results saved to ablation_results.csv

All ablation runs complete.
             name  test_acc  f1_macro  roc_auc_macro       ece use_gaussian  \
0         NoNoise  1.000000  1.000000            1.0  6.290144          NaN   
1    GaussianOnly  1.000000  1.000000            1.0  9.767211         True   
2  SaltPepperOnly  0.989899  0.990143            1.0  4.445658          NaN   
3      BothNoises  0.979798  0.980324            1.0  6.691181         True   

   gaussian_std use_saltpepper  salt_prob  pepper_prob  
0           

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
