In [1]:
import gzip, pickle, torch
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)

<torch._C.Generator at 0x10fad0bb0>

In [2]:
data = None
with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
    data = pickle.load(f, encoding='latin')

In [3]:
((x_train, y_train), (x_val, y_val), _) = data
x_train.shape, y_train.shape, x_val.shape, y_val.shape

((50000, 784), (50000,), (10000, 784), (10000,))

In [4]:
(x_train, y_train, x_val, y_val) = map(torch.tensor, (x_train, y_train, x_val, y_val))
x_train.shape, y_train.shape, x_val.shape, y_val.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [35]:
n_hidden = 50

w1 = torch.randn((x_train.shape[1], n_hidden))
b1 = torch.zeros(n_hidden)
w2 = torch.randn((n_hidden, 1))
b2 = torch.zeros(1)

w1.shape, b1.shape, w2.shape, b2.shape

(torch.Size([784, 50]), torch.Size([50]), torch.Size([50, 1]), torch.Size([1]))

#### anatomy

In [6]:
# linear function
def lin(x, w, b):
    return x@w+b

In [7]:
temp_x = x_train.clone()

In [8]:
temp_x.shape

torch.Size([50000, 784])

In [9]:
t = lin(temp_x, w1, b1)
t.shape

torch.Size([50000, 50])

In [10]:
def relu(x):
    return x.clamp_min(0)

In [11]:
t2 = relu(t)
t2.shape

torch.Size([50000, 50])

In [12]:
t2.max()

tensor(56.5448)

In [13]:
def model(x):
    l1 = lin(x, w1, b1)
    l2 = relu(l1)
    return lin(l2, w2, b2)

In [14]:
t3 = model(temp_x)

In [15]:
t3.shape

torch.Size([50000, 1])

In [16]:
# loss function (mse)

In [17]:
# formula: mean((pred - actual)**2)
# total sample size here means batch size

In [18]:
temp_y = y_train.clone()
temp_y.shape

torch.Size([50000])

In [19]:
temp_y[:4]

tensor([5, 0, 4, 1])

In [20]:
temp_y.view(-1, 1).shape

torch.Size([50000, 1])

In [21]:
(t3 - temp_y.view(-1, 1)).pow(2).mean()

tensor(4308.7598)

In [22]:
# gradient and backward pass

In [36]:
def lingrad(out, x, w, b):
    b.g = out.g.sum(0)
    x.g = out.g @ w.t()
    w.g = (x.t() @ out.g)
    
def lingrad(out, inp, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [37]:
def forward_and_backward(xb, target):
    # forward pass and loss
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    diff = (out - target.view(-1, 1))
    loss = diff.pow(2).mean()
    
    # calculate gradient
    out.g = 2*diff/xb.shape[0]
    lingrad(out, l2, w2, b2)
    l1.g = (l1 > 0).float() * l2.g
    lingrad(l1, xb, w1, b1)
    

In [None]:
forward_and_backward(x_train, y_train)

In [None]:
#  # forward pass:
#     l1 = lin(inp, w1, b1)
#     l2 = relu(l1)
#     out = lin(l2, w2, b2)
#     diff = out[:,0]-targ
#     loss = diff.pow(2).mean()
    
#     # backward pass:
#     out.g = 2.*diff[:,None] / inp.shape[0]
#     lin_grad(l2, out, w2, b2)
#     l1.g = (l1>0).float() * l2.g
#     lin_grad(inp, l1, w1, b1)

In [26]:
def get_grad(x):
    return x.g.clone()

chks = w1,b1,w2,b2,x_train
grads = tuple(map(get_grad, chks))

In [27]:
# calculate params grad using pytorch

In [28]:
def make_grad(x):
    return x.clone().requires_grad_(True)

ptgrads = ptw1, ptb1, ptw2, ptb2, ptx_train = tuple(map(make_grad, chks))

In [29]:
def pytorch_forward_backward(xb, target):
    l1 = lin(xb, ptw1, ptb1)
    l2 = relu(l1)
    out = lin(l2, ptw2, ptb2)
    diff = (out - target.view(-1, 1))
    loss =  diff.pow(2).mean()
    return loss

In [30]:
pytorch_loss = pytorch_forward_backward(ptx_train, y_train)

In [31]:
pytorch_loss.backward()

In [32]:
for a, b in zip(grads, ptgrads):
    print(torch.allclose(a, b))

False
False
False
False
False


In [33]:
grads[0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [34]:
ptgrads[0]

tensor([[ 1.9269,  1.4873,  0.9007,  ..., -0.4879, -0.9138, -0.6581],
        [ 0.0780,  0.5258, -0.4880,  ...,  0.4880,  0.7846,  0.0286],
        [ 0.6408,  0.5832,  1.0669,  ...,  1.4506,  0.2695, -0.2104],
        ...,
        [-1.0870,  0.5770,  1.0833,  ..., -0.0442,  0.3561, -0.0153],
        [-0.7371, -0.1260, -1.2686,  ..., -0.2711,  0.3694,  0.2720],
        [ 0.4637, -0.7807,  2.4361,  ..., -1.6001, -0.0425, -1.2497]],
       requires_grad=True)