SWAG solver, as outlined in "A Simple Baseline for Bayesian Uncertainty in Deep Learning" (Maddox et al, 2019). SWAG model code modified from the associated [GitHub](https://github.com/wjmaddox/swa_gaussian/tree/master)

In [1]:
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

from swag.posteriors import SWAG

In [3]:
NUM_CLASSES = 200

# Load in hotstarted model
hotstart_path = "./models/sgd_model.pt"
base_model = models.resnet18()
# base_model.fc = torch.nn.Linear(base_model.fc.in_features, NUM_CLASSES)

# TODO: Check on the number of models
model = SWAG(
    base_model, 
    no_cov_mat=True,
)

model.load_state_dict(torch.load(hotstart_path, weights_only=True))

# Training details
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
model.to(device)

# SWAG model details
swa_model = AveragedModel(model)
swa_model.to(device)
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

# Remember this is the number of epochs after pretraining
epochs = 30

# Create Loaders
train_loader = tiny_imagenet_loader(split="train", batch_size=128, num_workers=1)
val_loader = tiny_imagenet_loader(split="val", batch_size=128, num_workers=1)

RuntimeError: Error(s) in loading state_dict for SWAG:
	Missing key(s) in state_dict: "n_models", "base.conv1.weight_mean", "base.conv1.weight_sq_mean", "base.bn1.running_mean", "base.bn1.running_var", "base.bn1.weight_mean", "base.bn1.weight_sq_mean", "base.bn1.bias_mean", "base.bn1.bias_sq_mean", "base.layer1.0.conv1.weight_mean", "base.layer1.0.conv1.weight_sq_mean", "base.layer1.0.bn1.running_mean", "base.layer1.0.bn1.running_var", "base.layer1.0.bn1.weight_mean", "base.layer1.0.bn1.weight_sq_mean", "base.layer1.0.bn1.bias_mean", "base.layer1.0.bn1.bias_sq_mean", "base.layer1.0.conv2.weight_mean", "base.layer1.0.conv2.weight_sq_mean", "base.layer1.0.bn2.running_mean", "base.layer1.0.bn2.running_var", "base.layer1.0.bn2.weight_mean", "base.layer1.0.bn2.weight_sq_mean", "base.layer1.0.bn2.bias_mean", "base.layer1.0.bn2.bias_sq_mean", "base.layer1.1.conv1.weight_mean", "base.layer1.1.conv1.weight_sq_mean", "base.layer1.1.bn1.running_mean", "base.layer1.1.bn1.running_var", "base.layer1.1.bn1.weight_mean", "base.layer1.1.bn1.weight_sq_mean", "base.layer1.1.bn1.bias_mean", "base.layer1.1.bn1.bias_sq_mean", "base.layer1.1.conv2.weight_mean", "base.layer1.1.conv2.weight_sq_mean", "base.layer1.1.bn2.running_mean", "base.layer1.1.bn2.running_var", "base.layer1.1.bn2.weight_mean", "base.layer1.1.bn2.weight_sq_mean", "base.layer1.1.bn2.bias_mean", "base.layer1.1.bn2.bias_sq_mean", "base.layer2.0.conv1.weight_mean", "base.layer2.0.conv1.weight_sq_mean", "base.layer2.0.bn1.running_mean", "base.layer2.0.bn1.running_var", "base.layer2.0.bn1.weight_mean", "base.layer2.0.bn1.weight_sq_mean", "base.layer2.0.bn1.bias_mean", "base.layer2.0.bn1.bias_sq_mean", "base.layer2.0.conv2.weight_mean", "base.layer2.0.conv2.weight_sq_mean", "base.layer2.0.bn2.running_mean", "base.layer2.0.bn2.running_var", "base.layer2.0.bn2.weight_mean", "base.layer2.0.bn2.weight_sq_mean", "base.layer2.0.bn2.bias_mean", "base.layer2.0.bn2.bias_sq_mean", "base.layer2.0.downsample.0.weight_mean", "base.layer2.0.downsample.0.weight_sq_mean", "base.layer2.0.downsample.1.running_mean", "base.layer2.0.downsample.1.running_var", "base.layer2.0.downsample.1.weight_mean", "base.layer2.0.downsample.1.weight_sq_mean", "base.layer2.0.downsample.1.bias_mean", "base.layer2.0.downsample.1.bias_sq_mean", "base.layer2.1.conv1.weight_mean", "base.layer2.1.conv1.weight_sq_mean", "base.layer2.1.bn1.running_mean", "base.layer2.1.bn1.running_var", "base.layer2.1.bn1.weight_mean", "base.layer2.1.bn1.weight_sq_mean", "base.layer2.1.bn1.bias_mean", "base.layer2.1.bn1.bias_sq_mean", "base.layer2.1.conv2.weight_mean", "base.layer2.1.conv2.weight_sq_mean", "base.layer2.1.bn2.running_mean", "base.layer2.1.bn2.running_var", "base.layer2.1.bn2.weight_mean", "base.layer2.1.bn2.weight_sq_mean", "base.layer2.1.bn2.bias_mean", "base.layer2.1.bn2.bias_sq_mean", "base.layer3.0.conv1.weight_mean", "base.layer3.0.conv1.weight_sq_mean", "base.layer3.0.bn1.running_mean", "base.layer3.0.bn1.running_var", "base.layer3.0.bn1.weight_mean", "base.layer3.0.bn1.weight_sq_mean", "base.layer3.0.bn1.bias_mean", "base.layer3.0.bn1.bias_sq_mean", "base.layer3.0.conv2.weight_mean", "base.layer3.0.conv2.weight_sq_mean", "base.layer3.0.bn2.running_mean", "base.layer3.0.bn2.running_var", "base.layer3.0.bn2.weight_mean", "base.layer3.0.bn2.weight_sq_mean", "base.layer3.0.bn2.bias_mean", "base.layer3.0.bn2.bias_sq_mean", "base.layer3.0.downsample.0.weight_mean", "base.layer3.0.downsample.0.weight_sq_mean", "base.layer3.0.downsample.1.running_mean", "base.layer3.0.downsample.1.running_var", "base.layer3.0.downsample.1.weight_mean", "base.layer3.0.downsample.1.weight_sq_mean", "base.layer3.0.downsample.1.bias_mean", "base.layer3.0.downsample.1.bias_sq_mean", "base.layer3.1.conv1.weight_mean", "base.layer3.1.conv1.weight_sq_mean", "base.layer3.1.bn1.running_mean", "base.layer3.1.bn1.running_var", "base.layer3.1.bn1.weight_mean", "base.layer3.1.bn1.weight_sq_mean", "base.layer3.1.bn1.bias_mean", "base.layer3.1.bn1.bias_sq_mean", "base.layer3.1.conv2.weight_mean", "base.layer3.1.conv2.weight_sq_mean", "base.layer3.1.bn2.running_mean", "base.layer3.1.bn2.running_var", "base.layer3.1.bn2.weight_mean", "base.layer3.1.bn2.weight_sq_mean", "base.layer3.1.bn2.bias_mean", "base.layer3.1.bn2.bias_sq_mean", "base.layer4.0.conv1.weight_mean", "base.layer4.0.conv1.weight_sq_mean", "base.layer4.0.bn1.running_mean", "base.layer4.0.bn1.running_var", "base.layer4.0.bn1.weight_mean", "base.layer4.0.bn1.weight_sq_mean", "base.layer4.0.bn1.bias_mean", "base.layer4.0.bn1.bias_sq_mean", "base.layer4.0.conv2.weight_mean", "base.layer4.0.conv2.weight_sq_mean", "base.layer4.0.bn2.running_mean", "base.layer4.0.bn2.running_var", "base.layer4.0.bn2.weight_mean", "base.layer4.0.bn2.weight_sq_mean", "base.layer4.0.bn2.bias_mean", "base.layer4.0.bn2.bias_sq_mean", "base.layer4.0.downsample.0.weight_mean", "base.layer4.0.downsample.0.weight_sq_mean", "base.layer4.0.downsample.1.running_mean", "base.layer4.0.downsample.1.running_var", "base.layer4.0.downsample.1.weight_mean", "base.layer4.0.downsample.1.weight_sq_mean", "base.layer4.0.downsample.1.bias_mean", "base.layer4.0.downsample.1.bias_sq_mean", "base.layer4.1.conv1.weight_mean", "base.layer4.1.conv1.weight_sq_mean", "base.layer4.1.bn1.running_mean", "base.layer4.1.bn1.running_var", "base.layer4.1.bn1.weight_mean", "base.layer4.1.bn1.weight_sq_mean", "base.layer4.1.bn1.bias_mean", "base.layer4.1.bn1.bias_sq_mean", "base.layer4.1.conv2.weight_mean", "base.layer4.1.conv2.weight_sq_mean", "base.layer4.1.bn2.running_mean", "base.layer4.1.bn2.running_var", "base.layer4.1.bn2.weight_mean", "base.layer4.1.bn2.weight_sq_mean", "base.layer4.1.bn2.bias_mean", "base.layer4.1.bn2.bias_sq_mean", "base.fc.weight_mean", "base.fc.weight_sq_mean", "base.fc.bias_mean", "base.fc.bias_sq_mean". 
	Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn1.num_batches_tracked", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn2.num_batches_tracked", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn1.num_batches_tracked", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn2.num_batches_tracked", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn1.num_batches_tracked", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn2.num_batches_tracked", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn1.num_batches_tracked", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn2.num_batches_tracked", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn1.num_batches_tracked", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn2.num_batches_tracked", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.num_batches_tracked", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.num_batches_tracked", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.num_batches_tracked", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.num_batches_tracked", "fc.weight", "fc.bias". 

In [None]:
# Implement SAW portion of training for the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step()
    
    # Schedule learning rate with SWA
    swa_model.update_parameters(model)
    swa_scheduler.step()

In [None]:
# Save the model
torch.save(model.state_dict(), "./models/swa_model.pt")

In [None]:
def evaluate_model(model, loader):
    """ Return the accuracy of the model on the given data set (percentage correct labels assigned) """
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            correct = (torch.argmax(prediction, dim=1) == labels).sum()
            total += 128
            total_correct += correct

    return total_correct/total

In [None]:
model = models.resnet18()
model.load_state_dict(torch.load("./models/swa_model.pt", weights_only=True))
model.to(device)

val_accuracy = evaluate_model(model, val_loader)
print(f"Val Accuracy: {val_accuracy}")


In [None]:
print("Test Accuracy:")

corruptions = ["brightness", "contrast", "defocus_blur"]

for corruption in corruptions:
    for level in range(1, 3):
        test_loader = tiny_imagenet_corrupted_loader(
            corruption,
            severity=level,
            batch_size=128,
            root='./data/Tiny-ImageNet-C',
            num_workers=1
        )
        test_accuracy = evaluate_model(model, test_loader)
    
        print(f"{corruption} L{level}: {test_accuracy}")