In [1]:
# BLOCK 1 — Setup & Data (stable, safe normalization, and correct autocast setup)

import os, math, random, time, pickle, json
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader, random_split
from torch.backends import cudnn
from torch import optim, GradScaler
from torchvision.transforms import v2

# ✅ Correct device and autocast setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enable_half = device.type == "cuda"
scaler = GradScaler(enabled=enable_half)
print("Grad scaler is enabled:", enable_half)
print("Device:", device)

# ✅ Stable seed & deterministic control
seed = 1337
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)
cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# === File paths ===
if os.path.exists("/kaggle/input") and os.path.exists("/kaggle/working"):
    SVHN_test = "/kaggle/input/fii-atnn-2025-competition-2/SVHN_test.pkl"
    SVHN_train = "/kaggle/input/fii-atnn-2025-competition-2/SVHN_train.pkl"
    STATS_PATH = "/kaggle/working/svhn_stats.json"
else:
    SVHN_test = "data/SVHN_test.pkl"
    SVHN_train = "data/SVHN_train.pkl"
    STATS_PATH = "svhn_stats.json"


# === Dataset classes ===
class SVHN_Dataset(Dataset):
    def __init__(self, train: bool, transforms: v2.Transform):
        path = SVHN_train if train else SVHN_test
        with open(path, "rb") as fd:
            self.data = pickle.load(fd)
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i: int):
        image, label = self.data[i]
        if self.transforms is None:
            return image, label
        x = self.transforms(image)
        if torch.isnan(x).any() or torch.isinf(x).any():
            # ✅ Skip bad samples
            x = torch.clamp(x, 0, 1)
        return x, label


class TransformedSubset(torch.utils.data.Subset):
    def __init__(self, subset, transform):
        super().__init__(subset.dataset, subset.indices)
        self.transform = transform

    def __getitem__(self, idx):
        image, label = self.dataset[self.indices[idx]]
        if self.transform is None:
            return image, label
        x = self.transform(image)
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.clamp(x, 0, 1)
        return x, label


# === Compute or load dataset mean/std safely ===
def compute_mean_std_from_dataset(dataset):
    to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
    n_pixels = 0
    sum_c = torch.zeros(3, dtype=torch.float64)
    sumsq_c = torch.zeros(3, dtype=torch.float64)

    for i in tqdm(range(len(dataset)), desc="Computing mean/std", leave=False):
        img, _ = dataset[i]
        t = to_tensor(img)
        c, h, w = t.shape
        n = h * w
        n_pixels += n
        t = t.to(torch.float64)
        sum_c += t.reshape(c, -1).sum(dim=1)
        sumsq_c += (t.reshape(c, -1) ** 2).sum(dim=1)

    mean = (sum_c / n_pixels).to(torch.float32)
    var = (sumsq_c / n_pixels) - (mean.to(torch.float64) ** 2)
    std = torch.sqrt(var.clamp_min(1e-6)).to(torch.float32)
    return mean.tolist(), std.tolist()


if os.path.exists(STATS_PATH):
    with open(STATS_PATH, "r") as f:
        stats = json.load(f)
    mean, std = stats["mean"], stats["std"]
    print(f"Loaded stats: mean={mean}, std={std}")
else:
    base_train_for_stats = SVHN_Dataset(train=True, transforms=None)
    mean, std = compute_mean_std_from_dataset(base_train_for_stats)
    with open(STATS_PATH, "w") as f:
        json.dump({"mean": mean, "std": std}, f)
    print(f"Computed stats: mean={mean}, std={std}")

# ✅ Sanity check and clamp bad stats
mean = tuple(float(m) for m in mean)
std = tuple(max(float(s), 1e-3) for s in std)
print(f"Using mean={mean}, std={std}")


# === Augmentations ===
train_transforms = v2.Compose([
    v2.ToImage(),
    v2.RandomCrop(32, padding=4, padding_mode="reflect"),
    v2.RandomAffine(degrees=10, translate=(0.12, 0.12), scale=(0.9, 1.1)),
    v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.07),
    v2.ToDtype(torch.float32, scale=True),  # ✅ single scaling only
    v2.Normalize(mean, std, inplace=True),
    v2.RandomErasing(p=0.4, scale=(0.02, 0.15), ratio=(0.3, 3.3), inplace=True),
    v2.RandAugment(num_ops=2, magnitude=9),
])

eval_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),  # ✅ only one scaling op
    v2.Normalize(mean, std, inplace=True),
])


# === Dataset and loaders ===
base_train = SVHN_Dataset(train=True, transforms=None)
val_ratio = 0.10
val_len = int(len(base_train) * val_ratio)
train_len = len(base_train) - val_len
train_subset, val_subset = random_split(base_train, [train_len, val_len], generator=torch.Generator().manual_seed(seed))

train_set = TransformedSubset(train_subset, train_transforms)
val_set = TransformedSubset(val_subset, eval_transforms)
test_set = SVHN_Dataset(train=False, transforms=eval_transforms)

BATCH_SIZE = 256
NUM_WORKERS = min(4, (os.cpu_count() or 2) // 2)
pin = device.type == "cuda"

train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=NUM_WORKERS > 0, drop_last=True
)
val_loader = DataLoader(
    val_set, batch_size=512, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=NUM_WORKERS > 0
)
test_loader = DataLoader(
    test_set, batch_size=1024, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=NUM_WORKERS > 0
)


Grad scaler is enabled: True
Device: cuda
Loaded stats: mean=[0.5070751905441284, 0.48654890060424805, 0.44091787934303284], std=[0.2673342823982239, 0.2564384639263153, 0.2761504650115967]
Using mean=(0.5070751905441284, 0.48654890060424805, 0.44091787934303284), std=(0.2673342823982239, 0.2564384639263153, 0.2761504650115967)


In [2]:
# BLOCK 2 — Model and optimizer updates (stable setup with safe learning dynamics)

import math
from torch.optim.swa_utils import AveragedModel, SWALR

# === Model (unchanged) ===
class VGG13(nn.Module):
    def __init__(self):
        super(VGG13, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Flatten(),
            nn.Linear(512, 100)
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)


# === Experiment configurations (safe & diverse) ===
experiments = [
    # 1️⃣ Baseline — balanced setup
    {
        "name": "exp1_baseline",
        "base_lr": 0.15,
        "ema_decay": 0.9995,
        "rand_mag": 9,
        "mix_p": 0.8,
        "extra_aug": False,
        "remove_ema": False,
        "wd": 5e-4,
    },

    # 2️⃣ Strong augmentation — robustness-focused
    {
        "name": "exp2_strong_aug",
        "base_lr": 0.12,
        "ema_decay": 0.9995,
        "rand_mag": 14,
        "mix_p": 0.8,
        "extra_aug": True,
        "remove_ema": False,
        "wd": 5e-4,
    },

    # 3️⃣ Fast optimization — slightly higher LR but stable decay
    {
        "name": "exp3_fast_opt",
        "base_lr": 0.10,          
        "ema_decay": 0.9975,      
        "rand_mag": 9,
        "mix_p": 0.6,
        "extra_aug": False,
        "remove_ema": False,
        "wd": 6e-4,
    },

    # 4️⃣ Weak augmentation + light regularization
    {
        "name": "exp4_light_aug_fast_lr",
        "base_lr": 0.10,          
        "ema_decay": 0.998,
        "rand_mag": 6,
        "mix_p": 0.5,
        "extra_aug": False,
        "remove_ema": False,
        "wd": 5e-4,
    },
]


# === Global training constants ===
criterion = nn.CrossEntropyLoss(label_smoothing=0.02)

EPOCHS = 100
warmup_pct = 0.20             
min_lr_ratio = 1e-4
grad_clip = 0.5               
patience = 20
SWA_START_RATIO = 0.75        

# ✅ For reproducibility
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


In [3]:
# BLOCK 3 — Training Loop (stabilized: FP16-safe, NaN-guarded, consistent EMA/SWA updates)

from torch.optim.swa_utils import update_bn
import os, torch

def ensure_nchw(x: torch.Tensor) -> torch.Tensor:
    if x.ndim == 4 and x.shape[1] != 3 and x.shape[-1] == 3:
        return x.permute(0, 3, 1, 2).contiguous()
    return x


def accuracy(outputs: torch.Tensor, targets: torch.Tensor) -> float:
    with torch.no_grad():
        preds = outputs.argmax(1)
        return (preds == targets).float().mean().item()


