In [1]:
import torch, matplotlib.pyplot as plt, math, gzip, pickle
from torch import tensor
from pathlib import Path
from fastcore.test import test_close

In [2]:
data_path = Path('data')/'mnist.pkl.gz'
with gzip.open(data_path, 'rb') as f:
    data = pickle.load(f, encoding='latin-1')
((x_train, y_train), (x_val, y_val), _) = data
(x_train, y_train, x_val, y_val) = map(tensor, (x_train, y_train, x_val, y_val))
x_train.shape, y_train.shape, x_val.shape, y_val.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [4]:
n,m = x_train.shape
c = y_train.max()+1
nh = 50

In [5]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [6]:
def lin(x, w, b):
    return x@w+b

In [7]:
t = lin(x_val, w1, b1)
t.shape

torch.Size([10000, 50])

In [8]:
def relu(x):
    return x.clamp_min(0.)

In [9]:
t = relu(t)
t

tensor([[ 9.1507,  0.0000,  0.0000,  ...,  2.9785,  0.0000, 13.6512],
        [ 0.8934,  0.0000,  1.9529,  ...,  0.9084,  0.0000,  2.8889],
        [ 0.0000,  1.7810, 10.7273,  ...,  6.8667,  9.9371,  4.4541],
        ...,
        [ 0.0000,  0.0000,  0.5849,  ..., 20.3842, 11.5470,  0.0000],
        [ 0.0000,  0.0000,  3.8437,  ...,  0.0000, 12.7863,  0.0000],
        [ 0.4631,  0.0000, 18.1102,  ..., 14.2793,  3.7057,  1.5324]])

In [13]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l1 = relu(l1)
    return lin(l1, w2, b2)

In [14]:
res = model(x_val)
res.shape

torch.Size([10000, 1])

In [15]:
(res-y_val).shape

torch.Size([10000, 10000])

In [16]:
(res[:,0]-y_val).shape

torch.Size([10000])

In [17]:
(res.squeeze()-y_val).shape

torch.Size([10000])

In [18]:
y_train,y_val = y_train.float(),y_val.float()
preds = model(x_train)
preds.shape

torch.Size([50000, 1])

In [20]:
def mse(outp, targ):
    return (outp[:,0]-targ).pow(2).mean()

In [21]:
mse(preds, y_train)

tensor(6286.4141)

In [24]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [25]:
def forward_and_backward(inp, targ):
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    diff = out[:,0]-targ
    loss = diff.pow(2).mean()
    
    out.g = 2.*diff[:,None]/inp.shape[0]
    lin_grad(l2, out, w2, b2)
    l1.g = (l1>0).float() * l2.g
    lin_grad(inp, l1, w1, b1)

In [26]:
%time forward_and_backward(x_train, y_train)

In [28]:
def get_grad(x):
    return x.g.clone()
chks = w1,w2,b1,b2,x_train
grads = w1g,w2g,b1g,b2g,ig = tuple(map(get_grad, chks))

In [29]:
def mkgrad(x):
    return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mkgrad, chks))

In [30]:
def forward(inp, targ):
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    out = lin(l2, w22, b22)
    return mse(out, targ)

In [31]:
loss = forward(xt2, y_train)
loss.backward()

In [34]:
for a, b in zip(grads, ptgrads):
    print(torch.allclose(a, b.grad, rtol=0.01))
    test_close(a, b.grad, eps=0.01)

True
True
True
True
True


In [17]:
from sympy import symbols,diff

In [18]:
x,y = symbols('x y')

In [19]:
diff(x**2, x)

2*x

In [20]:
diff(x**2)

2*x

In [21]:
diff(3*x**2+9, x)

6*x

In [22]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [213]:
def lin_grad(inp, outp, w, b):
    inp.g = outp.g@w.t()
    w.g = inp.t()@outp.g
    b.g = outp.g.sum(0)  
    return (w,b,inp)

def lin_gradv2(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)
    return (w,b,inp)    

In [215]:
def lin_gradv3(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp1[:,:,None] * out.g[:,None]).sum(0)
    b.g = out.g.sum(0)
    return (w,b,inp)    

In [216]:
sz = 1000*10*2
inp1 = x_train[:sz]
test_outp = torch.randn((sz, 50))
test_outp.g = test_outp.clone()

In [218]:
%time o1 = lin_grad(inp1.clone(), test_outp, w1.clone(), b1.clone())

CPU times: user 219 ms, sys: 163 ms, total: 382 ms
Wall time: 98.5 ms


In [219]:
%time o2 = lin_gradv2(inp1.clone(), test_outp, w1.clone(), b1.clone())

CPU times: user 11.2 s, sys: 3min 28s, total: 3min 40s
Wall time: 1min 28s


In [221]:
%time o4 = lin_gradv3(inp1.clone(), test_outp, w1.clone(), b1.clone())

CPU times: user 9.67 s, sys: 2min 49s, total: 2min 59s
Wall time: 1min 13s


