In [0]:
import torch
import torch.nn as nn
import fastai.datasets as datasets
import gzip
import pickle
import math
import torch.nn.functional as F
import numpy as np
import torch.nn.init as init
import pdb

In [0]:
MNIST_URL = 'http://deeplearning.net/data/mnist/mnist.pkl'
fpath = datasets.download_data(MNIST_URL, ext='.gz')

with gzip.open(fpath, 'rb') as fp:
  ((x_train, y_train), (x_val, y_val), _) = pickle.load(fp, encoding='latin-1')
x_train.shape, y_train.shape, x_val.shape, y_val.shape

((50000, 784), (50000,), (10000, 784), (10000,))

In [0]:
x_train.mean(), x_train.std(), x_val.mean(), x_val.std()

(0.13044983, 0.3072898, 0.12865187, 0.3049646)

### Normalize

In [0]:
def norm(x, m, s): return (x - m) / s

In [0]:
m, s = x_train.mean(), x_train.std()
x_train, x_val = norm(x_train, m, s), norm(x_val, m, s)
x_val.mean(), x_val.std()

(-0.005850922, 0.99243325)

In [0]:
x_train[0], x_val[0]

(array([-0.424517, -0.424517, -0.424517, -0.424517, ..., -0.424517, -0.424517, -0.424517, -0.424517], dtype=float32),
 array([-0.424517, -0.424517, -0.424517, -0.424517, ..., -0.424517, -0.424517, -0.424517, -0.424517], dtype=float32))

In [0]:
x_train, y_train, x_val, y_val = map(torch.tensor, [x_train, y_train, x_val, y_val])

### Model: linear -> relu -> linear -> mse

In [0]:
def mse(y_hat, y):
  return (y_hat.squeeze() - y.float()).pow(2).mean()

def linear(x, w, b):
  return x @ w + b

def relu(x):
  return x.clamp_min(0.) - 0.5

def model(x):
  x = relu(linear(x, W1, b1))
  return linear(x, W2, b2)

nh = [100, 100]
W1 = torch.zeros(784, nh[0])
b1 = torch.zeros(nh[0])
W2 = torch.zeros(nh[0], 1)
b2 = torch.zeros(1)

init.kaiming_normal_(W1, mode='fan_out')
init.kaiming_normal_(W2, mode='fan_out')
y_hat = model(x_train)
mse(y_hat, y_train)

tensor(31.3463)

### Gradient Descent

In [0]:
def mse_grad(outp, targ):
  outp.g = 2 / outp.shape[0] * (outp.squeeze() - targ.float()).unsqueeze(-1)
  
def lin_grad(inp, outp, w, b):
#   pdb.set_trace()
  inp.g = outp.g @ w.t()
  w.g = inp.t() @ outp.g
  b.g = outp.g.sum(0)
  
def relu_grad(inp, outp):
  inp.g = (inp > 0.).float() * outp.g

### Chain Rule

In [0]:
def forward_backward(x, y):
  z1 = x @ W1 + b1
  a1 = relu(z1)
  outp = a1 @ W2 + b2
  loss = mse(outp, y)
  
  mse_grad(outp, y)
  lin_grad(a1, outp, W2, b2)
  relu_grad(z1, a1)
  lin_grad(x, z1, W1, b1)

In [0]:
forward_backward(x_train, y_train)

In [0]:
xg = x_train.g.clone()
w1g = W1.g.clone()
w2g = W2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()

In [0]:
x2 = x_train.clone().requires_grad_(True)
w11 = W1.clone().requires_grad_(True)
b11 = b1.clone().requires_grad_(True)
w22 = W2.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [0]:
def forward_backward2(x, y):
  l1 = x @ w11 + b11
  z1 = relu(l1)
  outp = z1 @ w22 + b22
  return mse(outp, y)

In [0]:
loss = forward_backward2(x_train, y_train)
loss.backward()

In [0]:
def test_near(a, b): return np.allclose(a, b)

assert test_near(w1g, w11.grad)
assert test_near(b1g, b11.grad)
assert test_near(w2g, w22.grad)
assert test_near(b2g, b22.grad)

### Refactor Model

