SWA solver as outlined in "Averaging Weights Leads to Wider Optima and Better Generalization" (Izmailov et al, 2019). Some code taken from [Pytorch SWA Blog](https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/)

In [1]:
from loaders import (
    tiny_imagenet_train_loader, 
    tiny_imagenet_val_loader,
    tiny_imagenet_corrupted_loader
)
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm
import json

In [2]:
NUM_CLASSES = 200

# Load in hotstarted model
hotstart_path = "./models/sgd_model.pt"
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load(hotstart_path, weights_only=True))

# Training details
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
model.to(device)

# SWA model details
swa_model = AveragedModel(model)
swa_model.to(device)
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

# Remember this is the number of epochs after pretraining
epochs = 30

# Create Loaders
train_loader = tiny_imagenet_train_loader(batch_size=128, num_workers=1)
val_loader = tiny_imagenet_val_loader(batch_size=128, num_workers=1)

Using device cuda


In [3]:
# Implement SAW portion of training for the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step()
    
    # Schedule learning rate with SWA
    swa_model.update_parameters(model)
    swa_scheduler.step()

    # Periodically save the model and scheduler in case the cluster crashes
    torch.save(model.state_dict(), "./models/swa_model_temp.pt")
    torch.save(swa_scheduler.state_dict(), "./models/swa_scheduler_temp.pt")
    

Training:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [4]:
# Save the model
torch.save(model.state_dict(), "./models/swa_model3.pt")
torch.save(swa_scheduler.state_dict(), "./models/swa_scheduler3.pt")

In [5]:
def evaluate_model(model, loader):
    """ Return the accuracy of the model on the given data set (percentage correct labels assigned) """
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            correct = (torch.argmax(prediction, dim=1) == labels).sum()
            total += 128
            total_correct += correct

    return total_correct/total

In [9]:
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load("./models/swa_model.pt", weights_only=True))
model.to(device)

val_accuracy = evaluate_model(model, val_loader)
print(f"Val Accuracy: {val_accuracy}")


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

Val Accuracy: 0.19907042384147644


In [10]:
print("Test Accuracy:")

corruptions = ["brightness", "contrast", "defocus_blur"]

keys = []
values = []

for corruption in corruptions:
    for level in range(1, 3):
        test_loader = tiny_imagenet_corrupted_loader(
            corruption,
            severity=level,
            batch_size=128,
            root='./data/Tiny-ImageNet-C',
            num_workers=1
        )
        test_accuracy = evaluate_model(model, test_loader)

        keys.append(corruption + "-" + str(level))
        values.append(test_accuracy.item())
    
        print(f"{corruption} L{level}: {test_accuracy}")

# Save all values
with open("results/swa_results.json", "w") as file:
    json.dump(dict(zip(keys, values)), file, indent=4)

Test Accuracy:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

brightness L1: 0.10462816804647446


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

brightness L2: 0.08445411920547485


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

contrast L1: 0.034909021109342575


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

contrast L2: 0.02571202628314495


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

defocus_blur L1: 0.06981804221868515


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

defocus_blur L2: 0.06190664693713188
