In [1]:
# MNIST with an MLP, from scratch

# - Step 1: build an MLP from scratch to solve MNIST. Question set: https://fleuret.org/dlc/materials/dlc-practical-3.pdf
# - Step 2: debug your network with backprop ninja and a reference implementation using torch's .backward()
# - Step 3: build the same MLP but will full pytorch code (nn.Linear, etc.)

In [58]:
import math
import torch
from torch import nn
import matplotlib.pyplot as plt
%matplotlib inline

In [59]:
from utils import load_data

In [60]:
train_input, train_target, test_input, test_target = load_data(one_hot_labels = True, normalize = True)

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples


In [8]:
train_input.shape

torch.Size([1000, 784])

In [9]:
plt.imshow(train_input[4].view((28,28)))

<matplotlib.image.AxesImage at 0x14b5c2900>

In [10]:
def compute_accuracy (preds, targets):
    """ Computes the accuracy between predictions and targets. Data is expected to be one-hot encoded. """
    _, idx1 = torch.max(preds, dim=1)
    _, idx2 = torch.max(targets, dim=1)
    d = idx1 == idx2
    return d.int().float().mean().item()

In [11]:
# unit test
preds = torch.zeros((4,7))
preds[0,1] = 1
preds[1,4] = 1
preds[2,2] = 1
preds[3,6] = 1
targets = torch.zeros((4,7))
targets[0,1] = 1
targets[1,4] = 1
targets[2,2] = 1
targets[3,2] = 1
compute_accuracy(preds, targets)

0.75

In [12]:
def sigma(x):
    return torch.tanh(x)

def dsigma(x):
    return (1.0 - torch.tanh(x)**2)

In [13]:
def loss (v,t):
    out = torch.sum(torch.pow(v-t, 2))
    return out

def dloss(v,t):
    return -2.0 * (t-v)

In [17]:
v = torch.randn((3, 6), dtype=torch.float32)
t = torch.randn((3, 6), dtype=torch.float32)
l=loss(v,t)
dloss(v,t)

tensor([[ 0.1337, -1.4099, -0.7194,  0.9427,  5.7936,  2.3108],
        [ 1.4946,  1.2805,  0.9905, -1.9410, -1.7132,  1.4164],
        [-0.4312,  1.1042,  6.3547,  0.2399,  0.9457, -0.8823]])

In [18]:
# multiply targets by 0.9 to be in the range of tanh
train_target *= 0.9
test_target *= 0.9

### Step 1: Backprop ninja

In [19]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [20]:
torch.manual_seed(1337)
w1 = torch.randn((784, 50))
b1 = torch.randn((50))
w2 = torch.randn((50, 10))
b2 = torch.randn((10))
parameters = [w1, b1, w2, b2]
for p in parameters:
    p.requires_grad = True

In [21]:
x1 = train_input[:5]
y1 = train_target[:5]
z1 = x1 @ w1 + b1
h1 = sigma(z1)
z2 = h1 @ w2 + b2
h2 = sigma(z2)
l = loss(h2, y1)
h2.shape, l

(torch.Size([5, 10]), tensor(43.2825, grad_fn=<SumBackward0>))

In [22]:
others = [h2,z2,h1,z1]
for p in parameters:
    p.grad = None
for t in others:
    t.retain_grad()
l.backward()
print(f'loss={l}')

loss=43.282527923583984


In [23]:
b2.grad.shape

torch.Size([10])

In [24]:
u = torch.randn((50,))
z1.shape
#(z1 + torch.randn((50,))).shape

torch.Size([5, 50])

In [25]:
# here we compare our gradient to the reference gradient computed by pytorch
dl = 1.0
dh2 = dloss(h2, y1) * dl
dh2.shape
cmp('h2',dh2,h2)
dz2 = dsigma(z2) * dh2
dz2.shape
cmp('z2',dz2, z2)
dw2 = h1.T @ dz2
cmp('w2',dw2, w2)
db2 = dz2.sum(0, keepdim=True)
#print(b2.shape)
cmp('b2',db2, b2)
dh1 = dz2 @ w2.T
cmp('h1',dh1, h1)
dz1 = dsigma(z1) * dh1
cmp('z1', dz1, z1)
dw1 = x1.T @ dz1
cmp('w1', dw1, w1)
db1 = dz1.sum(0, keepdim=True)
cmp('b1', db1, b1)


h2              | exact: True  | approximate: True  | maxdiff: 0.0
z2              | exact: False | approximate: True  | maxdiff: 1.6298145055770874e-09
w2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
h1              | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
z1              | exact: False | approximate: True  | maxdiff: 4.423782229423523e-09
w1              | exact: False | approximate: True  | maxdiff: 5.960464477539063e-08
b1              | exact: True  | approximate: True  | maxdiff: 0.0