In [0]:
class Relu():
  def __call__(self, inp):
    self.inp = inp
    self.outp = inp.clamp_min(0.) - 0.5
    return self.outp
  
  def backward(self):
    self.inp.g = (self.inp > 0.).float() * self.outp.g

In [0]:
class Lin():
  def __init__(self, w, b):
    self.w, self.b = w, b
    
  def __call__(self, x):
    self.inp = x
    self.outp = x @ self.w + self.b
    return self.outp
  
  def backward(self):
    self.inp.g = self.outp.g @ self.w.t()
    self.w.g = self.inp.t() @ self.outp.g
    self.b.g = self.outp.g.sum(0)

In [0]:
class MSE():
  def __call__(self, inp, targ):
    self.inp, self.targ = inp, targ
    self.outp = (inp.squeeze() - targ.float()).pow(2).mean()
    return self.outp
  
  def backward(self):
    self.inp.g = 2 / self.inp.shape[0] * (self.inp.squeeze() - self.targ.float()).unsqueeze(-1)

In [0]:
class Model():
  def __init__(self, w1, b1, w2, b2):
    self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
    self.loss_fn = MSE()
    
  def __call__(self, x, y):
    for o in self.layers:
      x = o(x)
    return self.loss_fn(x, y)
  
  def backward(self):
    self.loss_fn.backward()
#     pdb.set_trace()
    for o in reversed(self.layers):
      o.backward()

In [0]:
nh = 100
w1 = torch.zeros(784, nh)
b1 = torch.zeros(nh)
w2 = torch.zeros(nh, 1)
b2 = torch.zeros(1)

init.kaiming_normal_(w1, mode='fan_out')
init.kaiming_normal_(w2, mode='fan_out');

In [0]:
m = Model(w1, b1, w2, b2)

In [0]:
m(x_train, y_train)
m.backward()

In [0]:
def forward_backward2(x, y):
  l1 = x @ w11 + b11
  z1 = relu(l1)
  outp = z1 @ w22 + b22
  return mse(outp, y)

In [0]:
x2 = x_train.clone().requires_grad_(True)
w11 = w1.clone().requires_grad_(True)
b11 = b1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

loss = forward_backward2(x_train, y_train)
loss.backward()

In [0]:
loss = forward_backward2(x_train, y_train)
loss.backward()

In [0]:
print(test_near(w1.g, w11.grad))
print(test_near(b1.g, b11.grad))
print(test_near(w2.g, w22.grad))
print(test_near(b2.g, b22.grad))

False
False
False
False


In [0]:
%time m = Model(w1, b1, w2, b2)

CPU times: user 198 µs, sys: 1.91 ms, total: 2.11 ms
Wall time: 3.25 ms


In [0]:
%time m(x_train, y_train)

CPU times: user 152 ms, sys: 976 µs, total: 153 ms
Wall time: 156 ms


tensor(25.1401)

In [0]:
%time m.backward()

CPU times: user 279 ms, sys: 1.13 ms, total: 280 ms
Wall time: 282 ms


In [0]:
%time loss = forward_backward2(x_train, y_train)

CPU times: user 152 ms, sys: 979 µs, total: 153 ms
Wall time: 156 ms


In [0]:
%time loss.backward()

CPU times: user 157 ms, sys: 3.17 ms, total: 160 ms
Wall time: 163 ms


In [0]:
class Module():
  def __call__(self, *args):
    self.args = args
    self.outp = self.forward(*self.args)
    return self.outp
  
  def forward(self, *args):
    raise Exception('Not Implemented')
    
  def backward(self):
    self.bwd(self.outp, *self.args)

In [0]:
class Relu(Module):
  def forward(self, x):
    return x.clamp_min(0.) - 0.5
  
  def bwd(self, outp, inp):
    inp.g = (inp > 0.).float() * outp.g

In [0]:
class Lin(Module):
  def __init__(self, w, b):
    self.w, self.b = w, b
    
  def forward(self, x):
    return x @ self.w + self.b
  
  def bwd(self, outp, inp):
    self.w.g = inp.t() @ outp.g
    inp.g = outp.g @ self.w.t()
    self.b.g = outp.g.sum(0)

