In [1]:
import math
import torch
import os

In [2]:
os.getcwd()

'C:\\Users\\zeogo\\Documents\\ML books\\Pytorch course\\exercises practical'

In [3]:
os.chdir("..")

In [4]:
import torch
import torchvision
import dlc_practical_prologue as prologue

In [6]:
def sigma(x):
    active = x.tanh()
    return active

In [18]:
def dsigma(x):
    return (1 - torch.pow(x.tanh(), 2))

In [25]:
def loss(v, t):
    return (v - t).pow(2).sum()

In [26]:
def dloss(v, t):
    return 2*(v - t)

In [27]:
def forward_pass(w1, b1, w2, b2, x):
    x0 = x
    s1 = w1.mv(x0) + b1
    x1 = sigma(s1)
    s2 = w2.mv(x1) + b2
    x2 = sigma(s2)

    return x0, s1, x1, s2, x2

def backward_pass(w1, b1, w2, b2,
                  t,
                  x, s1, x1, s2, x2,
                  dl_dw1, dl_db1, dl_dw2, dl_db2):
    x0 = x
    dl_dx2 = dloss(x2, t)
    dl_ds2 = dsigma(s2) * dl_dx2
    dl_dx1 = w2.t().mv(dl_ds2)
    dl_ds1 = dsigma(s1) * dl_dx1

    dl_dw2.add_(dl_ds2.view(-1, 1).mm(x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    dl_dw1.add_(dl_ds1.view(-1, 1).mm(x0.view(1, -1)))
    dl_db1.add_(dl_ds1)

######################################################################

train_input, train_target, test_input, test_target = prologue.load_data(one_hot_labels = True,
                                                                        normalize = True)

nb_classes = train_target.size(1)
nb_train_samples = train_input.size(0)

zeta = 0.90

train_input = train_input * zeta
test_input = test_input * zeta

nb_hidden = 50
eta = 1e-1 / nb_train_samples
epsilon = 1e-6

w1 = torch.empty(nb_hidden, train_input.size(1)).normal_(0, epsilon)
b1 = torch.empty(nb_hidden).normal_(0, epsilon)
w2 = torch.empty(nb_classes, nb_hidden).normal_(0, epsilon)
b2 = torch.empty(nb_classes).normal_(0, epsilon)

dl_dw1 = torch.empty(w1.size())
dl_db1 = torch.empty(b1.size())
dl_dw2 = torch.empty(w2.size())
dl_db2 = torch.empty(b2.size())

for k in range(1000):

    # Back-prop

    acc_loss = 0
    nb_train_errors = 0

    dl_dw1.zero_()
    dl_db1.zero_()
    dl_dw2.zero_()
    dl_db2.zero_()

    for n in range(nb_train_samples):
        x0, s1, x1, s2, x2 = forward_pass(w1, b1, w2, b2, train_input[n])

        pred = x2.max(0)[1].item()
        if train_target[n, pred] < 0.5: nb_train_errors = nb_train_errors + 1
        acc_loss = acc_loss + loss(x2, train_target[n])

        backward_pass(w1, b1, w2, b2,
                      train_target[n],
                      x0, s1, x1, s2, x2,
                      dl_dw1, dl_db1, dl_dw2, dl_db2)

    # Gradient step

    w1 = w1 - eta * dl_dw1
    b1 = b1 - eta * dl_db1
    w2 = w2 - eta * dl_dw2
    b2 = b2 - eta * dl_db2

    # Test error

    nb_test_errors = 0

    for n in range(test_input.size(0)):
        _, _, _, _, x2 = forward_pass(w1, b1, w2, b2, test_input[n])

        pred = x2.max(0)[1].item()
        if test_target[n, pred] < 0.5: nb_test_errors = nb_test_errors + 1

    print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
          .format(k,
                  acc_loss,
                  (100 * nb_train_errors) / train_input.size(0),
                  (100 * nb_test_errors) / test_input.size(0)))

* Using MNIST
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw\train-images-idx3-ubyte.gz


 99%|███████████████████████████████████████████████████████████████████▌| 9846784/9912422 [00:23<00:00, 853730.44it/s]

Extracting ./data/mnist/MNIST\raw\train-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw\train-labels-idx1-ubyte.gz



0it [00:00, ?it/s]
  0%|                                                                                        | 0/28881 [00:00<?, ?it/s]
 57%|████████████████████████████████████████▊                               | 16384/28881 [00:00<00:00, 105801.38it/s]
32768it [00:00, 63438.54it/s]                                                                                          

Extracting ./data/mnist/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw\t10k-images-idx3-ubyte.gz



0it [00:00, ?it/s]
  0%|                                                                                      | 0/1648877 [00:00<?, ?it/s]
  1%|▋                                                                      | 16384/1648877 [00:00<00:20, 80475.78it/s]
  3%|██                                                                     | 49152/1648877 [00:00<00:16, 99465.60it/s]
  6%|████▏                                                                 | 98304/1648877 [00:00<00:12, 126418.47it/s]
 10%|██████▊                                                              | 163840/1648877 [00:00<00:09, 160757.73it/s]
 15%|██████████▎                                                          | 245760/1648877 [00:01<00:06, 202949.04it/s]
 20%|█████████████▋                                                       | 327680/1648877 [00:01<00:05, 250984.07it/s]
 25%|█████████████████▏                                                   | 409600/1648877 [00:01<00:04, 302567.84it/s]
 30%|███████████████

Extracting ./data/mnist/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz



0it [00:00, ?it/s]
  0%|                                                                                         | 0/4542 [00:00<?, ?it/s]
8192it [00:00, 24162.35it/s]                                                                                           

Extracting ./data/mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw
Processing...
Done!
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples
0 acc_train_loss 1000.00 acc_train_error 91.30% test_error 90.10%
1 acc_train_loss 963.68 acc_train_error 88.30% test_error 90.10%
2 acc_train_loss 940.46 acc_train_error 88.30% test_error 90.10%
3 acc_train_loss 925.61 acc_train_error 88.30% test_error 90.10%
4 acc_train_loss 916.12 acc_train_error 88.30% test_error 90.10%
5 acc_train_loss 910.04 acc_train_error 88.30% test_error 90.10%
6 acc_train_loss 906.14 acc_train_error 88.30% test_error 90.10%
7 acc_train_loss 903.63 acc_train_error 88.30% test_error 90.10%
8 acc_train_loss 902.02 acc_train_error 88.30% test_error 90.10%
9 acc_train_loss 900.98 acc_train_error 88.30% test_error 90.10%
10 acc_train_loss 900.32 acc_train_error 88.30% test_error 90.10%
11 acc_train_loss 899.88 acc_train_error 88.30% test_error 90.10%
12 acc_train_los

9920512it [00:40, 853730.44it/s]                                                                                       

16 acc_train_loss 899.19 acc_train_error 88.30% test_error 90.10%
17 acc_train_loss 899.16 acc_train_error 88.30% test_error 90.10%
18 acc_train_loss 899.13 acc_train_error 88.30% test_error 90.10%
19 acc_train_loss 899.12 acc_train_error 88.30% test_error 90.10%
20 acc_train_loss 899.10 acc_train_error 88.30% test_error 90.10%
21 acc_train_loss 899.08 acc_train_error 88.30% test_error 90.10%
22 acc_train_loss 899.05 acc_train_error 88.30% test_error 90.10%
23 acc_train_loss 899.01 acc_train_error 88.30% test_error 90.10%
24 acc_train_loss 898.93 acc_train_error 88.30% test_error 90.10%
25 acc_train_loss 898.80 acc_train_error 88.30% test_error 90.10%
26 acc_train_loss 898.56 acc_train_error 88.30% test_error 85.00%
27 acc_train_loss 898.15 acc_train_error 81.00% test_error 79.20%
28 acc_train_loss 897.41 acc_train_error 78.50% test_error 79.20%
29 acc_train_loss 896.12 acc_train_error 78.60% test_error 79.10%
30 acc_train_loss 893.90 acc_train_error 77.20% test_error 75.50%
31 acc_tra

KeyboardInterrupt: 