In [1]:
import torch
from dlc_practical_prologue import load_data

In [2]:
def sigma(x):
    return torch.tanh(x) 

In [3]:
def dsigma(x):
    return 1-(sigma(x))**2

In [4]:
def loss(v,t):
    return torch.sum(torch.pow((t-v),2))

In [5]:
def dloss(v,t):
    return 2*(t-v)

In [6]:
train_input, train_target, test_input, test_target = load_data(one_hot_labels=True,normalize = True)

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples




In [7]:
def forward_pass(w1,b1,w2,b2,x):
    x = x.view(784,1)
    x0 = x
    #print(x0.shape,w1.shape,b1.shape)
    s1 = torch.add(torch.mm(w1,x),b1)
    x1 = sigma(s1)
    s2 = torch.add(torch.mm(w2,x1),b2)
    x2 = sigma(s2)
    return x0,s1,x1,s2,x2
   

In [74]:
# k = train_input[0].view(784,1)
# k.shape

In [8]:
n_hidden_layers=50
x_size = train_input.size(1)
number_classes = train_target.size(1)
mu = 0
std = 1e-6
w1 = torch.normal(mu,std,size = (n_hidden_layers,x_size))
w2 = torch.normal(mu,std,size = (number_classes,n_hidden_layers))
b1 = torch.normal(mu,std,size = (n_hidden_layers,1))
b2 = torch.normal(mu,std,size = (number_classes,1))
dl_dw1 = torch.empty(w1.size())
dl_db1 = torch.empty(b1.size())
dl_dw2 = torch.empty(w2.size())
dl_db2 = torch.empty(b2.size())

In [9]:
def backward_pass(w1,b1,w2,b2,t,x,s1,x1,s2,x2,dl_dw1,dl_db1,dl_dw2,dl_db2):

