In this exercise, we are going to implement a [ResNet-like](https://arxiv.org/pdf/1512.03385.pdf) architecture for the image classification task.
The model is trained on the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset.

Tasks:

1. Implement residual connections in the missing places in the code.
2. Check that the given implementation reaches 97% test accuracy after a few epochs.
3. Check that when extending the residual blocks to 20 (having 40+ layers total), the model still trains well, i.e., achieves 97+% accuracy after three epochs.

We recommend switching to GPU, after initial testing on CPU.
With 20 residual blocks, one training epoch takes ~15 min on Colab CPU and ~30s on Colab GPU.
Remember to "disconnect and delete runtime" when you finish using it.

In [1]:
from typing import TypedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2
from tqdm import tqdm

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Input: shape (B, in_channels, H, W).
        Output: shape (B, out_channels, H, W).
        """
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x

In [3]:
class Net(nn.Module):
    def __init__(self, num_residual_blocks: int) -> None:
        super().__init__()
        self.backbone = nn.Sequential(
            ResidualBlock(1, 16),
            *(ResidualBlock(16, 16) for _ in range(num_residual_blocks - 1)),
        )
        self.head = nn.Linear(28 * 28 * 16, 10)

    def forward(self, x: Tensor) -> Tensor:
        """
        Input shape: shape (B, 1, 28, 28).
        Output shape: log probabilities, shape (B, 10).
        """
        x = self.backbone(x)  # shape (B, 16, 28, 28)
        x = nn.Flatten(start_dim=1)(x)  # shape (B, 28 * 28 * 16)
        x = self.head(x)  # shape (B, 10)
        output = nn.LogSoftmax(dim=1)(x)
        return output

In [4]:
def train(
    model: nn.Module,
    device: torch.device,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    epoch: int,
    log_interval: int,
) -> None:
    model.train()
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}")
    for batch_idx, (data, target) in progress_bar:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            progress_bar.set_postfix(loss=loss.item())


def test(model: nn.Module, device: torch.device, test_loader: DataLoader) -> None:
    model.eval()
    test_loss = 0.0
    correct = 0
    test_set_size = 0
    with torch.no_grad():
        for data, target in test_loader:
            test_set_size += data.shape[0]
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()

    test_loss /= test_set_size

    print(f"Test loss: {test_loss:.4f}, accuracy: {correct}/{test_set_size} ({correct / test_set_size:.1%})")

In [5]:
batch_size = 256
test_batch_size = 1000
epochs = 3
lr = 1e-2
seed = 1
log_interval = 10

# Check for CUDA / MPS (Apple) / XPU (Intel) / ... accelerator.
# This does not detect XLA devices (Google TPUs), they'd need separate checks.
device = torch.accelerator.current_accelerator(check_available=True) or torch.device("cpu")
use_accel = device != torch.device("cpu")
print(use_accel, device)

True cuda


In [6]:
class DataloaderArgs(TypedDict, total=False):
    batch_size: int
    shuffle: bool
    num_workers: int
    pin_memory: bool

train_kwargs: DataloaderArgs = {"batch_size": batch_size, "num_workers": 2, "shuffle": True, "pin_memory": use_accel}
test_kwargs: DataloaderArgs = {"batch_size": test_batch_size, "num_workers": 2, "pin_memory": use_accel}

In [7]:
torch.manual_seed(seed)

transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.1307,), (0.3081,)),
    ]
)
train_dataset = datasets.MNIST(
    "../data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST("../data", train=False, transform=transform)
train_loader = DataLoader(train_dataset, **train_kwargs)
test_loader = DataLoader(test_dataset, **test_kwargs)

100%|██████████| 9.91M/9.91M [00:07<00:00, 1.27MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 194kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.4MB/s]


In [10]:
model = Net(num_residual_blocks=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval)
    test(model, device, test_loader)

Epoch 1: 100%|██████████| 235/235 [00:07<00:00, 29.96it/s, loss=0.209]


Test loss: 0.4166, accuracy: 9261/10000 (92.6%)


Epoch 2: 100%|██████████| 235/235 [00:07<00:00, 29.38it/s, loss=0.0704]


Test loss: 0.1278, accuracy: 9681/10000 (96.8%)


Epoch 3: 100%|██████████| 235/235 [00:08<00:00, 29.18it/s, loss=0.0236]


Test loss: 0.0868, accuracy: 9727/10000 (97.3%)
