# Imports

In [1]:
import torch.nn as nn
from torch.nn import init
from torch.nn import Linear
import torch.nn.functional as F

In [2]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# The forward and backward passes

In [3]:
#export
from exp.nb_01 import *


def get_data():
    '''Download MNIST dataset and return training and validation tensors.'''
    path = datasets.download_data(MNIST_URL, ext='.gz')

    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid),
         _) = pickle.load(f, encoding='latin-1')

    return map(tensor, (x_train, y_train, x_valid, y_valid))


def normalize(x, mean, std):
    '''Normalize x by first centering it and then scale it.'''
    return (x - mean) / std

In [4]:
# Get MNIST data
x_train, y_train, x_valid, y_valid = get_data()
x_train.shape, x_valid.shape

(torch.Size([50000, 784]), torch.Size([10000, 784]))

In [5]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.1304), tensor(0.3073))

In [6]:
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [7]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.0001), tensor(1.))

Now the images have a mean of 0 and a standard deviation of 1.

In [8]:
#export
def test_near_zero(a, tol=1e-3):
    assert a.abs() < tol, f'Near zero: {a}'

In [9]:
test_near_zero(x_train.mean())
test_near_zero(1 - x_train.std())

In [10]:
n, m = x_train.shape
c = y_train.max() + 1
n, m, c

(50000, 784, tensor(10))

We have 10 classes.

In [11]:
# Number of hidden layers
nh = 50

In [12]:
# Initialize weights and biases
w1 = torch.randn((m, nh)) / math.sqrt(m)
b1 = torch.zeros(nh)
w2 = torch.randn((nh, 1)) / math.sqrt(nh)
b2 = torch.zeros(1)

w1.shape, b1.shape, w2.shape, b2.shape

(torch.Size([784, 50]), torch.Size([50]), torch.Size([50, 1]), torch.Size([1]))

In [13]:
test_near_zero(w1.mean())
test_near_zero(w1.std() - 1 / math.sqrt(m))

In [14]:
# Affine transformation
def lin(x, w, b):
    return x @ w + b

In [15]:
# Nonlinear activation function RELU
def relu(x):
    return x.clamp_min(0.)

In [16]:
x_train.mean(), x_valid.std()

(tensor(0.0001), tensor(0.9924))

In [17]:
t = lin(x_train, w1, b1)
t.mean(), t.std()

(tensor(0.0500), tensor(0.9977))

In [18]:
t = relu(lin(x_train, w1, b1))
t.mean(), t.std()

(tensor(0.4170), tensor(0.6064))

In [19]:
# Kaiming/He initialization
w1 = torch.randn((m, nh)) * math.sqrt(2 / m)
w1.mean(), w1.std()

(tensor(-0.0003), tensor(0.0504))

In [20]:
t = relu(lin(x_train, w1, b1))
t.mean(), t.std()

(tensor(0.5096), tensor(0.7652))

In [21]:
w1 = torch.zeros((m, nh))
init.kaiming_normal_(w1, mode='fan_out')
t = relu(lin(x_train, w1, b1))
t.mean(), t.std()

(tensor(0.5056), tensor(0.7994))

In [22]:
Linear(m, nh).weight.shape

torch.Size([50, 784])

In [23]:
def relu(x):
    return x.clamp_min(0.) - 0.5

In [24]:
w1 = torch.randn((m, nh)) * math.sqrt(2 / m)
t = relu(lin(x_train, w1, b1))
t.mean(), t.std()

(tensor(0.0457), tensor(0.8223))

In [25]:
def model(x, w1, b1, w2, b2):
    l1 = relu(lin(x, w1, b1))
    l2 = lin(l1, w2, b2)
    return l2

In [26]:
def mse(pred, target):
    return (pred.squeeze(-1) - target.float()).pow(2).mean()

In [27]:
preds = model(x_train, w1, b1, w2, b2)

In [28]:
mse(preds, y_train)

tensor(31.0276)

In [29]:
def mse_grad(pred, target):
    pred.g = 2 * (pred.squeeze(-1) - target.float()).unsqueeze(-1) / target.shape[0]

In [30]:
def relu_grad(inp, out):
    inp.g = (inp > 0).float() * out.g

In [31]:
def lin_grad(inp, out, w, b):
    w.g = inp.t() @ out.g
    b.g = out.g.sum(dim=0)
    assert w.g.shape == w.shape
    assert b.g.shape == b.shape
    inp.g = out.g @ w.t()

