In [1]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.classification import (
    MulticlassAccuracy,
    # MulticlassNegativeLogLikelihood,
    # MulticlassBrierScoreLoss,
    MulticlassCalibrationError,
)
from pathlib import Path
from loaders import (
    tiny_imagenet_train_loader, 
    tiny_imagenet_val_loader,
    tiny_imagenet_corrupted_loader
)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

from tqdm.notebook import tqdm

cuda


In [2]:
# ---- Data --------------------------------------------------------------------
BATCH_TRAIN = 128
BATCH_VAL  = 128

train_loader = tiny_imagenet_train_loader(batch_size=BATCH_TRAIN)
val_loader   = tiny_imagenet_val_loader(batch_size=BATCH_VAL)

NUM_CLASSES = 200

In [3]:
# ---- α‑BatchNorm --------------------------------------------------------------
class AlphaBN(nn.BatchNorm2d):
    """BatchNorm2d that fuses source & target stats at test‑time."""
    def __init__(self, num_features, alpha=0.9, **kwargs):
        super().__init__(num_features, affine=True, track_running_stats=True, **kwargs)
        self.alpha = alpha

    def forward(self, x):

        x = x.to(device)
        
        if self.training:               # standard BN in training
            return super().forward(x)

        # Evaluation → blend source stats with batch stats
        batch_mean = x.mean([0, 2, 3])
        batch_var  = x.var([0, 2, 3], unbiased=False)

        # print(self.running_mean.device)
        # print(self.running_var.device)
        # print(batch_mean.device)
        # print(batch_var.device)
        # print(self.weight.device)
        # print(self.bias.device)

        # batch_mean is the target mean
        # self.running_mean is the source mean
        mean = (self.alpha * self.running_mean + (1 - self.alpha) * batch_mean).detach()
        var  = (self.alpha * self.running_var  + (1 - self.alpha) * batch_var).detach()

        return F.batch_norm(x, mean, var, self.weight, self.bias,
                            False, 0.0, self.eps)

In [4]:
def convert_to_alpha_bn(module, alpha=0.9):
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, AlphaBN(child.num_features, alpha=alpha))
        else:
            convert_to_alpha_bn(child, alpha)

In [5]:
# ---- Core loss ----------------------------------------------------------------
def core_loss(logits):
    probs = logits.softmax(dim=1)        # (B, C)
    m = probs.mean(dim=0)                # (C,)
    outer = torch.outer(m, m)
    return outer.sum() - outer.diag().sum()