# === Mixup / CutMix helpers ===
def rand_bbox(W, H, lam):
    cut_rat = (1.0 - lam) ** 0.5
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx, cy = np.random.randint(W), np.random.randint(H)
    x1, y1 = np.clip(cx - cut_w // 2, 0, W), np.clip(cy - cut_h // 2, 0, H)
    x2, y2 = np.clip(cx + cut_w // 2, 0, W), np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2


def apply_mixup_cutmix(inputs, targets, mixup_alpha=0.4, cutmix_alpha=1.0, p=0.9):
    if np.random.rand() > p:
        return inputs, (targets, targets, 1.0), None
    bs, c, h, w = inputs.size()
    idx = torch.randperm(bs, device=inputs.device)
    y_a, y_b = targets, targets[idx]
    if np.random.rand() < 0.5:
        lam = np.random.beta(mixup_alpha, mixup_alpha)
        mixed = inputs * lam + inputs[idx] * (1.0 - lam)
        return mixed, (y_a, y_b, float(lam)), "mixup"
    else:
        lam = np.random.beta(cutmix_alpha, cutmix_alpha)
        x1, y1, x2, y2 = rand_bbox(w, h, lam)
        mixed = inputs.clone()
        mixed[:, :, y1:y2, x1:x2] = inputs[idx, :, y1:y2, x1:x2]
        lam = 1.0 - ((x2 - x1) * (y2 - y1) / (w * h + 1e-9))
        return mixed, (y_a, y_b, float(lam)), "cutmix"


def mix_criterion(crit, preds, y_a, y_b, lam):
    return lam * crit(preds, y_a) + (1.0 - lam) * crit(preds, y_b)


# === Stable training step ===
def run_epoch(loader, model, optimizer, scheduler, ema, criterion, train_mode=True, use_swa=False, mix_p=0.9):
    model.train(train_mode)
    total_loss, total_acc, total_n = 0.0, 0.0, 0
    bar = tqdm(loader, leave=False)
    for inputs, targets in bar:
        inputs = ensure_nchw(inputs.to(device, non_blocking=True)).float()
        targets = targets.to(device, non_blocking=True)

        if train_mode:
            inputs_mixed, (y_a, y_b, lam), mode = apply_mixup_cutmix(inputs, targets, p=mix_p)

            with torch.autocast(device_type=device.type, enabled=enable_half):
                outputs = model(inputs_mixed)
                loss = criterion(outputs, targets) if mode is None else mix_criterion(criterion, outputs, y_a, y_b, lam)

            # ✅ skip if NaN/Inf
            if not torch.isfinite(loss):
                print("⚠️  Skipping batch due to NaN/Inf loss")
                optimizer.zero_grad(set_to_none=True)
                continue

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

            if not use_swa:
                scheduler.step()
            if ema is not None:
                ema.update(model)

            batch_acc = accuracy(outputs, targets)

        else:
            with torch.inference_mode(), torch.autocast(device_type=device.type, enabled=enable_half):
                net = ema.ema if ema is not None else model
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            batch_acc = accuracy(outputs, targets)

        bs = targets.size(0)
        total_loss += loss.item() * bs
        total_acc += batch_acc * bs
        total_n += bs
        bar.set_description(f"{'Train' if train_mode else 'Valid'} loss {total_loss/total_n:.4f} acc {total_acc/total_n:.4f}")

    return total_loss / total_n, total_acc / total_n


# === EMA wrapper ===
class ModelEMA:
    def __init__(self, model, decay=0.9995):
        self.ema = VGG13().to(device)
        self.ema.load_state_dict(model.state_dict(), strict=True)
        for p in self.ema.parameters():
            p.requires_grad_(False)
        self.decay = decay

    @torch.no_grad()
    def update(self, model):
        d = self.decay
        msd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                v.copy_(v * d + (1.0 - d) * msd[k])

    def state_dict(self):
        return self.ema.state_dict()


# === Run all experiments ===
for cfg in experiments:
    print(f"\n\n=== Running {cfg['name']} ===")

    # Dynamic augmentations
    extra_layers = []
    if cfg.get("extra_aug", False):
        extra_layers += [
            v2.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),
            v2.RandomPerspective(distortion_scale=0.5, p=0.5),
        ]

    train_transforms = v2.Compose([
        v2.ToImage(),
        v2.RandomCrop(32, padding=4, padding_mode="reflect"),
        v2.RandomAffine(degrees=10, translate=(0.12, 0.12), scale=(0.9, 1.1)),
        v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        *extra_layers,
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean, std, inplace=True),
        v2.RandomErasing(p=0.4, scale=(0.02, 0.15), ratio=(0.3, 3.3), inplace=True),
        v2.RandAugment(num_ops=2, magnitude=cfg["rand_mag"]),
    ])

    # Fresh datasets
    base_train = SVHN_Dataset(train=True, transforms=None)
    val_len = int(len(base_train) * 0.10)
    train_len = len(base_train) - val_len
    train_subset, val_subset = random_split(base_train, [train_len, val_len], generator=torch.Generator().manual_seed(seed))
    train_set = TransformedSubset(train_subset, train_transforms)
    val_set = TransformedSubset(val_subset, eval_transforms)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=pin, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=512, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin)

    # === Model & optimizer ===
    model = VGG13().to(device)
    model = torch.jit.script(model)
    ema = None if cfg.get("remove_ema", False) else ModelEMA(model, decay=cfg["ema_decay"])
    base_lr = cfg["base_lr"]

    optimizer = optim.SGD(
        model.parameters(),
        lr=base_lr,
        momentum=0.9,
        weight_decay=cfg.get("wd", 5e-4),
        nesterov=True,
        fused=(device.type == "cuda")
    )

    # === Scheduler ===
    total_steps = EPOCHS * len(train_loader)
    warmup_steps = max(1, int(total_steps * warmup_pct))

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    # === SWA setup ===
    swa_start_epoch = int(EPOCHS * SWA_START_RATIO)
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=base_lr * 0.05)

    best_val_loss, best_epoch, no_improve = float("inf"), -1, 0
    save_dir = f"/kaggle/working/{cfg['name']}"
    os.makedirs(save_dir, exist_ok=True)

    # === Epoch loop ===
    for epoch in range(1, EPOCHS + 1):
        use_swa = epoch >= swa_start_epoch
        tr_loss, tr_acc = run_epoch(train_loader, model, optimizer, scheduler, ema, criterion, train_mode=True, use_swa=use_swa, mix_p=cfg["mix_p"])
        va_loss, va_acc = run_epoch(val_loader, model, optimizer, scheduler, ema, criterion, train_mode=False, use_swa=use_swa)

        lr_now = optimizer.param_groups[0]["lr"]
        print(f"[{cfg['name']}] Epoch {epoch:03d}/{EPOCHS} | lr={lr_now:.5f} | train_acc={tr_acc*100:.2f}% | val_acc={va_acc*100:.2f}% | val_loss={va_loss:.4f}")

        if use_swa:
            swa_model.update_parameters(model)
            swa_scheduler.step()

        if torch.isnan(torch.tensor(va_loss)) or va_loss == float("inf"):
            print(f"⚠️  NaN validation loss detected in {cfg['name']} — reducing LR.")
            for g in optimizer.param_groups:
                g["lr"] *= 0.5
            continue

        improved = va_loss < best_val_loss - 1e-4
        if improved:
            best_val_loss, best_epoch, no_improve = va_loss, epoch, 0
            torch.save((ema.state_dict() if ema else model.state_dict()), f"{save_dir}/best_vgg13.pth")
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"[{cfg['name']}] Early stopping at epoch {epoch} (best {best_epoch}, val_loss {best_val_loss:.4f}).")
                break

    # === SWA finalize ===
    print(f"Finalizing {cfg['name']} with SWA update...")
    swa_model.train()
    update_bn(train_loader, swa_model)
    swa_model.eval()

    @torch.inference_mode()
    def eval_on_val(module):
        tot_loss, tot_acc, tot_n = 0.0, 0.0, 0
        for inputs, targets in val_loader:
            inputs = ensure_nchw(inputs.to(device, non_blocking=True)).float()
            targets = targets.to(device, non_blocking=True)
            with torch.autocast(device_type=device.type, enabled=enable_half):
                outputs = module(inputs)
                loss = criterion(outputs, targets)
            bs = targets.size(0)
            tot_loss += loss.item() * bs
            tot_acc += accuracy(outputs, targets) * bs
            tot_n += bs
        return tot_loss / tot_n, tot_acc / tot_n

    swa_val_loss, swa_val_acc = eval_on_val(swa_model)
    print(f"[{cfg['name']}] SWA val_loss={swa_val_loss:.4f}, val_acc={swa_val_acc*100:.2f}% | best EMA loss={best_val_loss:.4f}")

    if swa_val_loss < best_val_loss - 1e-4:
        torch.save(swa_model.state_dict(), f"{save_dir}/best_vgg13.pth")
        print(f"[{cfg['name']}] SWA outperformed EMA — saved SWA weights.")
    else:
        print(f"[{cfg['name']}] Keeping EMA weights.")