In [26]:
lr = 0.1
with torch.no_grad():
    w1 += -lr * dw1
    b1 += -lr * db1.squeeze()
    w2 += -lr * dw2
    b2 += -lr * db2.squeeze()

In [27]:
l = loss(h2, y1)
l.item()

43.282527923583984

##### Now that we've checked our gradients are correct, we can implement the network

In [28]:
def forward(w1, b1, w2, b2, x):
    z1 = x @ w1 + b1
    h1 = sigma(z1)
    z2 = h1 @ w2 + b2
    h2 = sigma(z2)
    return z1, h1, z2, h2


In [29]:
def backward(w1, b1, w2, b2, x1, y1, h2, z2, h1, z1):
    dl = 1.0
    dh2 = dloss(h2, y1) * dl
    dz2 = dsigma(z2) * dh2
    dw2 = h1.T @ dz2
    db2 = dz2.sum(0, keepdim=True)
    dh1 = dz2 @ w2.T
    dz1 = dsigma(z1) * dh1
    dw1 = x1.T @ dz1
    db1 = dz1.sum(0, keepdim=True)
    return dw1, db1, dw2, db2

In [30]:
def update(w1, b1, w2, b2, dw1, db1, dw2, db2, lr):
    with torch.no_grad():
        w1 += -lr * dw1
        b1 += -lr * db1.squeeze()
        w2 += -lr * dw2
        b2 += -lr * db2.squeeze()
    return w1, b1, w2, b2

In [31]:
def init():
    """ init a network """
    torch.manual_seed(1337)
    istd = math.sqrt(1e-6)
    w1 = torch.zeros((784, 50))
    torch.nn.init.normal_(w1, mean=.0, std=istd)
    b1 = torch.zeros((50,))
    w2 = torch.zeros((50, 10))
    torch.nn.init.normal_(w2, mean=.0, std=istd)
    b2 = torch.zeros((10,))
    return w1, b1, w2, b2

In [32]:
w1, b1, w2, b2 = init()
parameters = [w1, b1, w2, b2]
for p in parameters:
    p.requires_grad_(True)

In [33]:
train_input.shape

torch.Size([1000, 784])

In [34]:
# main training loop
torch.set_printoptions(linewidth=200)
def train(w1, b1, w2, b2):
    lossi = []
    for step in range(10000):
        xb = train_input
        yb = train_target
        num_samples = xb.shape[0]
        #print(f'{num_samples=}')
        # forward
        z1, h1, z2, h2 = forward(w1, b1, w2, b2, xb)
        lsi = loss(h2, yb)
        # backward
        dw1, db1, dw2, db2 = backward(w1, b1, w2, b2, xb, yb, h2, z2, h1, z1)
        # update
        lr = 0.1 / num_samples if step < 5000 else 0.01 / num_samples
        w1, b1, w2, b2 = update(w1, b1, w2, b2, dw1, db1, dw2, db2, lr)
        if step % 100 == 0: print(f'step = {step}, loss = {lsi}')
        lossi.append(lsi.item())
    # compute accuracy
    _, _, _, preds = forward(w1, b1, w2, b2, train_input)
    train_accuracy = compute_accuracy(preds, train_target)
    _, _, _, preds = forward(w1, b1, w2, b2, test_input)
    test_accuracy = compute_accuracy(preds, test_target)
    print(f'{train_accuracy=}')
    print(f'{test_accuracy=}')
    return lossi

        
    

In [35]:
lossi = train(w1, b1, w2, b2)


step = 0, loss = 809.9498291015625
step = 100, loss = 288.54022216796875
step = 200, loss = 195.53509521484375
step = 300, loss = 139.04873657226562
step = 400, loss = 135.3380126953125
step = 500, loss = 99.42034149169922
step = 600, loss = 82.91433715820312
step = 700, loss = 72.19515991210938
step = 800, loss = 62.638893127441406
step = 900, loss = 61.924476623535156
step = 1000, loss = 58.28996276855469
step = 1100, loss = 51.019866943359375
step = 1200, loss = 41.54011535644531
step = 1300, loss = 55.92261505126953
step = 1400, loss = 31.951839447021484
step = 1500, loss = 37.646663665771484
step = 1600, loss = 33.600311279296875
step = 1700, loss = 29.174652099609375
step = 1800, loss = 24.88680648803711
step = 1900, loss = 29.827388763427734
step = 2000, loss = 27.54845428466797
step = 2100, loss = 27.90268325805664
step = 2200, loss = 24.179628372192383
step = 2300, loss = 24.80556869506836
step = 2400, loss = 21.696060180664062
step = 2500, loss = 18.241796493530273
step = 260

In [36]:
plt.plot(lossi)
#print(lossi)

[<matplotlib.lines.Line2D at 0x169cad970>]

