# Description

This notebook trains a convolutional neural network (CNN) for image classification on a typical dataset (CIFAR-100 for now). The dataset consists of 100 labels with 600 images per label. 500 images/label are in the train set and 100 images/label are in the test set.

Training can be offloaded to a GPU by choosing the appropriate value for the variable 'device' below.

Since the goal is to compare run-times and pricing, we aim for consistency (fixed model architecture, dataset, hyperparameters etc.) rather than tricks to speed up learning or improve model accuracy.

### Imports

In [5]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pylab as plt
import pickle

from torchvision.datasets import CIFAR100
from torchvision.models import resnet34
from torchvision import transforms

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [12]:
print(device)

cuda


In [13]:
torch.cuda.get_device_name()

'GeForce GTX 1080 Ti'

### Dataset and Data Loaders

In [17]:
#Since we are not using a pretrained network, we do **not** need to use these normalization constants 
#but we will do so anyway in case we want to compare the performance to a pre-trained network.

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

def load_data():
    '''Download dataset and apply preprocessing transforms.
    '''
    LOCAL_PATH = './data/CIFAR100'
    
    train_data = CIFAR100(LOCAL_PATH, 
                          train=True, 
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((224, 224)),
                              transforms.ToTensor(),
                              transforms.Normalize(mean=MEAN, std=STD)                              
                          ])
                         )
    
    test_data = CIFAR100(LOCAL_PATH,
                         train=False, 
                         download=True,
                         transform=transforms.Compose([
                             transforms.Resize((224, 224)),
                             transforms.ToTensor(),
                             transforms.Normalize(mean=MEAN, std=STD)                              
                         ])
                        )
    
    label_names = pickle.load(open(f'{LOCAL_PATH}/cifar-100-python/meta', "rb"), encoding='ISO-8859-1')["fine_label_names"]
    
    return train_data, test_data, label_names

In [18]:
train_data, test_data, label_names = load_data()

Files already downloaded and verified
Files already downloaded and verified


In [19]:
def create_dataloaders(train_data, test_data, batch_size, pin_memory=True):
    '''Create dataloaders that create batches for training.
    '''
    train_dl = DataLoader(train_data, 
                          batch_size=batch_size, 
                          pin_memory=pin_memory)
    
    test_dl = DataLoader(test_data, 
                         batch_size=batch_size, 
                         pin_memory=pin_memory)
    
    return train_dl, test_dl

In [20]:
train_dl, test_dl = create_dataloaders(train_data, test_data, 128, pin_memory=True)

### Model and Loss Initialization

In [21]:
model = resnet34(pretrained=False)
criterion = nn.CrossEntropyLoss()

### Model Training and Validation Loops

In [22]:
def train_model(train_dl, test_dl, model, criterion, N_epochs, print_freq, lr=1e-3):
    '''Loop over dataset in batches, compute loss, backprop and update weights
    '''
    
    model.train() #switch to train model (for dropout, batch normalization etc.)
    
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    acc_dict, loss_dict = {}, {}
    for epoch in range(N_epochs): #loop over epochs i.e. sweeps over full data
        curr_loss = 0
        N = 0
        
        for idx, (images, labels) in enumerate(train_dl): #loop over batches
            images = images.to(device)
            labels = labels.to(device)
            
            preds = model(images)
            loss = criterion(preds, labels)
            
            curr_loss += loss.item() #accumulate loss
            N += len(labels) #accumulate number of images seen in this epoch
                
            #backprop and updates
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % print_freq == 0 or epoch==N_epochs-1:
            val_loss, val_acc = validate(test_dl, model, criterion) #get model perf metrics from test set
            
            acc_dict[epoch] = val_acc
            loss_dict[epoch] = val_loss
            
            print(f'Iter = {epoch} Train Loss = {curr_loss / N} val_loss = {val_loss} val_acc = {val_acc}')
            
    return model, acc_dict, loss_dict

def validate(test_dl, model, criterion):
    '''Loop over test dataset and compute loss and accuracy
    '''
    model.eval() #switch to eval model
    
    loss = 0
    N = 0

    N_correct = 0
    
    with torch.no_grad(): #no need to keep variables for backprop computations
        for idx, (images, labels) in enumerate(test_dl):
            images = images.to(device)
            labels = labels.to(device)
            
            preds = model(images)
            preds_nonprob = preds.argmax(dim=1)
            
            N_correct += (labels==preds_nonprob).sum().item() #accuracy computation
            
            loss += criterion(preds, labels) #cumulative loss
            N += len(labels)
    
    return loss / N, N_correct/N

In [23]:
%time model, acc_dict, loss_dict = train_model(train_dl, test_dl, model, criterion, 20, 1)

Iter = 0 Train Loss = 0.029962537961006164 val_loss = 0.027555299922823906 val_acc = 0.1675
Iter = 1 Train Loss = 0.029955543751716614 val_loss = 0.025803718715906143 val_acc = 0.2005
Iter = 2 Train Loss = 0.023886292572021486 val_loss = 0.022289138287305832 val_acc = 0.29
Iter = 3 Train Loss = 0.0202695588350296 val_loss = 0.02001195400953293 val_acc = 0.3484
Iter = 4 Train Loss = 0.017271596004962923 val_loss = 0.017855264246463776 val_acc = 0.4118
Iter = 5 Train Loss = 0.014594539954662323 val_loss = 0.01709928922355175 val_acc = 0.4385
Iter = 6 Train Loss = 0.012312490348815918 val_loss = 0.01777018792927265 val_acc = 0.4387
Iter = 7 Train Loss = 0.010540116629600525 val_loss = 0.018000656738877296 val_acc = 0.4478
Iter = 8 Train Loss = 0.008653313562870026 val_loss = 0.01896815374493599 val_acc = 0.4467
Iter = 9 Train Loss = 0.006887058638334274 val_loss = 0.021302219480276108 val_acc = 0.4512
Iter = 10 Train Loss = 0.005184352113008499 val_loss = 0.022613035514950752 val_acc = 0.