In [226]:
for a,b in zip(o1, o2):
    print(torch.allclose(a.g, b.g, rtol=0.01))
    test_close(a.g, b.g, eps=0.01)

False
True
True


In [229]:
for a,b in zip(o1, o4):
    print(torch.allclose(a.g, b.g, rtol=0.01))
    test_close(a.g, b.g, eps=0.01)

False
True
True


In [111]:
o1[0].shape, o2[0].shape

(torch.Size([784, 50]), torch.Size([784, 50]))

In [125]:
o1[0].g[100]

tensor([ 10.8480, -39.3720,  30.6547, -17.2523,  12.8699,  18.2000, -28.5093,
          3.7808, -56.2174, -54.4469, -36.9470, -45.6147,  15.0857, -11.0759,
         -3.8133,  27.7800,  -8.7611,  44.8471, -55.6126,   4.0851, -33.8279,
         18.1730, -58.0049,  16.0285,  25.0503,  29.9208,  -1.4544, -51.3259,
          9.1065,  22.0640,   4.1045,   1.7488, -17.3718,  24.7957, -19.8788,
        -14.1662,  19.8739,  -2.7411, -58.3739,  70.6327,  20.9583,  16.4796,
         24.4096, -67.1682,  14.2389, -27.2982,  -3.5598,  -5.9210, -57.4400,
        -29.1226])

In [126]:
o2[0].g[100]

tensor([ 10.8480, -39.3720,  30.6547, -17.2522,  12.8699,  18.2000, -28.5093,
          3.7808, -56.2174, -54.4469, -36.9470, -45.6147,  15.0857, -11.0759,
         -3.8133,  27.7800,  -8.7610,  44.8471, -55.6126,   4.0851, -33.8279,
         18.1730, -58.0049,  16.0285,  25.0503,  29.9208,  -1.4544, -51.3259,
          9.1065,  22.0640,   4.1045,   1.7488, -17.3718,  24.7957, -19.8788,
        -14.1662,  19.8739,  -2.7411, -58.3739,  70.6327,  20.9583,  16.4796,
         24.4096, -67.1682,  14.2389, -27.2982,  -3.5598,  -5.9210, -57.4399,
        -29.1226])

In [156]:
test_close(o1[0].g[100], o2[0].g[100], eps=0.001)

In [162]:
for i, ele in enumerate(o1[0].g[100]):
    print(ele.item() == o2[0].g[100][i].item())
    print(ele.item(), o2[0].g[100][i].item())
    break

False
10.848006248474121 10.848001480102539


In [166]:
inp1.unsqueeze(-1).stride()

(784, 1, 1)

In [168]:
inp1.unsqueeze(-1).shape

torch.Size([20000, 784, 1])

In [176]:
inp1[:,:,None].shape

torch.Size([20000, 784, 1])

In [107]:
inp1.shape, inp1.unsqueeze(-1).shape

(torch.Size([20000, 784]), torch.Size([20000, 784, 1]))

In [110]:
test_outp.g.shape,test_outp.g.unsqueeze(1).shape,

(torch.Size([20000, 50]), torch.Size([20000, 1, 50]))

In [163]:
inp1.unsqueeze(-1).shape, test_outp.g.unsqueeze(1).shape,

(torch.Size([20000, 784, 1]), torch.Size([20000, 1, 50]))

In [23]:
def lin_grad(inp, outp, w, b):
    inp.g = outp.g@w.t()
    w.g = inp.t()@outp.g
    b.g = outp.g.sum(0)  

def lin_gradv2(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)   

In [24]:
def forward_and_backward(inp, targ):
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    outp = lin(l1, w2, b2)
    diff = outp[:,0]-targ
    loss = diff.pow(2).mean()
    
    outp.g = 2*diff[:,None]/inp.shape[0]
#     lin_grad(l1, outp, w2, b2)
    lin_gradv2(l1, outp, w2, b2)
    l1.g = (l1 > 0).float()*l1.g
#     lin_grad(inp, l1, w1, b1)    
    lin_gradv2(inp, l1, w1, b1)

In [25]:
%time forward_and_backward(x_train, y_train)

CPU times: user 32.6 s, sys: 10min 18s, total: 10min 50s
Wall time: 4min 7s


In [33]:
def get_grad(x):
    return x.g.clone()
chks = w1,w2,b1,b2,x_train
grad = tuple(map(get_grad, chks))

In [34]:
def mk_grad(x):
    return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mk_grad, chks))

In [35]:
def forward(inp, targ):
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    outp = lin(l2, w22, b22)
    return mse(outp, targ)

In [36]:
loss = forward(xt2, y_train)
loss.backward()

In [37]:
for a,b in zip(grad, ptgrads):
    print(torch.allclose(a, b, rtol=0.01))

False
False
False
False
False


In [38]:
from fastcore.test import test_close

In [39]:
for a,b in zip(grad, ptgrads):
    print(test_close(a, b.grad, eps=0.01))

AssertionError: close:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])