<a href="https://colab.research.google.com/github/EnesAgirman/ArUco_Markers/blob/main/Camelyon17_Wilds.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# ─── Cell 1: Install dependencies ─────────────────────────────────────────────
!pip install --quiet wilds torch torchvision


In [3]:
# ─── Cell 2: Imports & Dataset ────────────────────────────────────────────────
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from torchvision import transforms
from torchvision import models
from torch.optim import SGD, Adam, RMSprop
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2048

import torch.multiprocessing as mp
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass  # already set

# Download (or reuse) into WILDS’s cache automatically
dataset = get_dataset(
    dataset="camelyon17",
    download=True
)

# ✅ Add standard normalization (ImageNet stats) to help training stability
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
])

# Create train/val/test splits
train_ds = dataset.get_subset("train", transform=transform)
val_ds   = dataset.get_subset("val",   transform=transform)
test_ds  = dataset.get_subset("test",  transform=transform)

# Wrap in standard loaders (shuffled for baseline, used by other trainers)
train_loader = get_train_loader("standard", train_ds, batch_size=batch_size)
val_loader   = get_eval_loader ("standard", val_ds,   batch_size=batch_size)
test_loader  = get_eval_loader ("standard", test_ds,  batch_size=batch_size)

# empirical uniform distribution over the N samples (RS uses its own P_hat)
N = len(train_ds)
P_hat = torch.full((N,), 1.0/N, device=device)

class IndexedDataset(Dataset):
    """Returns (x, y, idx) instead of (x, y, meta)."""
    def __init__(self, subset):
        self.subset = subset
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        x, y, _ = self.subset[idx]
        return x, y, idx

# RS needs a loader *without* shuffling so idx→position stays consistent
rs_loader = DataLoader(
    IndexedDataset(train_ds),
    batch_size=512,                  # try 1024; 512 if CPU is slow, 2048 if CPU is strong
    shuffle=False,
    num_workers=min(12, os.cpu_count() or 4),
    pin_memory=True,
    persistent_workers=True,          # keeps workers alive between epochs
    prefetch_factor=4,                # prefetches per worker
)


In [4]:
# ─── Cell 3 ────────────────────────────────────────────────

