In [1]:
import math
import torch

import dlc_practical_prologue as prologue

## 1. Activation function

In [2]:
def sigma(x):
    return x.tanh()

def dsigma(x):
    return 4 * (x.exp() + x.mul(-1).exp()).pow(-2)

## 2. Loss

In [3]:
def loss(v, t):
    return (v-t).pow(2).sum()

def dloss(v,t):
    return 2 * (v-t)

## 3. Forward and backward passes

In [4]:
def forward_pass(w1, b1, w2, b2, x):
    x0 = x
    s1 = w1.mv(x0) + b1
    x1 = sigma(s1)
    s2 = w2.mv(x1) + b2
    x2 = sigma(s2)
    
    return x0, s1, x1, s2, x2

def backward_pass(w1, b1, w2, b2,
                  t,
                  x, s1, x1, s2, x2,
                  dl_dw1, dl_db1, dl_dw2, dl_db2):
    x0 = x
    dl_dx2 = dloss(x2 ,t)
    dl_ds2 = dsigma(s2) * dl_dx2
    dl_dx1 = w2.t().mv(dl_ds2) 
    dl_ds1 = dsigma(s1) * dl_dx1
    
    dl_dw2.add_(dl_ds2.view(-1, 1).mm(x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    dl_dw1.add_(dl_ds1.view(-1, 1).mm(x0.view(1, -1)))
    dl_db1.add_(dl_ds1)

## 4. Training the network

In [None]:
train_input, train_target, test_input, test_target = prologue.load_data(one_hot_labels = True,
                                                                        normalize = True)

nb_classes = train_target.size(1)
nb_train_samples = train_input.size(0)

zeta = 0.90

train_target = train_target * zeta
test_target = test_target * zeta

nb_hidden = 50
eta = 1e-1 / nb_train_samples
epsilon = 1e-6

w1 = torch.empty(nb_hidden, train_input.size(1)).normal_(0, epsilon)
b1 = torch.empty(nb_hidden).normal_(0, epsilon)
w2 = torch.empty(nb_classes, nb_hidden).normal_(0, epsilon)
b2 = torch.empty(nb_classes).normal_(0, epsilon)

dl_dw1 = torch.empty(w1.size())
dl_db1 = torch.empty(b1.size())
dl_dw2 = torch.empty(w2.size())
dl_db2 = torch.empty(b2.size())

for k in range(1000):
    
    acc_loss = 0
    nb_train_errors = 0

    dl_dw1.zero_()
    dl_db1.zero_()
    dl_dw2.zero_()
    dl_db2.zero_()

    for n in range(nb_train_samples):
        x0, s1, x1, s2, x2 = forward_pass(w1, b1, w2, b2, train_input[n])

        pred = x2.max(0)[1].item()
        if train_target[n, pred] < 0.5: nb_train_errors = nb_train_errors + 1
        acc_loss = acc_loss + loss(x2, train_target[n])

        backward_pass(w1, b1, w2, b2,
                      train_target[n],
                      x0, s1, x1, s2, x2,
                      dl_dw1, dl_db1, dl_dw2, dl_db2)

    # Gradient step

    w1 = w1 - eta * dl_dw1
    b1 = b1 - eta * dl_db1
    w2 = w2 - eta * dl_dw2
    b2 = b2 - eta * dl_db2

    # Test error

    nb_test_errors = 0

    for n in range(test_input.size(0)):
        _, _, _, _, x2 = forward_pass(w1, b1, w2, b2, test_input[n])

        pred = x2.max(0)[1].item()
        if test_target[n, pred] < 0.5: nb_test_errors = nb_test_errors + 1

    print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
          .format(k,
                  acc_loss,
                  (100 * nb_train_errors) / train_input.size(0),
                  (100 * nb_test_errors) / test_input.size(0)))

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples
0 acc_train_loss 810.00 acc_train_error 90.30% test_error 90.10%
1 acc_train_loss 780.58 acc_train_error 88.30% test_error 90.10%
2 acc_train_loss 761.77 acc_train_error 88.30% test_error 90.10%
3 acc_train_loss 749.74 acc_train_error 88.30% test_error 90.10%
4 acc_train_loss 742.04 acc_train_error 88.30% test_error 90.10%
5 acc_train_loss 737.11 acc_train_error 88.30% test_error 90.10%
6 acc_train_loss 733.95 acc_train_error 88.30% test_error 90.10%
7 acc_train_loss 731.92 acc_train_error 88.30% test_error 90.10%
8 acc_train_loss 730.62 acc_train_error 88.30% test_error 90.10%
9 acc_train_loss 729.78 acc_train_error 88.30% test_error 90.10%
10 acc_train_loss 729.25 acc_train_error 88.30% test_error 90.10%
11 acc_train_loss 728.90 acc_train_error 88.30% test_error 90.10%
12 acc_train_loss 728.68 acc_train_error 88.30% test_error 90.10%
13 acc_train_loss 728.53 acc_train_error 88

124 acc_train_loss 309.72 acc_train_error 12.00% test_error 22.20%
125 acc_train_loss 308.06 acc_train_error 12.60% test_error 24.10%
126 acc_train_loss 313.37 acc_train_error 12.80% test_error 21.70%
127 acc_train_loss 318.32 acc_train_error 12.10% test_error 24.30%
128 acc_train_loss 318.30 acc_train_error 13.30% test_error 20.30%
129 acc_train_loss 306.42 acc_train_error 11.20% test_error 23.20%
130 acc_train_loss 298.55 acc_train_error 12.80% test_error 21.40%
131 acc_train_loss 294.96 acc_train_error 10.00% test_error 22.30%
132 acc_train_loss 300.73 acc_train_error 12.90% test_error 23.20%
133 acc_train_loss 309.22 acc_train_error 11.10% test_error 23.90%
134 acc_train_loss 313.19 acc_train_error 13.50% test_error 23.20%
135 acc_train_loss 307.75 acc_train_error 11.00% test_error 21.00%
136 acc_train_loss 304.09 acc_train_error 11.00% test_error 24.50%
137 acc_train_loss 314.62 acc_train_error 13.50% test_error 22.70%
138 acc_train_loss 324.68 acc_train_error 10.80% test_error 25

248 acc_train_loss 203.71 acc_train_error 4.90% test_error 16.50%
249 acc_train_loss 200.48 acc_train_error 5.00% test_error 16.40%
250 acc_train_loss 192.69 acc_train_error 5.10% test_error 16.40%
251 acc_train_loss 186.63 acc_train_error 5.00% test_error 16.40%
252 acc_train_loss 181.59 acc_train_error 5.00% test_error 16.20%
253 acc_train_loss 178.20 acc_train_error 4.40% test_error 16.70%
254 acc_train_loss 174.64 acc_train_error 4.60% test_error 16.30%
255 acc_train_loss 171.33 acc_train_error 4.30% test_error 16.10%
256 acc_train_loss 167.74 acc_train_error 4.50% test_error 16.50%
257 acc_train_loss 164.65 acc_train_error 3.80% test_error 15.50%
258 acc_train_loss 162.06 acc_train_error 4.20% test_error 16.40%
259 acc_train_loss 160.28 acc_train_error 4.00% test_error 15.80%
260 acc_train_loss 159.43 acc_train_error 4.30% test_error 16.20%
261 acc_train_loss 159.66 acc_train_error 4.20% test_error 15.80%
262 acc_train_loss 161.48 acc_train_error 4.30% test_error 15.70%
263 acc_tr

374 acc_train_loss 154.48 acc_train_error 3.10% test_error 16.90%
375 acc_train_loss 159.67 acc_train_error 2.90% test_error 15.90%
376 acc_train_loss 158.88 acc_train_error 3.10% test_error 15.60%
377 acc_train_loss 154.02 acc_train_error 3.10% test_error 15.10%
378 acc_train_loss 148.20 acc_train_error 2.90% test_error 14.50%
379 acc_train_loss 144.64 acc_train_error 3.00% test_error 15.80%
380 acc_train_loss 145.92 acc_train_error 2.60% test_error 14.30%
381 acc_train_loss 147.75 acc_train_error 3.20% test_error 17.10%
382 acc_train_loss 153.55 acc_train_error 2.90% test_error 14.80%
383 acc_train_loss 153.97 acc_train_error 3.00% test_error 16.50%
384 acc_train_loss 157.01 acc_train_error 3.40% test_error 15.80%
385 acc_train_loss 153.90 acc_train_error 3.20% test_error 15.60%
386 acc_train_loss 152.80 acc_train_error 3.40% test_error 16.70%
387 acc_train_loss 151.59 acc_train_error 3.30% test_error 14.80%
388 acc_train_loss 149.02 acc_train_error 2.80% test_error 16.60%
389 acc_tr

500 acc_train_loss 125.33 acc_train_error 2.10% test_error 15.40%
501 acc_train_loss 117.78 acc_train_error 1.80% test_error 15.90%
502 acc_train_loss 110.98 acc_train_error 1.40% test_error 14.90%
503 acc_train_loss 104.95 acc_train_error 1.90% test_error 15.50%
504 acc_train_loss 101.41 acc_train_error 1.50% test_error 14.20%
505 acc_train_loss 99.33 acc_train_error 1.90% test_error 15.00%
506 acc_train_loss 99.28 acc_train_error 1.60% test_error 13.80%
507 acc_train_loss 100.76 acc_train_error 1.90% test_error 15.40%
508 acc_train_loss 103.69 acc_train_error 1.60% test_error 14.10%
509 acc_train_loss 108.24 acc_train_error 1.70% test_error 15.30%
510 acc_train_loss 112.58 acc_train_error 1.60% test_error 14.70%
511 acc_train_loss 117.16 acc_train_error 1.70% test_error 15.00%
512 acc_train_loss 118.55 acc_train_error 1.70% test_error 15.00%
513 acc_train_loss 118.54 acc_train_error 1.90% test_error 14.50%
514 acc_train_loss 115.97 acc_train_error 1.80% test_error 15.50%
515 acc_trai

625 acc_train_loss 88.04 acc_train_error 1.00% test_error 15.00%
626 acc_train_loss 85.92 acc_train_error 1.00% test_error 14.70%
627 acc_train_loss 85.12 acc_train_error 1.00% test_error 15.00%
628 acc_train_loss 83.83 acc_train_error 1.00% test_error 14.50%
629 acc_train_loss 83.61 acc_train_error 1.10% test_error 15.00%
630 acc_train_loss 82.68 acc_train_error 1.00% test_error 14.50%
631 acc_train_loss 82.49 acc_train_error 0.90% test_error 15.00%
632 acc_train_loss 81.63 acc_train_error 1.00% test_error 14.40%
633 acc_train_loss 81.24 acc_train_error 0.80% test_error 15.00%
634 acc_train_loss 80.38 acc_train_error 1.00% test_error 14.30%
635 acc_train_loss 79.84 acc_train_error 0.80% test_error 15.10%
636 acc_train_loss 79.11 acc_train_error 0.90% test_error 14.40%
637 acc_train_loss 78.63 acc_train_error 0.70% test_error 15.00%
638 acc_train_loss 78.21 acc_train_error 0.90% test_error 14.40%
639 acc_train_loss 78.02 acc_train_error 0.70% test_error 15.00%
640 acc_train_loss 78.06 

752 acc_train_loss 107.99 acc_train_error 0.50% test_error 14.00%
753 acc_train_loss 103.57 acc_train_error 0.80% test_error 17.20%
754 acc_train_loss 94.28 acc_train_error 0.60% test_error 14.40%
755 acc_train_loss 85.61 acc_train_error 0.70% test_error 16.30%
756 acc_train_loss 79.91 acc_train_error 0.80% test_error 14.50%
757 acc_train_loss 76.32 acc_train_error 0.60% test_error 16.00%
758 acc_train_loss 75.09 acc_train_error 0.60% test_error 14.70%
759 acc_train_loss 74.51 acc_train_error 0.50% test_error 15.60%
760 acc_train_loss 75.22 acc_train_error 0.60% test_error 15.00%
761 acc_train_loss 75.83 acc_train_error 0.50% test_error 15.80%
762 acc_train_loss 77.23 acc_train_error 0.60% test_error 15.00%
763 acc_train_loss 78.00 acc_train_error 0.50% test_error 16.10%
764 acc_train_loss 79.14 acc_train_error 0.60% test_error 15.10%
765 acc_train_loss 79.24 acc_train_error 0.50% test_error 16.10%
766 acc_train_loss 79.26 acc_train_error 0.60% test_error 15.20%
767 acc_train_loss 78.2