# ResNet20 Trained On CIFAR100

In [1]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2


In [2]:
# Import Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torchmetrics
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
        transforms.RandomErasing()
    ])

    test_val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
    ])

    return train_transform, test_val_transform

In [None]:
def get_loaders(batch_size=128):

    train_transform, val_transform = get_transforms()

    train_dataset = CIFAR100(root="./data", train=True, transform=train_transform, download=True)
    val_dataset = CIFAR100(root="./data", train=False, transform=val_transform, download=True)

    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2, 
        pin_memory=True, 
        persistent_workers=True
    )

    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2, 
        pin_memory=True, 
        persistent_workers=True
    )

    return train_loader, val_loader


In [None]:
def train_setup(
        model, lr=0.1, 
        momentum=0.9, 
        weight_decay=1e-4, 
        milestones=[100, 150], gamma=0.1
    ):
    
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(
        model.parameters(), 
        lr=lr, 
        momentum=momentum, 
        weight_decay=weight_decay
    )

    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        milestones=milestones,
        gamma=gamma
    )

    return criterion, optimizer, scheduler

In [None]:
def get_metrics(device, num_classes=100):
    
    metrics = torchmetrics.MetricCollection([
        torchmetrics.Accuracy(task="multiclass", num_classes=num_classes, top_k=1),
        torchmetrics.Accuracy(task="multiclass", num_classes=num_classes, top_k=5)
    ]).to(device)

    return metrics

In [None]:
def train_one_epoch(model, train_loader, optimizer, criterion, scheduler, metrics, device):
    model.train()
    loss_per_epoch = 0.0

    metrics.reset()
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        loss_per_epoch += loss.item()
        metrics.update(output, labels)
    
    if scheduler:
        scheduler.step()
    
    loss_per_epoch /= len(train_loader)
    return metrics, loss_per_epoch

In [None]:
def validate(model, val_loader, criterion, metrics, device):
    model.eval()
    loss_per_epoch = 0.0

    metrics.reset()

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            output = model(images)
            loss = criterion(output, labels)

            loss_per_epoch += loss.item()

            metrics.update(output, labels)
        
    loss_per_epoch /= len(val_loader)
    
    return metrics, loss_per_epoch