In [1]:
import torch
from dlc_practical_prologue import load_data

In [3]:
X_train,y_train,X_test,y_test = load_data(one_hot_labels=True,normalize=True)
y_train *=0.9
y_test *=0.9

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples


In [4]:
feature_size = X_train.size(1)
hidden_size = 50
output_size = y_train.size(1)
epsilon = 1e-6

w1 = torch.empty(hidden_size, feature_size).normal_(0,epsilon)
b1 = torch.empty(hidden_size,1).normal_(0,epsilon)

w2 = torch.empty(output_size,hidden_size).normal_(0,epsilon)
b2 = torch.empty(output_size,1).normal_(0,epsilon)

In [5]:
dl_dw1 = torch.zeros(hidden_size, feature_size)
dl_db1 = torch.zeros(hidden_size,1)

dl_dw2 = torch.zeros(output_size,hidden_size)
dl_db2 = torch.zeros(output_size,1)

In [6]:
def sigma(x):
    return torch.tanh(x)

def dsigma(x):
    return 1 - torch.pow(sigma(x),2)

In [7]:
def loss(v,t):
    return torch.sum(torch.pow((v-t),2))

def dloss(v,t):
    return 2 * (v-t)

In [8]:
def forward_pass(w1,b1,w2,b2,x):

    x0 = x
    s1 = torch.mv(w1,x0) + b1.flatten()
    #print(w1)
    x1 = sigma(s1)
    s2 = torch.mv(w2, x1) + b2.flatten()
    x2 = sigma(s2)
    return x0,s1,x1,s2,x2

def backward_pass(w1,b1,w2,b2,t,x,s1,x1,s2,x2,dl_dw1,dl_db1,dl_dw2,dl_db2):
    x0 = x
    dl_dx2 = dloss(x2, t)
    dl_ds2 = dsigma(s2) * dl_dx2
    dl_dx1 = w2.t().mv(dl_ds2)
    dl_ds1 = dsigma(s1) * dl_dx1
    
    dl_dw2.add_(dl_ds2.view(-1, 1).mm(x1.view(1, -1)))
    dl_db2.flatten().add_(dl_ds2)
    dl_dw1.add_(dl_ds1.view(-1, 1).mm(x0.view(1, -1)))
    dl_db1.flatten().add_(dl_ds1)
    
    return dl_dw1, dl_db1, dl_dw2, dl_db2  
    

In [10]:
for k in range(100):
    
   #train
    acc_loss = 0
    nb_train_errors = 0
    learning_rate = 0.1/X_train.size(0)

    dl_dw1.zero_()
    dl_db1.zero_()
    dl_dw2.zero_()
    dl_db2.zero_()

    for n in range(X_train.size(0)):
        
        x0, s1, x1, s2, x2 = forward_pass(w1, b1, w2, b2, X_train[n])
        pred_train = x2.max(0)[1].item()
        if y_train[n, pred_train] < 0.5: nb_train_errors = nb_train_errors + 1
        acc_loss = acc_loss + loss(x2.T, y_train[n])
        

        dl_dw1, dl_db1, dl_dw2, dl_db2 = backward_pass(w1, b1, w2, b2,
                     y_train[n],
                     x0, s1, x1, s2, x2,
                     dl_dw1, dl_db1, dl_dw2, dl_db2)
    
        
    w1 = w1 - learning_rate * dl_dw1
    b1 = b1 - learning_rate * dl_db1
    w2 = w2 - learning_rate * dl_dw2
    b2 = b2 - learning_rate * dl_db2
    

   # Test error
    nb_test_errors = 0

    for n in range(X_test.size(0)):
        
        _, _, _, _, x2 = forward_pass(w1, b1, w2, b2, X_test[n])
        pred = x2.max(0)[1].item()
        if y_test[n, pred] < 0.5: nb_test_errors = nb_test_errors + 1
    
    print('{:d} acc_train_loss {:.02f}'.format(k, acc_loss))
    
#     print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
#          .format(k, acc_loss,
#                  (100 * nb_train_errors) / X_train.size(0),
#                  (100 * nb_test_errors) / X_test.size(0)))

0 acc_train_loss 722.31
1 acc_train_loss 718.15
2 acc_train_loss 711.64
3 acc_train_loss 702.24
4 acc_train_loss 690.00
5 acc_train_loss 675.95
6 acc_train_loss 661.58
7 acc_train_loss 647.80
8 acc_train_loss 634.84
9 acc_train_loss 623.36
10 acc_train_loss 614.29
11 acc_train_loss 607.62
12 acc_train_loss 602.23
13 acc_train_loss 596.84
14 acc_train_loss 590.66
15 acc_train_loss 583.46
16 acc_train_loss 575.50
17 acc_train_loss 567.42
18 acc_train_loss 560.01
19 acc_train_loss 553.79
20 acc_train_loss 548.79
21 acc_train_loss 544.62
22 acc_train_loss 540.80
23 acc_train_loss 536.91
24 acc_train_loss 532.70
25 acc_train_loss 528.03
26 acc_train_loss 522.84
27 acc_train_loss 517.17
28 acc_train_loss 511.12
29 acc_train_loss 504.82
30 acc_train_loss 498.40
31 acc_train_loss 491.95
32 acc_train_loss 485.50
33 acc_train_loss 479.02
34 acc_train_loss 472.42
35 acc_train_loss 465.65
36 acc_train_loss 458.66
37 acc_train_loss 451.50
38 acc_train_loss 444.24
39 acc_train_loss 436.99
40 acc_tra