print("\nAll 4 experiments completed successfully ✅")




=== Running exp1_baseline ===


                                                                               

[exp1_baseline] Epoch 001/100 | lr=0.00754 | train_acc=3.22% | val_acc=0.90% | val_loss=4.8012


                                                                               

[exp1_baseline] Epoch 002/100 | lr=0.01504 | train_acc=10.85% | val_acc=1.40% | val_loss=4.7015


                                                                               

[exp1_baseline] Epoch 003/100 | lr=0.02254 | train_acc=16.24% | val_acc=2.08% | val_loss=4.5687


                                                                               

[exp1_baseline] Epoch 004/100 | lr=0.03004 | train_acc=19.64% | val_acc=4.42% | val_loss=4.4046


                                                                               

[exp1_baseline] Epoch 005/100 | lr=0.03754 | train_acc=23.50% | val_acc=9.36% | val_loss=4.2133


                                                                               

[exp1_baseline] Epoch 006/100 | lr=0.04504 | train_acc=25.17% | val_acc=14.50% | val_loss=4.0007


                                                                               

[exp1_baseline] Epoch 007/100 | lr=0.05254 | train_acc=30.09% | val_acc=19.30% | val_loss=3.7787


                                                                               

[exp1_baseline] Epoch 008/100 | lr=0.06004 | train_acc=30.37% | val_acc=23.32% | val_loss=3.5543


                                                                               

[exp1_baseline] Epoch 009/100 | lr=0.06754 | train_acc=34.87% | val_acc=27.56% | val_loss=3.3355


                                                                               

[exp1_baseline] Epoch 010/100 | lr=0.07504 | train_acc=34.76% | val_acc=31.42% | val_loss=3.1281


                                                                               

[exp1_baseline] Epoch 011/100 | lr=0.08254 | train_acc=38.25% | val_acc=34.00% | val_loss=2.9345


                                                                               

[exp1_baseline] Epoch 012/100 | lr=0.09004 | train_acc=41.49% | val_acc=36.94% | val_loss=2.7564


                                                                               

[exp1_baseline] Epoch 013/100 | lr=0.09754 | train_acc=44.40% | val_acc=39.60% | val_loss=2.5943


                                                                               

[exp1_baseline] Epoch 014/100 | lr=0.10504 | train_acc=40.91% | val_acc=42.86% | val_loss=2.4472


                                                                               

[exp1_baseline] Epoch 015/100 | lr=0.11254 | train_acc=46.85% | val_acc=45.32% | val_loss=2.3132


                                                                               

[exp1_baseline] Epoch 016/100 | lr=0.12004 | train_acc=43.79% | val_acc=47.64% | val_loss=2.1914


                                                                               

[exp1_baseline] Epoch 017/100 | lr=0.12754 | train_acc=44.27% | val_acc=49.62% | val_loss=2.0846


                                                                               

[exp1_baseline] Epoch 018/100 | lr=0.13504 | train_acc=48.09% | val_acc=51.10% | val_loss=1.9892


                                                                               

[exp1_baseline] Epoch 019/100 | lr=0.14254 | train_acc=47.10% | val_acc=52.86% | val_loss=1.9067


                                                                               

[exp1_baseline] Epoch 020/100 | lr=0.15000 | train_acc=51.03% | val_acc=54.46% | val_loss=1.8345


                                                                               

[exp1_baseline] Epoch 021/100 | lr=0.14994 | train_acc=49.49% | val_acc=55.70% | val_loss=1.7722


                                                                               

[exp1_baseline] Epoch 022/100 | lr=0.14977 | train_acc=47.98% | val_acc=56.84% | val_loss=1.7182


                                                                               

[exp1_baseline] Epoch 023/100 | lr=0.14948 | train_acc=49.98% | val_acc=57.82% | val_loss=1.6717


                                                                               

[exp1_baseline] Epoch 024/100 | lr=0.14908 | train_acc=45.62% | val_acc=58.84% | val_loss=1.6331


                                                                               

[exp1_baseline] Epoch 025/100 | lr=0.14856 | train_acc=46.84% | val_acc=59.76% | val_loss=1.6006


                                                                               

[exp1_baseline] Epoch 026/100 | lr=0.14793 | train_acc=51.04% | val_acc=60.20% | val_loss=1.5719


                                                                               

[exp1_baseline] Epoch 027/100 | lr=0.14718 | train_acc=54.60% | val_acc=60.64% | val_loss=1.5490


                                                                               

[exp1_baseline] Epoch 028/100 | lr=0.14633 | train_acc=56.22% | val_acc=61.14% | val_loss=1.5282


                                                                               

[exp1_baseline] Epoch 029/100 | lr=0.14536 | train_acc=54.79% | val_acc=61.88% | val_loss=1.5101


                                                                               

[exp1_baseline] Epoch 030/100 | lr=0.14429 | train_acc=50.53% | val_acc=62.42% | val_loss=1.4950


                                                                               

[exp1_baseline] Epoch 031/100 | lr=0.14311 | train_acc=53.42% | val_acc=62.60% | val_loss=1.4806


                                                                               

[exp1_baseline] Epoch 032/100 | lr=0.14183 | train_acc=56.88% | val_acc=63.00% | val_loss=1.4684


                                                                               

[exp1_baseline] Epoch 033/100 | lr=0.14044 | train_acc=56.80% | val_acc=63.14% | val_loss=1.4566


                                                                               

[exp1_baseline] Epoch 034/100 | lr=0.13895 | train_acc=55.24% | val_acc=63.54% | val_loss=1.4469


                                                                               

[exp1_baseline] Epoch 035/100 | lr=0.13736 | train_acc=57.22% | val_acc=63.80% | val_loss=1.4388


                                                                               

[exp1_baseline] Epoch 036/100 | lr=0.13568 | train_acc=57.51% | val_acc=64.06% | val_loss=1.4298


                                                                               

[exp1_baseline] Epoch 037/100 | lr=0.13390 | train_acc=55.56% | val_acc=64.30% | val_loss=1.4215


                                                                               

[exp1_baseline] Epoch 038/100 | lr=0.13203 | train_acc=59.71% | val_acc=64.58% | val_loss=1.4140


                                                                               

[exp1_baseline] Epoch 039/100 | lr=0.13008 | train_acc=57.12% | val_acc=64.70% | val_loss=1.4071


                                                                               

[exp1_baseline] Epoch 040/100 | lr=0.12804 | train_acc=55.63% | val_acc=64.82% | val_loss=1.4010


                                                                               

[exp1_baseline] Epoch 041/100 | lr=0.12591 | train_acc=56.18% | val_acc=65.20% | val_loss=1.3952


                                                                               

[exp1_baseline] Epoch 042/100 | lr=0.12371 | train_acc=60.22% | val_acc=65.56% | val_loss=1.3893


                                                                               

[exp1_baseline] Epoch 043/100 | lr=0.12143 | train_acc=60.26% | val_acc=65.72% | val_loss=1.3835


                                                                               

[exp1_baseline] Epoch 044/100 | lr=0.11909 | train_acc=53.82% | val_acc=66.12% | val_loss=1.3789


                                                                               

[exp1_baseline] Epoch 045/100 | lr=0.11667 | train_acc=56.82% | val_acc=66.16% | val_loss=1.3751


                                                                               

[exp1_baseline] Epoch 046/100 | lr=0.11419 | train_acc=58.83% | val_acc=66.30% | val_loss=1.3714


                                                                               

[exp1_baseline] Epoch 047/100 | lr=0.11165 | train_acc=56.36% | val_acc=66.52% | val_loss=1.3687


                                                                               

[exp1_baseline] Epoch 048/100 | lr=0.10905 | train_acc=59.88% | val_acc=66.70% | val_loss=1.3658


                                                                               

[exp1_baseline] Epoch 049/100 | lr=0.10640 | train_acc=60.47% | val_acc=66.98% | val_loss=1.3637


                                                                               

[exp1_baseline] Epoch 050/100 | lr=0.10371 | train_acc=60.14% | val_acc=66.98% | val_loss=1.3616


                                                                               

[exp1_baseline] Epoch 051/100 | lr=0.10096 | train_acc=66.35% | val_acc=66.96% | val_loss=1.3603


                                                                               

[exp1_baseline] Epoch 052/100 | lr=0.09818 | train_acc=58.46% | val_acc=67.32% | val_loss=1.3595


                                                                               

[exp1_baseline] Epoch 053/100 | lr=0.09536 | train_acc=54.05% | val_acc=67.38% | val_loss=1.3591


                                                                               

