# DenseNet
In this notebook, we train a DenseNet classifier for MNIST digits. (https://arxiv.org/abs/1608.06993)

In [8]:
"""
Script adapted from: https://github.com/kuangliu/pytorch-cifar
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
import sys
import os
from tqdm import tqdm

sys.path.append('../..')
from models import densenet121

## Dataloader

Here, we load the MNIST dataset, which is provided through torchvision. If you wish to use your own, ...

In [12]:
# Transform from PIL image format to tensor format
transform_train = transforms.Compose([
    # You can add more data augmentation techniques in series:
    # https://pytorch.org/docs/stable/torchvision/transforms.html
    transforms.ToTensor()
])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

# CIFAR10 Dataset: https://www.cs.toronto.edu/~kriz/cifar.html
# trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
# testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)

# SVHN Dataset: http://ufldl.stanford.edu/housenumbers/
trainset = torchvision.datasets.SVHN(root='../data', split='train', transform=transform_train, download=True)
testset = torchvision.datasets.SVHN(root='../data', split='test', transform=transform_test, download=True)

Using downloaded and verified file: ../data/train_32x32.mat
Using downloaded and verified file: ../data/test_32x32.mat


If making a proof of concept application, we can choose to overfit on a data subset for quick training.

In [13]:
train_ct = 10 # Size of train data
test_ct = 20 # Size of test data
batch_sz = 10
num_workers = 4



if train_ct:
    trainset = data.dataset.Subset(trainset, range(train_ct))

if test_ct:
    testset = data.dataset.Subset(testset, range(test_ct))

trainloader = data.DataLoader(trainset, batch_size=batch_sz, shuffle=True, num_workers=num_workers, )
testloader = data.DataLoader(testset, batch_size=batch_sz, shuffle=False, num_workers=num_workers)


## Training

Configure model

In [14]:
# This defines 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = densenet121()
net = net.to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net, [0])
    # cudnn.benchmark
    
resume = False # To resume training from saved checkpoint, 

if resume:
    # Load checkpoint.
    print('Resuming from checkpoint at ckpts/best.pth.tar...')
    assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('ckpts/best.pth.tar')
    net.load_state_dict(checkpoint['net'])
    global best_loss
    best_loss = checkpoint['test_loss']
    start_epoch = checkpoint['epoch']
    
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.1)

In [15]:
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            acc = f'{(100. * correct / total)}%'
        
            progress_bar.set_postfix(loss=train_loss/(batch_idx+1), accuracy=acc)
            progress_bar.update(inputs.size(0))
            
        
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = loss_fn(outputs, targets)
            
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                acc = f'{(100. * correct / total)}%'
                progress_bar.set_postfix(loss=test_loss/(batch_idx+1), accuracy=acc)
                progress_bar.update(inputs.size(0))

    best_acc = 0
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.t7')
        best_acc = acc

In [16]:
for epoch in range(0, 100):
    train(epoch)
#     test(epoch)

100%|██████████| 10/10 [00:00<00:00, 16.79it/s, accuracy=0.0%, loss=6.91]
100%|██████████| 10/10 [00:00<00:00, 15.85it/s, accuracy=30.0%, loss=6.5]
100%|██████████| 10/10 [00:00<00:00, 15.50it/s, accuracy=20.0%, loss=9.11]
100%|██████████| 10/10 [00:00<00:00, 15.49it/s, accuracy=30.0%, loss=5.09]
100%|██████████| 10/10 [00:00<00:00, 15.71it/s, accuracy=30.0%, loss=5.33]
100%|██████████| 10/10 [00:00<00:00, 16.07it/s, accuracy=40.0%, loss=3.74]
100%|██████████| 10/10 [00:00<00:00, 15.79it/s, accuracy=40.0%, loss=2.49]
100%|██████████| 10/10 [00:00<00:00, 15.88it/s, accuracy=40.0%, loss=2.31]
100%|██████████| 10/10 [00:00<00:00, 15.90it/s, accuracy=60.0%, loss=1.42]
100%|██████████| 10/10 [00:00<00:00, 15.99it/s, accuracy=60.0%, loss=1.04]
100%|██████████| 10/10 [00:00<00:00, 15.86it/s, accuracy=50.0%, loss=0.918]
100%|██████████| 10/10 [00:00<00:00, 15.81it/s, accuracy=60.0%, loss=0.95]
100%|██████████| 10/10 [00:00<00:00, 15.83it/s, accuracy=60.0%, loss=0.861]
100%|██████████| 10/10 [0

KeyboardInterrupt: 

## Inference

## Export (API)

## Export (CoreML)