# Simplified model
"""
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet18(weights=None, num_classes=2)
    def forward(self, x):
        return self.backbone(x)
"""

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)  # 96x96 → 96x96
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)                                 # 96x96 → 48x48
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)                          # 48x48 → 48x48
        # After pooling again: 48x48 → 24x24
        self.fc1 = nn.Linear(32 * 24 * 24, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # 96x96 → 48x48
        x = self.pool(F.relu(self.conv2(x)))   # 48x48 → 24x24
        x = x.view(x.size(0), -1)              # Flatten (B, 32*24*24)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Projection onto simplex
def proj_simplex(v):
    if v.dim() != 1:
        raise ValueError("proj_simplex expects a 1-D tensor")
    # From: https://arxiv.org/pdf/1101.6081.pdf
    u, _ = torch.sort(v, descending=True)
    cssv = torch.cumsum(u, 0) - 1
    ind = torch.arange(1, v.numel() + 1, device=v.device, dtype=v.dtype)
    t = cssv / ind
    support = u > t
    if not support.any():
        return torch.full_like(v, 1.0 / v.numel())
    rho = torch.nonzero(support, as_tuple=False).max()
    theta = t[rho]
    w = torch.clamp(v - theta, min=0.0)
    return w


class KappaFunctionImproved(torch.autograd.Function):
    @staticmethod
    def forward(ctx, F, P_hat, tau, distance_scale, min_divergence,
                softplus_temp=2.0, neg_scale=0.1, h_max=0.5):
        # a = E_{P_hat}[F], b = ||F||^2
        a = torch.dot(F, P_hat)
        b = torch.clamp(torch.sum(F * F), min=1e-12)
        num = tau - a
        # ✅ clamp h to avoid extreme steps when b is tiny
        h = torch.clamp(num / b, min=-h_max, max=h_max)  # scalar

        w_vec = P_hat + h * F
        P_star = proj_simplex(w_vec)

        diff = P_star - P_hat
        D_base = 0.5 * torch.sum(diff * diff)
        D_eff = distance_scale * D_base + min_divergence

        E_f = torch.dot(P_star, F)
        raw_kappa = (E_f - tau) / D_eff

        # activation: positive via tempered softplus, negative shrunk
        pos = torch.nn.functional.softplus(raw_kappa / softplus_temp) * softplus_temp
        neg = raw_kappa * neg_scale
        kappa = torch.where(raw_kappa >= 0, pos, neg)

        ctx.save_for_backward(F, P_hat, P_star, raw_kappa,
                              torch.tensor(distance_scale, device=F.device, dtype=F.dtype),
                              torch.tensor(min_divergence, device=F.device, dtype=F.dtype),
                              torch.tensor(softplus_temp, device=F.device, dtype=F.dtype),
                              torch.tensor(neg_scale, device=F.device, dtype=F.dtype),
                              torch.tensor(h_max, device=F.device, dtype=F.dtype))
        ctx.tau = tau
        return kappa, raw_kappa  # return raw for diagnostics

    @staticmethod
    def backward(ctx, grad_out_kappa, grad_out_raw_unused=None):
        (F, P_hat, P_star, raw_kappa,
         distance_scale_t, min_div_t, temp_t, neg_scale_t, hmax_t) = ctx.saved_tensors
        distance_scale = float(distance_scale_t.item())
        min_divergence = float(min_div_t.item())
        softplus_temp  = float(temp_t.item())
        neg_scale      = float(neg_scale_t.item())
        h_max          = float(hmax_t.item())
        tau            = ctx.tau

        diff = P_star - P_hat
        D_base = 0.5 * torch.sum(diff * diff)
        D_eff = distance_scale * D_base + min_divergence

        E_f = torch.dot(P_star, F)
        a = torch.dot(F, P_hat)
        b = torch.clamp(torch.sum(F * F), min=1e-12)
        num = tau - a
        # ✅ keep consistent with forward
        h = torch.clamp(num / b, min=-h_max, max=h_max)

        grad_b = 2.0 * F

        # Projection Jacobian on active support (simplex proj, local linearization)
        support = P_star > 0
        def apply_J_proj(u):
            u_s = u[support]
            if u_s.numel() == 0:
                return torch.zeros_like(u)
            mean_s = u_s.mean()
            out = torch.zeros_like(u)
            out[support] = u_s - mean_s
            return out

        t = apply_J_proj(F)      # dP*/d(h*F) · F
        s = apply_J_proj(diff)   # dP*/d(P*) · (P* - P_hat)  (heuristic)

        Ft = torch.dot(F, t)
        Fs = torch.dot(F, s)

        # dh/dF via implicit differentiation of h = (tau - a)/b (with clamp ignored in grad)
        grad_h = (-b * P_hat - num * grad_b) / (b * b)

        A = P_star + h * t + grad_h * Ft
        B = h * s + grad_h * Fs

        numerator = A * D_eff - (E_f - tau) * B
        grad_raw_kappa = numerator / (D_eff * D_eff)

        sigmoid_part = torch.sigmoid(raw_kappa / softplus_temp)
        grad_activation = torch.where(raw_kappa >= 0, sigmoid_part,
                                      torch.full_like(sigmoid_part, neg_scale))
        grad_kappa_wrt_raw = grad_activation

        grad_raw = grad_out_kappa * grad_kappa_wrt_raw
        grad_F = grad_raw * grad_raw_kappa

        return grad_F, None, None, None, None, None, None, None, None




In [5]:
# ─── Cell 4: Train & Eval Helpers ─────────────────────────────────────────────
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total


import torch.nn.functional as F

from collections import Counter

def evaluate_with_diagnostics(model, loader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    preds = torch.cat(all_preds).tolist()
    labels = torch.cat(all_labels).tolist()
    acc = correct / total if total > 0 else 0.0
    print(f"[Eval] Accuracy: {acc:.4f}")
    print("Prediction distribution:", Counter(preds))
    print("True label distribution:", Counter(labels))
    cm = torch.zeros(2, 2, dtype=torch.int64)
    for p, t in zip(preds, labels):
        cm[t, p] += 1  # row=true, col=pred
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)
    return acc

from collections import Counter

def evaluate_with_diagnostics(model, loader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    preds = torch.cat(all_preds).tolist()
    labels = torch.cat(all_labels).tolist()
    acc = correct / total if total > 0 else 0.0
    print(f"[Eval] Accuracy: {acc:.4f}")
    print("Prediction distribution:", Counter(preds))
    print("True label distribution:", Counter(labels))
    cm = torch.zeros(2, 2, dtype=torch.int64)
    for p, t in zip(preds, labels):
        cm[t, p] += 1  # row=true, col=pred
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)
    return acc


# --- Comparison baselines (optional) ---
def train_standard(model, loader, epochs, lr):
    model.to(device)
    opt = SGD(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    losses = []
    for ep in range(epochs):
        model.train()
        total = 0.0
        for x, y, _ in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad()
            loss = ce(model(x), y)
            loss.backward()
            opt.step()
            total += loss.item() * x.size(0)
        avg = total / len(loader.dataset)
        losses.append(avg)
        print(f"[SGD] Epoch {ep+1}/{epochs} CE loss: {avg:.4f}", flush=True)
    return losses

def train_sam(model, loader, epochs, lr, rho=0.05):
    model.to(device)
    base_opt = SGD
    optimizer = base_opt(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []
    for ep in range(epochs):
        model.train()
        total = 0.0
        for x, y, _ in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            grad_norm = torch.norm(torch.stack([p.grad.norm() for p in model.parameters() if p.grad is not None]))
            scale = rho / (grad_norm + 1e-12)
            e_ws = []
            with torch.no_grad():
                for p in model.parameters():
                    if p.grad is None:
                        e_ws.append(None)
                        continue
                    e_w = p.grad * scale
                    p.add_(e_w)
                    e_ws.append(e_w)
            optimizer.zero_grad()
            logits2 = model(x)
            loss2 = criterion(logits2, y)
            loss2.backward()
            with torch.no_grad():
                for p, e_w in zip(model.parameters(), e_ws):
                    if e_w is not None:
                        p.sub_(e_w)
            optimizer.step()
            total += loss2.item() * x.size(0)
        avg = total / len(loader.dataset)
        losses.append(avg)
        print(f"[SAM] Epoch {ep+1}/{epochs} CE loss: {avg:.4f}", flush=True)
    return losses


In [6]:
from math import ceil
from tqdm import tqdm

def train_rs(
    model,
    rs_loader,
    epochs,
    lr=2e-5,
    initial_tau_multiplier=0.8,
    tau_target_multiplier=0.9,
    tau_smooth_lr=0.3,
    P_hat_momentum=0.3,
    damp_surrogate=0.1,
    ce_reg_weight=0.1,
    warmup_ce_weight=1.0,
    entropy_reg_weight=0.01,
    use_subset=False,
    subset_size=2048,
    normalize_w=False,
    debug=False,
    base_tau=0.1,
    distance_scale=1.0,
    softplus_temp=2.0,
    warmup_epochs=5,
    kappa_shrink_threshold=1e3,
    tau_smooth_lr_boost=2.0,
    neg_scale=0.1,
    start_at_base=False,
    fixed_tau=False,
    val_loader=None,
    val_interval=5,
    early_stop_patience=3,
    min_epochs_before_early_stop=5,
    negative_kappa_damping=0.1,
    shrink_factor_on_approx_violation=0.7,
    shrink_tau_on_zero_kappa=True,
    zero_kappa_decay_factor=0.95,
    # NEW stabilizers:
    h_max=0.5,
    min_div_from_var_scale=1e-2,
    min_div_floor=1e-4,
    stratify_seed=123,
    # NEW quality-of-life:
    F_eval_fraction=1.0,   # 1.0 = full pass; use 0.25 to debug faster
):

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    ce_element = nn.CrossEntropyLoss(reduction='none')

    # --- subset handling (unchanged) ---
    if use_subset:
        ...
        # (leave your existing subset code unchanged)
    else:
        active_loader = rs_loader
        N = len(rs_loader.dataset)

    # initialize P_hat
    P_hat = torch.full((N,), 1.0 / N, device=device)

    # ✅ show we’re about to compute F (this can take time on full data)
    est_batches = ceil((N * F_eval_fraction) / active_loader.batch_size)
    print(f"[RS] Computing F on {F_eval_fraction*100:.0f}% of training data "
          f"({int(N*F_eval_fraction)}/{N}) ~{est_batches} batches...", flush=True)

    def compute_F(model, loader, fraction=1.0, desc="F pass"):
        """Compute per-example CE losses for a fraction of the dataset, with progress."""
        model.eval()
        take_items = int(len(loader.dataset) * fraction)
        if take_items <= 0:
            take_items = len(loader.dataset)
        # progress bar
        F_cpu = torch.zeros(len(loader.dataset), device='cpu')
        seen = 0
        for x, y, idx in tqdm(loader, total=ceil(take_items/loader.batch_size), desc=desc, leave=False):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            with torch.no_grad():
                losses = ce_element(model(x), y).cpu()
            F_cpu[idx] = losses
            seen += x.size(0)
            if seen >= take_items:
                break
        return F_cpu.to(device)

    # initial values (now with progress)
    F_vals = compute_F(model, active_loader, fraction=F_eval_fraction, desc="F (init)").clone().requires_grad_(True)
    E_f = torch.dot(P_hat, F_vals).item()
    ef_ema = E_f
    if start_at_base:
        tau = base_tau
    else:
        tau = max(base_tau, initial_tau_multiplier * E_f)
    print(f"[RS init] N={N} E_f={E_f:.4f}, tau={tau:.4f}, "
          f"subset={use_subset} (F_eval_fraction={F_eval_fraction})", flush=True)

    kappa_hist, ce_hist, val_accs = [], [], []
    D_base_ema = None
    prev_kappa_val = None
    prev_tau = tau
    stalled_count = 0
    prev_kappa_zeroish_count = 0

    for ep in range(1, epochs + 1):
        epoch_start = torch.cuda.Event(enable_timing=True)
        epoch_end = torch.cuda.Event(enable_timing=True)
        epoch_start.record()

        tau_reasons = []
        if fixed_tau:
            tau_reasons.append("fixed_tau")

        # compute F and E_f
        F_vals = compute_F(model, active_loader).clone().requires_grad_(True)
        E_f = torch.dot(P_hat, F_vals).item()
        varF = float(F_vals.var(unbiased=False).item())

        # asymmetric EMA for E_f
        if E_f < ef_ema:
            beta_down = 0.7
            ef_ema = beta_down * ef_ema + (1 - beta_down) * E_f
        else:
            beta_up = 0.9
            ef_ema = beta_up * ef_ema + (1 - beta_up) * E_f

        # adapt τ if not fixed
        if not fixed_tau:
            effective_tau_smooth_lr = tau_smooth_lr
            if prev_kappa_val is not None and prev_kappa_val > kappa_shrink_threshold:
                effective_tau_smooth_lr = min(1.0, tau_smooth_lr * tau_smooth_lr_boost)
                tau_reasons.append("accelerated_due_to_large_kappa")
            if prev_kappa_zeroish_count >= 1:
                effective_tau_smooth_lr = effective_tau_smooth_lr * 0.1
                tau_reasons.append("damped_upward_due_to_zero_kappa")
            tau_target = tau_target_multiplier * ef_ema
            tau_pre = tau
            tau = tau + effective_tau_smooth_lr * (tau_target - tau)
            tau = min(tau, ef_ema * 0.99)
            if tau > tau_pre + 1e-8:
                tau_reasons.append("relaxed_upward")
            elif tau < tau_pre - 1e-8:
                tau_reasons.append("shrunk_toward_target")

        # --- approximate P_star block for provisional τ correction ---
        with torch.no_grad():
            a = torch.dot(F_vals, P_hat)
            b = torch.clamp(torch.sum(F_vals * F_vals), min=1e-12)
            h = torch.clamp((tau - a) / b, min=-h_max, max=h_max)  # ✅ clamp
            w_vec_approx = P_hat + h * F_vals
            P_star_approx = proj_simplex(w_vec_approx)
            diff_approx = P_star_approx - P_hat
            D_base = 0.5 * torch.sum(diff_approx * diff_approx).detach()

            # keep EMA just for logging (no longer drives min_div directly)
            if D_base_ema is None:
                D_base_ema = D_base
            else:
                D_base_ema = 0.9 * D_base_ema + 0.1 * D_base

            # ✅ min_divergence from variance of F
            min_divergence = max(min_div_floor, min_div_from_var_scale * varF)

            D_eff = distance_scale * D_base + min_divergence
            raw_kappa_approx = (torch.dot(P_star_approx, F_vals) - tau) / D_eff

            if raw_kappa_approx.item() < 0 and not fixed_tau:
                target_tau = torch.dot(P_star_approx, F_vals).item() - 1e-6
                desired_tau = max(base_tau, target_tau)
                tau_before = tau
                tau_candidate = min(tau * shrink_factor_on_approx_violation, desired_tau)
                tau = max(base_tau, tau_candidate)
                tau_reasons.append(f"approx_shrink(factor={shrink_factor_on_approx_violation})")
                if tau < tau_before - 1e-12:
                    tau_reasons.append("shrunk_due_to_approx_violation")

        # zero-kappa based mild decay of tau
        if prev_kappa_zeroish_count >= 2 and not fixed_tau and shrink_tau_on_zero_kappa:
            old_tau = tau
            tau = max(base_tau, tau * zero_kappa_decay_factor)
            if tau < old_tau:
                tau_reasons.append("zero_kappa_decay")

        # --- compute κ or warmup ---
        model.train()
        optimizer.zero_grad()
        use_rs = ep > warmup_epochs
        running_ce = 0.0
        raw_kappa_val = None
        kappa_scalar_val = 0.0
        w = torch.zeros_like(F_vals)
        negative_damped_flag = False
        zero_kappa_fallback_flag = False

        if not use_rs:
            # warmup CE with entropy regularization
            for x, y, idx in active_loader:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                logits = model(x)
                F_b = ce_element(logits, y)
                ce_loss = F_b.mean()
                probs = torch.nn.functional.softmax(logits, dim=1)
                entropy = - (probs * torch.log(probs + 1e-12)).sum(dim=1).mean()
                loss = warmup_ce_weight * ce_loss - entropy_reg_weight * entropy
                loss.backward()
                running_ce += ce_loss.item() * x.size(0)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            kappa_scalar_val = 0.0
        else:
            # RS path (✅ pass variance-based min_div and h_max)
            result = KappaFunctionImproved.apply(
                F_vals, P_hat, tau,
                distance_scale, max(min_div_floor, min_div_from_var_scale * varF),
                softplus_temp, neg_scale, h_max
            )
            if isinstance(result, (tuple, list)) and len(result) == 2:
                kappa, raw_kappa = result
                raw_kappa_scalar = raw_kappa if raw_kappa.dim() == 0 else raw_kappa.mean()
                raw_kappa_val = raw_kappa_scalar.item()
            else:
                kappa = result
                raw_kappa_val = None

            kappa_scalar = kappa if kappa.dim() == 0 else kappa.mean()
            kappa_scalar_val = kappa_scalar.item()

            if kappa_scalar_val < 0:
                negative_damped_flag = True
                if debug:
                    print(f"[RS] Ep {ep:2d} negative κ={kappa_scalar_val:.3e}; damping by {negative_kappa_damping}", flush=True)
                kappa_to_backprop = kappa_scalar * negative_kappa_damping
            else:
                kappa_to_backprop = kappa_scalar

            kappa_to_backprop.backward()
            w = F_vals.grad.detach()
            if normalize_w and w.abs().mean() > 0:
                w = w / (w.abs().mean() + 1e-8)

            if abs(kappa_scalar_val) < 1e-8:
                zero_kappa_fallback_flag = True
                if shrink_tau_on_zero_kappa and not fixed_tau:
                    old_tau = tau
                    tau = max(base_tau, 0.9 * E_f)
                    if tau < old_tau:
                        tau_reasons.append("zero_kappa_fallback")
                w = torch.zeros_like(F_vals)

            # surrogate + CE + entropy update
            optimizer.zero_grad()
            for x, y, idx in active_loader:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                logits = model(x)
                F_b = ce_element(logits, y)
                surrogate = torch.dot(w[idx], F_b)
                ce_mean = F_b.mean()
                probs = torch.nn.functional.softmax(logits, dim=1)
                entropy = - (probs * torch.log(probs + 1e-12)).sum(dim=1).mean()
                loss = damp_surrogate * surrogate + ce_reg_weight * ce_mean - entropy_reg_weight * entropy
                loss.backward()
                running_ce += ce_mean.item() * x.size(0)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

        # --- update P_hat and exact τ enforcement (✅ clamped h) ---
        with torch.no_grad():
            a = torch.dot(F_vals, P_hat)
            b = torch.clamp(torch.sum(F_vals * F_vals), min=1e-12)
            h_for_update = torch.clamp((tau - a) / b, min=-h_max, max=h_max)
            w_vec = P_hat + h_for_update * F_vals
            P_star = proj_simplex(w_vec)
            diff = P_star - P_hat
            D_base_exact = 0.5 * torch.sum(diff * diff).item()

            worst_case_exp = torch.dot(P_star, F_vals).item()
            if worst_case_exp - 1e-6 < tau and not fixed_tau:
                new_tau = max(base_tau, worst_case_exp - 1e-6)
                if new_tau < tau:
                    tau_reasons.append("shrunk_due_to_exact_violation")
                    tau = new_tau

            P_hat_new = (1 - P_hat_momentum) * P_hat + P_hat_momentum * P_star
            P_hat.copy_(P_hat_new / P_hat_new.sum())

        # timing
        epoch_end.record()
        torch.cuda.synchronize()
        elapsed_ms = epoch_start.elapsed_time(epoch_end)

        # --- exact snapshot for diagnostics (✅ clamped h, var-based min_div) ---
        with torch.no_grad():
            a_final = torch.dot(F_vals, P_hat)
            b_final = torch.clamp(torch.sum(F_vals * F_vals), min=1e-12)
            h_final = torch.clamp((tau - a_final) / b_final, min=-h_max, max=h_max)
            w_vec_final = P_hat + h_final * F_vals
            P_star_final = proj_simplex(w_vec_final)
            diff_final = P_star_final - P_hat
            D_base_final = 0.5 * torch.sum(diff_final * diff_final)
            min_div_final = max(min_div_floor, min_div_from_var_scale * varF)
            D_eff_final = distance_scale * D_base_final + min_div_final
            raw_kappa_exact_final = (torch.dot(P_star_final, F_vals) - tau) / D_eff_final

        ce_avg = running_ce / N
        numerator = E_f - tau
        h_val = h_final.item() if isinstance(h_final, torch.Tensor) else h_final
        D_val = D_base_exact
        reason_str = ";".join(tau_reasons) if tau_reasons else "none"

        if raw_kappa_val is not None:
            print(f"[RS] Ep {ep:2d} κ={kappa_scalar_val:.3e} raw_kappa={raw_kappa_val:.3e} "
                  f"E_f={E_f:.4f} num={numerator:.4e} h={h_val:.3e} D={D_val:.3e} "
                  f"τ={tau:.4f} CE={ce_avg:.4f} epoch_time={elapsed_ms/1000:.2f}s", flush=True)
        else:
            print(f"[RS] Ep {ep:2d} κ={kappa_scalar_val:.3e} raw_kappa_exact={raw_kappa_exact_final:.3e} "
                  f"E_f={E_f:.4f} num={numerator:.4e} h={h_val:.3e} D={D_val:.3e} "
                  f"τ={tau:.4f} CE={ce_avg:.4f} epoch_time={elapsed_ms/1000:.2f}s", flush=True)

        print(f"[debug] E_f={E_f:.4f} tau={tau:.4f} reasons={reason_str} "
              f"Var(F)={varF:.4e} D_base={float(D_base_final):.4e} D_eff={float(D_eff_final):.4e} "
              f"raw_kappa_exact={float(raw_kappa_exact_final):.4e} raw_kappa_approx={float(raw_kappa_approx):.4e}",
              flush=True)

        kappa_hist.append(kappa_scalar_val)
        ce_hist.append(ce_avg)

        # validation
        if val_loader is not None and (ep % val_interval == 0 or ep == 1):
            val_acc = evaluate_with_diagnostics(model, val_loader)
            val_accs.append((ep, val_acc))

        # update zero-kappa tracker
        kappa_zeroish = abs(kappa_scalar_val) < 1e-8
        if kappa_zeroish:
            prev_kappa_zeroish_count += 1
        else:
            prev_kappa_zeroish_count = 0

        # early stopping
        tau_stable = abs(tau - prev_tau) < 1e-4
        if ep >= min_epochs_before_early_stop and tau_stable and kappa_zeroish:
            print(f"[RS] Early stopping at epoch {ep} because τ stabilized and κ≈0 for {early_stop_patience} epochs.", flush=True)
            break

        prev_kappa_val = kappa_scalar_val
        prev_tau = tau

    return kappa_hist, ce_hist, val_accs


In [None]:
# ─── Cell 5 ─────────────────────────────────────────────

model_std = SimpleCNN().to(device)
model_rs = SimpleCNN().to(device)
model_sam = SimpleCNN().to(device)

print("\nStarting RS training (improved)…", flush=True)
kappa_list, train_losses_rs, val_history = train_rs(
    model_rs,
    rs_loader,
    epochs=80,
    lr=5e-5,                          # slightly larger learning rate
    initial_tau_multiplier=0.3,       # start tau lower relative to E_f
    tau_target_multiplier=0.6,        # target stays under E_f
    tau_smooth_lr=0.1,                # slower smoothing, less jitter
    P_hat_momentum=0.2,
    damp_surrogate=0.1,
    ce_reg_weight=1.0,                # stronger supervised signal
    warmup_ce_weight=1.0,
    entropy_reg_weight=0.05,          # stronger entropy regularization to avoid collapse
    use_subset=False,
    subset_size=4096,                 # larger subset to reduce sampling bias
    normalize_w=False,
    debug=True,
    base_tau=0.1,
    distance_scale=1.0,
    softplus_temp=1.0,               # sharper kappa activation
    warmup_epochs=10,                # longer pure CE warmup
    start_at_base=True,
    fixed_tau=False,
    val_loader=val_loader,
    val_interval=5,
    early_stop_patience=5,
    min_epochs_before_early_stop=10,
    negative_kappa_damping=0.2,       # gentler handling of negative kappa
    shrink_factor_on_approx_violation=0.85,  # less aggressive tau shrink
    zero_kappa_decay_factor=0.98,     # milder decay when kappa≈0
    shrink_tau_on_zero_kappa=False,    # allow tau to recover upward if κ=0
)





print("\nStarting SAM training…", flush=True)
train_losses_sam = train_sam(model_sam, train_loader, epochs=40, lr=2e-5, rho=0.05)

print("Starting SGD training…", flush=True)
train_losses_std = train_standard(model_std, train_loader, epochs=40, lr=2e-5)

# Evaluate
acc_std = evaluate(model_std, val_loader)
acc_rs = evaluate(model_rs, val_loader)
print(f"RS Validation Accuracy: {acc_rs}")
acc_sam = evaluate(model_sam, val_loader)
print(f"\nValidation acc · SGD: {acc_std:.4f}, RS: {acc_rs:.4f}, SAM: {acc_sam:.4f}", flush=True)





Starting RS training (improved)…
[RS] Computing F on 100% of training data (302436/302436) ~591 batches...


F (init):   0%|          | 0/591 [00:00<?, ?it/s]

In [None]:
plt.figure(figsize=(7,4))
plt.plot(train_losses_std, label="SGD")
plt.plot(train_losses_rs, label="RS")
plt.plot(train_losses_sam, label="SAM")
plt.legend()
plt.title("Train CE / κ comparison")
plt.show()

plt.figure(figsize=(6,3))
plt.plot(kappa_list, label="RS κ")
plt.xlabel("Epoch")
plt.ylabel("κ")
plt.title("RS κ over time")
plt.legend()
plt.show()

In [None]:
import os

# 1) checkpoint directory in Drive
save_dir = '/content/drive/My Drive/camelyon17_checkpoints'
os.makedirs(save_dir, exist_ok=True)

# 2) collect them
models = {
    'baseline_sgd': model_std,
    'rs_on_data':   model_rs,
    'sam_model':    model_sam,
}

# 3) save
for name, mdl in models.items():
    fn = f'{name}.pth'
    path = os.path.join(save_dir, fn)
    torch.save(mdl.state_dict(), path)
    print(f'Saved {name} weights to {path}')

In [None]:
from google.colab import runtime
runtime.unassign()