[exp1_baseline] Epoch 054/100 | lr=0.09251 | train_acc=64.65% | val_acc=67.32% | val_loss=1.3586


                                                                               

[exp1_baseline] Epoch 055/100 | lr=0.08964 | train_acc=60.45% | val_acc=67.62% | val_loss=1.3584


                                                                               

[exp1_baseline] Epoch 056/100 | lr=0.08674 | train_acc=55.29% | val_acc=67.58% | val_loss=1.3583


                                                                               

[exp1_baseline] Epoch 057/100 | lr=0.08382 | train_acc=57.85% | val_acc=67.74% | val_loss=1.3578


                                                                               

[exp1_baseline] Epoch 058/100 | lr=0.08089 | train_acc=65.10% | val_acc=67.52% | val_loss=1.3571


                                                                               

[exp1_baseline] Epoch 059/100 | lr=0.07795 | train_acc=55.93% | val_acc=67.74% | val_loss=1.3559


                                                                               

[exp1_baseline] Epoch 060/100 | lr=0.07501 | train_acc=61.25% | val_acc=67.84% | val_loss=1.3556


                                                                               

[exp1_baseline] Epoch 061/100 | lr=0.07206 | train_acc=60.38% | val_acc=67.98% | val_loss=1.3559


                                                                               

[exp1_baseline] Epoch 062/100 | lr=0.06912 | train_acc=64.30% | val_acc=67.98% | val_loss=1.3561


                                                                               

[exp1_baseline] Epoch 063/100 | lr=0.06619 | train_acc=63.25% | val_acc=67.94% | val_loss=1.3562


                                                                               

[exp1_baseline] Epoch 064/100 | lr=0.06328 | train_acc=62.54% | val_acc=67.82% | val_loss=1.3569


                                                                               

[exp1_baseline] Epoch 065/100 | lr=0.06038 | train_acc=61.58% | val_acc=68.12% | val_loss=1.3577


                                                                               

[exp1_baseline] Epoch 066/100 | lr=0.05750 | train_acc=64.48% | val_acc=68.10% | val_loss=1.3589


                                                                               

[exp1_baseline] Epoch 067/100 | lr=0.05465 | train_acc=64.49% | val_acc=68.20% | val_loss=1.3604


                                                                               

[exp1_baseline] Epoch 068/100 | lr=0.05183 | train_acc=66.65% | val_acc=68.08% | val_loss=1.3625


                                                                               

[exp1_baseline] Epoch 069/100 | lr=0.04905 | train_acc=70.37% | val_acc=67.98% | val_loss=1.3645


                                                                               

[exp1_baseline] Epoch 070/100 | lr=0.04631 | train_acc=62.02% | val_acc=68.10% | val_loss=1.3661


                                                                               

[exp1_baseline] Epoch 071/100 | lr=0.04361 | train_acc=63.44% | val_acc=68.18% | val_loss=1.3673


                                                                               

[exp1_baseline] Epoch 072/100 | lr=0.04096 | train_acc=67.29% | val_acc=68.20% | val_loss=1.3695


                                                                               

[exp1_baseline] Epoch 073/100 | lr=0.03836 | train_acc=68.27% | val_acc=68.28% | val_loss=1.3721


                                                                               

[exp1_baseline] Epoch 074/100 | lr=0.03582 | train_acc=67.61% | val_acc=68.28% | val_loss=1.3744


                                                                               

[exp1_baseline] Epoch 075/100 | lr=0.03582 | train_acc=62.29% | val_acc=68.38% | val_loss=1.3765


                                                                               

[exp1_baseline] Epoch 076/100 | lr=0.03513 | train_acc=68.85% | val_acc=68.42% | val_loss=1.3777


                                                                               

[exp1_baseline] Epoch 077/100 | lr=0.03312 | train_acc=64.46% | val_acc=68.54% | val_loss=1.3797


                                                                               

[exp1_baseline] Epoch 078/100 | lr=0.02999 | train_acc=70.97% | val_acc=68.48% | val_loss=1.3827


                                                                               

[exp1_baseline] Epoch 079/100 | lr=0.02604 | train_acc=71.88% | val_acc=68.68% | val_loss=1.3863


                                                                               

[exp1_baseline] Epoch 080/100 | lr=0.02166 | train_acc=68.12% | val_acc=68.78% | val_loss=1.3899
[exp1_baseline] Early stopping at epoch 80 (best 60, val_loss 1.3556).
Finalizing exp1_baseline with SWA update...




[exp1_baseline] SWA val_loss=1.7216, val_acc=63.10% | best EMA loss=1.3556
[exp1_baseline] Keeping EMA weights.


=== Running exp2_strong_aug ===


                                                                               

[exp2_strong_aug] Epoch 001/100 | lr=0.00603 | train_acc=2.89% | val_acc=0.86% | val_loss=4.8178


                                                                               

[exp2_strong_aug] Epoch 002/100 | lr=0.01203 | train_acc=9.48% | val_acc=1.30% | val_loss=4.7229


                                                                               

[exp2_strong_aug] Epoch 003/100 | lr=0.01803 | train_acc=14.77% | val_acc=2.36% | val_loss=4.5951


                                                                               

[exp2_strong_aug] Epoch 004/100 | lr=0.02403 | train_acc=18.22% | val_acc=5.00% | val_loss=4.4415


                                                                               

[exp2_strong_aug] Epoch 005/100 | lr=0.03003 | train_acc=22.59% | val_acc=8.52% | val_loss=4.2684


                                                                               

[exp2_strong_aug] Epoch 006/100 | lr=0.03603 | train_acc=22.68% | val_acc=12.96% | val_loss=4.0781


                                                                               

[exp2_strong_aug] Epoch 007/100 | lr=0.04203 | train_acc=27.95% | val_acc=16.92% | val_loss=3.8773


                                                                               

[exp2_strong_aug] Epoch 008/100 | lr=0.04803 | train_acc=30.45% | val_acc=21.30% | val_loss=3.6720


                                                                               

[exp2_strong_aug] Epoch 009/100 | lr=0.05403 | train_acc=33.06% | val_acc=24.62% | val_loss=3.4680


                                                                               

[exp2_strong_aug] Epoch 010/100 | lr=0.06003 | train_acc=34.76% | val_acc=28.44% | val_loss=3.2708


                                                                               

[exp2_strong_aug] Epoch 011/100 | lr=0.06603 | train_acc=38.26% | val_acc=31.98% | val_loss=3.0815


                                                                               

[exp2_strong_aug] Epoch 012/100 | lr=0.07203 | train_acc=39.75% | val_acc=34.50% | val_loss=2.9042


                                                                               

[exp2_strong_aug] Epoch 013/100 | lr=0.07803 | train_acc=39.64% | val_acc=37.28% | val_loss=2.7398


                                                                               

[exp2_strong_aug] Epoch 014/100 | lr=0.08403 | train_acc=42.03% | val_acc=40.18% | val_loss=2.5879


                                                                               

[exp2_strong_aug] Epoch 015/100 | lr=0.09003 | train_acc=44.43% | val_acc=42.74% | val_loss=2.4486


                                                                               

[exp2_strong_aug] Epoch 016/100 | lr=0.09603 | train_acc=42.02% | val_acc=45.00% | val_loss=2.3221


                                                                               

[exp2_strong_aug] Epoch 017/100 | lr=0.10203 | train_acc=45.85% | val_acc=46.74% | val_loss=2.2085


                                                                               

[exp2_strong_aug] Epoch 018/100 | lr=0.10803 | train_acc=44.17% | val_acc=49.30% | val_loss=2.1083


                                                                               

[exp2_strong_aug] Epoch 019/100 | lr=0.11403 | train_acc=45.58% | val_acc=50.92% | val_loss=2.0190


                                                                               

[exp2_strong_aug] Epoch 020/100 | lr=0.12000 | train_acc=46.74% | val_acc=52.34% | val_loss=1.9386


                                                                               

[exp2_strong_aug] Epoch 021/100 | lr=0.11995 | train_acc=48.40% | val_acc=53.70% | val_loss=1.8680


                                                                               

[exp2_strong_aug] Epoch 022/100 | lr=0.11982 | train_acc=49.97% | val_acc=55.08% | val_loss=1.8078


                                                                               

[exp2_strong_aug] Epoch 023/100 | lr=0.11958 | train_acc=50.36% | val_acc=56.62% | val_loss=1.7547


                                                                               

[exp2_strong_aug] Epoch 024/100 | lr=0.11926 | train_acc=46.24% | val_acc=57.76% | val_loss=1.7092


                                                                               

