Code based on https://github.com/pytorch/examples/blob/master/mnist/main.py

In this exercise, we are going to implement a [ResNet-like](https://arxiv.org/pdf/1512.03385.pdf) architecture for the image classification task.
The model is trained on the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset.

Tasks:

    1. Implement residual connections in the missing places in the code.

    2. Check that the given implementation reaches 97% test accuracy after a few epochs.

    3. Check that when extending the residual blocks to 20 (having 40+ layers total), the model still trains well, i.e., achieves 97+% accuracy after three epochs.


Note: in this lab scenario we are using mypy for typing. You can disable easily by not running the cell below.
Typing in python is not mandatory, but if the types are natural, it can lead to less debugging, especially
that types can be checked statically without running the code (typically done even within IDE).


In [1]:
!pip install nb-mypy -qqq
%load_ext nb_mypy

Version 1.0.5


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# There is no typing for torchvision yet.
from torchvision import datasets, transforms  # type: ignore
from torch.utils.data import DataLoader
from typing_extensions import TypedDict
from tqdm import tqdm

In [3]:
class ResidualConnection(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement forward function
        y = self.conv_block_2(self.conv_block_1(x))
        z = x + y
        return nn.functional.relu(z)

In [10]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.rc = nn.Sequential(
            ResidualConnection(1, 16),
            # TODO: verify that after increasing 3 to 19 still trains
            *(ResidualConnection(16, 16) for _ in range(19)),
        )
        self.fc = nn.Linear(
            28 * 28 * 16, 10
        )  # 28 * 28 * 16 is the size of flattened output of the last ResidualConnection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.rc(x)
        x = nn.Flatten(start_dim=1)(x)
        x = self.fc(x)
        output = nn.LogSoftmax(dim=1)(x)
        return output

In [5]:
def train(
    model: nn.Module,
    device: torch.device,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    epoch: int,
    log_interval: int,
) -> None:
    model.train()
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}")
    for batch_idx, (data, target) in pbar:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            pbar.set_postfix(loss=loss.item())


def test(model: nn.Module, device: torch.device, test_loader: DataLoader) -> None:
    model.eval()
    test_loss = 0.0
    correct = 0
    test_set_size = 0
    with torch.no_grad():
        for data, target in test_loader:
            test_set_size += data.shape[0]
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= test_set_size

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            test_set_size,
            100.0 * correct / test_set_size,
        )
    )

In [6]:
batch_size = 256
test_batch_size = 1000
epochs = 3
lr = 1e-2
seed = 1
log_interval = 10

In [7]:
use_cuda = torch.cuda.is_available()

torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

DataloaderArgs = TypedDict(
    "DataloaderArgs",
    {"batch_size": int, "shuffle": bool, "num_workers": int, "pin_memory": bool},
    total=False,
)

train_kwargs: DataloaderArgs = {"batch_size": batch_size}
test_kwargs: DataloaderArgs = {"batch_size": test_batch_size}
if use_cuda:
    cuda_kwargs: DataloaderArgs = {
        "num_workers": 1,
        "pin_memory": True,
        "shuffle": True,
    }
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [8]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = DataLoader(dataset1, **train_kwargs)
test_loader = DataLoader(dataset2, **test_kwargs)

In [11]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval)
    test(model, device, test_loader)

Epoch 1: 100%|██████████| 235/235 [02:25<00:00,  1.61it/s, loss=0.0684]



Test set: Average loss: 1.0374, Accuracy: 8902/10000 (89%)



Epoch 2: 100%|██████████| 235/235 [02:25<00:00,  1.62it/s, loss=0.357] 



Test set: Average loss: 0.8623, Accuracy: 8641/10000 (86%)



Epoch 3: 100%|██████████| 235/235 [02:26<00:00,  1.60it/s, loss=0.0113]



Test set: Average loss: 0.1524, Accuracy: 9572/10000 (96%)

