# CNNs (Convolutional Neural Networks) — Motivation → Details → MNIST Implementation

Lecture flow:
1. **Why CNNs?** (motivation)
2. **Convolution layer** (filters, channels, feature maps)
3. **Stride & padding** + output size examples
4. **Pooling**
5. **Train a CNN on MNIST**


## 1) Motivation: why CNN instead of a fully-connected MLP for images?

Images have **spatial structure**:
- Nearby pixels are related (edges, corners, textures).
- The same pattern can appear anywhere (translation).

A fully-connected layer ignores locality and becomes huge:
- A $28\times 28$ grayscale image has $784$ inputs.
- A hidden layer of size $512$ would need $784\cdot 512 \approx 401{,}408$ weights (plus bias).

CNNs use three ideas:
1. **Local connectivity**: each neuron looks at a small window (kernel).
2. **Weight sharing**: the same kernel slides across the image → far fewer parameters.
3. **Translation equivariance**: shifting the input shifts the feature map.

Parameter example:
- One $3\times 3$ kernel on a 1-channel image has $3\cdot 3\cdot 1 = 9$ weights (+ bias),
  no matter if the image is $28\times 28$ or $256\times 256$.


## 2) What a convolution layer does (structure)

A 2D convolution layer takes input of shape:

$$
(N, C_{in}, H_{in}, W_{in})
$$

and produces output of shape:

$$
(N, C_{out}, H_{out}, W_{out})
$$

- $C_{in}$ = input channels (MNIST has $1$ channel)
- $C_{out}$ = number of filters (feature maps)
- Each filter has shape $(C_{in}, k_H, k_W)$ and slides across the image.

In PyTorch, `nn.Conv2d` exposes `kernel_size`, `stride`, `padding`, `dilation`, `groups`


## 3) Stride & padding (output size)

PyTorch output-size formula for `Conv2d`

$$
H_{out} = \left\lfloor \frac{H_{in} + 2p_H - d_H (k_H - 1) - 1}{s_H} + 1 \right\rfloor
$$

$$
W_{out} = \left\lfloor \frac{W_{in} + 2p_W - d_W (k_W - 1) - 1}{s_W} + 1 \right\rfloor
$$

Common special case (no dilation, $d=1$):

$$
H_{out} = \left\lfloor \frac{H_{in} + 2p - k}{s} + 1 \right\rfloor
$$

`padding="same"` keeps output shape equal to input shape, but it only supports stride $=1$ in PyTorch.


In [1]:
# Helper to compute conv output size (same formula as the docs)
import math

def conv2d_out(h_in, w_in, k=3, s=1, p=0, d=1):
    h_out = math.floor((h_in + 2*p - d*(k-1) - 1)/s + 1)
    w_out = math.floor((w_in + 2*p - d*(k-1) - 1)/s + 1)
    return h_out, w_out

print("MNIST 28x28, k=3")
print("  p=0, s=1:", conv2d_out(28, 28, k=3, s=1, p=0))  # (26,26)
print("  p=1, s=1:", conv2d_out(28, 28, k=3, s=1, p=1))  # (28,28)
print("  p=1, s=2:", conv2d_out(28, 28, k=3, s=2, p=1))  # (14,14)


MNIST 28x28, k=3
  p=0, s=1: (26, 26)
  p=1, s=1: (28, 28)
  p=1, s=2: (14, 14)


## 4) Pooling (downsampling)

Pooling reduces spatial size (H, W) and computation.
Max pooling is common: it keeps the strongest activation in each window.

PyTorch `nn.MaxPool2d` output shape follows a similar formula.
Why use pooling?
- smaller feature maps → faster + fewer parameters later
- some robustness to small shifts



## 6) A simple CNN for MNIST (dimension tracking)

- Input: $(N, 1, 28, 28)$
- Conv(1→32, k=3, p=1)

- MaxPool(2×2, s=2)

- Conv(32→64, k=3, p=1)

- MaxPool(2×2, s=2)

- Flatten

- Linear → 10 logits


In [11]:
# construct the model
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
          nn.Conv2d(1,32,3,padding=1),
          nn.MaxPool2d(kernel_size=2, stride = 2),
          nn.ReLU(),
          nn.Conv2d(32,64,kernel_size=3, padding=1),
          nn.MaxPool2d(kernel_size = 2, stride = 2),
          nn.ReLU()

        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*7*7, 64),
            nn.ReLU(),
            nn.Linear(64,10)

        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = SimpleCNN().to(device)
print(model)


SimpleCNN(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU()
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=3136, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=10, bias=True)
  )
)


In [12]:
# load the training data and train
import pandas as pd
from torch.utils.data import Dataset, DataLoader

class MNISTCSV(Dataset):
    def __init__(self, path):
        data = pd.read_csv(path)

        self.y = torch.tensor(data.iloc[:, 0].values, dtype=torch.long)
        self.X = torch.tensor(data.iloc[:, 1:].values, dtype=torch.float32)

        # reshape to the cnn form
        self.X = self.X.view(-1, 1, 28, 28) / 255.0

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = MNISTCSV("/content/sample_data/mnist_train_small.csv")

train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 5

for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        # forward
        outputs = model(x)
        loss = criterion(outputs, y)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # accuracy
        _, predicted = torch.max(outputs, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Loss: {total_loss/len(train_loader):.4f} "
          f"Accuracy: {100*correct/total:.2f}%")


Epoch [1/5] Loss: 0.3999 Accuracy: 87.44%
Epoch [2/5] Loss: 0.0917 Accuracy: 97.17%
Epoch [3/5] Loss: 0.0617 Accuracy: 98.07%
Epoch [4/5] Loss: 0.0463 Accuracy: 98.50%
Epoch [5/5] Loss: 0.0346 Accuracy: 98.87%