[exp2_strong_aug] Epoch 025/100 | lr=0.11885 | train_acc=51.20% | val_acc=58.50% | val_loss=1.6704


                                                                               

[exp2_strong_aug] Epoch 026/100 | lr=0.11834 | train_acc=52.88% | val_acc=59.20% | val_loss=1.6372


                                                                               

[exp2_strong_aug] Epoch 027/100 | lr=0.11775 | train_acc=55.03% | val_acc=59.90% | val_loss=1.6073


                                                                               

[exp2_strong_aug] Epoch 028/100 | lr=0.11706 | train_acc=58.36% | val_acc=60.46% | val_loss=1.5808


                                                                               

[exp2_strong_aug] Epoch 029/100 | lr=0.11629 | train_acc=55.84% | val_acc=61.00% | val_loss=1.5577


                                                                               

[exp2_strong_aug] Epoch 030/100 | lr=0.11543 | train_acc=61.11% | val_acc=61.58% | val_loss=1.5376


                                                                               

[exp2_strong_aug] Epoch 031/100 | lr=0.11449 | train_acc=54.06% | val_acc=62.00% | val_loss=1.5216


                                                                               

[exp2_strong_aug] Epoch 032/100 | lr=0.11346 | train_acc=59.51% | val_acc=62.60% | val_loss=1.5082


                                                                               

[exp2_strong_aug] Epoch 033/100 | lr=0.11235 | train_acc=66.29% | val_acc=62.82% | val_loss=1.4964


                                                                               

[exp2_strong_aug] Epoch 034/100 | lr=0.11116 | train_acc=59.92% | val_acc=63.06% | val_loss=1.4863


                                                                               

[exp2_strong_aug] Epoch 035/100 | lr=0.10989 | train_acc=54.54% | val_acc=63.44% | val_loss=1.4783


                                                                               

[exp2_strong_aug] Epoch 036/100 | lr=0.10854 | train_acc=55.62% | val_acc=63.64% | val_loss=1.4716


                                                                               

[exp2_strong_aug] Epoch 037/100 | lr=0.10712 | train_acc=54.63% | val_acc=63.88% | val_loss=1.4648


                                                                               

[exp2_strong_aug] Epoch 038/100 | lr=0.10563 | train_acc=54.78% | val_acc=63.94% | val_loss=1.4581


                                                                               

[exp2_strong_aug] Epoch 039/100 | lr=0.10406 | train_acc=58.24% | val_acc=64.12% | val_loss=1.4517


                                                                               

[exp2_strong_aug] Epoch 040/100 | lr=0.10243 | train_acc=62.59% | val_acc=64.04% | val_loss=1.4456


                                                                               

[exp2_strong_aug] Epoch 041/100 | lr=0.10073 | train_acc=58.07% | val_acc=64.00% | val_loss=1.4410


                                                                               

[exp2_strong_aug] Epoch 042/100 | lr=0.09897 | train_acc=60.64% | val_acc=64.26% | val_loss=1.4365


                                                                               

[exp2_strong_aug] Epoch 043/100 | lr=0.09715 | train_acc=60.10% | val_acc=64.66% | val_loss=1.4319


                                                                               

[exp2_strong_aug] Epoch 044/100 | lr=0.09527 | train_acc=59.19% | val_acc=64.92% | val_loss=1.4270


                                                                               

[exp2_strong_aug] Epoch 045/100 | lr=0.09334 | train_acc=63.85% | val_acc=64.86% | val_loss=1.4233


                                                                               

[exp2_strong_aug] Epoch 046/100 | lr=0.09135 | train_acc=57.16% | val_acc=65.08% | val_loss=1.4193


                                                                               

[exp2_strong_aug] Epoch 047/100 | lr=0.08932 | train_acc=63.10% | val_acc=65.42% | val_loss=1.4159


                                                                               

[exp2_strong_aug] Epoch 048/100 | lr=0.08724 | train_acc=65.01% | val_acc=65.68% | val_loss=1.4132


                                                                               

[exp2_strong_aug] Epoch 049/100 | lr=0.08512 | train_acc=61.75% | val_acc=65.82% | val_loss=1.4111


                                                                               

[exp2_strong_aug] Epoch 050/100 | lr=0.08296 | train_acc=67.21% | val_acc=65.82% | val_loss=1.4095


                                                                               

[exp2_strong_aug] Epoch 051/100 | lr=0.08077 | train_acc=55.15% | val_acc=65.96% | val_loss=1.4084


                                                                               

[exp2_strong_aug] Epoch 052/100 | lr=0.07855 | train_acc=56.00% | val_acc=65.90% | val_loss=1.4083


                                                                               

[exp2_strong_aug] Epoch 053/100 | lr=0.07629 | train_acc=59.24% | val_acc=65.84% | val_loss=1.4073


                                                                               

[exp2_strong_aug] Epoch 054/100 | lr=0.07401 | train_acc=63.37% | val_acc=66.02% | val_loss=1.4065


                                                                               

[exp2_strong_aug] Epoch 055/100 | lr=0.07171 | train_acc=64.37% | val_acc=66.18% | val_loss=1.4068


                                                                               

[exp2_strong_aug] Epoch 056/100 | lr=0.06939 | train_acc=61.22% | val_acc=66.38% | val_loss=1.4071


                                                                               

[exp2_strong_aug] Epoch 057/100 | lr=0.06706 | train_acc=60.48% | val_acc=66.34% | val_loss=1.4077


                                                                               

[exp2_strong_aug] Epoch 058/100 | lr=0.06471 | train_acc=63.40% | val_acc=66.60% | val_loss=1.4078


                                                                               

[exp2_strong_aug] Epoch 059/100 | lr=0.06236 | train_acc=64.29% | val_acc=66.56% | val_loss=1.4073


                                                                               

[exp2_strong_aug] Epoch 060/100 | lr=0.06001 | train_acc=67.50% | val_acc=66.72% | val_loss=1.4067


                                                                               

[exp2_strong_aug] Epoch 061/100 | lr=0.05765 | train_acc=61.21% | val_acc=66.92% | val_loss=1.4070


                                                                               

[exp2_strong_aug] Epoch 062/100 | lr=0.05530 | train_acc=62.67% | val_acc=66.96% | val_loss=1.4080


                                                                               

[exp2_strong_aug] Epoch 063/100 | lr=0.05295 | train_acc=64.40% | val_acc=67.22% | val_loss=1.4088


                                                                               

[exp2_strong_aug] Epoch 064/100 | lr=0.05062 | train_acc=67.93% | val_acc=67.20% | val_loss=1.4100


                                                                               

[exp2_strong_aug] Epoch 065/100 | lr=0.04830 | train_acc=66.66% | val_acc=67.12% | val_loss=1.4124


                                                                               

[exp2_strong_aug] Epoch 066/100 | lr=0.04600 | train_acc=64.23% | val_acc=67.18% | val_loss=1.4149


                                                                               

[exp2_strong_aug] Epoch 067/100 | lr=0.04372 | train_acc=66.27% | val_acc=67.38% | val_loss=1.4165


                                                                               

[exp2_strong_aug] Epoch 068/100 | lr=0.04147 | train_acc=68.23% | val_acc=67.46% | val_loss=1.4178


                                                                               

[exp2_strong_aug] Epoch 069/100 | lr=0.03924 | train_acc=65.27% | val_acc=67.40% | val_loss=1.4196


                                                                               

[exp2_strong_aug] Epoch 070/100 | lr=0.03705 | train_acc=66.39% | val_acc=67.38% | val_loss=1.4219


                                                                               

[exp2_strong_aug] Epoch 071/100 | lr=0.03489 | train_acc=67.83% | val_acc=67.42% | val_loss=1.4243


                                                                               

[exp2_strong_aug] Epoch 072/100 | lr=0.03277 | train_acc=65.85% | val_acc=67.36% | val_loss=1.4271


                                                                               

[exp2_strong_aug] Epoch 073/100 | lr=0.03069 | train_acc=65.92% | val_acc=67.34% | val_loss=1.4297


                                                                               

[exp2_strong_aug] Epoch 074/100 | lr=0.02866 | train_acc=66.55% | val_acc=67.44% | val_loss=1.4327
[exp2_strong_aug] Early stopping at epoch 74 (best 54, val_loss 1.4065).
Finalizing exp2_strong_aug with SWA update...




[exp2_strong_aug] SWA val_loss=4.6059, val_acc=0.98% | best EMA loss=1.4065
[exp2_strong_aug] Keeping EMA weights.


=== Running exp3_fast_opt ===


                                                                               

[exp3_fast_opt] Epoch 001/100 | lr=0.00503 | train_acc=2.80% | val_acc=0.86% | val_loss=4.7181


                                                                               

