In [None]:
import time
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from tqdm.notebook import trange, tqdm

# The CIFAR10 Datset
In this example, we are going to classify images as belonging to one of 10 classes given in the `CIFAR10` dataset. First, let's download the dataset and visualize a few samples.

In [None]:
root_path = "~/Data/CIFAR10"

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
train_dataset = torchvision.datasets.CIFAR10(root_path, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
]))
test_dataset = torchvision.datasets.CIFAR10(root_path, train=False, download=True)

In [None]:
samples = []

for i in range(64):
    samples.append(train_dataset[i][0])
grid = torchvision.utils.make_grid(samples)
show(grid)

# Helper Functions

We will add a few functions for model evaluations.

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# [0, 0.9, 0.1, 0, 0, 0, 0, 0, 0, 0]

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


# Model Creation

We will now create a deep network for image classification.

In [None]:
model = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

model.cuda()

In [None]:
# Model Parameters
batch_size = 256
learning_rate = 1e-3
epochs = 20
print_frequency = 100

# Data Loader

For batch sampling of our dataset, we wrap the dataset object in a `DataLoader` object.

In [None]:
from torchvision.transforms.transforms import ToPILImage
from sklearn.model_selection import train_test_split
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

dataset = torchvision.datasets.CIFAR10(root_path, train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    normalize,
    transforms.Grayscale(),
    torch.flatten
]))


train_dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10(root_path, train=True, transform=transforms.Compose([                                                                     
        transforms.ToTensor(),
        normalize,
        transforms.Grayscale(),
        torch.flatten
    ])),
    batch_size=batch_size, shuffle=True,
    num_workers=4, pin_memory=False)

test_dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10(root_path, train=False, transform=transforms.Compose([                                                                     
        transforms.ToTensor(),
        normalize,
        transforms.Grayscale(),
        torch.flatten
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=4, pin_memory=False)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), learning_rate)

In [None]:
def train(model, loader, criterion, optimizer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()
    pbar = tqdm(enumerate(loader), total=len(loader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        data_time.update(time.time() - end)

        output = model(input)
        loss = criterion(output, target)

        # Update Step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()

        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])

        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_frequency == 0:
            pbar.set_description("Epoch [%d]\t Loss %.2f\t Prec@1 %.3f (%.3f)" % (epoch, losses.avg, top1.val, top1.avg))


def test(model, loader, criterion):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()

    end = time.time()
    pbar = tqdm(enumerate(loader), total=len(loader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        data_time.update(time.time() - end)

        output = model(input)
        loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])

        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_frequency == 0:
            pbar.set_description("Loss %.2f\t Prec@1 %.3f (%.3f)" % (losses.avg, top1.val, top1.avg))

In [None]:
for epoch in range(epochs):
    train(model, train_dataloader, criterion, optimizer)

In [None]:
test(model, test_dataloader, criterion)