# DenseNet
In this notebook, we train a DenseNet classifier for SVHN and CIFAR10 datasets, and launch it in a web API.

In [2]:
"""
Script adapted from: https://github.com/kuangliu/pytorch-cifar
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
import sys
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

from densenet import densenet121

## Dataloader

Here, we load either the SVHN or CIFAR10 datasets, which are provided through torchvision. If you wish to use your own, ...

Note that we are treating the test set as a validation set.

In [None]:
# Transform from PIL image format to tensor format
transform_train = transforms.Compose([
    # You can add more data augmentation techniques in series: 
    # https://pytorch.org/docs/stable/torchvision/transforms.html
    transforms.ToTensor()
])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

# CIFAR10 Dataset: https://www.cs.toronto.edu/~kriz/cifar.html
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)

# SVHN Dataset: http://ufldl.stanford.edu/housenumbers/
# trainset = torchvision.datasets.SVHN(root='../data', split='train', transform=transform_train, download=True)
# testset = torchvision.datasets.SVHN(root='../data', split='test', transform=transform_test, download=True)

If making a proof of concept application, we can choose to overfit on a data subset for quick training.

In [None]:
train_ct = 128 # Size of train data
test_ct = 32 # Size of test data
batch_sz = 32 # Size of batch for one gradient step (bigger batches take more memory but are faster)
num_workers = 4 # 4 * number of GPUs

if train_ct:
    trainset = data.dataset.Subset(trainset, range(train_ct))

if test_ct:
    testset = data.dataset.Subset(testset, range(test_ct))

trainloader = data.DataLoader(trainset, batch_size=batch_sz, shuffle=True, num_workers=num_workers, )
testloader = data.DataLoader(testset, batch_size=batch_sz, shuffle=False, num_workers=num_workers)

## Training

First, we configure the model for training, and define our loss and optimizer.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpu_ids = [0] # On Colab, we have access to one GPU. Change this value as you see fit

net = densenet121()
net = net.to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net, gpu_ids)
    
resume = False # Resume training from a saved checkpoint

if resume:
    print('Resuming from checkpoint at ./checkpoint/best_model.pth.tar')
    assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/best_model.pth.tar')
    net.load_state_dict(checkpoint['net'])
    global best_loss
    best_loss = checkpoint['test_loss']
    start_epoch = checkpoint['epoch']
    
# Loss function: https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss
loss_fn = nn.CrossEntropyLoss() 
optimizer = optim.Adam(net.parameters(), lr=0.1)

We define `train`, which performs a forward/back propagation pass on our dataset per epoch. Similarly, `test` performs evaluation on the test set.

In [None]:
def train(epoch):
    """
    Trains our net on data from the trainloader for a single epoch
    """
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad() # Clear any stored gradients for new step
            outputs = net(inputs)
            
            loss = loss_fn(outputs, targets) # Calculate loss between prediction and label     
            loss.backward() # Backpropagate gradient updates through net based on loss
            optimizer.step() # Update net weights based on gradients

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            acc = (100. * correct / total)
        
            progress_bar.set_postfix(loss=train_loss/(batch_idx+1), accuracy=f'{acc}%')
            progress_bar.update(inputs.size(0))
            
        
def test(epoch):
    """
    Run net in inference mode on test data. 
    """                       
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    best_acc = 0
    # Ensures the net will not update weights
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = loss_fn(outputs, targets)
            
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                acc = (100. * correct / total)
                progress_bar.set_postfix(loss=test_loss/(batch_idx+1), accuracy=f'{acc}%')
                progress_bar.update(inputs.size(0))
                
                # Save best model
                if acc > best_acc:
                    print("Saving...")
                    save_state(net, acc, epoch)
                    best_acc = acc

def save_state(net, acc, epoch):
    """
    Save the current net state, accuracy and epoch
    """
    state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/best_model.pth.tar')

In [None]:
test_freq = 5 # Frequency to run model on validation data

for epoch in range(0, 100):
    train(epoch)
    if epoch % test_freq == 0:
        test(epoch)

## Inference
Now that we have trained a model, we can use it for inference on new data! With each run, we save the best model weights at:
`./checkpoint/best_model.pth.tar`

In [None]:
def classify_image(x):
    outputs = net(x)
    _, predicted = outputs.max(1)
    return predicted

In [None]:
cp = torch.load('./checkpoint/best_model.pth.tar')
net.load_state_dict(cp['net'])
net.eval()

sample = trainset[1][0]
plt.imshow(sample.permute(1,2,0))

y = classify_image(sample.unsqueeze(0))[0]
print(f'Predicted class: {y}')

(If you are running on CIFAR, you can get the associated class labels at https://www.cs.toronto.edu/~kriz/cifar.html.
0 = airplane, 1 = automobile, etc.)

## Export (API)
We want to use our models inside our apps. One way to do this is to wrap our calls in an API. We included a `app.py` that takes the model saved at `models/densenet/checkpoint/best_model.pth.tar` and wraps it in a simple Flask server that receives images and returns the classification. Check out our web API hackpack for more information on how this works! (https://github.com/TreeHacks/hackpack-web-api)

Congratulations! You have just trained and launched your own ML model.