In [1]:
import math
import time
import random
import matplotlib.pyplot as plt
import os
import warnings
import dataclasses as dc
import typing as tp
warnings.filterwarnings("ignore") # W0901 12:56:55.922000 133054240231424 torch/fx/experimental/symbolic_shapes.py:4449] [0/1] xindex is not in var_ranges, defaulting to unknown range.

import torch
from torch import Tensor, nn, _assert
from torch.nn import functional as F
import torch.utils.data
from torchvision import datasets, transforms
DEVICE = torch.device("cpu")

# DataLoading

In [2]:
class DataLoader:
    def __init__(self, train:bool):
        transform = transforms.Compose([
            transforms.ToTensor(), # (H, W, C)/(H, W) -> (C, H, W) AND [0, 255] -> [0.0, 1.0]
        ])
        if train:
            self.ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        else:
            self.ds = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        
    def iter_batches(self, batch_size):
        while True:
            self.dataset = torch.utils.data.DataLoader(
                dataset=self.ds,
                batch_size=batch_size,
                shuffle=True,
                pin_memory=True,
                drop_last=True
            )
            for X_batch, y_batch in self.dataset:
                yield X_batch.to(DEVICE), y_batch.to(DEVICE)

# LeNet

In [3]:
FAN_IN:tp.TypeAlias = int
FAN_OUT:tp.TypeAlias = int
KERNEL_SIZE:tp.TypeAlias = tuple[int, int]
STRIDES:tp.TypeAlias = tuple[int, int]

## Initialize Weights

In [4]:
@dc.dataclass
class Wei:
    wei:Tensor
    grad:tp.Optional[Tensor] = None

In [5]:
def init_weights(
    kernel1:tuple[FAN_OUT, FAN_IN, KERNEL_SIZE] = (6, 1, (5, 5)),
    kernel2:tuple[FAN_OUT, FAN_IN, KERNEL_SIZE] = (16, 6, (5, 5)),
    kernel3:tuple[FAN_OUT, FAN_IN, KERNEL_SIZE] = (120, 16, (5, 5)),
    weight4:tuple[FAN_OUT, FAN_IN] = (84, 120),
    weight5:tuple[FAN_OUT, FAN_IN] = (10, 84)
):
    sqrt2 = 2**0.5
    prod = lambda x: x[0]*x[1]

    # First Convolutional Layer: Stride 1
    bound = sqrt2/prod(kernel1[-1])
    W1 = Wei(torch.empty(size=(kernel1[0], kernel1[1], *kernel1[-1]), requires_grad=True).uniform_(-bound, bound))
    B1 = Wei(torch.zeros(size=kernel1[0], requires_grad=True))

    # Max Pooling Layer: Stride 2, Kernel 2x2, No Weights

    # Second Convolutional Layer: Stride 1
    bound = sqrt2/prod(kernel2[-1])
    W2 = Wei(torch.empty(size=(kernel2[0], kernel2[1], *kernel2[-1]), requires_grad=True).uniform_(-bound, bound))
    B2 = Wei(torch.zeros(size=kernel2[0], requires_grad=True))

    # Max Pooling Layer: Stride 2, Kernel 2x2, No Weights

    # Third Convolutional Layer
    bound = sqrt2/prod(kernel3[-1])
    W3 = Wei(torch.empty(size=(kernel3[0], kernel3[1], *kernel3[-1]), requires_grad=True).uniform_(-bound, bound))
    B3 = Wei(torch.zeros(size=kernel3[0], requires_grad=True))

    # First Linear Layer
    bound = sqrt2/prod(weight4[-1])
    W4 = Wei(torch.empty(size=weight4, requires_grad=True).uniform_(-bound, bound))
    B4 = Wei(torch.zeros(size=weight4[0], requires_grad=True))

    # Second Linear Layer
    bound = sqrt2/prod(weight5[-1])
    W5 = Wei(torch.empty(size=weight5, requires_grad=True).uniform_(-bound, bound))
    B5 = Wei(torch.zeros(size=weight5[0], requires_grad=True))

    return W1, B1, W2, B2, W3, B3, W4, B4, W5, B5

## Conv2d