### Step 2: Reference implementation using pytorch's .backward()

In [37]:
w1, b1, w2, b2 = init()
parameters = [w1, b1, w2, b2]
for p in parameters:
    p.requires_grad_(True)

In [38]:
# reference code
torch.set_printoptions(linewidth=200)
import torch.nn as F

def train(w1, b1, w2, b2):
    lossi = []
    for step in range(10000):
        xb = train_input
        yb = train_target
        num_samples = xb.shape[0]
        # forward
        z1, h1, z2, h2 = forward(w1, b1, w2, b2, xb)
        xloss = F.MSELoss()
        lsi = xloss(h2, yb) * yb.nelement()
        # backward
        for p in parameters:
            p.grad = None
        lsi.backward()
        # update
        lr = 0.1 / num_samples
        for p in parameters:
            p.data += -lr * p.grad
        if step % 100 == 0: print(f'step = {step}, loss = {lsi}')
        lossi.append(lsi.item())
    # compute accuracy
    _, _, _, preds = forward(w1, b1, w2, b2, train_input)
    train_accuracy = compute_accuracy(preds, train_target)
    _, _, _, preds = forward(w1, b1, w2, b2, test_input)
    test_accuracy = compute_accuracy(preds, test_target)
    print(f'{train_accuracy=}')
    print(f'{test_accuracy=}')
    return lossi

In [39]:
lossi = train(w1, b1, w2, b2)

step = 0, loss = 809.9498901367188
step = 100, loss = 288.32464599609375
step = 200, loss = 181.5811767578125
step = 300, loss = 145.126220703125
step = 400, loss = 126.37185668945312
step = 500, loss = 95.3000717163086
step = 600, loss = 79.75885772705078
step = 700, loss = 72.97013092041016
step = 800, loss = 60.800933837890625
step = 900, loss = 67.0023422241211
step = 1000, loss = 54.57250213623047
step = 1100, loss = 67.70767974853516
step = 1200, loss = 47.989532470703125
step = 1300, loss = 42.321044921875
step = 1400, loss = 37.51057815551758
step = 1500, loss = 47.89775085449219
step = 1600, loss = 40.352420806884766
step = 1700, loss = 34.35886764526367
step = 1800, loss = 32.882774353027344
step = 1900, loss = 25.800189971923828
step = 2000, loss = 27.818504333496094
step = 2100, loss = 24.97629737854004
step = 2200, loss = 34.46820068359375
step = 2300, loss = 24.227466583251953
step = 2400, loss = 25.058338165283203
step = 2500, loss = 20.51760482788086
step = 2600, loss =

In [50]:
plt.plot(lossi)
plt.plot()

[]

### Step 3: Build the same MLP layer but with fully pytorch code (nn.Linear(), etc.)

In [54]:
# network dimensions
n_in = 784
n_hidden = 200
n_out = 10

In [55]:
X_tr, Y_tr = train_input, train_target
X_test, Y_test = test_input, test_target

In [56]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList((
            nn.Linear(n_in, n_hidden, bias=True),
            nn.Tanh(),
            nn.Linear(n_hidden, n_out, bias=True), 
            nn.Tanh(), 
        ))

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def __parameters__(self):
        return [p for layer in self.layers for p in layer.parameters]

model = MLP()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

In [57]:
# training
num_epochs = 10000

for n in range(num_epochs):
    y_pred = model(X_tr)
    loss = loss_fn(y_pred, Y_tr)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if n % 1000 == 0: 
        with torch.no_grad():
            # train accuracy
            acc_train = compute_accuracy(y_pred, Y_tr)
            # test accuracy
            y_test_preds = model(X_test)
            acc_test = compute_accuracy(y_test_preds, Y_test)
            print(f'step = {n:6d}\tloss={loss.item():.5f}\taccuracy (train, test): {acc_train:.5f}\t{acc_test:.5f}')




step =      0	loss=0.14046	accuracy (train, test): 0.11500	0.19800
step =   1000	loss=0.00009	accuracy (train, test): 1.00000	0.87300
step =   2000	loss=0.00005	accuracy (train, test): 1.00000	0.87500
step =   3000	loss=0.00009	accuracy (train, test): 1.00000	0.87300
step =   4000	loss=0.00002	accuracy (train, test): 1.00000	0.87700
step =   5000	loss=0.00003	accuracy (train, test): 1.00000	0.87800
step =   6000	loss=0.00012	accuracy (train, test): 1.00000	0.88100
step =   7000	loss=0.00000	accuracy (train, test): 1.00000	0.88000
step =   8000	loss=0.00004	accuracy (train, test): 1.00000	0.88100
step =   9000	loss=0.00003	accuracy (train, test): 1.00000	0.88200


##### Exercise: try to improve accuracy!