In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import math

In [3]:
import torch 
from torchvision import transforms
from torchvision.datasets import MNIST

In [4]:
dataset = MNIST('/workspace/data/', download=True, transform=transforms.ToTensor())
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspace/data/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
dataset.data.shape

torch.Size([60000, 28, 28])

In [6]:
n_train = 50_000
n_valid = dataset.data.shape[0] - n_train
x_train, y_train = dataset.data[:n_train, :, :].view(n_train, -1) / 255, dataset.targets[:n_train]
x_valid, y_valid = dataset.data[n_train:, :, :].view(n_valid, -1) / 255, dataset.targets[n_train:]

In [7]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.1310), tensor(0.3085))

In [8]:
def normalize(x, m, s):
    return (x - m) / s

In [9]:
x_train = normalize(x_train, train_mean, train_std)
# NOTE: use training, not validation mean and std for validation set
x_valid = normalize(x_valid, train_mean, train_std)

In [10]:
x_train.mean(), x_train.std()

(tensor(2.1126e-08), tensor(1.))

In [11]:
def test_near_zero(x, tol=1e-3):
    assert x.abs() < tol

In [12]:
test_near_zero(x_train.mean())
test_near_zero(x_train.std() - 1)

In [13]:
n, m = x_train.shape
c = y_train.max() + 1
n, m, c

(50000, 784, tensor(10))

## Basic architecture

In [14]:
# num hidden units
nh = 50

In [15]:
# random init (0, 1)
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [16]:
# this should roughly be (0, 1)
x_valid.mean(), x_valid.std()

(tensor(-0.0059), tensor(0.9924))

In [17]:
def lin(x, w, b):
    return (x @ w) + b

In [18]:
t = lin(x_valid, w1, b1)
t.mean(), t.std()

(tensor(1.9039), tensor(27.9104))

This is a pretty terrible result, as it will lead to exploding gradients after just a few layers. Let's use Xavier init to make mean and std of the output activation closer to (0, 1).

In [19]:
# simplified xavier init
w1 = torch.randn(m, nh) / math.sqrt(m)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1) / math.sqrt(nh)
b2 = torch.zeros(1)

In [20]:
test_near_zero(w1.mean())
test_near_zero(w1.std() - 1 / math.sqrt(m))

In [21]:
# ... so should this, since we used kaiming he init, which is designed to do this
t = lin(x_valid, w1, b1)
t.mean(), t.std()

(tensor(-0.0703), tensor(1.0330))

This is promising, however, it doesn't take into account the non-linearity activation. Modern networks use ReLu, Swish, Mish, etc. Xavier init doesn't work well with such non-linearities.

In [22]:
def relu(x: torch.Tensor):
    return x.clamp_min(0.)

In [23]:
t = relu(lin(x_valid, w1, b1))

In [24]:
t.mean(), t.std()

(tensor(0.3750), tensor(0.5805))

As you can notice, the output is not centered at 0 and the std is far from 1.

In [25]:
# kaiming he init for relu
# Delving Deep into Rectifiers: Surpassing Human-Level Performance on 
#   ImageNet Classification (https://arxiv.org/abs/1502.01852)
w1 = torch.randn(m, nh) * math.sqrt(2/m)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1) * math.sqrt(2/nh)
b2 = torch.zeros(1)

In [26]:
w1.mean(), w1.std()

(tensor(-8.1577e-05), tensor(0.0506))

In [27]:
t = relu(lin(x_valid, w1, b1))
t.mean(), t.std()

(tensor(0.6055), tensor(0.8364))

This gives us a much better standard deviation ― closer to 1. The mean is not close to zero, but that is intentional. The ReLu activation removed every value below 0, thus the mean cannot be zero. Something closer to 0.5 is now expected.

In [28]:
from torch.nn import init

In [29]:
w1 = torch.randn(m, nh)
init.kaiming_normal_(w1, mode='fan_out')
t = relu(lin(x_valid, w1, b1))

In [30]:
w1.mean(), w1.std()

(tensor(7.3505e-05), tensor(0.0506))

In [31]:
t.mean(), t.std()

(tensor(0.5569), tensor(0.8265))

What if we change the definition of ReLu to also subtract 0.5, to bring the mean back to 0...

In [32]:
# what if...
def relu(x):
    return x.clamp_min(0.) - 0.5

In [33]:
# kaiming-he init for relu
w1 = torch.randn(m, nh) * math.sqrt(2./m)
t1 = relu(lin(x_valid, w1, b1))
t1.mean(), t.std()

(tensor(0.0906), tensor(0.8265))

The mean is now closer to 0 and the standard deviation is more stable and closer to 0.8 ― not perfect, but it's better.

## Forward pass

In [34]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [35]:
y_pred = model(x_valid)
y_pred

tensor([[ 0.5270],
        [-1.2309],
        [-0.4345],
        ...,
        [-0.4670],
        [ 0.4274],
        [ 1.5621]])

In [36]:
assert y_pred.shape == torch.Size([x_valid.shape[0], 1])

## Loss function: MSE

Of course, MSE is not a suitable loss function for multi-class classification; we will use a better loss function soon. For now, let's use MSE to keep things simple.

In [37]:
y_pred.shape