In [6]:
def _conv2d(
    x:Tensor, # (H, W)
    w:Tensor, # (h, w)
    full:bool=False,
    convolve:bool=False
):
    x = x[None, None, ...] # (1, 1, H, W)
    w = w[None, None, ...] # (1, 1, h, w)
    if full:
        pad_h = w.size(-2) - 1
        pad_w = w.size(-1) - 1
        x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode='constant')
    if convolve:
        w = w.flip((2, 3))
    return F.conv2d(
        input=x, # (1, 1, H, W)
        weight=w, # (1, 1, h, w)
        stride=1
    )[0, 0, :, :]

```python
from scipy.signal import correlate2d, convolve2d
x = torch.randn(size=(2, 1, 28, 28), requires_grad=True)
w = torch.randn(size=(6, 1, 5, 5), requires_grad=True)
b = torch.randn(size=(6,), requires_grad=True)

# "full" testing
ful = correlate2d(x[0, 0].detach().numpy(), w[0, 0].detach().numpy(), mode='full')
myful = _conv2d(x[0, 0].detach(), w[0, 0].detach(), full=True)
diff_full = (ful-myful.numpy())
diff_full.mean(), diff_full.std(), abs(diff_full).max(), abs(diff_full).min()

# "convolve" testing
my_convolve = _conv2d(x[0, 0].detach(), w[0, 0].detach(), convolve=True, full=True)
convolve = convolve2d(x[0, 0].detach().numpy(), w[0, 0].detach().numpy(), mode='full')
diff_convolve = (convolve-my_convolve.numpy())
diff_convolve.mean(), diff_convolve.std(), abs(diff_convolve).max(), abs(diff_convolve).min()
```