[exp3_fast_opt] Epoch 002/100 | lr=0.01003 | train_acc=9.88% | val_acc=6.12% | val_loss=4.3543


                                                                               

[exp3_fast_opt] Epoch 003/100 | lr=0.01503 | train_acc=15.33% | val_acc=14.58% | val_loss=3.9385


                                                                               

[exp3_fast_opt] Epoch 004/100 | lr=0.02003 | train_acc=21.08% | val_acc=20.64% | val_loss=3.5620


                                                                               

[exp3_fast_opt] Epoch 005/100 | lr=0.02503 | train_acc=23.95% | val_acc=25.28% | val_loss=3.2511


                                                                               

[exp3_fast_opt] Epoch 006/100 | lr=0.03003 | train_acc=25.65% | val_acc=29.54% | val_loss=2.9901


                                                                               

[exp3_fast_opt] Epoch 007/100 | lr=0.03503 | train_acc=32.36% | val_acc=34.50% | val_loss=2.7566


                                                                               

[exp3_fast_opt] Epoch 008/100 | lr=0.04003 | train_acc=37.82% | val_acc=38.28% | val_loss=2.5487


                                                                               

[exp3_fast_opt] Epoch 009/100 | lr=0.04503 | train_acc=37.30% | val_acc=42.56% | val_loss=2.3721


                                                                               

[exp3_fast_opt] Epoch 010/100 | lr=0.05003 | train_acc=45.06% | val_acc=45.56% | val_loss=2.2234


                                                                               

[exp3_fast_opt] Epoch 011/100 | lr=0.05503 | train_acc=43.09% | val_acc=48.36% | val_loss=2.0980


                                                                               

[exp3_fast_opt] Epoch 012/100 | lr=0.06003 | train_acc=46.71% | val_acc=50.44% | val_loss=1.9949


                                                                               

[exp3_fast_opt] Epoch 013/100 | lr=0.06503 | train_acc=49.16% | val_acc=52.40% | val_loss=1.9049


                                                                               

[exp3_fast_opt] Epoch 014/100 | lr=0.07003 | train_acc=54.75% | val_acc=54.46% | val_loss=1.8295


                                                                               

[exp3_fast_opt] Epoch 015/100 | lr=0.07503 | train_acc=55.86% | val_acc=55.94% | val_loss=1.7654


                                                                               

[exp3_fast_opt] Epoch 016/100 | lr=0.08003 | train_acc=52.69% | val_acc=57.00% | val_loss=1.7174


                                                                               

[exp3_fast_opt] Epoch 017/100 | lr=0.08503 | train_acc=51.58% | val_acc=58.28% | val_loss=1.6783


                                                                               

[exp3_fast_opt] Epoch 018/100 | lr=0.09003 | train_acc=63.14% | val_acc=59.24% | val_loss=1.6393


                                                                               

[exp3_fast_opt] Epoch 019/100 | lr=0.09503 | train_acc=57.23% | val_acc=59.98% | val_loss=1.6089


                                                                               

[exp3_fast_opt] Epoch 020/100 | lr=0.10000 | train_acc=59.51% | val_acc=60.50% | val_loss=1.5838


                                                                               

[exp3_fast_opt] Epoch 021/100 | lr=0.09996 | train_acc=56.50% | val_acc=61.06% | val_loss=1.5603


                                                                               

[exp3_fast_opt] Epoch 022/100 | lr=0.09985 | train_acc=59.77% | val_acc=61.64% | val_loss=1.5404


                                                                               

[exp3_fast_opt] Epoch 023/100 | lr=0.09965 | train_acc=63.82% | val_acc=62.08% | val_loss=1.5180


                                                                               

[exp3_fast_opt] Epoch 024/100 | lr=0.09938 | train_acc=59.76% | val_acc=62.90% | val_loss=1.5042


                                                                               

[exp3_fast_opt] Epoch 025/100 | lr=0.09904 | train_acc=56.76% | val_acc=63.18% | val_loss=1.4924


                                                                               

[exp3_fast_opt] Epoch 026/100 | lr=0.09862 | train_acc=63.44% | val_acc=63.00% | val_loss=1.4846


                                                                               

[exp3_fast_opt] Epoch 027/100 | lr=0.09812 | train_acc=67.59% | val_acc=63.14% | val_loss=1.4826


                                                                               

[exp3_fast_opt] Epoch 028/100 | lr=0.09755 | train_acc=64.49% | val_acc=63.04% | val_loss=1.4808


                                                                               

[exp3_fast_opt] Epoch 029/100 | lr=0.09691 | train_acc=67.45% | val_acc=63.42% | val_loss=1.4795


                                                                               

[exp3_fast_opt] Epoch 030/100 | lr=0.09619 | train_acc=72.11% | val_acc=63.88% | val_loss=1.4798


                                                                               

[exp3_fast_opt] Epoch 031/100 | lr=0.09541 | train_acc=69.75% | val_acc=63.92% | val_loss=1.4835


                                                                               

[exp3_fast_opt] Epoch 032/100 | lr=0.09455 | train_acc=60.54% | val_acc=64.22% | val_loss=1.4875


                                                                               

[exp3_fast_opt] Epoch 033/100 | lr=0.09363 | train_acc=70.58% | val_acc=63.96% | val_loss=1.4912


                                                                               

[exp3_fast_opt] Epoch 034/100 | lr=0.09263 | train_acc=62.83% | val_acc=64.30% | val_loss=1.4980


                                                                               

[exp3_fast_opt] Epoch 035/100 | lr=0.09157 | train_acc=70.71% | val_acc=64.06% | val_loss=1.5012


                                                                               

[exp3_fast_opt] Epoch 036/100 | lr=0.09045 | train_acc=66.23% | val_acc=64.30% | val_loss=1.5033


                                                                               

[exp3_fast_opt] Epoch 037/100 | lr=0.08927 | train_acc=68.52% | val_acc=63.90% | val_loss=1.5058


                                                                               

[exp3_fast_opt] Epoch 038/100 | lr=0.08802 | train_acc=67.92% | val_acc=64.08% | val_loss=1.5112


                                                                               

[exp3_fast_opt] Epoch 039/100 | lr=0.08672 | train_acc=66.92% | val_acc=64.22% | val_loss=1.5145


                                                                               

[exp3_fast_opt] Epoch 040/100 | lr=0.08536 | train_acc=67.41% | val_acc=64.64% | val_loss=1.5193


                                                                               

[exp3_fast_opt] Epoch 041/100 | lr=0.08394 | train_acc=68.09% | val_acc=64.94% | val_loss=1.5165


                                                                               

[exp3_fast_opt] Epoch 042/100 | lr=0.08247 | train_acc=70.31% | val_acc=65.14% | val_loss=1.5124


                                                                               

[exp3_fast_opt] Epoch 043/100 | lr=0.08096 | train_acc=67.90% | val_acc=64.86% | val_loss=1.5125


                                                                               

[exp3_fast_opt] Epoch 044/100 | lr=0.07939 | train_acc=68.29% | val_acc=65.16% | val_loss=1.5149


                                                                               

[exp3_fast_opt] Epoch 045/100 | lr=0.07778 | train_acc=62.18% | val_acc=64.82% | val_loss=1.5189


                                                                               

[exp3_fast_opt] Epoch 046/100 | lr=0.07613 | train_acc=68.56% | val_acc=64.76% | val_loss=1.5199


                                                                               

[exp3_fast_opt] Epoch 047/100 | lr=0.07443 | train_acc=68.11% | val_acc=64.80% | val_loss=1.5228


                                                                               

[exp3_fast_opt] Epoch 048/100 | lr=0.07270 | train_acc=68.71% | val_acc=64.88% | val_loss=1.5200


                                                                               

[exp3_fast_opt] Epoch 049/100 | lr=0.07094 | train_acc=69.94% | val_acc=64.84% | val_loss=1.5193
[exp3_fast_opt] Early stopping at epoch 49 (best 29, val_loss 1.4795).
Finalizing exp3_fast_opt with SWA update...
[exp3_fast_opt] SWA val_loss=4.6058, val_acc=0.94% | best EMA loss=1.4795
[exp3_fast_opt] Keeping EMA weights.


=== Running exp4_light_aug_fast_lr ===


                                                                               

[exp4_light_aug_fast_lr] Epoch 001/100 | lr=0.00503 | train_acc=2.77% | val_acc=0.68% | val_loss=4.8111


                                                                               

[exp4_light_aug_fast_lr] Epoch 002/100 | lr=0.01003 | train_acc=11.77% | val_acc=3.20% | val_loss=4.4673


                                                                               

