In [6]:
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

In [11]:
# Base model parameters
loader = tiny_imagenet_loader(train=True, batch_size=128, num_workers=4)
model = models.resnet18()
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
epochs = 300

# SWA parameters
swa_model = AveragedModel(model)

model.to(device)
swa_model.to(device)

# TODO: Review documentation (Comes from SWA paper)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

Using device cuda


In [None]:
# Check the model parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

11689512


In [None]:
# Train the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step()
    
    # Schedule learning rate with SWA
    if epoch > swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

2. On the test set of corrupted images, compute the adapted parameters
$$\hat{\theta_i} = \arg\max_{\theta} \frac{\tilde{\alpha}}{m}\sum_{j=1}^m -H(Y\mid \tilde{x}_j,\theta) + \log q_i(\theta)$$

In [None]:
# 

3. For each test input $\tilde{x}_j$, marginalize over ensemble
$$P(y\mid\tilde{x_j}) = \frac{1}{k}\sum_{i=1}^k P(y\mid\tilde{x_j}, \hat{\theta_i})$$