# Step-by-step Guide to Train a Network
## 0. Summary
This tutorial includes the following three parts:

1. Train a LeNet-5 on MNIST;

2. Train a VGG and a ResNet on CIFAR-10;

3. Train a network using mixup strategy.

This tutorial is based on the [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)

## 1. Preparation

We will do the following steps to train a network:

1. Load and normalizing the dataset;

2. Define the network;

3. Decide to use which loss function, optimizer and/or other stategies;

4. Train the network on the training data;

5. Test the network on the test data.

In [1]:
import errno
import os
import os.path as osp
import shutil
from collections import OrderedDict
import time

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.datasets.mnist import MNIST
from torchvision.models.vgg import vgg11
from tqdm import tqdm
import numpy as np
from torch.autograd import Variable

In [2]:
# Set device
if torch.cuda.is_available():
    # for windows and linux with GPU support
    device = 'cuda:0'
elif torch.backends.mps.is_available():
    # for Apple silicon Mac
    device = 'mps'
else:
    device = 'cpu'
print(device)

mps


### 1.1 Load and normalizing the dataset
In this lecture, we will use two datasets, MNIST and CIFAR-10. In our experiments, we will use *torchvision* to load the dataset.

In [3]:
mnist_data_path = 'data/mnist'
mnist_transform = T.Compose([
    T.Resize((32, 32)),
    T.ToTensor(),
])
mnist_train = MNIST(mnist_data_path, download=True, transform=mnist_transform)
mnist_test = MNIST(mnist_data_path, train=False, transform=mnist_transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=1024, num_workers=2, pin_memory=True)

