In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import math

In [3]:
import torch 
from torchvision import transforms
from torchvision.datasets import MNIST

In [4]:
dataset = MNIST('/workspace/data/', download=True, transform=transforms.ToTensor())
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspace/data/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
dataset.data.shape

torch.Size([60000, 28, 28])

In [6]:
n_train = 50_000
n_valid = dataset.data.shape[0] - n_train
x_train, y_train = dataset.data[:n_train, :, :].view(n_train, -1) / 255, dataset.targets[:n_train]
x_valid, y_valid = dataset.data[n_train:, :, :].view(n_valid, -1) / 255, dataset.targets[n_train:]

In [7]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.1310), tensor(0.3085))

In [8]:
def normalize(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
    return (x - mean) / std

In [9]:
x_train = normalize(x_train, train_mean, train_std)
# NOTE: use training, not validation mean and std for validation set
x_valid = normalize(x_valid, train_mean, train_std)

In [10]:
x_train.mean(), x_train.std()

(tensor(2.1126e-08), tensor(1.))

In [11]:
def test_near_zero(x, tol=1e-3):
    return x.abs() < tol

In [12]:
test_near_zero(x_train.mean()), test_near_zero(x_train.std() - 1)

(tensor(True), tensor(True))

In [13]:
n, m = x_train.shape
c = y_train.max() + 1
n, m, c

(50000, 784, tensor(10))

## Basic architecture

In [14]:
# num hidden
nh = 50

In [15]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [16]:
# this should roughly be (0, 1)
x_valid.mean(), x_valid.std()

(tensor(-0.0059), tensor(0.9924))

In [17]:
def linear(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
    return (x @ w) + b

In [18]:
t = linear(x_valid, w1, b1)
t.mean(), t.std()

(tensor(-1.9614), tensor(26.5270))

This is a pretty terrible result. Let's use a simplified Kaiming He init to make mean and std of the output activation closer to (0, 1).

In [19]:
# simplified kaiming he init
w1 = torch.randn(m, nh) / math.sqrt(m)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1) / math.sqrt(nh)
b2 = torch.zeros(1)

In [20]:
test_near_zero(w1.mean()), test_near_zero(w1.std() - 1 / math.sqrt(m))

(tensor(True), tensor(True))

In [21]:
def linear(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
    return (x @ w) + b

In [22]:
# ... and so should this, since we used kaiming he init, which is designed to do this
t = linear(x_valid, w1, b1)
t.mean(), t.std()

(tensor(0.0690), tensor(0.9727))

This is promising, however, it doesn't take into account the non-linearity activation. Modern networks use ReLu, Swish, Mish, etc. The Kaiming He init doesn't work well with such non-linearities.

In [23]:
def relu(x: torch.Tensor):
    return x.clamp_min(0.)

In [24]:
t = relu(linear(x_valid, w1, b1))

In [25]:
t.mean(), t.std()

(tensor(0.4197), tensor(0.5770))

As you can notice, the output is not centered at 0 and with unit standard deviation.

In [26]:
# kaiming he init for relu
# Delving Deep into Rectifiers: Surpassing Human-Level Performance on 
#   ImageNet Classification (https://arxiv.org/abs/1502.01852)
w1 = torch.randn(m, nh) * math.sqrt(2/m)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1) * math.sqrt(2/nh)
b2 = torch.zeros(1)

In [27]:
w1.mean(), w1.std()

(tensor(0.0002), tensor(0.0503))

In [28]:
t = relu(linear(x_valid, w1, b1))
t.mean(), t.std()

(tensor(0.5602), tensor(0.8502))

This gives us a much better standard deviation ― closer to 1. The mean is not close to zero, but that is intentional. The ReLu activation removed every value below 0, thus the mean cannot be zero. Something closer to 0.5 is now expected.

In [29]:
from torch.nn import init

In [30]:
w1 = torch.randn(m, nh)
w1 = init.kaiming_normal_(w1, mode='fan_out')
t = relu(linear(x_valid, w1, b1))
t.mean(), t.std()

(tensor(0.5532), tensor(0.8531))

What if we change the definition of ReLu to also subtract 0.5, to bring the mean back to 0...

In [31]:
# what if...
def relu_new(x: torch.Tensor):
    return x.clamp_min(0.) - 0.5

In [32]:
w1 = torch.randn(m, nh)
w1 = init.kaiming_normal_(w1, mode='fan_out')
t = relu_new(linear(x_valid, w1, b1))
t.mean(), t.std()

(tensor(0.0494), tensor(0.8318))

The mean is now closer to 0 and the standard deviation is more stable and closer to 0.8 ― not perfect, but it's better.

## Forward pass

In [33]:
def model(x):
    l1 = linear(x, w1, b1)
    l2 = relu(l1)
    l3 = linear(l2, w2, b2)
    return l3

In [34]:
y_pred = model(x_valid)
y_pred

tensor([[1.0466],
        [0.8612],
        [2.0450],
        ...,
        [0.3272],
        [1.6357],
        [1.4171]])

In [35]:
assert y_pred.shape == torch.Size([x_valid.shape[0], 1])

## Loss function: MSE

Of course, MSE is not a suitable loss function for multi-class classification; we will use a better loss function soon. For now, let's use MSE to keep things simple.

In [36]:
y_pred.shape

torch.Size([10000, 1])

In [37]:
def mse(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    return (y_pred - y_true).pow(2).mean()

In [38]:
mse(y_pred=y_pred, y_true=y_valid)

tensor(20.7159)

## Gradients and backward pass

In [39]:
def mse_grad(inp, targ):
    # grad of loss function w.r.t. output of previous layer
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]

In [40]:
def relu_grad(inp, out):
    # grad of ReLu w.r.t. input activations
    inp.g = (inp>0).float() * out.g

In [41]:
def lin_grad(inp, out, w, b):
    # grad of matmul w.r.t. input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [42]:
def forward_and_backward(inp, targ):
    l1 = inp @ w1 + b1
    l2 = relu(l1)
    out = l2 @ w2 + b2
    # we don't actually need the loss in the backward pass!
    loss = mse(out, targ)
    
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)

