In [1]:
import torch, matplotlib.pyplot as plt, math, gzip, pickle
from torch import tensor
from pathlib import Path
from fastcore.test import test_close

In [2]:
data_path = Path('data')/'mnist.pkl.gz'
with gzip.open(data_path, 'rb') as f:
    data = pickle.load(f, encoding='latin-1')
((x_train, y_train), (x_val, y_val), _) = data
(x_train, y_train, x_val, y_val) = map(tensor, (x_train, y_train, x_val, y_val))
x_train.shape, y_train.shape, x_val.shape, y_val.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [35]:
n,m = x_train.shape
c = y_train.max()+1
nh = 50

In [36]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [5]:
def lin(x, w, b):
    return x@w+b

In [6]:
t = lin(x_val, w1, b1)
t.shape

torch.Size([10000, 50])

In [7]:
def relu(x):
    return x.clamp_min(0.)

In [8]:
t = relu(t)
t

tensor([[ 4.5253,  0.6355,  3.6919,  ...,  3.1195,  0.0000,  0.2296],
        [ 1.6364,  0.0000,  0.0000,  ...,  8.0429,  0.0000,  0.0000],
        [ 0.0000, 11.7669,  0.0000,  ...,  3.1532,  2.2159,  8.6557],
        ...,
        [ 0.0000,  0.0000,  9.5680,  ...,  0.2190,  0.2477,  0.0000],
        [ 0.0000,  0.0000,  7.9458,  ...,  0.0000,  5.1125,  7.2693],
        [ 0.0000,  5.8269,  4.1611,  ...,  3.2949,  8.5024,  0.0000]])

In [9]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l1 = relu(l1)
    return lin(l1, w2, b2)

In [10]:
res = model(x_val)
res.shape

torch.Size([10000, 1])

In [11]:
(res-y_val).shape

torch.Size([10000, 10000])

In [12]:
(res[:,0]-y_val).shape

torch.Size([10000])

In [13]:
(res.squeeze()-y_val).shape

torch.Size([10000])

In [14]:
y_train,y_val = y_train.float(),y_val.float()
preds = model(x_train)
preds.shape

torch.Size([50000, 1])

In [15]:
def mse(outp, targ):
    return (outp[:,0]-targ).pow(2).mean()

In [16]:
mse(preds, y_train)

tensor(2209.5806)

In [17]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [18]:
def lin_grad_v2(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = inp.t() @ out.g
    b.g = out.g.sum(0)

In [19]:
def forward_and_backward(inp, targ):
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    diff = out[:,0]-targ
    loss = diff.pow(2).mean()
    
    out.g = 2.*diff[:,None]/inp.shape[0]
    lin_grad_v2(l2, out, w2, b2)
    l1.g = (l1>0).float() * l2.g
    lin_grad_v2(inp, l1, w1, b1)

In [20]:
%time forward_and_backward(x_train, y_train)

CPU times: user 775 ms, sys: 229 ms, total: 1 s
Wall time: 181 ms


In [21]:
def get_grad(x):
    return x.g.clone()
chks = w1,w2,b1,b2,x_train
grads = w1g,w2g,b1g,b2g,ig = tuple(map(get_grad, chks))

In [22]:
def mkgrad(x):
    return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mkgrad, chks))

In [23]:
def forward(inp, targ):
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    out = lin(l2, w22, b22)
    return mse(out, targ)

In [24]:
loss = forward(xt2, y_train)
loss.backward()

In [25]:
for a, b in zip(grads, ptgrads):
    print(torch.allclose(a, b.grad, rtol=0.01))
    test_close(a, b.grad, eps=0.01)

True
True
True
True
True


In [37]:
def lin_grad(inp, outp, w, b):
    inp.g = outp.g@w.t()
    w.g = inp.t()@outp.g
    b.g = outp.g.sum(0)  

In [38]:
def forward_and_backward(inp, targ):
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    outp = lin(l2, w2, b2)
    diff = outp[:,0]-targ
    loss = diff.pow(2).mean()
    
    outp.g = 2.*diff[:,None]/inp.shape[0]
    lin_grad(l2, outp, w2, b2)
    l1.g = (l1 > 0).float()*l2.g  
    lin_grad(inp, l1, w1, b1)

In [39]:
%time forward_and_backward(x_train, y_train)

CPU times: user 793 ms, sys: 282 ms, total: 1.08 s
Wall time: 188 ms


In [40]:
def get_grad(x):
    return x.g.clone()
chks = w1,w2,b1,b2,x_train
grad = tuple(map(get_grad, chks))

In [41]:
def mk_grad(x):
    return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mk_grad, chks))

In [42]:
def forward(inp, targ):
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    outp = lin(l2, w22, b22)
    return mse(outp, targ)

In [43]:
loss = forward(xt2, y_train)
loss.backward()

In [45]:
for a,b in zip(grad, ptgrads):
    print(torch.allclose(a, b.grad, rtol=0.01))

True
True
True
True
True


In [46]:
from fastcore.test import test_close

In [47]:
for a,b in zip(grad, ptgrads):
    test_close(a, b.grad, eps=0.01)

None
None
None
None
None
