Building a multi-layer perceptron with one hidden layer from scratch and test it on MNIST data

In [1]:
import torch

# Activation function

In [2]:
def sigma(x):
    return torch.tanh(x)

In [3]:
def dsigma(x):
    return 1 - torch.pow(sigma(x), 2)

# Loss

In [4]:
def loss(v, t):
    return torch.pow(torch.norm(v-t), 2)

In [5]:
def dloss(v, t):
    return 2*(v-t)

# Importing the data

In [6]:
import dlc_practical_prologue as prologue

train_input, train_target, test_input,test_target = prologue.load_data(one_hot_labels = True,
                                                                      normalize = True)

train_target = train_target * 0.9
test_target  = test_target * 0.9
print(train_input.shape)

nb_hidden = 50
epsilon = 1e-6
nb_classes = train_target.size(1)
n_train_samples = train_input.size(0)
#n_test_samples = test_input.size(0)
learning_rate = 0.01
step = learning_rate/n_train_samples

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples
torch.Size([1000, 784])




In [7]:
train_input.shape

torch.Size([1000, 784])

In [8]:
train_target.shape

torch.Size([1000, 10])

In [9]:
test_input.shape

torch.Size([1000, 784])

In [10]:
test_target.shape 

torch.Size([1000, 10])

In [11]:
w1 = torch.empty(nb_hidden, train_input.size(1)).normal_(0, epsilon)
b1 = torch.empty(nb_hidden).normal_(0, epsilon)
w2 = torch.empty(nb_classes, nb_hidden).normal_(0, epsilon)
b2 = torch.empty(nb_classes).normal_(0, epsilon)


dl_dw1 = torch.empty(w1.size())
dl_db1 = torch.empty(b1.size())
dl_dw2 = torch.empty(w2.size())
dl_db2 = torch.empty(b2.size())

In [12]:
b2.shape

torch.Size([10])

In [13]:
def forward_pass(w1, b1, w2, b2, x):
    x0 = x
    s1 = torch.mv(w1, x0) + b1
    x1 = sigma(s1)
    s2 = torch.mv(w2,x1) + b2
    x2 = sigma(s2)
    
    return x0, s1, x1, s2, x2