#_____using chain rule to compute back prop    
    dl_dx2 = dloss(x2,t)
    dl_ds2 = torch.mm(dl_dx2,dsigma(s2))
    dl_dw2.add_(dl_ds2.view(-1, 1).mm(x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    
    dl_dx1 = torch.mm(dl_ds2.t(),w2)
    dl_ds1 = torch.mm(dl_dx1,dsigma(s1))
    dl_dw1.add_(dl_ds1.view(-1, 1).mm(x.view(1, -1)))
    dl_db1.add_(dl_ds1)
    

In [11]:
nb_classes = train_target.size(1)
nb_train_samples = train_input.size(0)

zeta = 0.90

train_input = train_input * zeta
test_input = test_input * zeta

nb_hidden = 50
eta = 1e-1 / nb_train_samples
epsilon = 1e-6
for k in range(1000):

    # Back-prop

    acc_loss = 0
    nb_train_errors = 0

    dl_dw1.zero_()
    dl_db1.zero_()
    dl_dw2.zero_()
    dl_db2.zero_()

    for n in range(nb_train_samples):
        x0, s1, x1, s2, x2 = forward_pass(w1, b1, w2, b2, train_input[n])

        pred = x2.max(0)[1].item()
        if train_target[n, pred] < 0.5: nb_train_errors = nb_train_errors + 1
        acc_loss = acc_loss + loss(x2, train_target[n])

        backward_pass(w1, b1, w2, b2,
                      train_target[n],
                      x0, s1, x1, s2, x2,
                      dl_dw1, dl_db1, dl_dw2, dl_db2)

    # Gradient step

    w1 = w1 - eta * dl_dw1
    b1 = b1 - eta * dl_db1
    w2 = w2 - eta * dl_dw2
    b2 = b2 - eta * dl_db2

    # Test error

    nb_test_errors = 0

    for n in range(test_input.size(0)):
        _, _, _, _, x2 = forward_pass(w1, b1, w2, b2, test_input[n])

        pred = x2.max(0)[1].item()
        if test_target[n, pred] < 0.5: nb_test_errors = nb_test_errors + 1

    print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
          .format(k,
                  acc_loss,
                  (100 * nb_train_errors) / train_input.size(0),
                  (100 * nb_test_errors) / test_input.size(0)))

0 acc_train_loss 10000.00 acc_train_error 90.30% test_error 91.50%
1 acc_train_loss 17843.07 acc_train_error 90.30% test_error 91.50%
2 acc_train_loss 64945.97 acc_train_error 90.30% test_error 91.40%
3 acc_train_loss 128922.86 acc_train_error 90.30% test_error 90.70%
4 acc_train_loss 129876.70 acc_train_error 89.20% test_error 90.10%
5 acc_train_loss 129905.71 acc_train_error 89.80% test_error 90.50%
6 acc_train_loss 129921.73 acc_train_error 89.50% test_error 90.10%
7 acc_train_loss 129931.55 acc_train_error 90.00% test_error 91.80%
8 acc_train_loss 129943.27 acc_train_error 91.10% test_error 90.20%
9 acc_train_loss 129947.08 acc_train_error 89.80% test_error 90.50%
10 acc_train_loss 129953.08 acc_train_error 89.90% test_error 90.10%
11 acc_train_loss 129958.98 acc_train_error 90.10% test_error 91.70%
12 acc_train_loss 129960.90 acc_train_error 91.40% test_error 90.20%
13 acc_train_loss 129962.75 acc_train_error 89.80% test_error 90.50%
14 acc_train_loss 129967.21 acc_train_error 89.

119 acc_train_loss 129994.15 acc_train_error 89.90% test_error 90.60%
120 acc_train_loss 129994.15 acc_train_error 90.00% test_error 90.60%
121 acc_train_loss 129994.15 acc_train_error 90.00% test_error 90.60%
122 acc_train_loss 129994.15 acc_train_error 90.00% test_error 90.50%
123 acc_train_loss 129994.59 acc_train_error 90.00% test_error 90.60%
124 acc_train_loss 129997.98 acc_train_error 90.00% test_error 90.60%
125 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
126 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
127 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
128 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
129 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
130 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
131 acc_train_loss 129998.02 acc_train_error 90.00% test_error 90.60%
132 acc_train_loss 129998.04 acc_train_error 90.00% test_error 90.60%
133 acc_train_loss 1

237 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.50%
238 acc_train_loss 129998.52 acc_train_error 90.10% test_error 90.60%
239 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
240 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
241 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
242 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
243 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
244 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
245 acc_train_loss 129998.52 acc_train_error 90.00% test_error 90.60%
246 acc_train_loss 129999.16 acc_train_error 90.00% test_error 90.60%
247 acc_train_loss 129999.20 acc_train_error 90.00% test_error 90.60%
248 acc_train_loss 129999.49 acc_train_error 90.00% test_error 90.60%
249 acc_train_loss 129999.49 acc_train_error 90.00% test_error 90.60%
250 acc_train_loss 129999.50 acc_train_error 90.00% test_error 90.40%
251 acc_train_loss 1

355 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.50%
356 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
357 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
358 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
359 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
360 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
361 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
362 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
363 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
364 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
365 acc_train_loss 129999.59 acc_train_error 90.00% test_error 90.60%
366 acc_train_loss 129999.60 acc_train_error 90.00% test_error 90.60%
367 acc_train_loss 129999.60 acc_train_error 90.00% test_error 90.60%
368 acc_train_loss 129999.60 acc_train_error 90.00% test_error 90.60%
369 acc_train_loss 1

473 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
474 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
475 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
476 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
477 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
478 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
479 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
480 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
481 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
482 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
483 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
484 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
485 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
486 acc_train_loss 129999.62 acc_train_error 90.00% test_error 90.60%
487 acc_train_loss 1

591 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
592 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
593 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
594 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
595 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
596 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
597 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
598 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
599 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
600 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
601 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
602 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
603 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
604 acc_train_loss 129999.88 acc_train_error 90.00% test_error 90.60%
605 acc_train_loss 1

709 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
710 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
711 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
712 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
713 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
714 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
715 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
716 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
717 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
718 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
719 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
720 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
721 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
722 acc_train_loss 129999.90 acc_train_error 90.00% test_error 90.60%
723 acc_train_loss 1

827 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
828 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
829 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
830 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
831 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
832 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
833 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
834 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
835 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
836 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
837 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
838 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
839 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
840 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
841 acc_train_loss 1

945 acc_train_loss 129999.91 acc_train_error 89.90% test_error 90.60%
946 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
947 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
948 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
949 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
950 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
951 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
952 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
953 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
954 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
955 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
956 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
957 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
958 acc_train_loss 129999.91 acc_train_error 90.00% test_error 90.60%
959 acc_train_loss 1