[exp4_light_aug_fast_lr] Epoch 003/100 | lr=0.01503 | train_acc=17.75% | val_acc=12.00% | val_loss=4.0674


                                                                               

[exp4_light_aug_fast_lr] Epoch 004/100 | lr=0.02003 | train_acc=22.98% | val_acc=19.04% | val_loss=3.6829


                                                                               

[exp4_light_aug_fast_lr] Epoch 005/100 | lr=0.02503 | train_acc=27.52% | val_acc=24.70% | val_loss=3.3433


                                                                               

[exp4_light_aug_fast_lr] Epoch 006/100 | lr=0.03003 | train_acc=31.66% | val_acc=29.38% | val_loss=3.0514


                                                                               

[exp4_light_aug_fast_lr] Epoch 007/100 | lr=0.03503 | train_acc=34.82% | val_acc=33.46% | val_loss=2.8028


                                                                               

[exp4_light_aug_fast_lr] Epoch 008/100 | lr=0.04003 | train_acc=37.88% | val_acc=37.36% | val_loss=2.5922


                                                                               

[exp4_light_aug_fast_lr] Epoch 009/100 | lr=0.04503 | train_acc=40.94% | val_acc=41.34% | val_loss=2.4121


                                                                               

[exp4_light_aug_fast_lr] Epoch 010/100 | lr=0.05003 | train_acc=44.94% | val_acc=44.74% | val_loss=2.2617


                                                                               

[exp4_light_aug_fast_lr] Epoch 011/100 | lr=0.05503 | train_acc=46.61% | val_acc=47.44% | val_loss=2.1396


                                                                               

[exp4_light_aug_fast_lr] Epoch 012/100 | lr=0.06003 | train_acc=52.47% | val_acc=49.70% | val_loss=2.0357


                                                                               

[exp4_light_aug_fast_lr] Epoch 013/100 | lr=0.06503 | train_acc=53.04% | val_acc=51.34% | val_loss=1.9504


                                                                               

[exp4_light_aug_fast_lr] Epoch 014/100 | lr=0.07003 | train_acc=55.30% | val_acc=52.42% | val_loss=1.8826


                                                                               

[exp4_light_aug_fast_lr] Epoch 015/100 | lr=0.07503 | train_acc=60.52% | val_acc=53.88% | val_loss=1.8242


                                                                               

[exp4_light_aug_fast_lr] Epoch 016/100 | lr=0.08003 | train_acc=59.08% | val_acc=55.14% | val_loss=1.7762


                                                                               

[exp4_light_aug_fast_lr] Epoch 017/100 | lr=0.08503 | train_acc=58.86% | val_acc=55.94% | val_loss=1.7394


                                                                               

[exp4_light_aug_fast_lr] Epoch 018/100 | lr=0.09003 | train_acc=58.59% | val_acc=56.42% | val_loss=1.7082


                                                                               

[exp4_light_aug_fast_lr] Epoch 019/100 | lr=0.09503 | train_acc=62.65% | val_acc=57.34% | val_loss=1.6794


                                                                               

[exp4_light_aug_fast_lr] Epoch 020/100 | lr=0.10000 | train_acc=59.40% | val_acc=58.00% | val_loss=1.6542


                                                                               

[exp4_light_aug_fast_lr] Epoch 021/100 | lr=0.09996 | train_acc=67.48% | val_acc=58.60% | val_loss=1.6274


                                                                               

[exp4_light_aug_fast_lr] Epoch 022/100 | lr=0.09985 | train_acc=70.12% | val_acc=59.56% | val_loss=1.6087


                                                                               

[exp4_light_aug_fast_lr] Epoch 023/100 | lr=0.09965 | train_acc=67.71% | val_acc=60.44% | val_loss=1.5911


                                                                               

[exp4_light_aug_fast_lr] Epoch 024/100 | lr=0.09938 | train_acc=70.93% | val_acc=60.80% | val_loss=1.5794


                                                                               

[exp4_light_aug_fast_lr] Epoch 025/100 | lr=0.09904 | train_acc=66.52% | val_acc=61.18% | val_loss=1.5672


                                                                               

[exp4_light_aug_fast_lr] Epoch 026/100 | lr=0.09862 | train_acc=69.54% | val_acc=61.38% | val_loss=1.5577


                                                                               

[exp4_light_aug_fast_lr] Epoch 027/100 | lr=0.09812 | train_acc=72.07% | val_acc=61.68% | val_loss=1.5527


                                                                               

[exp4_light_aug_fast_lr] Epoch 028/100 | lr=0.09755 | train_acc=72.38% | val_acc=61.72% | val_loss=1.5462


                                                                               

[exp4_light_aug_fast_lr] Epoch 029/100 | lr=0.09691 | train_acc=70.25% | val_acc=62.22% | val_loss=1.5394


                                                                               

[exp4_light_aug_fast_lr] Epoch 030/100 | lr=0.09619 | train_acc=76.34% | val_acc=62.28% | val_loss=1.5363


                                                                               

[exp4_light_aug_fast_lr] Epoch 031/100 | lr=0.09541 | train_acc=76.08% | val_acc=62.60% | val_loss=1.5384


                                                                               

[exp4_light_aug_fast_lr] Epoch 032/100 | lr=0.09455 | train_acc=75.23% | val_acc=62.72% | val_loss=1.5378


                                                                               

[exp4_light_aug_fast_lr] Epoch 033/100 | lr=0.09363 | train_acc=75.64% | val_acc=62.70% | val_loss=1.5377


                                                                               

[exp4_light_aug_fast_lr] Epoch 034/100 | lr=0.09263 | train_acc=72.58% | val_acc=63.14% | val_loss=1.5330


                                                                               

[exp4_light_aug_fast_lr] Epoch 035/100 | lr=0.09157 | train_acc=74.84% | val_acc=63.40% | val_loss=1.5339


                                                                               

[exp4_light_aug_fast_lr] Epoch 036/100 | lr=0.09045 | train_acc=68.08% | val_acc=63.26% | val_loss=1.5311


                                                                               

[exp4_light_aug_fast_lr] Epoch 037/100 | lr=0.08927 | train_acc=72.46% | val_acc=63.44% | val_loss=1.5348


                                                                               

[exp4_light_aug_fast_lr] Epoch 038/100 | lr=0.08802 | train_acc=78.38% | val_acc=63.38% | val_loss=1.5366


                                                                               

[exp4_light_aug_fast_lr] Epoch 039/100 | lr=0.08672 | train_acc=74.78% | val_acc=63.36% | val_loss=1.5395


                                                                               

[exp4_light_aug_fast_lr] Epoch 040/100 | lr=0.08536 | train_acc=63.19% | val_acc=63.86% | val_loss=1.5413


                                                                               

[exp4_light_aug_fast_lr] Epoch 041/100 | lr=0.08394 | train_acc=75.61% | val_acc=63.54% | val_loss=1.5419


                                                                               

[exp4_light_aug_fast_lr] Epoch 042/100 | lr=0.08247 | train_acc=79.93% | val_acc=63.30% | val_loss=1.5461


                                                                               

[exp4_light_aug_fast_lr] Epoch 043/100 | lr=0.08096 | train_acc=74.80% | val_acc=63.62% | val_loss=1.5498


                                                                               

[exp4_light_aug_fast_lr] Epoch 044/100 | lr=0.07939 | train_acc=73.33% | val_acc=63.76% | val_loss=1.5530


                                                                               

[exp4_light_aug_fast_lr] Epoch 045/100 | lr=0.07778 | train_acc=71.59% | val_acc=63.64% | val_loss=1.5540


                                                                               

[exp4_light_aug_fast_lr] Epoch 046/100 | lr=0.07613 | train_acc=73.72% | val_acc=63.86% | val_loss=1.5563


                                                                               

[exp4_light_aug_fast_lr] Epoch 047/100 | lr=0.07443 | train_acc=72.16% | val_acc=63.80% | val_loss=1.5584


                                                                               

[exp4_light_aug_fast_lr] Epoch 048/100 | lr=0.07270 | train_acc=73.17% | val_acc=64.02% | val_loss=1.5566


                                                                               

[exp4_light_aug_fast_lr] Epoch 049/100 | lr=0.07094 | train_acc=73.63% | val_acc=63.88% | val_loss=1.5567


                                                                               

[exp4_light_aug_fast_lr] Epoch 050/100 | lr=0.06914 | train_acc=74.71% | val_acc=63.90% | val_loss=1.5603


                                                                               

[exp4_light_aug_fast_lr] Epoch 051/100 | lr=0.06731 | train_acc=73.45% | val_acc=63.84% | val_loss=1.5575


                                                                               

[exp4_light_aug_fast_lr] Epoch 052/100 | lr=0.06545 | train_acc=73.99% | val_acc=64.14% | val_loss=1.5588


                                                                               

