In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassNegativeLogLikelihood,
    MulticlassBrierScoreLoss,
    MulticlassCalibrationError,
)
from pathlib import Path
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader  # type: ignore
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ---- Data --------------------------------------------------------------------
ROOT = Path("data")          # adjust as needed
BATCH_TRAIN = 256
BATCH_TEST  = 128

train_loader = tiny_imagenet_loader(ROOT, split="train",
                                    batch_size=BATCH_TRAIN, shuffle=True,  num_workers=8)
val_loader   = tiny_imagenet_loader(ROOT, split="val",
                                    batch_size=BATCH_TEST,  shuffle=False, num_workers=8)
test_loader  = tiny_imagenet_loader(ROOT, split="test",
                                    batch_size=BATCH_TEST,  shuffle=False, num_workers=8)

corrupt_loader = tiny_imagenet_corrupted_loader(
    ROOT, severity_levels=[1,2,3,4,5],
    batch_size=BATCH_TEST, shuffle=False, num_workers=8)
NUM_CLASSES = 200

In [None]:
# ---- α‑BatchNorm --------------------------------------------------------------
class AlphaBN(nn.BatchNorm2d):
    """BatchNorm2d that fuses source & target stats at test‑time."""
    def __init__(self, num_features, alpha=0.9, **kwargs):
        super().__init__(num_features, affine=True, track_running_stats=True, **kwargs)
        self.alpha = alpha

    def forward(self, x):
        if self.training:               # standard BN in training
            return super().forward(x)

        # Evaluation → blend source stats with batch stats
        batch_mean = x.mean([0, 2, 3])
        batch_var  = x.var([0, 2, 3], unbiased=False)

        # batch_mean is the target mean
        # self.running_mean is the source mean
        mean = self.alpha * self.running_mean + (1 - self.alpha) * batch_mean
        var  = self.alpha * self.running_var  + (1 - self.alpha) * batch_var

        return F.batch_norm(x, mean, var, self.weight, self.bias,
                            False, 0.0, self.eps)

In [None]:
def convert_to_alpha_bn(module, alpha=0.9):
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, AlphaBN(child.num_features, alpha=alpha))
        else:
            convert_to_alpha_bn(child, alpha)

In [None]:
# ---- Core loss ----------------------------------------------------------------
def core_loss(logits):
    probs = logits.softmax(dim=1)        # (B, C)
    m = probs.mean(dim=0)                # (C,)
    outer = torch.outer(m, m)
    return outer.sum() - outer.diag().sum()

In [None]:
# ---- Train baseline -----------------------------------------------------------
def train_baseline(epochs=90, lr=0.1):
    model = torchvision.models.resnet50(weights='IMAGENET1K_V2')
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model = model.to(device)

    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            loss = criterion(model(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sched.step()
        if (epoch+1) % 10 == 0 or epoch==0:
            print(f\"Epoch {epoch+1}/{epochs} done\")
    return model

In [None]:
# ---- Adapt BN affine params on Tiny‑ImageNet‑C --------------------------------
def adapt_alpha_bn(model, alpha=0.9, lr=1e-3, epochs=1):
    convert_to_alpha_bn(model, alpha)
    for p in model.parameters(): p.requires_grad_(False)
    bn_params = [m.weight for m in model.modules() if isinstance(m, AlphaBN)] + \\
                [m.bias   for m in model.modules() if isinstance(m, AlphaBN)]
    for p in bn_params: p.requires_grad_(True)
    opt = torch.optim.Adam(bn_params, lr=lr)

    model.eval()
    for _ in range(epochs):
        for x, _ in corrupt_loader:
            x = x.to(device)
            logits = model(x)
            loss = core_loss(logits)
            opt.zero_grad(); loss.backward(); opt.step()
    return model

In [None]:
# ---- Metrics ------------------------------------------------------------------
def evaluate(model, loader):
    model.eval()
    acc = MulticlassAccuracy(NUM_CLASSES).to(device)
    nll = MulticlassNegativeLogLikelihood(NUM_CLASSES).to(device)
    brier = MulticlassBrierScoreLoss(NUM_CLASSES).to(device)
    ece = MulticlassCalibrationError(NUM_CLASSES, n_bins=15, norm='l1').to(device)

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs  = logits.softmax(dim=1)

            acc.update(probs, y)
            nll.update(logits, y)
            brier.update(probs, y)
            ece.update(probs, y)

    print(f\"Accuracy                : {acc.compute():.4f}\")
    print(f\"Negative log-likelihood : {nll.compute():.4f}\")
    print(f\"Brier score             : {brier.compute():.4f}\")
    print(f\"Expected calibration err: {ece.compute():.4f}\")

In [None]:
# ---- Full run -----------------------------------------------------------------
# 1. Train or load baseline
baseline_ckpt = Path('baseline_resnet50.pth')
if baseline_ckpt.exists():
    model = torchvision.models.resnet50(weights=None)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model.load_state_dict(torch.load(baseline_ckpt, map_location=device))
    model = model.to(device)
    print('Loaded pretrained baseline.')
else:
    model = train_baseline()
    torch.save(model.state_dict(), baseline_ckpt)

print('\\nPerformance on clean test set (before adaptation)')
evaluate(model, test_loader)

# 2. α‑BN + Core adaptation
model = adapt_alpha_bn(model, alpha=0.9, lr=1e-3, epochs=1)

print('\\nPerformance on Tiny‑ImageNet‑C (after adaptation)')
evaluate(model, corrupt_loader)