In [1]:
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

In [2]:
# Base model parameters
train_loader = tiny_imagenet_loader(train=True, batch_size=128, num_workers=1)
val_loader = tiny_imagenet_loader(train=True, batch_size=128, num_workers=1)
model = models.resnet18()
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
epochs = 150

# SWA parameters
swa_model = AveragedModel(model)

model.to(device)
swa_model.to(device)

# TODO: Review documentation (Comes from SWA paper)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

Using device cuda


In [6]:
# Check the model parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

11689512


In [None]:
# Validation
def validate(model, loader):
    model.evaluate()
    with torch.no_grad():
        for images, labels in loader:
            prediction = model(images)
            

In [None]:
# Train the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step()
    
    # Schedule learning rate with SWA
    if epoch > swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

In [9]:
# Save the model
torch.save(model.state_dict(), "./models/swa_implementation")

In [8]:
def evaluate_model(model, loader):
    """ Return the accuracy of the model on the given data set (percentage correct labels assigned) """
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            correct = (torch.argmax(prediction, dim=1) == labels).sum()
            total += 128
            total_correct += correct

    return total_correct/total

In [10]:
model = models.resnet18()
model.load_state_dict(torch.load("./models/swa_implementation", weights_only=True))
model.to(device)

val_accuracy = evaluate_model(model, val_loader)
print(f"Val Accuracy: {val_accuracy}")

print("Test Accuracy:")

corruptions = ["brightness", "contrast", "defocus_blur", "elastic_trasform"

for corruption in corruptions:
    test_loader = tiny_imagenet_corrupted_loader(
        corruption='brightness',
        severity=1,
        batch_size=128,
        root='./data/Tiny-ImageNet-C',
        num_workers=1
    )
    test_accuracy = evaluate_model(model, test_loader)

    print(f"{corruption} L1: {test_accuracy}")


SyntaxError: unterminated string literal (detected at line 24) (2561125294.py, line 24)

In [4]:
print(total_correct/total)

tensor(0.9988, device='cuda:0')


In [None]:
print("test")