In [0]:
class MSE(Module):
  def forward(self, outp, targ):
    return (outp.squeeze() - targ.float()).pow(2).mean()
  
  def bwd(self, outp, *args):
    outp, targ = args
    outp.g = 2 / outp.shape[0] * (outp.squeeze() - targ.float()).unsqueeze(-1)

In [0]:
class Model(Module):
  def __init__(self, w1, b1, w2, b2):
    self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
    self.crit = MSE()
    
  def forward(self, x, y):
    for o in self.layers:
      x = o(x)
    return self.crit(x, y)
  
  def bwd(self, outp, *args):
    self.crit.backward()
    for o in reversed(self.layers):
      o.backward()

In [0]:
w1.g, b1.g, w2.g, b2.g = [None] * 4
m = Model(w1, b1, w2, b2)

In [0]:
%time loss = m(x_train, y_train)

CPU times: user 152 ms, sys: 962 µs, total: 153 ms
Wall time: 156 ms


In [0]:
%time m.backward()

CPU times: user 285 ms, sys: 2.63 ms, total: 288 ms
Wall time: 297 ms


In [0]:
x2 = x_train.clone().requires_grad_(True)
w11 = w1.clone().requires_grad_(True)
b11 = b1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

loss = forward_backward2(x_train, y_train)
loss.backward()

In [0]:
assert test_near(w1.g, w11.grad)
assert test_near(b1.g, b11.grad)
assert test_near(w2.g, w22.grad)
assert test_near(b2.g, b22.grad)

In [0]:
class Optim():
  def __init__(self, lr, *parameters):
    self.lr, self.parameters = lr, parameters
  
  def step(self):
    for o in self.parameters:
      o = o - self.lr * o.g
  
  def zero_grad(self):
    for o in self.parameters:
      o.g = 0.

In [0]:
def train_loop(x, y, model, crit, opt, bs=128, epoch=1):
  for n in range(epoch):
    nr_batch = math.ceil(x.shape[0] / bs)
    for i in range(nr_batch):
      start = bs * i
      end = start + bs
      x_batch = x[start:end]
      y_batch = y[start:end]
      outp = model(x_batch)
      loss = crit(outp, y_batch)
      if (i + 1) % 30 == 0:
        print("mini-batch %d, loss: %.3f" % (i + 1, loss))
      opt.zero_grad()
      crit.backward()
      model.backward()
      opt.step()

In [0]:
class Model(Module):
  def __init__(self, w1, b1, w2, b2):
    self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
    
  def forward(self, x):
    for o in self.layers:
      x = o(x)
    return x
  
  def bwd(self, outp, *args):
    for o in reversed(self.layers):
      o.backward()

In [0]:
lr = 1e-3
m = Model(w1, b1, w2, b2)
opt = Optim(lr, w1, b1, w2, b2)
crit = MSE()

train_loop(x_train, y_train, m, crit, opt)

mini-batch 30, loss: 27.630
mini-batch 60, loss: 20.840
mini-batch 90, loss: 24.904
mini-batch 120, loss: 25.661
mini-batch 150, loss: 25.589
mini-batch 180, loss: 21.758
mini-batch 210, loss: 26.606
mini-batch 240, loss: 27.184
mini-batch 270, loss: 22.656
mini-batch 300, loss: 25.626
mini-batch 330, loss: 27.401
mini-batch 360, loss: 25.355
mini-batch 390, loss: 25.346


In [0]:
class PytorchModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.layers = [nn.Linear(784, 100), nn.ReLU(), nn.Linear(100, 1)]
    self.crit = nn.MSELoss()
    
  def forward(self, x, y):
    for o in self.layers:
      x = o(x)
    return self.crit(x, y)

In [0]:
pm = PytorchModel().cuda()

In [0]:
%time loss = pm(x_train, y_train.float())

  return F.mse_loss(input, target, reduction=self.reduction)


CPU times: user 3.61 s, sys: 0 ns, total: 3.61 s
Wall time: 3.62 s


In [0]:
%time loss.backward()

CPU times: user 9.34 s, sys: 49.7 ms, total: 9.39 s
Wall time: 9.41 s


In [0]:
a = pm.layers[0]

In [0]:
list(a.parameters())[0].shape

torch.Size([100, 784])