In [14]:
def backward_pass(w1, b1, w2, b2,
                  t,
                  x, s1, x1, s2, x2,
                 dl_dw1, dl_db1, dl_dw2, dl_db2):
    
    x0 = x
    dl_dx2 = loss(x2, t)
    dl_ds2 = dsigma(s2) * dl_dx2
    dl_dx1 = torch.mv(w2.t(), dl_ds2)
    dl_ds1 = dsigma(s1) * dl_dx1
    
    dl_dw2.add_(torch.mm(dl_ds2.view(-1,1), x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    dl_dw1.add_(torch.mm(dl_ds1.view(-1, 1), x0.view(1,-1)))
    dl_db1.add_(dl_ds1)
    

In [16]:
for k in range(1000):
    
    accumulated_loss = 0
    n_train_errors = 0
    
    """
    Performing 1,000 gradient steps with a step size equal to 0.1 divided
    by the number of training samples(variable--steps)
    First reset the tensors to zero for summing up the gradients and doing
    a forward and backward pass for each training example
    """
    dl_dw1.zero_()
    dl_db1.zero_()
    dl_dw2.zero_()
    dl_db2.zero_()
    
    for n in range(n_train_samples):
        #forward prop
        x0, s1, x1, s2, x2 = forward_pass(w1, b1, w2, b2, train_input[n])
        
        pred = x2.max(0)[1].item() #takes class with maximum prob(x2, returns prob)
        if train_target[n, pred] < 0.5:
            n_train_errors = n_train_errors + 1
        accumulated_loss = accumulated_loss + loss(x2, train_target[n])
            
    
        #backward prop
        backward_pass(w1, b1, w2, b2,
                      train_target[n],
                      x0, s1, x1, s2, x2,
                      dl_dw1, dl_db1, dl_dw2, dl_db2)
    
        #update rule
    w1 = w1 - step * dl_dw1
    b1 = b1 - step * dl_db1
    w2 = w2 - step * dl_dw2
    b2 = b2 - step * dl_db2
    
     # Test error

    n_test_errors = 0

    for n in range(test_input.size(0)):
        _, _, _, _, x2 = forward_pass(w1, b1, w2, b2, test_input[n])

        pred = x2.max(0)[1].item()
        if test_target[n, pred] < 0.5: n_test_errors = n_test_errors + 1

    print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
          .format(k,
                  accumulated_loss,
                  (100 * n_train_errors) / train_input.size(0),
                  (100 * n_test_errors) / test_input.size(0)))

0 acc_train_loss 1475.75 acc_train_error 90.30% test_error 91.30%
1 acc_train_loss 1554.80 acc_train_error 90.70% test_error 88.80%
2 acc_train_loss 1645.54 acc_train_error 90.70% test_error 89.00%
3 acc_train_loss 1759.16 acc_train_error 90.50% test_error 89.20%
4 acc_train_loss 1933.35 acc_train_error 90.70% test_error 89.30%
5 acc_train_loss 2304.88 acc_train_error 90.70% test_error 89.30%
6 acc_train_loss 3408.90 acc_train_error 90.70% test_error 89.30%
7 acc_train_loss 7022.70 acc_train_error 90.70% test_error 90.40%
8 acc_train_loss 11533.95 acc_train_error 89.80% test_error 90.50%
9 acc_train_loss 12116.67 acc_train_error 89.90% test_error 90.20%
10 acc_train_loss 12281.89 acc_train_error 89.40% test_error 90.50%
11 acc_train_loss 12363.32 acc_train_error 89.90% test_error 90.10%
12 acc_train_loss 12412.28 acc_train_error 90.00% test_error 90.60%
13 acc_train_loss 12445.05 acc_train_error 89.50% test_error 90.20%
14 acc_train_loss 12468.58 acc_train_error 89.60% test_error 90.40

121 acc_train_loss 12601.78 acc_train_error 89.90% test_error 90.60%
122 acc_train_loss 12601.85 acc_train_error 90.00% test_error 90.80%
123 acc_train_loss 12601.92 acc_train_error 90.00% test_error 90.50%
124 acc_train_loss 12602.00 acc_train_error 90.00% test_error 90.60%
125 acc_train_loss 12602.08 acc_train_error 90.00% test_error 90.80%
126 acc_train_loss 12602.15 acc_train_error 89.90% test_error 90.60%
127 acc_train_loss 12602.22 acc_train_error 89.80% test_error 90.60%
128 acc_train_loss 12602.29 acc_train_error 90.10% test_error 90.60%
129 acc_train_loss 12602.35 acc_train_error 90.00% test_error 90.60%
130 acc_train_loss 12602.42 acc_train_error 90.00% test_error 90.70%
131 acc_train_loss 12602.47 acc_train_error 90.00% test_error 90.70%
132 acc_train_loss 12602.53 acc_train_error 90.00% test_error 90.60%
133 acc_train_loss 12602.58 acc_train_error 90.00% test_error 90.60%
134 acc_train_loss 12602.64 acc_train_error 90.10% test_error 90.60%
135 acc_train_loss 12602.69 acc_tr

240 acc_train_loss 12606.06 acc_train_error 90.00% test_error 90.60%
241 acc_train_loss 12606.09 acc_train_error 89.90% test_error 90.60%
242 acc_train_loss 12606.10 acc_train_error 90.00% test_error 90.60%
243 acc_train_loss 12606.12 acc_train_error 90.00% test_error 90.60%
244 acc_train_loss 12606.14 acc_train_error 90.00% test_error 90.70%
245 acc_train_loss 12606.17 acc_train_error 90.20% test_error 90.60%
246 acc_train_loss 12606.18 acc_train_error 90.10% test_error 90.60%
247 acc_train_loss 12606.20 acc_train_error 90.00% test_error 90.60%
248 acc_train_loss 12606.22 acc_train_error 90.00% test_error 90.60%
249 acc_train_loss 12606.23 acc_train_error 90.10% test_error 90.60%
250 acc_train_loss 12606.24 acc_train_error 90.00% test_error 90.70%
251 acc_train_loss 12606.26 acc_train_error 90.00% test_error 90.60%
252 acc_train_loss 12606.28 acc_train_error 90.00% test_error 90.60%
253 acc_train_loss 12606.29 acc_train_error 90.10% test_error 90.80%
254 acc_train_loss 12606.31 acc_tr

359 acc_train_loss 12607.44 acc_train_error 90.00% test_error 90.60%
360 acc_train_loss 12607.44 acc_train_error 90.00% test_error 90.60%
361 acc_train_loss 12607.45 acc_train_error 90.00% test_error 90.60%
362 acc_train_loss 12607.46 acc_train_error 89.80% test_error 90.60%
363 acc_train_loss 12607.46 acc_train_error 90.00% test_error 90.60%
364 acc_train_loss 12607.47 acc_train_error 89.90% test_error 90.50%
365 acc_train_loss 12607.48 acc_train_error 90.00% test_error 90.50%
366 acc_train_loss 12607.49 acc_train_error 90.00% test_error 90.60%
367 acc_train_loss 12607.50 acc_train_error 90.10% test_error 90.50%
368 acc_train_loss 12607.51 acc_train_error 89.90% test_error 90.60%
369 acc_train_loss 12607.52 acc_train_error 90.00% test_error 90.60%
370 acc_train_loss 12607.53 acc_train_error 89.90% test_error 90.60%
371 acc_train_loss 12607.53 acc_train_error 90.00% test_error 90.50%
372 acc_train_loss 12607.53 acc_train_error 89.90% test_error 90.70%
373 acc_train_loss 12607.54 acc_tr

478 acc_train_loss 12608.05 acc_train_error 90.00% test_error 90.60%
479 acc_train_loss 12608.05 acc_train_error 90.00% test_error 90.60%
480 acc_train_loss 12608.06 acc_train_error 90.00% test_error 90.60%
481 acc_train_loss 12608.06 acc_train_error 90.00% test_error 90.60%
482 acc_train_loss 12608.06 acc_train_error 90.00% test_error 90.60%
483 acc_train_loss 12608.07 acc_train_error 89.90% test_error 90.60%
484 acc_train_loss 12608.07 acc_train_error 89.90% test_error 90.50%
485 acc_train_loss 12608.08 acc_train_error 90.00% test_error 90.60%
486 acc_train_loss 12608.08 acc_train_error 90.00% test_error 90.60%
487 acc_train_loss 12608.08 acc_train_error 90.00% test_error 90.60%
488 acc_train_loss 12608.09 acc_train_error 90.00% test_error 90.60%
489 acc_train_loss 12608.09 acc_train_error 90.00% test_error 90.60%
490 acc_train_loss 12608.10 acc_train_error 90.00% test_error 90.60%
491 acc_train_loss 12608.10 acc_train_error 90.00% test_error 90.60%
492 acc_train_loss 12608.10 acc_tr

597 acc_train_loss 12608.41 acc_train_error 90.00% test_error 90.60%
598 acc_train_loss 12608.41 acc_train_error 90.00% test_error 90.60%
599 acc_train_loss 12608.42 acc_train_error 89.90% test_error 90.60%
600 acc_train_loss 12608.42 acc_train_error 90.00% test_error 90.50%
601 acc_train_loss 12608.43 acc_train_error 90.00% test_error 90.60%
602 acc_train_loss 12608.43 acc_train_error 89.90% test_error 90.60%
603 acc_train_loss 12608.43 acc_train_error 90.00% test_error 90.60%
604 acc_train_loss 12608.43 acc_train_error 90.00% test_error 90.60%


KeyboardInterrupt: 