# CNNs (Convolutional Neural Networks) — Motivation → Details → MNIST Implementation

Lecture flow:
1. **Why CNNs?** (motivation)
2. **Convolution layer** (filters, channels, feature maps)
3. **Stride & padding** + output size examples
4. **Pooling**
5. **Train a CNN on MNIST**


## 1) Motivation: why CNN instead of a fully-connected MLP for images?

Images have **spatial structure**:
- Nearby pixels are related (edges, corners, textures).
- The same pattern can appear anywhere (translation).

A fully-connected layer ignores locality and becomes huge:
- A $28\times 28$ grayscale image has $784$ inputs.
- A hidden layer of size $512$ would need $784\cdot 512 \approx 401{,}408$ weights (plus bias).

CNNs use three ideas:
1. **Local connectivity**: each neuron looks at a small window (kernel).
2. **Weight sharing**: the same kernel slides across the image → far fewer parameters.
3. **Translation equivariance**: shifting the input shifts the feature map.

Parameter example:
- One $3\times 3$ kernel on a 1-channel image has $3\cdot 3\cdot 1 = 9$ weights (+ bias),
  no matter if the image is $28\times 28$ or $256\times 256$.


## 2) What a convolution layer does (structure)

A 2D convolution layer takes input of shape:

$$
(N, C_{in}, H_{in}, W_{in})
$$

and produces output of shape:

$$
(N, C_{out}, H_{out}, W_{out})
$$

- $C_{in}$ = input channels (MNIST has $1$ channel)
- $C_{out}$ = number of filters (feature maps)
- Each filter has shape $(C_{in}, k_H, k_W)$ and slides across the image.

In PyTorch, `nn.Conv2d` exposes `kernel_size`, `stride`, `padding`, `dilation`, `groups`


## 3) Stride & padding (output size)

PyTorch output-size formula for `Conv2d`

$$
H_{out} = \left\lfloor \frac{H_{in} + 2p_H - d_H (k_H - 1) - 1}{s_H} + 1 \right\rfloor
$$

$$
W_{out} = \left\lfloor \frac{W_{in} + 2p_W - d_W (k_W - 1) - 1}{s_W} + 1 \right\rfloor
$$

Common special case (no dilation, $d=1$):

$$
H_{out} = \left\lfloor \frac{H_{in} + 2p - k}{s} + 1 \right\rfloor
$$

`padding="same"` keeps output shape equal to input shape, but it only supports stride $=1$ in PyTorch.


In [None]:
# Helper to compute conv output size (same formula as the docs)
import math

def conv2d_out(h_in, w_in, k=3, s=1, p=0, d=1):
    h_out = math.floor((h_in + 2*p - d*(k-1) - 1)/s + 1)
    w_out = math.floor((w_in + 2*p - d*(k-1) - 1)/s + 1)
    return h_out, w_out

print("MNIST 28x28, k=3")
print("  p=0, s=1:", conv2d_out(28, 28, k=3, s=1, p=0))  # (26,26)
print("  p=1, s=1:", conv2d_out(28, 28, k=3, s=1, p=1))  # (28,28)
print("  p=1, s=2:", conv2d_out(28, 28, k=3, s=2, p=1))  # (14,14)


## 4) Pooling (downsampling)

Pooling reduces spatial size (H, W) and computation.
Max pooling is common: it keeps the strongest activation in each window.

PyTorch `nn.MaxPool2d` output shape follows a similar formula.
Why use pooling?
- smaller feature maps → faster + fewer parameters later
- some robustness to small shifts



## 6) A simple CNN for MNIST (dimension tracking)

- Input: $(N, 1, 28, 28)$
- Conv(1→32, k=3, p=1)

- MaxPool(2×2, s=2)

- Conv(32→64, k=3, p=1)

- MaxPool(2×2, s=2)

- Flatten

- Linear → 10 logits


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
          
        )
        self.classifier = nn.Sequential(
         
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = SimpleCNN().to(device)
print(model)
