## Implementing a (simple) Convolutional Neural Network

In this notebook we implement a simple ConvNN architecture.
The focus lies on understanding the key components of the network, such as Conv2d-layers, Max-Pooling and (most importantly) how all the dimensions play out.

We test our architecture on the MNIST-dataset (as if there was any other option).

In [1]:
# load the digits data set
from sklearn import datasets

digits = datasets.load_digits()

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [216]:
class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 3, 2) # [1, 6, 5, 5]
        self.conv2 = nn.Conv2d(3, 6, 1) # [1, 16, 2, 2]
        
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(6 * 3 * 3, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # -> [1, 6, 2, 2]

        x = F.relu(self.conv2(x), (2,2))
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Look at model architecture and test if it works

In [217]:
model = CNN()
print(model)

CNN(
  (conv1): Conv2d(1, 3, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(3, 6, kernel_size=(1, 1), stride=(1, 1))
  (fc1): Linear(in_features=54, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [219]:
image1 = digits.images[0]
image1 = image1.reshape(1,1,8,8)
image1t = torch.from_numpy(image1).float()
model(image1t)

tensor([[-0.0182, -0.2177, -0.4765,  0.2735,  0.0639,  0.2117,  0.0493, -0.2719,
         -0.0214, -0.3175]], grad_fn=<AddmmBackward>)

## Training the Network

In [220]:
import torch.optim as optim

# Simple SGD optimizer with static learning rate, no weight decay, no momentum
optimizer = optim.SGD(model.parameters(), lr=0.05)

# build a loss function
mse = nn.MSELoss()

In [270]:
from rich.progress import track

for epoch in track(range(10), 'Training ...'):
    
    print(f"Finished {epoch=}")
    for idx in range(digits.images.shape[0]):
        
        # transform image to tensor
        image = digits.images[idx]
        image_t = torch.from_numpy(image.reshape(1,1,8,8)).float()

        # get target vector & flip the right vector
        target = torch.zeros((1,10))
        target[0, digits.target[idx]] = 1
        
        # calculate prediction
        optimizer.zero_grad()
        prediction = model(image_t)

        # calculate loss
        loss = mse(prediction, target)
        loss.backward()
        optimizer.step()

## Batch training

Because training in batches is such an important concept in PyTorch (e.g. all layers and modules accept the batch-size as first dimension) I have included an example of how to train in batches.

When executed, one can see just how much faster this goes through. Notice how we are going through 100 epochs faster this way than we were going trough 10 the cell before.

In [272]:
from rich.progress import track

targets = torch.zeros(digits.images.shape[0], 10)
for idx in range(targets.shape[0]):
    targets[idx, digits.target[idx]] = 1

for epoch in track(range(100), 'Training ...'):
    
    # batch training
    images_t = torch.from_numpy(digits.images).float()
    images_t = images_t.view(1797, 1, 8, 8)
    
    # predict for all images at once
    optimizer.zero_grad()
    predictions = model(images_t)
    
    losses = mse(predictions, targets)
    losses.backward()
    optimizer.step()

## Final Notes

There are things missing here that will be included in other documents, such as:
- DataLoaders
- Cross Validation
- Validation Plots
- Better Optimizer
- Hyperparameter Tuning
- ...

Thanks for reading!