# Exercise 4. Convolutional networks

## Part 3. ResNet

In the third part you need to train a convolutional neural network with a ResNet architecture.

In [None]:
skip_training = False  # Set this flag to True before validation and submission

In [None]:
# During evaluation, this cell sets skip_training to True
# skip_training = True

In [None]:
# Select data directory
import os
if os.path.isdir('/coursedata'):
    course_data_dir = '/coursedata'
elif os.path.isdir('../data'):
    course_data_dir = '../data'
else:
    # Specify course_data_dir on your machine
    # course_data_dir = ...
    # YOUR CODE HERE
    raise NotImplementedError()

print('The data directory is %s' % course_data_dir)

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
# Select the device for training (use GPU if you have one)
#device = torch.device('cuda:0')
device = torch.device('cpu')

In [None]:
if skip_training:
    # The models are always evaluated on CPU
    device = torch.device("cpu")

## FashionMNIST dataset

Let us use the FashionMNIST dataset. It consists of 60,000 training images of 10 classes: 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

data_dir = os.path.join(course_data_dir, 'fashion_mnist')
print('Data stored in %s' % data_dir)
trainset = torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform)

classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
           'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

In [None]:
# This function computes the accuracy on the test dataset
def compute_accuracy(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

## ResNet

Let us train a network with an architecure inspired by [ResNet](https://arxiv.org/pdf/1512.03385.pdf).

### ResNet block
Our ResNet consists of blocks with two convolutional layers and a skip connection.

In the most general case, our implementation should have:

<img src="resnet_block_04.png" width=220 style="float: right;">

* Two convolutional layers with:
    * 3x3 kernel
    * no bias terms
    * padding with one pixel on both sides
    * 2d batch normalization after each convolutional layer.

* **The first convolutional layer also (optionally) has:**
    * different number of input channels and output channels
    * change of the resolution with stride.

* The skip connection:
    * simply copies the input if the resolution and the number of channels do not change.
    * If either the resolution or the number of channels change, the skip connection should have one convolutional layer with:
        * 1x1 convolution **without bias**
        * change of the resolution with stride (optional)
        * different number of input channels and output channels (optional)
    * If either the resolution or the number of channels change, the 1x1 convolutional layer is followed by 2d batch normalization.

* The ReLU nonlinearity is applied after the first convolutional layer and at the end of the block.

**Note: Batch normalization is expected to be right after a convolutional layer.**

<img src="resnet_blocks_123.png" width=650 style="float: top;">

The implementation should also handle specific cases such as:

Left: The number of channels and the resolution do not change.
There are no computations in the skip connection.

Middle: The number of channels changes, the resolution does not change.

Right: The number of channels does not change, the resolution changes.

Your task is to implement this block. You should use the implementations of layers in `nn.Conv2d`, `nn.BatchNorm2d` as the tests rely on those implementations.

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.
        """
        super(Block, self).__init__()
        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, x):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# Test your implementation of the Block

# The number of channels and resolution do not change
batch_size = 20
x = torch.zeros(batch_size, 16, 28, 28)
block = Block(in_channels=16, out_channels=16)
y = block(x)
assert y.shape == torch.Size([batch_size, 16, 28, 28]), "Bad shape of y: y.shape={}".format(y.shape)

# Increase the number of channels
block = Block(in_channels=16, out_channels=32)
y = block(x)
assert y.shape == torch.Size([batch_size, 32, 28, 28]), "Bad shape of y: y.shape={}".format(y.shape)

# Decrease the resolution
block = Block(in_channels=16, out_channels=16, stride=2)
y = block(x)
assert y.shape == torch.Size([batch_size, 16, 14, 14]), "Bad shape of y: y.shape={}".format(y.shape)

# Increase the number of channels and decrease the resolution
block = Block(in_channels=16, out_channels=32, stride=2)
y = block(x)
assert y.shape == torch.Size([batch_size, 32, 14, 14]), "Bad shape of y: y.shape={}".format(y.shape)

print('The shapes seem to be ok.')

In [None]:
# This is a cell used for grading

In [None]:
# This is a cell used for grading

### Group of blocks

ResNet consists of several groups of blocks. The first block in a group may change the number of channels (often multiples the number by 2) and subsample (using strides).

<img src="resnet_group.png" width=500 style="float: left;">

In [None]:
# Let us implement a group of blocks in this cell
class GroupOfBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, stride=1):
        super(GroupOfBlocks, self).__init__()

        first_block = Block(in_channels, out_channels, stride)
        other_blocks = [Block(out_channels, out_channels) for _ in range(1, n_blocks)]
        self.group = nn.Sequential(first_block, *other_blocks)

    def forward(self, x):
        return self.group(x)

In [None]:
# Let's test it
group = GroupOfBlocks(in_channels=10, out_channels=20, n_blocks=3)
print(group)

### ResNet

Let us mplement a ResNet with the following architecture. It contains three groups of blocks, each group having two basic blocks.

<img src="resnet.png" width=900 style="float: left;">

The cell below contains the implementation of all the layers except for the groups of blocks. Your task is to insert the groups of blocks in the middle.

Note:
* The number of channels in the second **group** should be double the number of channels in the first **group**, the number of channels in the third **group** should be four times the number of channels in the first **group**.
* The second and the third **group** should reduce the resolution by `stride=2`, as shown in the figure.

In [None]:
class ResNet(nn.Module):
    def __init__(self, n_blocks, n_channels=64, num_classes=10):
        """
        Args:
          n_blocks (list):   A list with three elements which contains the number of blocks in 
                             each of the three groups of blocks in ResNet.
                             For instance, n_blocks = [2, 4, 6] means that the first group has two blocks,
                             the second group has four blocks and the third one has six blocks.
          n_channels (int):  Number of channels in the first group of blocks.
          num_classes (int): Number of classes.
        """
        assert len(n_blocks) == 3, "The number of groups should be three."
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_channels, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(n_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # YOUR CODE HERE
        raise NotImplementedError()

        self.avgpool = nn.AvgPool2d(kernel_size=4, stride=1)
        self.fc = nn.Linear(4*n_channels, num_classes)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x, verbose=False):
        if verbose: print(x.shape)
        x = self.conv1(x)
        if verbose: print('conv1:  ', x.shape)
        x = self.bn1(x)
        if verbose: print('bn1:    ', x.shape)
        x = self.relu(x)
        if verbose: print('relu:   ', x.shape)
        x = self.maxpool(x)
        if verbose: print('maxpool:', x.shape)

        # YOUR CODE HERE
        raise NotImplementedError()

        x = self.avgpool(x)
        if verbose: print('avgpool:', x.shape)

        x = x.view(-1, self.fc.in_features)
        if verbose: print('x.view: ', x.shape)
        x = self.fc(x)
        if verbose: print('out:    ', x.shape)

        return x

In [None]:
# Create a network with 2 block in each of the three groups
n_blocks = [2, 2, 2]  # number of blocks in the three groups
net = ResNet(n_blocks, n_channels=10)
net.to(device)

# Feed a batch of images from the training data to test the network
with torch.no_grad():
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    images = images.to(device)
    print('Shape of the input tensor:', images.shape)

    y = net.forward(images, verbose=True)
    print(y.shape)
    assert y.shape == torch.Size([32, 10]), "Bad shape of y: y.shape={}".format(y.shape)

print('The shapes to be ok.')

In [None]:
# Let us print the architecture of the network
net

In [None]:
# Let us now train the ResNet using the same training loop
n_blocks = [2, 2, 2]  # number of blocks in the three groups
net = ResNet(n_blocks, n_channels=16)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

In [None]:
n_epochs = 10

In [None]:
net.train()
for epoch in range(n_epochs):
    running_loss = 0.0
    print_every = 200  # mini-batches
    for i, (inputs, labels) in enumerate(trainloader, 0):
        # Transfer to GPU
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if (i % print_every) == (print_every-1):
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/print_every))
            running_loss = 0.0

        if skip_training:
            break
    if skip_training:
        break

    # Print accuracy after every epoch
    accuracy = compute_accuracy(net, testloader)
    print('Accuracy of the network on the test images: %.3f' % accuracy)

print('Finished Training')

You should get the test accuracy at about 90-91%.

In [None]:
# Save the network to a file, submit this file together with your notebook
filename = '4_resnet.pth'
if not skip_training:
    try:
        do_save = input('Do you want to save the model (type yes to confirm)? ').lower()
        if do_save == 'yes':
            torch.save(net.state_dict(), filename)
            print('Model saved to %s' % filename)
        else:
            print('Model not saved')
    except:
        raise Exception('The notebook should be run or validated with skip_training=True.')
else:
    net = ResNet(n_blocks, n_channels=16)
    net.load_state_dict(torch.load(filename, map_location=lambda storage, loc: storage))
    net.to(device)
    print('Model loaded from %s' % filename)

In [None]:
# Let us compute the accuracy on the test set
accuracy = compute_accuracy(net, testloader)
print('Accuracy of the network on the test images: %.3f' % accuracy)