In [7]:
def conv2d_forward(
    x:Tensor,               # (B, fi, H, W) 
    wie:Tensor,               # (fo, fi, h, w)
    bias:tp.Optional[Tensor],  # (fo,)
    stride:STRIDES = (1, 1)
):
    B, C, H, W = x.shape
    fo, fi, h, w = wie.shape
    sh, sw = stride
    assert C == fi, f"Expected {C} == {fi}"
    assert H >= h, f"Expected {H} >= {h}"
    assert W >= w, f"Expected {W} >= {w}"
    assert bias.shape[0] == fo, f"Expected {bias.shape[0]} == {fo}"

    output_shape = (B, fo, int((H-h)//sh + 1), int((W-w)//sw + 1))
    O = torch.zeros(output_shape) # (B, C_out, H1, W1)
    for fan_out in range(fo):
        for fan_in in range(fi):
            for bdim in range(B):
                O[bdim, fan_out] += _conv2d(x[bdim, fan_in], wie[fan_out, fan_in])[::sh, ::sw]

    if bias is not None:
        O += bias.view(1, -1, 1, 1)
    return O

In [8]:
def _dilate_matrix(x:Tensor, dilation:tuple[int, int]):
    """`x: shape(B, C, H, W)`\n `dilation:tuple`"""
    (B, C, H, W), (Hd, Wd)  = x.shape, dilation
    dilated = torch.zeros((B, C, Hd*(H-1)+1, Wd*(W-1)+1 ))
    dilated[:, :, ::Hd, ::Wd] = x
    return dilated

In [9]:
def conv2d_backward(
    x:Tensor,                   # (B, fi, H, W)
    wei:Tensor,                 # (fo, fi, h, w)
    bias:tp.Optional[Tensor],   # (fo,)
    dL_dO:Tensor,               # (B, fo, H1, W1)
    stride:STRIDES = (1, 1)
):  
    fo, fi, h, w = wei.shape
    B, C, H, W = x.shape

    dL_dx, dL_dwei = torch.zeros_like(x), torch.zeros_like(wei)
    dL_dO = _dilate_matrix(dL_dO, dilation=stride)
    for fan_out in range(fo): # C_out
        for bdim in range(B): # B
            for fan_in in range(fi): # C_in
                dL_dwei[fan_out, fan_in] += _conv2d(x[bdim, fan_in], dL_dO[bdim, fan_out]) # (H, W)*(H1, W1) => (Hk, Wk)
                dL_dx[bdim, fan_in] += _conv2d(dL_dO[bdim, fan_out], wei[fan_out, fan_in], full=True, convolve=True) # (H1, W1)*(Hk, Wk) => (H, W)
    
    dL_db = dL_dO.sum(dim=(0, 2, 3)) if bias is not None else None
    return dL_dx, dL_dwei, dL_db

### Test conv2d forward and backward methods

In [10]:
# x = torch.randn(size=(2, 1, 28, 28), requires_grad=True)
# w = torch.randn(size=(6, 1, 5, 5), requires_grad=True)
# b = torch.randn(size=(6,), requires_grad=True)

# my_conv2d = conv2d_forward(x, w, b)
# my_conv2d.retain_grad()

# torch_conv2d = F.conv2d(x, w, b, stride=1)
# torch.testing.assert_close(my_conv2d.detach(), torch_conv2d)

# loss = my_conv2d.mean()
# loss.backward()

# torch_dL_dx, torch_dL_dw, torch_dL_db = x.grad, w.grad, b.grad
# dL_dO = my_conv2d.grad.clone()

# my_dL_dx, my_dL_dw, my_dL_db = conv2d_backward(x, w, b, dL_dO)
# torch.testing.assert_close(my_dL_dx, torch_dL_dx)
# torch.testing.assert_close(my_dL_dw, torch_dL_dw)
# torch.testing.assert_close(my_dL_db, torch_dL_db)

## MaxPool2d

In [11]:
def _maxpool(matrix:Tensor, kernel_size:KERNEL_SIZE, strides:STRIDES): # (H, W)
        (H, W), (Hk, Wk), (Hs, Ws) = matrix.shape, kernel_size, strides
        output_shape = ((H-Hk+1)//Hs + 1, (W-Wk+1)//Ws + 1)
        indices, maxpooled = [], []
        for i in range(0, H - Hk + 1, Hs):
            for j in range(0, W - Wk + 1, Ws):
                window = matrix[i:i+Hk, j:j+Wk]
                max_index = torch.unravel_index(torch.argmax(window), window.shape)
                max_index_global = (max_index[0] + i, max_index[1] + j)
                indices.append(max_index_global)
                maxpooled.append(window[max_index])
        maxpooled = torch.tensor(maxpooled).reshape(output_shape)
        indices = torch.tensor(indices) # (H1*W1, 2)
        # (H1, W1), ((H1, W1), (H1, W1))
        return maxpooled, (indices[:, 0].reshape(output_shape), indices[:, 1].reshape(output_shape))

channeled_maxpool = torch.vmap(_maxpool, in_dims=(0, None, None), out_dims=0) # (C, H, W)
vmaxpool = torch.vmap(channeled_maxpool, in_dims=(0, None, None), out_dims=0) # (B, C, H, W)

In [12]:
def channeled_maxpool(matrix:Tensor, kernel_size:KERNEL_SIZE, strides:STRIDES): # (C, H, W)
    (C, H, W) = matrix.shape
    maxpooled, Rindices, Cindices = [], [], []
    for c in range(C):
        maxpooled_, (indices_r, indices_c) = _maxpool(matrix[c], kernel_size, strides)
        maxpooled.append(maxpooled_)
        Rindices.append(indices_r)
        Cindices.append(indices_c)
    return torch.stack(maxpooled), (torch.stack(Rindices), torch.stack(Cindices))

def vmaxpool(matrix:Tensor, kernel_size:KERNEL_SIZE, strides:STRIDES): # (B, C, H, W)
    (B, C, H, W) = matrix.shape
    maxpooled, Rindices, Cindices = [], [], []
    for b in range(B):
        maxpooled_b, (indices_r, indices_c) = channeled_maxpool(matrix[b], kernel_size, strides)
        maxpooled.append(maxpooled_b)
        Rindices.append(indices_r)
        Cindices.append(indices_c)
    return torch.stack(maxpooled), (torch.stack(Rindices), torch.stack(Cindices))

In [13]:
def maxpool2d_forward(
    x:Tensor,
    kernel_size:KERNEL_SIZE,
    strides:STRIDES,
) -> tuple[Tensor, tuple[Tensor, Tensor], torch.Size]:
    O, (ridx, cidx) = vmaxpool(x, kernel_size, strides)
    return O, (ridx, cidx), x.shape

In [14]:
def maxpool_backward(
    dL_dO:Tensor,
    x_shape:torch.Size,
    indices:tuple[Tensor, Tensor],
):
    """SOMETHING IS WRONG"""
    (ridx, cidx) = indices

    dL_dY = torch.zeros(x_shape) # (B, C, H, W)
    dL_dY[:, :, ridx, cidx] += dL_dO

    dY_dX = torch.zeros(x_shape) # (B, C, H, W)
    dY_dX[:, :, ridx, cidx] += 1.0

    dL_dX = dL_dY*dY_dX # (B, C, H, W)
    return dL_dX # (B, C, H, W)

### Test maxpool2d forward and backward methods

In [15]:
x = torch.randn(size=(2, 1, 28, 28), requires_grad=True)
x.requires_grad = True

torch_maxpool, torch_idx = F.max_pool2d_with_indices(x, kernel_size=(2, 2), stride=(2, 2))
torch_maxpool.retain_grad()

In [16]:
torch_loss = torch_maxpool.mean()
torch_loss.backward()

torch_dL_dX, torch_dL_dO = x.grad, torch_maxpool.grad.clone()
assert not any([torch_dL_dX is None, torch_dL_dO is None])

In [17]:
my_maxpool, indices, x_shape = maxpool2d_forward(x.clone(), kernel_size=(2, 2), strides=(2, 2))
torch.testing.assert_close(my_maxpool, torch_maxpool)
my_dL_dx = maxpool_backward(torch_dL_dO, x_shape, indices)

diff = torch_dL_dX - my_dL_dx
diff.mean(), diff.std(), diff.abs().max(), diff.abs().min()

(tensor(-0.0005), tensor(0.0010), tensor(0.0026), tensor(0.))

In [18]:
torch_idx

tensor([[[[ 29,   3,   5,  34,   9,  38,  41,  43,  45,  18,  21,  22,  52,  54],
          [ 85,  58,  89,  91,  64,  67,  97,  98,  72, 102,  76,  79,  81,  83],
          [141, 143, 145, 119, 120, 151, 125, 126, 129, 159, 161, 162, 165, 138],
          [169, 171, 200, 175, 205, 206, 209, 211, 184, 215, 216, 219, 220, 195],
          [252, 254, 228, 230, 232, 262, 265, 239, 241, 270, 273, 274, 277, 251],
          [280, 283, 284, 315, 316, 319, 292, 322, 296, 326, 300, 330, 333, 306],
          [364, 339, 369, 371, 373, 347, 376, 378, 380, 383, 385, 359, 388, 391],
          [392, 422, 397, 426, 429, 403, 405, 407, 436, 411, 412, 443, 444, 418],
          [449, 479, 480, 454, 457, 486, 488, 491, 493, 495, 468, 470, 500, 502],
          [532, 535, 508, 539, 513, 515, 544, 546, 549, 550, 553, 555, 528, 530],
          [561, 591, 564, 594, 596, 570, 601, 602, 604, 606, 609, 611, 585, 615],
          [616, 618, 648, 651, 652, 655, 657, 659, 661, 662, 636, 638, 640, 670],
          [673, 

In [19]:
indices

(tensor([[[[ 1,  0,  0,  1,  0,  1,  1,  1,  1,  0,  0,  0,  1,  1],
           [ 3,  2,  3,  3,  2,  2,  3,  3,  2,  3,  2,  2,  2,  2],
           [ 5,  5,  5,  4,  4,  5,  4,  4,  4,  5,  5,  5,  5,  4],
           [ 6,  6,  7,  6,  7,  7,  7,  7,  6,  7,  7,  7,  7,  6],
           [ 9,  9,  8,  8,  8,  9,  9,  8,  8,  9,  9,  9,  9,  8],
           [10, 10, 10, 11, 11, 11, 10, 11, 10, 11, 10, 11, 11, 10],
           [13, 12, 13, 13, 13, 12, 13, 13, 13, 13, 13, 12, 13, 13],
           [14, 15, 14, 15, 15, 14, 14, 14, 15, 14, 14, 15, 15, 14],
           [16, 17, 17, 16, 16, 17, 17, 17, 17, 17, 16, 16, 17, 17],
           [19, 19, 18, 19, 18, 18, 19, 19, 19, 19, 19, 19, 18, 18],
           [20, 21, 20, 21, 21, 20, 21, 21, 21, 21, 21, 21, 20, 21],
           [22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 22, 22, 22, 23],
           [24, 25, 25, 25, 24, 24, 25, 24, 25, 24, 24, 24, 25, 25],
           [26, 27, 27, 27, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27]]],
 
 
         [[[ 0,  1,  0,  0, 

In [20]:
torch_dL_dX, torch_dL_dX.shape

(tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0026, 0.0000],
           [0.0000, 0.0000, 0.0026,  ..., 0.0026, 0.0000, 0.0026],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0026, 0.0000, 0.0026],
           [0.0026, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0026]]],
 
 
         [[[0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0000, 0.0026],
           [0.0000, 0.0000, 0.0026,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0026, 0.0000],
           [0.0000, 0.0000, 0.0026,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0026, 0.0000]]]]),
 torch.Size([2, 1, 28, 28]))

In [21]:
my_dL_dx, my_dL_dx.shape

(tensor([[[[0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0000, 0.0026],
           [0.0000, 0.0026, 0.0026,  ..., 0.0000, 0.0026, 0.0000],
           [0.0000, 0.0000, 0.0026,  ..., 0.0026, 0.0000, 0.0026],
           ...,
           [0.0000, 0.0026, 0.0000,  ..., 0.0026, 0.0026, 0.0026],
           [0.0026, 0.0000, 0.0026,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0026, 0.0026]]],
 
 
         [[[0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0000, 0.0026],
           [0.0000, 0.0026, 0.0026,  ..., 0.0000, 0.0026, 0.0000],
           [0.0000, 0.0000, 0.0026,  ..., 0.0026, 0.0000, 0.0026],
           ...,
           [0.0000, 0.0026, 0.0000,  ..., 0.0026, 0.0026, 0.0026],
           [0.0026, 0.0000, 0.0026,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0026, 0.0000,  ..., 0.0000, 0.0026, 0.0026]]]]),
 torch.Size([2, 1, 28, 28]))

In [24]:
(my_dL_dx == torch_dL_dX).sum(), torch_dL_dX.numel()

(tensor(1280), 1568)

In [20]:
x

tensor([[[[-0.7207, -0.7193, -0.4922,  ...,  0.4852,  0.5819, -0.0300],
          [ 0.9409,  0.4120,  0.6129,  ...,  0.7283, -0.8180,  1.4888],
          [ 0.0620,  0.1983,  0.0675,  ..., -0.8164, -2.1203,  0.5834],
          ...,
          [-0.4766,  0.1509, -1.9746,  ..., -1.1059, -0.4576, -0.5482],
          [ 0.1515,  0.2605,  1.2307,  ...,  0.4600,  0.2387, -0.8263],
          [-0.2699,  0.1249, -1.4637,  ...,  1.7832, -0.1992, -0.9448]]],


        [[[-0.8269, -0.6305,  0.8517,  ...,  0.1559,  0.4775,  0.2800],
          [ 0.3817,  1.0025,  0.3259,  ..., -0.0707, -1.1159,  0.0803],
          [ 0.3949,  1.8837, -0.5904,  ...,  1.0609, -0.0464,  0.5574],
          ...,
          [ 1.2029, -0.4394, -0.2655,  ..., -0.4808, -0.0988,  0.9696],
          [-1.8645,  0.0060, -0.3333,  ..., -0.6156, -0.9566,  0.2655],
          [-1.9624,  0.2632, -0.4373,  ...,  0.5423, -0.6441, -0.5565]]]],
       requires_grad=True)

In [17]:
my_loss, torch_loss = my_maxpool.mean(), torch_maxpool.mean()
my_loss.backward()

torch_dL_dX, torch_dL_dO = x.grad, my_maxpool.grad.clone()
my_dL_dX = maxpool_backward(torch_dL_dO, x_shape, indices)

## Linear

In [None]:
def linear_forward(
    x:Tensor,      # (B, fi)
    wie:Tensor,    # (fo, fi)
    bias:Tensor,   # (fo,)
):
    return x @ wie.T + bias.unsqueeze(0) 

def linear_backward(
    x:Tensor,       # (B, fi)
    wie:Tensor,     # (fo, fi)
    bias:Tensor,    # (fo,)
    dL_dO:Tensor,   # (B, fo)
):
    dL_dx = dL_dO @ wie      # (B, fi) <= (B, fo) @ (fo, fi)
    dL_dwie = dL_dO.T @ x    # (fo, fi) <= (B, fo).T @ (B, fi)
    dL_db = dL_dO.sum(dim=0) # (fo,) <= (B, fo)
    return dL_dx, dL_dwie, dL_db

## Reshape

In [None]:
def reshape_forward(x:Tensor, shape:tuple): # (B, C, H, W)
    return x.reshape(shape), x.shape        # (B, C*H*W)

def reshape_backward(dL_dO:Tensor, x_shape:torch.Size): # (B, C*H*W)
    return dL_dO.reshape(x_shape)                       # (B, C, H, W)

## ReLU

In [23]:
def relu_forward(x:Tensor):
    return torch.maximum(x, torch.tensor(0))

def relu_backward(relu:Tensor, dL_dO:Tensor):
    dO_dx = relu * dL_dO
    dL_dx = dL_dO * dO_dx
    return dL_dx

## SoftMax

In [None]:
def softmax_forward(logits:Tensor):
    max_val, idx = logits.max(-1, keepdim=True) 
    logits -= max_val
    exp = torch.exp(logits)
    proba = exp / exp.sum(-1, keepdim=True)
    return proba

def softmax_backward(probs:Tensor, dL_dprobs:Tensor):
    nc = probs.shape[-1]
    t1 = torch.einsum("ij,ik->ijk", probs, probs) # (B, nc, nc)
    t2 = torch.einsum("ij,jk->ijk", probs, torch.eye(nc, nc)) # (B, nc, nc)
    dprobs_dlogits = t2 - t1 # (B, nc, nc)

    dL_dlogits = (dL_dprobs[:, None, :] @ dprobs_dlogits)[:, 0, :] # ((B, 1, nc) @ (B, nc, nc))[:, 0, :]
    return dL_dlogits # (B, nc)

## Cross Entropy

In [None]:
def cross_entropy_forward(y_true:Tensor, y_proba:Tensor):
    log_probs = torch.log(y_proba) # (B, nc)
    loss = -log_probs[torch.arange(len(y_true)), y_true].mean()
    return loss

def cross_entropy_backward(y_true:Tensor, y_proba:Tensor):
    B = len(y_true)
    dL_dlogprobas = torch.zeros_like(y_proba) # (B, nc)
    dL_dlogprobas[torch.arange(B), y_true] = -1/B

    dlogprobas_dprobas = 1/y_proba # (B, nc)

    dL_dprobas = dL_dlogprobas * dlogprobas_dprobas # (B, nc)
    return dL_dprobas # (B, nc)

## LeNet Module

In [None]:
class Module:
    def forward(self, *args):
        raise NotImplementedError
    def backward(self, *args):
        raise NotImplementedError

In [27]:
class LeNet(Module):
    def __init__(self):
        super().__init__()
        self.parameters = init_weights(
            (6, 1, (5, 5)),
            (16, 6, (5, 5)),
            (120, 16, (5, 5)),
            (84, 120),
            (10, 84)
        )
        (
            self.W1, self.B1, # Convolutional Layer 1
            self.W2, self.B2, # Convolutional Layer 2
            self.W3, self.B3, # Convolutional Layer 3
            self.W4, self.B4, # Linear Layer 1
            self.W5, self.B5  # Linear Layer 2
        ) = self.parameters

    def forward(self, x:Tensor): # (B, 1, 32, 32)
        self.x = x

        z1 = conv2d_forward(self.x, self.W1.wei, self.B1.wei) # (B, 6, 28, 28)
        self.h1 = relu_forward(z1) # (B, 6, 28, 28)

        self.hp1, (self.ridx1, self.cidx1), self.h1_shape = maxpool2d_forward(self.h1, (2, 2), (2, 2)) # (B, 6, 14, 14)

        z2 = conv2d_forward(self.hp1, self.W2.wei, self.B2.wei) # (B, 16, 10, 10) ##################
        self.h2 = relu_forward(z2) # (B, 16, 10, 10)

        self.hp2, (self.ridx2, self.cidx2), self.h2_shape = maxpool2d_forward(self.h2, (2, 2), (2, 2)) # (B, 16, 5, 5) 

        self.z3 = conv2d_forward(self.hp2, self.W3.wei, self.B3.wei) # (B, 120, 1, 1) 
        self.h3 = relu_forward(self.z3) # (B, 120, 1, 1) 

        h3_reshaped, self.h3_shape = reshape_forward(self.h3, (-1, self.h3.size(1))) # (B, 120, 1, 1) -> (B, 120) 

        self.z4 = linear_forward(h3_reshaped, self.W4.wei, self.B4.wei) # (B, 84)
        self.h4 = relu_forward(self.z4) # (B, 84)

        self.z5 = linear_forward(self.h4, self.W5.wei, self.B5.wei) # (B, 10)
        self.y_proba = softmax_forward(self.z5) # (B, 10)

        return self.y_proba
    
    def backward(self, y_true:Tensor):
        dL_dprobs = cross_entropy_backward(y_true, self.y_proba) # (B, 10)

        dL_dz5 = softmax_backward(self.y_proba, dL_dprobs) # (B, 10)
        dL_dh4, dL_dW5, dL_dB5 = linear_backward(self.z5, self.W5.wei, self.B5.wei, dL_dz5) # (B, 84)
        self.W5.grad, self.B5.grad = dL_dW5, dL_dB5

        dL_dz4 = relu_backward(self.h4, dL_dh4) # (B, 84)
        dL_h3reshaped, dL_dW4, dL_dB4  = linear_backward(self.z4, self.W4.wei, self.B4.wei, dL_dz4) # (B, 120)
        self.W4.grad, self.B4.grad = dL_dW4, dL_dB4

        dL_dh3 = reshape_backward(dL_h3reshaped, self.h3_shape) # (B, 120, 1, 1)

        dL_dz3 = relu_backward(self.h3, dL_dh3) # (B, 120, 1, 1)
        dL_dhp2, dL_dW3, dL_dB3 = conv2d_backward(self.hp2, self.W3.wei, self.B3.wei, dL_dz3) # (B, 16, 5, 5)
        self.W3.grad, self.B3.grad = dL_dW3, dL_dB3

        dL_dh2 = maxpool_backward(dL_dhp2, self.h2_shape, (self.ridx2, self.cidx2)) # (B, 16, 10, 10)

        dL_dz2 = relu_backward(self.h2, dL_dh2) # (B, 16, 10, 10)
        dL_dhp1, dL_dW2, dL_dB2 = conv2d_backward(self.hp1, self.W2.wei, self.B2.wei, dL_dz2) # (B, 6, 14, 14)
        self.W2.grad, self.B2.grad = dL_dW2, dL_dB2

        dL_dh1 = maxpool_backward(dL_dhp1, self.h1_shape, (self.ridx1, self.cidx1)) # (B, 6, 28, 28)

        dL_dz1 = relu_backward(self.h1, dL_dh1) # (B, 6, 28, 28)
        _, dL_dW1, dL_dB1 = conv2d_backward(self.x, self.W1.wei, self.B1.wei, dL_dz1) # (B, 1, 32, 32)
        self.W1.grad, self.B1.grad = dL_dW1, dL_dB1

# Optimizer

In [None]:
class SGD:
    def __init__(self, parameters:list[Wei], lr:float):
        self.parameters = parameters
        self.lr = lr

    def step(self):
        for p in self.parameters:
            p.wei -= self.lr*p.grad

    def zero_grad(self):
        for p in self.parameters:
            p.grad = None

# Training

In [None]:
model = LeNet()
optimizer = SGD(model.parameters, lr=0.01)

In [None]:
# @torch.compile # comment while debugging
def one_train_step(X_batch:Tensor, y_batch:Tensor):
    y_proba = model.forward(X_batch)
    loss = cross_entropy_forward(y_batch, y_proba)
    model.backward(y_batch)
    optimizer.step()
    optimizer.zero_grad()
    return loss

```python
def _conv2d(
    x:Tensor, # (H, W)
    w:Tensor, # (h, w)
    full:bool=False,
    convolve:bool=False
):
    x = x[None, None, ...] # (1, 1, H, W)
    w = w[None, None, ...] # (1, 1, h, w)
    if full:
        pad_h = w.size(-2) - 1
        pad_w = w.size(-1) - 1
        x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode='constant')
    if convolve:
        w = w.flip((2, 3))
    return F.conv2d(
        input=x, # (1, 1, H, W)
        weight=w, # (1, 1, h, w)
        stride=1
    )[0, 0, :, :]

def conv2d(
    inputs:Tensor, # (H, W)
    kernel:Tensor, # (hk, wk)
    stride:tuple[int, int] # (sh, sw)
):
    H, W = inputs.size()
    hk, wk = kernel.size()
    out = torch.zeros(((H-hk+1)//stride[0], (W-wk+1)//stride[1]))
    for i in range(0, H, stride[0]):
        for j in range(0, W, stride[1]):
            portion = inputs[i:i+hk, j:j+wk]
            out[i//stride[0], j//stride[1]] = (portion * kernel).sum()
    return out
```