In [32]:
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = inp @ w1 + b1
    l2 = relu(l1)
    out = l2 @ w2 + b2
    # we don't actually need the loss in backward!
    loss = mse(out, targ)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)

In [33]:
forward_and_backward(x_train, y_train)

In [34]:
# Save for testing against later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

We cheat a little bit and use PyTorch autograd to check our results.

In [35]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [36]:
def forward(inp, targ):
    # forward pass:
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)

In [37]:
loss = forward(xt2, y_train)

In [38]:
loss.backward()

In [39]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

Now we are sure that our implementation of both forward and backward passes are correct after comparing it with PyTorch.

In [40]:
class Relu:
    def __call__(self, x):
        self.inp = x
        self.out = x.clamp_min(0.) - 0.5
        return self.out

    def backward(self):
        self.inp.g = (self.inp > 0).float() * self.out.g

In [41]:
class MSE:
    def __call__(self, preds, target):
        self.preds = preds
        self.target = target
        return (preds.squeeze() - target.float()).pow(2).mean()

    def backward(self):
        self.preds.g = 2 * (self.preds.squeeze() - self.target.float()
                            ).unsqueeze(1) / self.target.shape[0]

In [42]:
class Linear:
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def __call__(self, inp):
        self.inp = inp
        self.out = self.inp @ self.w + self.b
        return self.out

    def backward(self):
        self.w.g = self.inp.t() @ self.out.g
        self.b.g = self.out.g.sum(dim=0)
        self.inp.g = self.out.g @ self.w.t()

In [43]:
class Model:
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Linear(w1, b1), Relu(), Linear(w2, b2)]
        self.loss = MSE()

    def __call__(self, x, target):
        for l in self.layers:
            x = l(x)
        return self.loss(x, target)

    def backward(self):
        self.loss.backward()
        for l in self.layers[::-1]:
            l.backward()

In [44]:
w1.g, b1.g, w2.g, b2.g = [None] * 4
model = Model(w1, b1, w2, b2)

In [45]:
%time loss = model(x_train, y_train)

CPU times: user 160 ms, sys: 5.22 ms, total: 166 ms
Wall time: 42.7 ms


In [46]:
%time model.backward()

CPU times: user 455 ms, sys: 52.9 ms, total: 508 ms
Wall time: 127 ms


In [47]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

`Module` is almost the same as `nn.Module` from PyTorch. When you call it, it calls the `forward` method for that layer.

In [48]:
class Module:
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*self.args)
        return self.out

    def forward(self):
        raise NotImplementedError()

    def backward(self):
        self.bwd(self.out, *self.args)

In [49]:
class Linear(Module):
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def forward(self, inp):
        return inp @ self.w + self.b

    def bwd(self, out, inp):
        self.w.g = inp.t() @ out.g
        self.b.g = out.g.sum(dim=0)
        inp.g = out.g @ self.w.t()

In [50]:
class Relu(Module):
    def forward(self, inp):
        return inp.clamp_min(0.) - 0.5

    def bwd(self, out, inp):
        inp.g = (inp > 0).float() * out.g

In [51]:
class Mse(Module):
    def forward(self, inp, target):
        return (inp.squeeze() - target.float()).pow(2).mean()

    def bwd(self, out, inp, target):
        inp.g = 2 * (inp.squeeze() -
                     target.float()).unsqueeze(1) / target.shape[0]

In [52]:
class Model:
    def __init__(self):
        self.layers = [Linear(w1, b1), Relu(), Linear(w2, b2)]
        self.loss = Mse()

    def __call__(self, x, target):
        for l in self.layers:
            x = l(x)
        return self.loss(x, target)

    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [53]:
w1.g, b1.g, w2.g, b2.g = [None] * 4
model = Model()

In [54]:
%time loss = model(x_train, y_train)

CPU times: user 117 ms, sys: 1.89 ms, total: 119 ms
Wall time: 30.7 ms


In [55]:
%time model.backward()

CPU times: user 444 ms, sys: 49.7 ms, total: 494 ms
Wall time: 123 ms


In [56]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

In [57]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, n_out)]
        self.loss = mse

    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x.squeeze(), targ)

In [58]:
model = Model(m, nh, 1)

In [59]:
%time loss = model(x_train, y_train)

CPU times: user 85.1 ms, sys: 1.03 ms, total: 86.1 ms
Wall time: 22.2 ms


In [60]:
%time loss.backward()

CPU times: user 268 ms, sys: 4.31 ms, total: 272 ms
Wall time: 40.3 ms


In [61]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)