torch.Size([10000, 1])

In [38]:
def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()

In [39]:
mse(y_pred, y_valid)

tensor(26.3738)

## Gradients and backward pass

In [40]:
def mse_grad(inp, targ):
    # grad of loss function w.r.t. output of previous layer
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]

In [41]:
def relu_grad(inp, out):
    # grad of ReLu w.r.t. input activations
    inp.g = (inp>0).float() * out.g

In [42]:
def lin_grad(inp, out, w, b):
    # grad of matmul w.r.t. input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [43]:
def forward_and_backward(inp, targ):
    # forward pass
    l1 = inp @ w1 + b1
    l2 = relu(l1)
    out = l2 @ w2 + b2
    # we don't actually need the loss in the backward pass!
    loss = mse(out, targ)
    
    # backward pass
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)

In [44]:
forward_and_backward(x_train, y_train)

In [45]:
# save for testing later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

Let's check the results against PyTorch.

In [46]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [47]:
def forward(inp, targ):
    # forward pass
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)

In [48]:
loss = forward(xt2, y_train)

In [49]:
loss.backward()

In [50]:
def test_near(a: torch.tensor, b:torch.tensor):
    return torch.allclose(a, b, rtol=1e-3, atol=1e-5)

In [51]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

True

## Refactor model

In [52]:
class Relu():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.) - 0.5
        return self.out
    
    def backward(self):
        self.inp.g = (self.inp > 0).float() * self.out.g

In [53]:
class Lin():
    def __init__(self, w, b):
        self.w = w
        self.b = b
        
    def __call__(self, inp):
        self.inp = inp
        self.out = inp @ self.w + self.b
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g @ self.w.t()
        # creating a giant outer product, just to sum it, is inefficient!
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)

In [54]:
class Mse:
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        self.out = (inp.squeeze(-1) - targ).pow(2).mean()
        return self.out
    
    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) / self.inp.shape[0]

In [55]:
class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [56]:
w1.g, b1.g, w2.g, b2.g = [None] * 4
model = Model(w1, b1, w2, b2)

In [57]:
%time loss = model(x_train, y_train)

CPU times: user 81.9 ms, sys: 0 ns, total: 81.9 ms
Wall time: 16.8 ms


In [58]:
%time model.backward()

CPU times: user 4.85 s, sys: 1.9 s, total: 6.76 s
Wall time: 1.29 s


In [59]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

True

In [60]:
class Module():
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
    
    def forward(self):
        raise NotImplemented()
    
    def backward(self):
        self.bwd(self.out, *self.args)

In [61]:
class Relu(Module):
    def forward(self, inp):
        return inp.clamp_min(0.) - 0.5
    
    def bwd(self, out, inp):
        inp.g = (inp > 0).float() * out.g

In [62]:
class Lin(Module):
    def __init__(self, w, b):
        self.w = w
        self.b = b
        
    def forward(self, inp):
        return inp @ self.w + self.b
        
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        # creating a giant outer product, just to sum it, is inefficient!
        self.w.g = torch.einsum('bi,bj->ij', inp, out.g)
        self.b.g = out.g.sum(0)

In [63]:
class Mse(Module):
    def forward(self, inp, targ):
        return (inp.squeeze(-1) - targ).pow(2).mean()
        
    def bwd(self, out, inp, targ):
        inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]

In [64]:
class Model():
    def __init__(self):
        self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [65]:
w1.g, b1.g, w2.g, b2.g = [None] * 4
model = Model()

In [66]:
%time loss = model(x_train, y_train)

CPU times: user 83.5 ms, sys: 0 ns, total: 83.5 ms
Wall time: 17.1 ms


In [67]:
%time model.backward()

CPU times: user 169 ms, sys: 67 ms, total: 236 ms
Wall time: 42.5 ms


In [68]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

True

## Without Einsum

In [69]:
class Lin(Module):
    def __init__(self, w, b):
        self.w = w
        self.b = b
        
    def forward(self, inp):
        return inp @ self.w + self.b
        
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        # creating a giant outer product, just to sum it, is inefficient!
        self.w.g = inp.t() @ out.g
        self.b.g = out.g.sum(0)

In [70]:
w1.g, b1.g, w2.g, b2.g = [None] * 4
model = Model()

In [71]:
%time loss = model(x_train, y_train)

CPU times: user 91.4 ms, sys: 2.21 ms, total: 93.6 ms
Wall time: 16.8 ms


In [72]:
%time model.backward()

CPU times: user 181 ms, sys: 68 ms, total: 249 ms
Wall time: 43 ms


In [73]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

True

## nn.Linear and nn.Module

In [74]:
from torch import nn

In [75]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [
            nn.Linear(n_in, nh),
            nn.ReLU(),
            nn.Linear(nh, n_out)]
        self.loss = mse
        
    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x.squeeze(), targ)

In [76]:
model = Model(m, nh, 1)

In [77]:
%time loss = model(x_train, y_train)

CPU times: user 104 ms, sys: 0 ns, total: 104 ms
Wall time: 19.9 ms


In [78]:
%time loss.backward()

CPU times: user 136 ms, sys: 0 ns, total: 136 ms
Wall time: 22.9 ms