[exp4_light_aug_fast_lr] Epoch 053/100 | lr=0.06358 | train_acc=76.31% | val_acc=64.32% | val_loss=1.5633


                                                                               

[exp4_light_aug_fast_lr] Epoch 054/100 | lr=0.06168 | train_acc=75.07% | val_acc=64.24% | val_loss=1.5632


                                                                               

[exp4_light_aug_fast_lr] Epoch 055/100 | lr=0.05976 | train_acc=79.97% | val_acc=64.14% | val_loss=1.5682


                                                                               

[exp4_light_aug_fast_lr] Epoch 056/100 | lr=0.05783 | train_acc=73.98% | val_acc=64.40% | val_loss=1.5694
[exp4_light_aug_fast_lr] Early stopping at epoch 56 (best 36, val_loss 1.5311).
Finalizing exp4_light_aug_fast_lr with SWA update...




[exp4_light_aug_fast_lr] SWA val_loss=4.6061, val_acc=0.96% | best EMA loss=1.5311
[exp4_light_aug_fast_lr] Keeping EMA weights.

All 4 experiments completed successfully ✅


In [4]:
# BLOCK 4 — Robust Inference (multi-experiment TTA with BN recalibration and bias debiasing)

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.optim.swa_utils import update_bn

def preprocess_for_model(x: torch.Tensor) -> torch.Tensor:
    x = x.to(device, non_blocking=True)
    if x.ndim == 4 and x.shape[1] != 3 and x.shape[-1] == 3:
        x = x.permute(0, 3, 1, 2).contiguous()
    return x.float()


@torch.inference_mode()
def inference_tta(model):
    """Performs TTA inference with adaptive BN + logit debiasing."""
    # ✅ Recalibrate BatchNorm statistics before inference
    print("↻ Updating BatchNorm stats on test loader for better TTA consistency...")
    update_bn(test_loader, model, device=device)

    model.train()  # allow BN to adapt for each TTA variation
    preds = []

    tta_transforms = [
        v2.Identity(),
        v2.RandomHorizontalFlip(p=1.0),
        v2.RandomAffine(degrees=0, translate=(0.1, 0), scale=None),
        v2.RandomAffine(degrees=0, translate=(0, 0.1), scale=None),
    ]

    for inputs, _ in tqdm(test_loader, leave=False):
        x_raw = inputs
        augmented_logits = []
        for t in tta_transforms:
            # Apply transform safely
            x = torch.stack([t(xi) for xi in x_raw])
            x = preprocess_for_model(x)

            # ✅ Force FP32 precision for stable inference
            with torch.no_grad():
                logits = model(x)

                # ✅ Debias the final linear layer
                last_linear = next((mod for mod in reversed(list(model.modules())) if isinstance(mod, nn.Linear)), None)
                if last_linear is not None and hasattr(last_linear, "bias") and last_linear.bias is not None:
                    bias_vec = last_linear.bias.detach().view(1, -1)
                    logits = logits - bias_vec

                augmented_logits.append(logits)

        # Average predictions from all TTA views
        avg_logits = torch.stack(augmented_logits).mean(dim=0)
        preds.extend(torch.argmax(avg_logits, dim=1).tolist())

    model.eval()
    return preds


# === Run inference for all experiments ===
for cfg in experiments:
    exp_dir = f"/kaggle/working/{cfg['name']}"
    best_path = f"{exp_dir}/best_vgg13.pth"

    if not os.path.exists(best_path):
        print(f"[{cfg['name']}] ⚠️ Skipped — no best_vgg13.pth found.")
        continue

    print(f"\n=== Inference for {cfg['name']} ===")

    # ✅ Load model safely in FP32
    model = VGG13().to(device).float()
    state_dict = torch.load(best_path, map_location=device)
    model.load_state_dict(state_dict, strict=True)
    model.eval()

    # ✅ Optional TorchScript compilation
    scripted = torch.jit.script(model)
    scripted_path = f"{exp_dir}/best_vgg13_scripted.pt"
    scripted.save(scripted_path)
    print(f"[{cfg['name']}] Saved scripted model → {scripted_path}")

    torch.cuda.empty_cache()

    # ✅ Run robust TTA inference
    preds = inference_tta(model)

    # 🔍 Fallback check: if model collapsed to single-class output
    vals, cnts = np.unique(np.array(preds), return_counts=True)
    if len(vals) == 1:
        print(f"⚠️ Warning: model predicted only one class ({vals[0]}). Check training stability.")
        # Try disabling bias correction and re-run once
        preds = []
        model.eval()
        for inputs, _ in tqdm(test_loader, leave=False):
            x = preprocess_for_model(inputs)
            with torch.no_grad():
                logits = model(x)
            preds.extend(torch.argmax(logits, dim=1).tolist())
        vals, cnts = np.unique(np.array(preds), return_counts=True)

    # === Log results ===
    print(f"[{cfg['name']}] Sample preds (first 20): {preds[:20]}")
    hist = list(zip(vals[:10].tolist(), cnts[:10].tolist()))
    print(f"[{cfg['name']}] Pred histogram (top 10): {hist}")

    # === Save submission ===
    out_csv = f"{exp_dir}/submission.csv"
    pd.DataFrame({"ID": list(range(len(preds))), "target": preds}).to_csv(out_csv, index=False)
    print(f"[{cfg['name']}] Saved submission → {out_csv}")

print("\n✅ All experiment inferences completed successfully with stable TTA and BN recalibration.")



=== Inference for exp1_baseline ===
[exp1_baseline] Saved scripted model → /kaggle/working/exp1_baseline/best_vgg13_scripted.pt
↻ Updating BatchNorm stats on test loader for better TTA consistency...


                                               

[exp1_baseline] Sample preds (first 20): [98, 50, 61, 6, 7, 33, 15, 21, 0, 98, 28, 25, 60, 27, 3, 5, 19, 57, 83, 40]
[exp1_baseline] Pred histogram (top 10): [(0, 99), (1, 92), (2, 115), (3, 114), (4, 93), (5, 107), (6, 94), (7, 100), (8, 96), (9, 93)]
[exp1_baseline] Saved submission → /kaggle/working/exp1_baseline/submission.csv

=== Inference for exp2_strong_aug ===
[exp2_strong_aug] Saved scripted model → /kaggle/working/exp2_strong_aug/best_vgg13_scripted.pt
↻ Updating BatchNorm stats on test loader for better TTA consistency...


                                               

[exp2_strong_aug] Sample preds (first 20): [98, 50, 61, 6, 59, 33, 15, 21, 94, 98, 28, 41, 60, 27, 3, 83, 19, 57, 83, 40]
[exp2_strong_aug] Pred histogram (top 10): [(0, 100), (1, 74), (2, 103), (3, 132), (4, 97), (5, 120), (6, 95), (7, 85), (8, 94), (9, 94)]
[exp2_strong_aug] Saved submission → /kaggle/working/exp2_strong_aug/submission.csv

=== Inference for exp3_fast_opt ===
[exp3_fast_opt] Saved scripted model → /kaggle/working/exp3_fast_opt/best_vgg13_scripted.pt
↻ Updating BatchNorm stats on test loader for better TTA consistency...


                                               

[exp3_fast_opt] Sample preds (first 20): [98, 50, 61, 6, 64, 33, 18, 21, 98, 98, 28, 41, 60, 27, 3, 5, 19, 57, 83, 40]
[exp3_fast_opt] Pred histogram (top 10): [(0, 99), (1, 86), (2, 100), (3, 113), (4, 101), (5, 108), (6, 94), (7, 89), (8, 100), (9, 82)]
[exp3_fast_opt] Saved submission → /kaggle/working/exp3_fast_opt/submission.csv

=== Inference for exp4_light_aug_fast_lr ===
[exp4_light_aug_fast_lr] Saved scripted model → /kaggle/working/exp4_light_aug_fast_lr/best_vgg13_scripted.pt
↻ Updating BatchNorm stats on test loader for better TTA consistency...


                                               

[exp4_light_aug_fast_lr] Sample preds (first 20): [1, 50, 61, 6, 59, 33, 18, 9, 98, 98, 28, 41, 60, 27, 3, 5, 19, 57, 83, 40]
[exp4_light_aug_fast_lr] Pred histogram (top 10): [(0, 90), (1, 95), (2, 100), (3, 123), (4, 87), (5, 110), (6, 93), (7, 76), (8, 115), (9, 97)]
[exp4_light_aug_fast_lr] Saved submission → /kaggle/working/exp4_light_aug_fast_lr/submission.csv

✅ All experiment inferences completed successfully with stable TTA and BN recalibration.


