Create a baseline model which will be used to hotstart training for SWA-related implementations

In [6]:
from loaders import (
    tiny_imagenet_train_loader, 
    tiny_imagenet_val_loader,
    tiny_imagenet_corrupted_loader
)
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cuda


In [None]:
# Base model parameters
NUM_CLASSES = 200

train_loader = tiny_imagenet_train_loader(batch_size=128, num_workers=1)
val_loader = tiny_imagenet_val_loader(batch_size=128, num_workers=1)
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()

epochs = 100

model.to(device)

In [None]:
# Check the model parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

In [None]:
# Validation
def validate(model, loader):
    model.evaluate()
    with torch.no_grad():
        for images, labels in loader:
            prediction = model(images)
            

In [None]:
# Train the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step() 

    if epoch % 5 == 0:
        torch.save(model.state_dict(), "./models/sgd_model2.pt")

In [None]:
# Save the model
torch.save(model.state_dict(), "./models/sgd_model2.pt")

In [15]:
def evaluate_model(model, loader):
    """ Return the accuracy of the model on the given data set (percentage correct labels assigned) """
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            correct = (torch.argmax(prediction, dim=1) == labels).sum()
            total += 128
            total_correct += correct
            print(torch.argmax(prediction, dim=1))
            print(labels)
            break

    return total_correct/total

In [16]:
model = models.resnet18()
# model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load("./models/sgd_model.pt", weights_only=True))
model.to(device)

val_accuracy = evaluate_model(model, val_loader)
print(f"Val Accuracy: {val_accuracy}")


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

tensor([183, 168, 125,   3, 112, 192, 194, 153,  24, 115, 128,  95, 148,   4,
        129,  68,  94, 100,   4,  56, 132,  10, 134, 169,  64, 178, 199,  44,
        106,  63, 156, 129,  18, 173, 184,   5,   7, 117, 197,  34,   5,  68,
         20,  58, 114,   1, 151,  59, 149, 105,  17,  65, 196,   9, 138,  78,
        131,  23,  20, 178, 100, 116, 170, 197, 117, 142,  61,   1,  57,  44,
         75,  19,  18, 116,  98, 146, 135,  85, 116,   6, 170, 127, 146,  87,
         99, 162,  11,   5, 141,  43, 170,  49, 113, 144, 108, 127,   0, 105,
          8,  19,  29, 169, 187,  80,  63, 113, 108,  17, 128, 102,  70, 188,
        178,  53,  99,  11,  72,  95, 159,  20,  25, 169,  23,  34, 122, 103,
        146,  70], device='cuda:0')
tensor([ 84,  36,  80, 132, 191, 192,  32, 166,  73, 135, 116, 189,  13, 114,
          2, 199,  26,  67, 181,  89, 114,  97, 175,   6,  26, 183, 192,  35,
         75, 197,  37,  22,  20,  57,  51,  29,  36,  99, 180, 100, 151, 128,
         21, 124,  92,  42, 

In [None]:
print("Test Accuracy:")

corruptions = ["brightness", "contrast", "defocus_blur"]

for corruption in corruptions:
    for level in range(1, 3):
        test_loader = tiny_imagenet_corrupted_loader(
            corruption,
            severity=level,
            batch_size=128,
            root='./data/Tiny-ImageNet-C',
            num_workers=1
        )
        test_accuracy = evaluate_model(model, test_loader)
    
        print(f"{corruption} L{level}: {test_accuracy}")