# ResNet20 Trained On CIFAR100

In [None]:
!pip install torchmetrics

In [None]:
# Import Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torchmetrics
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
        transforms.RandomErasing()
    ])

    test_val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
    ])

    return train_transform, test_val_transform

In [None]:
def get_loaders(batch_size=128):

    train_transform, val_transform = get_transforms()

    train_dataset = CIFAR100(root="./data", train=True, transform=train_transform, download=True)
    val_dataset = CIFAR100(root="./data", train=False, transform=val_transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader


In [None]:
def train_setup(
        model, lr=0.1, 
        momentum=0.9, 
        weight_decay=1e-4, 
        milestones=[100, 150], gamma=0.1
    ):
    
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(
        model.parameters(), 
        lr=lr, 
        momentum=momentum, 
        weight_decay=weight_decay
    )

    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        milestones=milestones,
        gamma=gamma
    )

    return criterion, optimizer, scheduler