In [1]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.classification import (
    MulticlassAccuracy,
    # MulticlassNegativeLogLikelihood,
    # MulticlassBrierScoreLoss,
    MulticlassCalibrationError,
)
from pathlib import Path
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader  # type: ignore
device = "cuda" if torch.cuda.is_available() else "cpu"

from tqdm.notebook import tqdm

In [2]:
# ---- Data --------------------------------------------------------------------
BATCH_TRAIN = 128
BATCH_TEST  = 128

train_loader = tiny_imagenet_loader(split="train", batch_size=BATCH_TRAIN)
val_loader   = tiny_imagenet_loader(split="val", batch_size=BATCH_TEST)
test_loader   = tiny_imagenet_loader(split="test", batch_size=BATCH_TEST)

NUM_CLASSES = 200

In [3]:
# ---- α‑BatchNorm --------------------------------------------------------------
class AlphaBN(nn.BatchNorm2d):
    """BatchNorm2d that fuses source & target stats at test‑time."""
    def __init__(self, num_features, alpha=0.9, **kwargs):
        super().__init__(num_features, affine=True, track_running_stats=True, **kwargs)
        self.alpha = alpha

    def forward(self, x):
        if self.training:               # standard BN in training
            return super().forward(x)

        # Evaluation → blend source stats with batch stats
        batch_mean = x.mean([0, 2, 3])
        batch_var  = x.var([0, 2, 3], unbiased=False)

        # batch_mean is the target mean
        # self.running_mean is the source mean
        mean = self.alpha * self.running_mean + (1 - self.alpha) * batch_mean
        var  = self.alpha * self.running_var  + (1 - self.alpha) * batch_var

        return F.batch_norm(x, mean, var, self.weight, self.bias,
                            False, 0.0, self.eps)

In [4]:
def convert_to_alpha_bn(module, alpha=0.9):
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, AlphaBN(child.num_features, alpha=alpha))
        else:
            convert_to_alpha_bn(child, alpha)

In [5]:
# ---- Core loss ----------------------------------------------------------------
def core_loss(logits):
    probs = logits.softmax(dim=1)        # (B, C)
    m = probs.mean(dim=0)                # (C,)
    outer = torch.outer(m, m)
    return outer.sum() - outer.diag().sum()

In [6]:
# ---- Train baseline -----------------------------------------------------------
def train_baseline(epochs=90, lr=0.1):
    model = torchvision.models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model = model.to(device)

    best_val_loss = None

    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            loss = loss_fn(model(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sched.step()

        # Periodically validate and save in case of crash
        if epoch % 5 == 0:
            model.eval()
            val_acc, val_loss, val_ece = evaluate(model, val_loader, print_results=False)
            model.train()

            if not best_val_loss or val_loss < best_val_loss:
                torch.save(model.state_dict(), f"./models/bn_train2.pt")
                best_val = val_acc
            elif val_loss > best_val_loss + 1:
                print("Early stop because of degredation in validation")
                break
            
            print(f"Training loss: {loss.item()}")
            print(f"Validation loss: {val_loss}")
            print(f"Validation accuracy: {val_acc}")
            

    return model

In [7]:
# ---- Adapt BN affine params on Tiny‑ImageNet‑C --------------------------------
def adapt_alpha_bn(model, test_loader, device, alpha=0.9, lr=1e-3, epochs=1):
    convert_to_alpha_bn(model, alpha)
    for p in model.parameters(): p.requires_grad_(False)
    bn_params = [m.weight for m in model.modules() if isinstance(m, AlphaBN)] + \
                [m.bias   for m in model.modules() if isinstance(m, AlphaBN)]
    for p in bn_params: p.requires_grad_(True)
    opt = torch.optim.Adam(bn_params, lr=lr)

    model.eval()
    for _ in range(epochs):
        for x, _ in test_loader:
            x = x.to(device)
            logits = model(x)
            loss = core_loss(logits)
            opt.zero_grad(); loss.backward(); opt.step()
    return model

In [8]:
# ---- Metrics ------------------------------------------------------------------
def evaluate(model, loader, print_results=True):
    model.eval()
    acc = MulticlassAccuracy(NUM_CLASSES).to(device)
    ece = MulticlassCalibrationError(NUM_CLASSES, n_bins=15, norm='l1').to(device)
    loss_fn = torch.nn.CrossEntropyLoss()

    total_loss = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs  = logits.softmax(dim=1)

            acc.update(probs, y)
            total_loss += loss_fn(probs, logits).item()
            total_samples += y.size(0)
            ece.update(probs, y)

    loss = total_loss / total_samples

    if print_results:
        print(f"Accuracy                : {acc.compute():.4f}")
        print(f"Cross-Entropy Loss      : {loss:.4f}")
        print(f"Expected calibration err: {ece.compute():.4f}")

    return acc.compute(), loss, ece.compute()

In [9]:
# ---- Full run -----------------------------------------------------------------
# 1. Train or load baseline
baseline_ckpt = Path('models/bn_baseline_model2.pt')
if baseline_ckpt.exists():
    print('Loading pretrained model')
    model = torchvision.models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model.load_state_dict(torch.load(baseline_ckpt, map_location=device))
    model = model.to(device)
else:
    print('Training model froms scratch')
    model = train_baseline(epochs=100)
    torch.save(model.state_dict(), baseline_ckpt)

print('\nPerformance on clean test set (before adaptation)')
evaluate(model, train_loader)
evaluate(model, test_loader)

Training model froms scratch


Training:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# 2. α‑BN + Core adaptation (look at brightness-1 for example)

corrupt_loader = tiny_imagenet_corrupted_loader(
    corruption='brightness', 
    severity=1, 
    batch_size=128
)

adapted_model = adapt_alpha_bn(model, corrupt_loader, device, alpha=0.9, lr=1e-3, epochs=1)
adapted_model.to(device)

print('\\nPerformance on Tiny‑ImageNet‑C (after adaptation)')
evaluate(adapted_model, corrupt_loader)