# Multi-layer perceptron activity

In this activity, you will make a multi-layer perceptron (MLP) model in the PyTorch deep learning package to perform classification of hand-written digits in the classic MNIST dataset.

In [None]:
import torch  # Main torch import for torch tensors
import torch.nn as nn  # Neural network module for building deep learning models
import torch.nn.functional as F  # Functional module, includes activation functions
import torch.optim as optim  # Optimization module
import torchvision  # Vision / image processing package built on top of torch

from matplotlib import pyplot as plt  # Plotting and visualization
from sklearn.metrics import accuracy_score  # Computing accuracy metric

In [None]:
# For Colab
# DATA_PATH = "/content/datafiles"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
DATA_PATH = "/Users/trevoryu/Code/data/mnist/"

# 1. Setup the data

In [None]:
# Common practice to normalize input data to neural networks (0 mean, unit variance)
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # All inputs to PyTorch neural networks must be torch.Tensor
    torchvision.transforms.Normalize(mean=0.1307, std=0.3081)  # Subtracts mean and divides by std. Note that the raw data is between [0, 1]
])

# Download the MNIST data and lazily apply the transformation pipeline
train_data = torchvision.datasets.MNIST(DATA_PATH, train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST(DATA_PATH, train=False, download=True, transform=transform)

# Setup data loaders
# Note: Iterating through the dataloader yields batches of (inputs, targets)
# where inputs is a torch.Tensor of shape (B, 1, 28, 28) and targets is a torch.Tensor of shape (B,)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)

In [None]:
fig, axs = plt.subplots(4, 5, figsize=(5, 6))

plot_images = []
plot_labels = []

for i, ax in enumerate(axs.flatten(), start=1000):
    (image, label) = test_data[i]

    # Save this data for later
    plot_images.append(image)
    plot_labels.append(label)

    # Plot each image
    ax.imshow(image.squeeze(), cmap="gray")
    ax.set_title(f"Label: {label}")
    ax.axis("off")
plt.show()

plot_images = torch.cat(plot_images)  # Combine all the images into a single batch for later

print(f"Each image is a torch.Tensor and has shape {image.shape}.")
print(f"The labels are the integers 0 to 9, representing the digits.")

# 2a. Define the ResNetV2 Block

ResNet-V1: [Deep Residual Learning for Image Recognition (He et al., 2015)](https://arxiv.org/pdf/1512.03385.pdf)

ResNet-V2: [Identity Mappings in Deep Residual Networks (He et al., 2016)](https://arxiv.org/pdf/1603.05027.pdf)
- BN-ReLU-Conv-BN-ReLU-Conv
- First conv does stride (spatial downsampling) and channel changes
- Input: (B, C_in, H, W)
- Output: (B, C_out, H/stride, W/stride)

Relevant documentation:

- [PyTorch nn.Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)

- [PyTorch nn.BatchNorm2d](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#batchnorm2d)

- [PyTorch nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, channels, stride=1):
        super(ResidualBlock, self).__init__()
        conv_kwargs = {
            "kernel_size": (3, 3),
            "padding": 1,  # To ensure 3x3 conv does not reduce image size. padding=1 also works
            "bias": False
        }
        self.stride = stride
        self.in_channels = in_channels
        self.channels = channels
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        # This conv is in_channels -> channels and applies stride
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels, stride=stride, **conv_kwargs)
        self.bn2 = nn.BatchNorm2d(channels)
        # This conv is channels -> channels
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, **conv_kwargs)
    
    def strided_identity(self, x):
        # Downsample with 'nearest' method (this is striding if dims are divisible by stride)
        # Equivalently x = x[:, :, ::stride, ::stride].contiguous()
        if self.stride != 1:
            x = F.interpolate(x, mode='nearest', scale_factor=(1/self.stride))
        # Create padding tensor for extra channels
        if self.channels != self.in_channels:
            (b, c, h, w) = x.shape
            num_pad_channels = self.channels - self.in_channels
            pad = torch.zeros((b, num_pad_channels, h, w), device=x.device)
            # Append padding to the downsampled identity
            x = torch.cat((x, pad), dim=1)
        return x

    def forward(self, x):
        # TODO: Compute residual
        identity = self.strided_identity(x)
        # TODO: Computer processing pathway
        z = self.bn1(x)
        z = self.relu(z)
        z = self.conv1(z)
        z = self.bn2(z)
        z = self.relu(z)
        z = self.conv2(z)
        # TODO: Add residual and return result
        out = identity + z
        return out
      


In [None]:
# Inputs are of shape (B, 1, 28, 28) and we expect an output tensor of shape (B, 32, 28, 28)
block = ResidualBlock(1, 32)
x = torch.randn(2, 1, 28, 28)
z = block(x)
z.shape

In [None]:
# Inputs are of shape (B, 32, 28, 28) and we expect an output tensor of shape (B, 64, 14, 14)
block = ResidualBlock(32, 64, stride=2)
x = torch.randn(2, 32, 28, 28)
z = block(x)
z.shape

# 2b. Define CNN architecture

Input layer:
- Conv(1, 32)
- BatchNorm

Processing layers:
1. ResNetBlock(32, 32)
2. ResNetBlock(32, 32)
3. ResNetBlock(32, 64, s=2)
4. ResNetBlock(64, 64)
5. ResNetBlock(64, 128, s=2)
6. ResNetBlock(128, 128)

Output layers:
- AdaptiveAveragePooling
- Linear(49, 10)


In [None]:
class ResNetV2(nn.Module):
    def __init__(self, in_channels=1, in_shape=(28,28)):
        super().__init__()
        self.in_channels = in_channels
        self.in_shape = in_shape
        # Input layers
        self.input_conv = nn.Conv2d(in_channels, 32, kernel_size=(3, 3), padding=1)
        self.input_bn = nn.BatchNorm2d(32)
        # Processing blocks
        self.layer_1 = ResidualBlock(32, 32)
        self.layer_2 = ResidualBlock(32, 32)
        self.layer_3 = ResidualBlock(32, 64, stride=2)
        self.layer_4 = ResidualBlock(64, 64)
        self.layer_5 = ResidualBlock(64, 128, stride=2)
        self.layer_6 = ResidualBlock(128, 128)
        # Output layers
        self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.output_layer = nn.Linear(128, 10)
    
    def forward(self, x):
        """
        :param x: Tensor of shape (B, 1, 28, 28)
        :returns: Tensor of shape (B, 10)
        """
        # Input layers
        x = self.input_conv(x)
        x = self.input_bn(x)
        # Processing blocks
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.layer_5(x)
        x = self.layer_6(x)
        # Output layers
        x = self.pool(x)
        x = x.squeeze()
        x = self.output_layer(x)
        return x

In [None]:
model = ResNetV2()
inputs = torch.randn(2, 1, 28, 28)
outputs = model(inputs)
outputs.shape

# 3. Setup optimizer and loss function

The current standard optimizer in deep learning is the Adam optimizer. Use a learning rate of $1\times 10^{-2}$. 

The task we are performing is multiclass classification (10 independent classes, one for each digit). The loss function to use for this task is cross entropy loss.

Relevant documentation:
- [PyTorch optimizers](https://pytorch.org/docs/stable/optim.html)

- [PyTorch loss functions](https://pytorch.org/docs/stable/nn.html#loss-functions)

In [None]:
# TODO: Instantiate your model and setup the optimizer
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5

model = ResNetV2().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

# 4. Setup training loop

During the training loop, we perform the following steps:

1. Fetch the next batch of inputs and targets from the dataloader
2. Zero the parameter gradients
3. Compute the model output predictions from the inputs
4. Compute the loss between the model outputs and the targets
5. Compute the parameter gradients with backpropagation
6. Perform a gradient descent step with the optimizer to update the model parameters

Relevant documentation:
- [PyTorch optimization step](https://pytorch.org/docs/stable/optim.html#taking-an-optimization-step)

In [None]:
def train(model, train_loader, loss_fn, optimizer, device="cpu", epoch=-1):
    """
    Trains a model for one epoch (one pass through the entire training data).

    :param model: PyTorch model
    :param train_loader: PyTorch Dataloader for training data
    :param loss_fn: PyTorch loss function
    :param optimizer: PyTorch optimizer, initialized with model parameters
    :kwarg epoch: Integer epoch to use when printing loss and accuracy
    :returns: Accuracy score
    """
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(device)
    model.train()  # Set model in training mode
    for i, (inputs, targets) in enumerate(train_loader):  # 1. Fetch next batch of data
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()  # 2. Zero parameter gradients
        outputs = model(inputs)  # 3. Compute model outputs
        loss = loss_fn(outputs, targets)  # 4. Compute loss between outputs and targets
        loss.backward()  # 5. Backpropagation for parameter gradients
        optimizer.step()  # 6. Gradient descent step

        # Track some values to compute statistics
        total_loss += loss.item()
        preds = torch.argmax(outputs.cpu(), dim=-1)  # Take the class with the highest output as the prediction
        all_predictions.extend(preds.tolist())
        all_targets.extend(targets.tolist())

        # Print some statistics every 100 batches
        if i % 100 == 0:
            running_loss = total_loss / (i + 1)
            print(f"Epoch {epoch + 1}, batch {i + 1}: loss = {running_loss:.2f}")

    acc = accuracy_score(all_targets, all_predictions)

    # Print average loss and accuracy
    print(f"Epoch {epoch + 1} done. Average train loss = {total_loss / len(train_loader):.2f}, average train accuracy = {acc * 100:.3f}%")
    return acc

In testing, we don't need to compute gradients or do an optimization step.

In [None]:
def test(model, test_loader, loss_fn, device="cpu", epoch=-1):
    """
    Tests a model for one epoch of test data.

    Note:
        In testing and evaluation, we do not perform gradient descent optimization, so steps 2, 5, and 6 are not needed.
        For performance, we also tell torch not to track gradients by using the `with torch.no_grad()` context.

    :param model: PyTorch model
    :param test_loader: PyTorch Dataloader for test data
    :param loss_fn: PyTorch loss function
    :kwarg epoch: Integer epoch to use when printing loss and accuracy

    :returns: Accuracy score
    """
    total_loss = 0
    all_predictions = []
    all_targets = []
    model = model.to(device)
    model.eval()  # Set model in evaluation mode
    for i, (inputs, targets) in enumerate(test_loader):  # 1. Fetch next batch of data
        inputs = inputs.to(device)
        targets = targets.to(device)
        with torch.no_grad():
            outputs = model(inputs)  # 3. Compute model outputs
            loss = loss_fn(outputs, targets)  # 4. Compute loss between outputs and targets

            # Track some values to compute statistics
            total_loss += loss.item()
            preds = torch.argmax(outputs.cpu(), dim=-1)  # Take the class with the highest output as the prediction
            all_predictions.extend(preds.tolist())
            all_targets.extend(targets.tolist())

    acc = accuracy_score(all_targets, all_predictions)

    # Print average loss and accuracy
    print(f"Epoch {epoch + 1} done. Average test loss = {total_loss / len(test_loader):.2f}, average test accuracy = {acc * 100:.3f}%")
    return acc

# 6. Train the model

In [None]:
train_metrics = []
test_metrics = []
for epoch in range(NUM_EPOCHS):
    # TODO: Fill in the rest of the arguments to the train and test functions
    train_acc = train(model, train_loader, loss_fn, optimizer, device=DEVICE, epoch=epoch)
    test_acc = test(model, test_loader, loss_fn, device=DEVICE, epoch=epoch)

    train_metrics.append(train_acc)
    test_metrics.append(test_acc)
    #####

# 5. Visually compare the model predictions

We will lastly see the trained model's predictions on the 20 examples we visualized in the beginning.

In [None]:
# Evaluate the model on the plot_images
model.eval()
model = model.to("cpu")

with torch.no_grad():
    plot_outputs = model(plot_images)
    plot_preds = torch.argmax(plot_outputs, dim=-1)

# Plot and show the labels
fig, axs = plt.subplots(4, 5, figsize=(7, 8))

for i, ax in enumerate(axs.flatten()):
    image = plot_images[i]
    label = plot_labels[i]
    pred = plot_preds[i]

    ax.imshow(image.squeeze(), cmap="gray")
    ax.set_title(f"Prediction: {pred}\nLabel: {label}")
    ax.axis("off")
plt.show()

In [None]:
# Plot training curves
xs = 1 + torch.arange(NUM_EPOCHS)
plt.plot(xs, train_metrics, "o-", label="Train accuracy")
plt.plot(xs, test_metrics, "o-", label="Test accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy score")
plt.legend()
plt.show()