In [4]:
cifar_data_path = 'data/cifar10'
normalize = T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
cifar_train_transform = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    normalize,
])
cifar_test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
cifar_train = CIFAR10(cifar_data_path, download=True, transform=cifar_train_transform)
cifar_test = CIFAR10(cifar_data_path, train=False, transform=cifar_test_transform)
cifar_train_loader = DataLoader(cifar_train, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
cifar_test_loader = DataLoader(cifar_test, batch_size=1024, num_workers=2, pin_memory=True)

Files already downloaded and verified


### 1.2 Define the network

In [5]:
class LeNet5(nn.Module):

    def __init__(self, num_classes):
        super(LeNet5, self).__init__()
        self.extractor = nn.Sequential(OrderedDict([
            ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
            ('relu1', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
            ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
            ('relu2', nn.ReLU()),
            ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
            ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
            ('relu3', nn.ReLU()),
        ]))
        self.classifier = nn.Sequential(OrderedDict([
            ('f6', nn.Linear(120, 84)),
            ('relu4', nn.ReLU()),
            ('f7', nn.Linear(84, num_classes)),
        ]))

    def forward(self, x):
        x = self.extractor(x)
        x = x.flatten(1)
        x = self.classifier(x)
        return x

lenet = LeNet5(num_classes=10)

In [6]:
vgg = vgg11(num_classes=10)

In [7]:
'''ResNet in PyTorch.
BasicBlock and Bottleneck module is from the original ResNet paper:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
PreActBlock and PreActBottleneck module is from the later paper:
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, lin=0, lout=5):
        out = x
        if lin < 1 and lout > -1:
            out = self.conv1(out)
            out = self.bn1(out)
            out = F.relu(out)
        if lin < 2 and lout > 0:
            out = self.layer1(out)
        if lin < 3 and lout > 1:
            out = self.layer2(out)
        if lin < 4 and lout > 2:
            out = self.layer3(out)
        if lin < 5 and lout > 3:
            out = self.layer4(out)
        if lout > 4:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        return out


def ResNet18():
    return ResNet(PreActBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

### 1.3 Decide to use which loss function
In our experiments, we will use the Cross-Entropy loss.

In [8]:
criterion = nn.CrossEntropyLoss()

### 1.4 Define the training and test process

In [9]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def mkdir_if_missing(directory):
    if not osp.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

def save_checkpoint(state, is_best=False, fpath=''):
    if len(osp.dirname(fpath)) != 0:
        mkdir_if_missing(osp.dirname(fpath))
    torch.save(state, fpath)
    if is_best:
        shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))

def adjust_learning_rate(optimizer, epoch, initial_lr):
    """decrease the learning rate at 100 and 150 epoch"""
    lr = initial_lr
    if epoch >= 100:
        lr /= 10
    if epoch >= 150:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
def warp_tqdm(data_loader, disable_tqdm):
    if disable_tqdm:
        tqdm_loader = data_loader
    else:
        tqdm_loader = tqdm(data_loader, ncols=0)
    return tqdm_loader

In [10]:
def train(train_loader, model, criterion, optimizer, epoch):
    losses = AverageMeter()

    # switch to train mode
    model.train()

    start_time = time.time()

    for input, target in warp_tqdm(train_loader, True):


        input = input.to(device)
        target = target.to(device)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

    log = 'Epoch:{:03} Time:{:.3f}s Loss: {loss.avg:.4f} '.format(epoch, time.time()-start_time, loss=losses)
    return log


In [11]:
def test(test_loader, model, criterion):
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for input, target in test_loader:

        # compute output
        with torch.no_grad():
            input = input.to(device)
            target = target.to(device)
            output = model(input)

        # measure accuracy and record loss
        acc1 = accuracy(output.data, target)[0]
        top1.update(acc1.item(), input.size(0))

        # measure elapsed time

    log = 'Test Acc@1: {top1.avg:.3f}'.format(top1=top1)

    return top1.avg, log

## 2. Train a LeNet-5 on MNIST

In [12]:
ckpt_dir = 'experiments/mnist'
num_epochs = 15
model = lenet.to(device)
train_loader = mnist_train_loader
test_loader = mnist_test_loader
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
best_acc = 0
for epoch in range(num_epochs):
    train_log = train(train_loader, model, criterion, optimizer, epoch)
    acc, test_log = test(test_loader, model, criterion)
    log = train_log + test_log
    print(log)
    is_best = acc > best_acc
    best_acc = max(acc, best_acc)
    if is_best:
        save_checkpoint({'epoch':epoch,
        'state_dict':model.state_dict(),
        'acc': acc,
        }, False,  os.path.join(ckpt_dir, 'best_model.pth.tar'))

Epoch:000 Time:3.582s Loss: 2.3022 Test Acc@1: 12.190
Epoch:001 Time:2.997s Loss: 2.2976 Test Acc@1: 18.230
Epoch:002 Time:2.850s Loss: 2.2912 Test Acc@1: 33.600
Epoch:003 Time:2.990s Loss: 2.2777 Test Acc@1: 37.680
Epoch:004 Time:2.945s Loss: 2.2210 Test Acc@1: 47.120
Epoch:005 Time:2.966s Loss: 1.5398 Test Acc@1: 79.520
Epoch:006 Time:2.853s Loss: 0.5888 Test Acc@1: 88.070
Epoch:007 Time:2.860s Loss: 0.4033 Test Acc@1: 89.920
Epoch:008 Time:2.989s Loss: 0.3297 Test Acc@1: 91.870
Epoch:009 Time:2.857s Loss: 0.2816 Test Acc@1: 92.800
Epoch:010 Time:2.923s Loss: 0.2458 Test Acc@1: 93.670
Epoch:011 Time:2.913s Loss: 0.2188 Test Acc@1: 94.120
Epoch:012 Time:2.954s Loss: 0.1972 Test Acc@1: 94.780
Epoch:013 Time:2.881s Loss: 0.1795 Test Acc@1: 95.490
Epoch:014 Time:2.962s Loss: 0.1651 Test Acc@1: 95.530


## 3. Train a VGG/ResNet on CIFAR-10

In [13]:
ckpt_dir = 'experiments/cifar10-vgg'
# num_epochs = 200
num_epochs = 10
model = vgg.to(device)
train_loader = cifar_train_loader
test_loader = cifar_test_loader
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
best_acc = 0
for epoch in range(num_epochs):
    adjust_learning_rate(optimizer, epoch, 0.1)
    train_log = train(train_loader, model, criterion, optimizer, epoch)
    acc, test_log = test(test_loader, model, criterion)
    log = train_log + test_log
    print(log)
    is_best = acc > best_acc
    best_acc = max(acc, best_acc)
    if is_best:
        save_checkpoint({'epoch':epoch,
        'state_dict':model.state_dict(),
        'acc': acc,
        }, False,  os.path.join(ckpt_dir, 'best_model.pth.tar'))

Epoch:000 Time:84.020s Loss: 2.2953 Test Acc@1: 10.650
Epoch:001 Time:84.599s Loss: 2.0956 Test Acc@1: 28.860
Epoch:002 Time:84.586s Loss: 1.7746 Test Acc@1: 41.350
Epoch:003 Time:85.737s Loss: 1.5418 Test Acc@1: 45.290
Epoch:004 Time:84.972s Loss: 1.4598 Test Acc@1: 49.550
Epoch:005 Time:85.266s Loss: 1.4081 Test Acc@1: 52.110
Epoch:006 Time:85.476s Loss: 1.2246 Test Acc@1: 55.140
Epoch:007 Time:85.660s Loss: 1.1308 Test Acc@1: 64.620
Epoch:008 Time:85.526s Loss: 1.0546 Test Acc@1: 68.370
Epoch:009 Time:86.257s Loss: 0.9687 Test Acc@1: 69.850


In [14]:
ckpt_dir = 'experiments/cifar10-resnet'
# num_epochs = 200
num_epochs = 10
model = ResNet18().to(device)
train_loader = cifar_train_loader
test_loader = cifar_test_loader
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
best_acc = 0
for epoch in range(num_epochs):
    adjust_learning_rate(optimizer, epoch, 0.1)
    train_log = train(train_loader, model, criterion, optimizer, epoch)
    acc, test_log = test(test_loader, model, criterion)
    log = train_log + test_log
    print(log)
    is_best = acc > best_acc
    best_acc = max(acc, best_acc)
    if is_best:
        save_checkpoint({'epoch':epoch,
        'state_dict':model.state_dict(),
        'acc': acc,
        }, False,  os.path.join(ckpt_dir, 'best_model.pth.tar'))

Epoch:000 Time:115.617s Loss: 1.6182 Test Acc@1: 49.540
Epoch:001 Time:116.192s Loss: 1.0852 Test Acc@1: 62.220
Epoch:002 Time:116.093s Loss: 0.8269 Test Acc@1: 69.960
Epoch:003 Time:115.787s Loss: 0.6583 Test Acc@1: 76.500
Epoch:004 Time:115.703s Loss: 0.5580 Test Acc@1: 80.520
Epoch:005 Time:116.433s Loss: 0.4834 Test Acc@1: 77.030
Epoch:006 Time:117.487s Loss: 0.4288 Test Acc@1: 82.980
Epoch:007 Time:118.290s Loss: 0.3894 Test Acc@1: 80.960
Epoch:008 Time:117.545s Loss: 0.3476 Test Acc@1: 85.240
Epoch:009 Time:117.586s Loss: 0.3209 Test Acc@1: 85.870


## 4. MixUp

MixUp is a simple strategy to alleviate the problem of *memorizaiton* and *sensitivity to adversarial examples* when training deep neural networks. Basically, network trained with MixUp uses convex combinations of data pairs(images and labels) for training. This regularizes the network to favor simple linear behavior in-between training examples.

Suppose we have two data pairs $(x_i, y_i)$ and $(x_j, y_j)$, the MixUp virtual training example is constructed as 
$$\tilde{x}=\lambda x_i + (1-\lambda) x_j,\quad\textrm{where }x_i,x_j \textrm{ are raw input vectors}$$
$$\tilde{y}=\lambda y_i + (1-\lambda) y_j,\quad\textrm{where }y_i,y_j \textrm{ are one-hot label encodings}$$
where $\lambda\sim\mathrm{Beta}(\alpha,\alpha),\alpha\in(0,\infty)$.

In [15]:
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).to(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [16]:
def train_mixup(train_loader, model, criterion, optimizer, alpha, epoch):
    losses = AverageMeter()

    # switch to train mode
    model.train()

    start_time = time.time()

    for input, target in warp_tqdm(train_loader, True):


        input = input.to(device)
        target = target.to(device)

        input, target_a, target_b, lam = mixup_data(input, target, alpha, True)
        input, target_a, target_b = map(Variable, (input, target_a, target_b))

        # compute output
        output = model(input)
        loss = mixup_criterion(criterion, output, target_a, target_b, lam)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

    log = 'Epoch:{:03} Time:{:.3f} Loss: {loss.avg:.4f} '.format(epoch, time.time() - start_time, loss=losses)
    return log


## Train a ResNet on CIFAR-10 with MixUp

In [17]:
ckpt_dir = 'experiments/cifar10-resnet-mixup'
alpha = 1.0
# num_epochs = 200
num_epochs = 10
model =ResNet18().to(device)
train_loader = cifar_train_loader
test_loader = cifar_test_loader
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
best_acc = 0
for epoch in range(num_epochs):
    adjust_learning_rate(optimizer, epoch, 0.1)
    train_log = train_mixup(train_loader, model, criterion, optimizer, alpha, epoch)
    acc, test_log = test(test_loader, model, criterion)
    log = train_log + test_log
    print(log)
    is_best = acc > best_acc
    best_acc = max(acc, best_acc)
    if is_best:
        save_checkpoint({'epoch':epoch,
        'state_dict':model.state_dict(),
        'acc': acc,
        }, False, os.path.join(ckpt_dir, 'best_model.pth.tar'))

Epoch:000 Time:119.481 Loss: 2.1024 Test Acc@1: 32.690
Epoch:001 Time:120.603 Loss: 1.8866 Test Acc@1: 50.470
Epoch:002 Time:119.001 Loss: 1.7622 Test Acc@1: 56.830
Epoch:003 Time:117.712 Loss: 1.6800 Test Acc@1: 59.980
Epoch:004 Time:117.717 Loss: 1.6073 Test Acc@1: 60.210
Epoch:005 Time:118.680 Loss: 1.5544 Test Acc@1: 66.940
Epoch:006 Time:117.109 Loss: 1.5052 Test Acc@1: 66.130
Epoch:007 Time:115.651 Loss: 1.4903 Test Acc@1: 70.520
Epoch:008 Time:115.897 Loss: 1.4607 Test Acc@1: 73.250
Epoch:009 Time:115.572 Loss: 1.4349 Test Acc@1: 75.310
