Create a baseline model which will be used to hotstart training for SWA-related implementations

In [1]:
from loaders import (
    tiny_imagenet_train_loader, 
    tiny_imagenet_val_loader,
    tiny_imagenet_corrupted_loader
)
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cuda


In [2]:
# Base model parameters
MODEL_NAME = "sgd_solver_new"

NUM_CLASSES = 200

train_loader = tiny_imagenet_train_loader(batch_size=128, num_workers=1)
val_loader = tiny_imagenet_val_loader(batch_size=128, num_workers=1)
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()

epochs = 80

model.to(device)

Folders:


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
# Check the model parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

11279112


In [19]:
# Validation
def validate(model, loader, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            
            prediction = model(images)
            total_correct += (torch.argmax(prediction, dim=1) == labels).sum()
            total_samples += labels.size(0)

    return total_correct / total_samples
            
            

In [None]:
# Train the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    model.train()

    best_val_acc = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step() 

    if epoch % 5 == 4:
        val_acc = validate(model, val_loader, device)
        print(f"Val acc: {val_acc.item()}")

        if val_acc > best_val_acc:
            torch.save(model.state_dict(), f"./models/{MODEL_NAME}.pt")
            best_val_acc = val_acc

In [26]:
def evaluate_model(model, loader):
    """ Return the accuracy of the model on the given data set (percentage correct labels assigned) """
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            correct = (torch.argmax(prediction, dim=1) == labels).sum()
            total += 128
            total_correct += correct

    return total_correct/total

In [27]:
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load("./models/sgd_solver_new.pt", weights_only=True))
model.to(device)
model.eval()

train_accuracy = eval_model(model, train_loader)
print(f"Train Accuracy: {val_accuracy}")

val_accuracy = evaluate_model(model, val_loader)
print(f"Val Accuracy: {val_accuracy}")


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

Val Accuracy: 0.19907042384147644


In [28]:
print("Test Accuracy:")

corruptions = ["brightness", "contrast", "defocus_blur"]

for corruption in corruptions:
    for level in range(1, 3):
        test_loader = tiny_imagenet_corrupted_loader(
            corruption,
            severity=level,
            batch_size=128,
            root='./data/Tiny-ImageNet-C',
            num_workers=1
        )
        test_accuracy = evaluate_model(model, test_loader)
    
        print(f"{corruption} L{level}: {test_accuracy}")

Test Accuracy:
Folders:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

brightness L1: 0.10462816804647446
Folders:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

brightness L2: 0.08445411920547485
Folders:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

contrast L1: 0.034909021109342575
Folders:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

contrast L2: 0.02571202628314495
Folders:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

defocus_blur L1: 0.06981804221868515
Folders:


Testing:   0%|          | 0/79 [00:00<?, ?it/s]

defocus_blur L2: 0.06190664693713188
