In [1]:
import torch, matplotlib.pyplot as plt, math, gzip, pickle
from torch import tensor
from pathlib import Path
from fastcore.test import test_close

In [2]:
data_path = Path('data')/'mnist.pkl.gz'
with gzip.open(data_path, 'rb') as f:
    data = pickle.load(f, encoding='latin-1')
((x_train, y_train), (x_val, y_val), _) = data
(x_train, y_train, x_val, y_val) = map(tensor, (x_train, y_train, x_val, y_val))
x_train.shape, y_train.shape, x_val.shape, y_val.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [3]:
n,m = x_train.shape
c = y_train.max()+1
nh = 50

In [6]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [7]:
def lin(x, w, b):
    return x@w+b

In [8]:
t = lin(x_val, w1, b1)
t.shape

torch.Size([10000, 50])

In [9]:
def relu(x):
    return x.clamp_min(0.)

In [10]:
t = relu(t)
t

tensor([[ 9.0405, 12.9419,  2.5688,  ...,  0.0000, 10.7064,  0.0000],
        [ 0.0000,  0.1317,  3.3706,  ...,  0.0000,  5.4220,  0.0000],
        [ 0.0000,  1.8912, 10.8913,  ..., 12.1424,  0.1268,  0.0000],
        ...,
        [12.4659,  0.9316,  0.0000,  ...,  3.4006,  1.9964,  0.0000],
        [ 1.6960,  2.2690,  4.2956,  ...,  2.7973,  7.6176,  0.0000],
        [ 1.4008,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [11]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l1 = relu(l1)
    return lin(l1, w2, b2)

In [12]:
res = model(x_val)
res.shape

torch.Size([10000, 1])

In [13]:
(res-y_val).shape

torch.Size([10000, 10000])

In [14]:
(res[:,0]-y_val).shape

torch.Size([10000])

In [15]:
(res.squeeze()-y_val).shape

torch.Size([10000])

In [16]:
y_train,y_val = y_train.float(),y_val.float()
preds = model(x_train)
preds.shape

torch.Size([50000, 1])

In [17]:
def mse(outp, targ):
    return (outp[:,0]-targ).pow(2).mean()

In [18]:
mse(preds, y_train)

tensor(2053.6414)

In [19]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [20]:
def lin_grad_v2(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = inp.t() @ out.g
    b.g = out.g.sum(0)

In [21]:
def forward_and_backward_v1(inp, targ):
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    diff = out[:,0]-targ
    loss = diff.pow(2).mean()
    
    out.g = 2*diff[:,None]/inp.shape[0]
    lin_grad(l2, out, w2, b2)
    l1.g = (l1>0).float() * l2.g
    lin_grad(inp, l1, w1, b1)

In [22]:
def forward_and_backward_v2(inp, targ):
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    diff = out[:,0]-targ
    loss = diff.pow(2).mean()
    
    out.g = 2*diff[:,None]/inp.shape[0]
    lin_grad_v2(l2, out, w2, b2)
    l1.g = (l1>0).float() * l2.g
    lin_grad_v2(inp, l1, w1, b1)

In [23]:
# %time forward_and_backward_v1(x_train, y_train)

In [24]:
%time forward_and_backward_v2(x_train, y_train)

CPU times: user 870 ms, sys: 230 ms, total: 1.1 s
Wall time: 197 ms


In [25]:
def get_grad(x):
    return x.g.clone()
chks = w1,w2,b1,b2,x_train
grads = w1g,w2g,b1g,b2g,ig = tuple(map(get_grad, chks))

In [26]:
def mkgrad(x):
    return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mkgrad, chks))

In [27]:
def forward(inp, targ):
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    out = lin(l2, w22, b22)
    return mse(out, targ)

In [28]:
loss = forward(xt2, y_train)
loss.backward()

In [29]:
for a, b in zip(grads, ptgrads):
    print(torch.allclose(a, b.grad, rtol=0.01))
    test_close(a, b.grad, eps=0.01)

True
True
True
True
True


In [30]:
class Relu:
    def __call__(self, inp):
        self.inp = inp
        self.out = self.inp.clamp_min(0.)
        return self.out
    
    def backward(self):
        self.inp.g = (self.inp > 0.).float()*self.out.g

In [31]:
class Lin:
    def __init__(self, w, b):
        self.w,self.b = w,b
    
    def __call__(self, inp):
        self.inp = inp
        self.out = lin(self.inp, self.w, self.b)
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g @ self.w.t()
        self.w.g = self.inp.t() @ self.out.g
        self.b.g = self.out.g.sum(0)

In [32]:
class Mse:
    def __call__(self, inp, targ):
        self.inp,self.targ = inp,targ
        self.out = mse(self.inp, self.targ)
        return self.out
    
    def backward(self):
        self.inp.g = 2.*(self.inp[:,0]-self.targ).unsqueeze(-1)/self.inp.shape[0]

In [33]:
class Model:
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
        self.loss = Mse()
    
    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [34]:
model = Model(w1, b1, w2, b2)

In [35]:
loss = model(x_train, y_train)
loss

tensor(2053.6414)

In [36]:
model.backward()

In [42]:
test_close(w12.grad, w1.g, eps=0.01)
test_close(w22.grad, w2.g, eps=0.01)
test_close(b12.grad, b1.g, eps=0.01)
test_close(b22.grad, b2.g, eps=0.01)

In [66]:
class Module:
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
    
    def forward(self):
        raise Exception('not implemented')
    
    def backward(self):
        self.bwd(self.out, *self.args)
    
    def bwd(self):
        raise Exception('not implemented')

In [67]:
class Relu(Module):
    def forward(self, inp):
        return inp.clamp_min(0.)
    
    def bwd(self, out, inp):
        inp.g = (inp > 0).float() * out.g

In [68]:
class Lin(Module):
    def __init__(self, w, b):
        self.w,self.b = w,b
        
    def forward(self, inp):
        return lin(inp, self.w, self.b)
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = inp.t() @ out.g
        self.b = out.g.sum(0)

In [69]:
class Mse(Module):
    def forward(self, inp, targ):
        return (inp[:,0]-targ).pow(2).mean()
    
    def bwd(self, out, inp, targ):
        inp.g = 2*(inp[:,0]-targ).unsqueeze(-1)/inp.shape[0]

In [70]:
model = Model(w1, b1, w2, b2)

In [71]:
loss = model(x_train, y_train)

In [72]:
model.backward()

In [73]:
test_close(w12.grad, w1.g, eps=0.01)
test_close(w22.grad, w2.g, eps=0.01)
test_close(b12.grad, b1.g, eps=0.01)
test_close(b22.grad, b2.g, eps=0.01)

In [74]:
from torch import nn
import torch.nn.functional as F

In [85]:
class Linear(nn.Module):
    def __init__(self, n_inp, n_out):
        super().__init__()
        self.w = torch.randn(n_inp, n_out).requires_grad_()
        self.b = torch.zeros(n_out).requires_grad_()
    
    def forward(self, x):
        return x@self.w+self.b

In [86]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [Linear(n_in, nh), nn.ReLU(), Linear(nh,n_out)]
        
    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return F.mse_loss(x, targ[:,None])

In [87]:
model = Model(m, nh, 1)
loss = model(x_train, y_train)
loss.backward()

In [89]:
m0 = model.layers[0]

In [93]:
m0.b.grad

tensor([-5.9069e-02, -4.7124e+00, -3.2190e+00, -6.5474e+01,  2.7367e+01,
        -1.2808e+01,  2.7148e-01, -1.3550e+00,  1.3779e+00,  1.4171e+01,
        -4.6234e+01, -2.2148e+01,  1.7744e+01, -2.9474e+01,  4.9907e+01,
         7.2142e+00, -6.2156e+00, -1.3886e+01,  7.5250e+01,  1.1279e+02,
         2.0816e+01,  4.5769e+00, -8.3816e+00,  1.0587e+01, -1.1703e+01,
         1.9713e-01, -2.9629e+01,  7.8672e-01,  1.0509e+01,  1.0844e+02,
         2.5222e+01,  1.2354e+01,  8.1031e+00,  7.5347e+00, -3.2608e+00,
         1.0320e+01,  4.6792e+01,  1.5885e+02,  1.4045e+01, -4.4256e+01,
        -4.2168e+01, -7.9100e+00, -4.9124e+01, -4.8970e+01,  6.2878e+00,
         8.1540e+00, -1.1929e+01,  4.5437e+01, -2.0906e+01,  1.3813e-02])