Create a baseline model which will be used to hotstart training for SWA-related implementations

In [1]:
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

In [2]:
# Base model parameters
train_loader = tiny_imagenet_loader(train=True, batch_size=128, num_workers=1)
val_loader = tiny_imagenet_loader(train=True, batch_size=128, num_workers=1)
model = models.resnet18()
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
epochs = 100

model.to(device)

Using device cuda


In [6]:
# Check the model parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

11689512


In [None]:
# Validation
def validate(model, loader):
    model.evaluate()
    with torch.no_grad():
        for images, labels in loader:
            prediction = model(images)
            

In [None]:
# Train the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step() 
    scheduler.step()

In [9]:
# Save the model
torch.save(model.state_dict(), "./models/sgd_model.pt")

In [8]:
def evaluate_model(model, loader):
    """ Return the accuracy of the model on the given data set (percentage correct labels assigned) """
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            correct = (torch.argmax(prediction, dim=1) == labels).sum()
            total += 128
            total_correct += correct

    return total_correct/total

In [None]:
model = models.resnet18()
model.load_state_dict(torch.load("./models/sgd_model.pt", weights_only=True))
model.to(device)

val_accuracy = evaluate_model(model, val_loader)
print(f"Val Accuracy: {val_accuracy}")


In [17]:
print("Test Accuracy:")

corruptions = ["brightness", "contrast", "defocus_blur"]

for corruption in corruptions:
    for level in range(1, 3):
        test_loader = tiny_imagenet_corrupted_loader(
            corruption,
            severity=level,
            batch_size=128,
            root='./data/Tiny-ImageNet-C',
            num_workers=1
        )
        test_accuracy = evaluate_model(model, test_loader)
    
        print(f"{corruption} L{level}: {test_accuracy}")

Test Accuracy:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

brightness L1: 0.10957279056310654


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

brightness L2: 0.09285996854305267


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

contrast L1: 0.037480223923921585


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

contrast L2: 0.026602057740092278


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

defocus_blur L1: 0.0738726258277893


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

defocus_blur L2: 0.06457674503326416