In [6]:
# ---- Train baseline -----------------------------------------------------------
def train_baseline(epochs=90, lr=0.1):
    model = torchvision.models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model.to(device)

    best_val_loss = None

    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            loss = loss_fn(model(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sched.step()

        # Periodically validate and save in case of crash
        if epoch % 5 == 0:
            model.eval()
            val_acc, val_loss, val_ece = evaluate(model, val_loader, print_results=False)
            model.train()

            if not best_val_loss or val_loss < best_val_loss:
                torch.save(model.state_dict(), f"./models/bn_train3.pt")
                best_val = val_acc
            elif val_loss > best_val_loss + 1:
                print("Early stop because of degredation in validation loss")
                break
            
            print(f"Training loss: {loss.item()}")
            print(f"Validation loss: {val_loss}")
            print(f"Validation accuracy: {val_acc}")
            

    return model

In [7]:
# ---- Adapt BN affine params on Tiny‑ImageNet‑C --------------------------------
def adapt_alpha_bn(model, test_loader, device, alpha=0.9, lr=1e-3, epochs=1):
    convert_to_alpha_bn(model, alpha)
    model.to(device)
    for p in model.parameters(): p.requires_grad_(False)
    bn_params = [m.weight for m in model.modules() if isinstance(m, AlphaBN)] + \
                [m.bias   for m in model.modules() if isinstance(m, AlphaBN)]
    for p in bn_params: p.requires_grad_(True)
    opt = torch.optim.Adam(bn_params, lr=lr)

    model.eval()
    for _ in range(epochs):
        for x, _ in test_loader:
            x = x.to(device)
            logits = model(x)
            loss = core_loss(logits)
            opt.zero_grad(); loss.backward(); opt.step()
    return model

In [8]:
# ---- Metrics ------------------------------------------------------------------
def evaluate(model, loader, print_results=True):
    model.eval()
    ece = MulticlassCalibrationError(NUM_CLASSES, n_bins=15, norm='l1').to(device)
    loss_fn = torch.nn.CrossEntropyLoss()

    total_correct = 0
    total_loss = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in tqdm(loader, desc="Evaluating", leave=False):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs  = logits.softmax(dim=1)
            preds = torch.argmax(probs, dim=1)

            total_correct += (preds==y).sum()
            total_loss += loss_fn(probs, y).item()
            total_samples += y.size(0)
            ece.update(probs, y)

    loss = total_loss / total_samples
    accuracy = total_correct/total_samples

    if print_results:
        print(f"Accuracy                : {accuracy:.4f}")
        print(f"Cross-Entropy Loss      : {loss:.4f}")
        print(f"Expected calibration err: {ece.compute():.4f}")

    print(f"{total_correct} / {total_samples}")

    return accuracy, loss, ece.compute()

In [None]:
# ---- Full run -----------------------------------------------------------------
# 1. Train or load baseline
baseline_ckpt = Path('models/bn_baseline_model3.pt')
if baseline_ckpt.exists():
    print('Loading pretrained model')
    model = torchvision.models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model.load_state_dict(torch.load(baseline_ckpt, map_location=device))
    model.to(device)
else:
    print('Training model froms scratch')
    model = train_baseline(epochs=100)
    torch.save(model.state_dict(), baseline_ckpt)

print('\nTraining Results')
# accl, loss, ece = evaluate(model, train_loader)

print('\nTest Results')
acc, loss, ece = evaluate(model, val_loader)

Training model froms scratch


Training:   0%|          | 0/100 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

33 / 10000
Training loss: 3.572425603866577
Validation loss: 0.04186589994430542
Validation accuracy: 0.0032999999821186066


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

33 / 10000
Training loss: 2.0659303665161133
Validation loss: 0.04187268214225769
Validation accuracy: 0.0032999999821186066


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

44 / 10000
Training loss: 1.18183434009552
Validation loss: 0.0418758994102478
Validation accuracy: 0.004399999976158142


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

51 / 10000
Training loss: 1.0925800800323486
Validation loss: 0.04186864032745361
Validation accuracy: 0.0050999997183680534


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

38 / 10000
Training loss: 0.2996460199356079
Validation loss: 0.04187883596420288
Validation accuracy: 0.0037999998312443495


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

55 / 10000
Training loss: 0.5105670094490051
Validation loss: 0.04187023587226868
Validation accuracy: 0.005499999970197678


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

41 / 10000
Training loss: 0.48891907930374146
Validation loss: 0.04187819495201111
Validation accuracy: 0.004100000020116568


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

54 / 10000
Training loss: 0.5407708287239075
Validation loss: 0.04186956348419189
Validation accuracy: 0.005399999674409628


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

34 / 10000
Training loss: 0.42343729734420776
Validation loss: 0.0418816900730133
Validation accuracy: 0.003399999812245369


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

44 / 10000
Training loss: 0.035032011568546295
Validation loss: 0.041879119968414306
Validation accuracy: 0.004399999976158142


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

39 / 10000
Training loss: 0.01116199605166912
Validation loss: 0.0418783745765686
Validation accuracy: 0.0038999998942017555


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

42 / 10000
Training loss: 0.006530293263494968
Validation loss: 0.04187610306739807
Validation accuracy: 0.00419999985024333


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

45 / 10000
Training loss: 0.007940402254462242
Validation loss: 0.04187455854415893
Validation accuracy: 0.0044999998062849045


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

44 / 10000
Training loss: 0.02485503815114498
Validation loss: 0.041874964570999144
Validation accuracy: 0.004399999976158142


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

41 / 10000
Training loss: 0.02109863981604576
Validation loss: 0.04186923699378967
Validation accuracy: 0.004100000020116568


In [None]:
# 1.5 Test model without adaptation

corrupt_loader = tiny_imagenet_corrupted_loader(
    corruption='brightness', 
    severity=1, 
    batch_size=128
)

model.eval()

acc, loss, ece = evaluate(model, corrupt_loader)

In [None]:
# 2. α‑BN + Core adaptation (look at brightness-1 for example)
adapted_model = adapt_alpha_bn(model, test_loader=corrupt_loader, device=device, alpha=0.9, lr=1e-3, epochs=1)

print('\nPerformance on Tiny‑ImageNet‑C (after adaptation)')
acc, loss, ece = evaluate(adapted_model, corrupt_loader)