# AlexNet on CIFAR10

Similar to the notebook `lenet5.ipynb`, we will import AlexNet from `torchvision` to classify CIFAR10 images.

In [1]:
import time
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from tqdm.notebook import trange, tqdm
from torch.utils.tensorboard import SummaryWriter

## Helper Functions

In [2]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        
        
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

## Training and Evaluation Functions

In [3]:
def train_loop(dataloader, model, loss_fn, optimizer, logger=None):
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = loss_fn(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()

        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])
        
        if i % print_frequency == 0:
            pbar.set_description("Epoch [%d]\t Loss %.2f\t Prec@1 %.3f (%.3f)" % (epoch, losses.avg, top1.val, top1.avg))
            if logger:
                logger.add_scalar("training loss",
                                  loss.item(),
                                  epoch * len(dataloader) + i)
           
        
def val_loop(dataloader, model, loss_fn, logger=None):
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = loss_fn(output, target)

        output = output.float()
        loss = loss.float()
        
        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])

        if i % print_frequency == 0:
            pbar.set_description("Epoch [%d]\t Loss %.2f\t Prec@1 %.3f (%.3f)" % (epoch, losses.avg, top1.val, top1.avg))
            if logger:
                logger.add_scalar("validation loss",
                                  loss.item(),
                                  epoch * len(dataloader) + i)
    
    if logger:
        logger.add_scalar("validation accuracy",
                          top1.avg,
                          epoch)
            

def test_loop(dataloader, model, loss_fn):
    losses = AverageMeter()
    top1 = AverageMeter()
    
    model.eval()
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = loss_fn(output, target)

        output = output.float()
        loss = loss.float()

        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])

    # Print result
    print(f"Average Loss: {losses.avg:>8f}\nAccuracy: {top1.avg}\n")

# Model Hyperparameters and Initialization

In [None]:
# Model Parameters
batch_size = 256
learning_rate = 1e-3
epochs = 20
print_frequency = 100

# Import AlexNet model
model = torchvision.models.alexnet(pretrained=False)

# Modify the last layer
model.classifier[6] = nn.Linear(4096, 10)
model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

## Dataset Preparation

In [None]:
root_path = "./Data/CIFAR10"

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

dataset = torchvision.datasets.CIFAR10(root_path, train=True, transform=transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    normalize
]), download=True)

dataset_size = len(dataset)
train_size = int(dataset_size * .95)
val_size = dataset_size - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size, shuffle=True,
    num_workers=8, pin_memory=False)


val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size, shuffle=False,
    num_workers=8, pin_memory=False)


test_dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10(root_path, train=False, transform=transforms.Compose([                                                                     
        transforms.ToTensor(),
        normalize
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=8, pin_memory=False)

## Train model

In [None]:
logger = SummaryWriter("runs/alexnet")

for epoch in range(epochs):
    train_loop(train_dataloader, model, criterion, optimizer, logger)
    val_loop(val_dataloader, model, criterion, logger)

In [None]:
test_loop(test_dataloader, model, criterion)