In [43]:
forward_and_backward(x_train, y_train)

In [44]:
# save for testing later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig = x_train.g.clone()

Let's check the results against PyTorch.

In [45]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [46]:
def forward(inp, targ):
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)

In [47]:
loss = forward(xt2, y_train)

In [48]:
loss.backward()

In [58]:
def test_near(a: torch.tensor, b:torch.tensor):
    return torch.allclose(a, b, rtol=1e-3, atol=1e-5)

test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

False

In [56]:
w12.grad[:3], w1g[:3]

(tensor([[-0.0594, -0.3999,  0.1522, -0.3356,  0.2963,  0.8817, -0.1711, -0.0693,
          -0.5706, -0.2508,  0.1139, -0.2435,  0.0273,  0.3508, -0.0809,  0.0195,
          -0.0805,  0.2433,  0.7586, -0.3073, -0.1211,  0.3047,  0.2199,  0.0074,
           0.0019, -0.3475, -0.0736,  0.3731,  0.1078,  0.1344,  0.0013,  0.2045,
           0.2504, -0.1791,  0.0068, -0.1838,  0.2169, -0.0431, -0.1259,  0.0087,
           0.1221,  0.1463, -0.0648,  0.0897,  0.0583, -0.0732, -0.0294,  0.0917,
          -0.2673, -0.1678],
         [-0.0594, -0.3999,  0.1522, -0.3356,  0.2963,  0.8817, -0.1711, -0.0693,
          -0.5706, -0.2508,  0.1139, -0.2435,  0.0273,  0.3508, -0.0809,  0.0195,
          -0.0805,  0.2433,  0.7586, -0.3073, -0.1211,  0.3047,  0.2199,  0.0074,
           0.0019, -0.3475, -0.0736,  0.3731,  0.1078,  0.1344,  0.0013,  0.2045,
           0.2504, -0.1791,  0.0068, -0.1838,  0.2169, -0.0431, -0.1259,  0.0087,
           0.1221,  0.1463, -0.0648,  0.0897,  0.0583, -0.0732, -0.02

Gradients are similar, but not as close as we want them to be. Why is that?