In [1]:
import torch
import math
from torch.nn import functional as F
from torch import nn

In [66]:
def generate_disc_set(nb):
    # creating the circle in the middle of the points
    axis = torch.FloatTensor(1,2).uniform_(0.5,0.5)
    r = (2/math.pi)**0.5
    
    train_input = torch.FloatTensor(nb,2).uniform_(-1,1)
    train_target = torch.FloatTensor(nb, 2)
    test_input = torch.FloatTensor(nb,2).uniform_(-1,1)
    test_target = torch.FloatTensor(nb, 2)
    
    for i in range(0, len(train_input)):
        a = abs((train_input[i] - axis).pow(2).sum(1).view(-1).pow(0.5))
        b = abs((test_input[i] - axis).pow(2).sum(1).view(-1).pow(0.5))
    
        if a < r:
            train_target[i][0] = 0
            train_target[i][1] = 1
        else:
            train_target[i][0] = 1
            train_target[i][1] = 0
            
        if b < r:
            test_target[i][0] = 0
            test_target[i][1] = 1
        else:
            test_target[i][0] = 1
            test_target[i][1] = 0
        
    return train_input, train_target, test_input, test_target

In [80]:
train_input, train_target, test_input, test_target = generate_disc_set(1000)

In [91]:
def ReLu(x):
    return x.clamp(min=0)

def dReLu(x):
    return (torch.sign(x) + 1)/2

In [92]:
def loss(v, t):
    return (v - t).pow(2).sum()

def dloss(v, t):
    return 2 * (v - t)

In [93]:
def sigma(x):
    return x.tanh()

def dsigma(x):
    return 4 * (x.exp() + x.mul(-1).exp()).pow(-2)

In [95]:
def forward_pass(w1, b1, w2, b2,w3, b3,  x):
    x0 = x
    s1 = w1.mv(x0) + b1
    x1 = ReLu(s1)
    s2 = w2.mv(x1) + b2
    x2 = ReLu(s2)
    s3 = w3.mv(x2) + b3
    x3 = ReLu(s3)

    return x0, s1, x1, s2, x2, s3, x3

def backward_pass(w1, b1, w2, b2, w3, b3,
                  t,
                  x, s1, x1, s2, x2, s3, x3,
                  dl_dw1, dl_db1, dl_dw2, dl_db2,dl_dw3,dl_db3):
    x0 = x
    dl_dx3 = dloss(x3, t)
    dl_ds3 = dReLu(s3) * dl_dx3
    
    dl_dx2 = w3.t().mv(dl_ds3)
    dl_ds2 = dReLu(s2) * dl_dx2
    
    dl_dx1 = w2.t().mv(dl_ds2)
    dl_ds1 = dReLu(s1) * dl_dx1
    
    
    dl_dw3.add_(dl_ds3.view(-1, 1).mm(x2.view(1, -1)))
    dl_db3.add_(dl_ds3) 

    dl_dw2.add_(dl_ds2.view(-1, 1).mm(x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    
    dl_dw1.add_(dl_ds1.view(-1, 1).mm(x0.view(1, -1)))
    dl_db1.add_(dl_ds1)

######################################################################

nb_classes = train_target.size(1)
nb_train_samples = train_input.size(0)

zeta = 1

train_target = train_target * zeta
test_target = test_target * zeta

nb_hidden = 25
eta = 0.0005
epsilon = 1e-1

w1 = torch.empty(nb_hidden, train_input.size(1)).normal_(0, epsilon)
b1 = torch.empty(nb_hidden).normal_(0, epsilon)

w2 = torch.empty(nb_hidden, nb_hidden).normal_(0, epsilon)
b2 = torch.empty(nb_hidden).normal_(0, epsilon)

w3 = torch.empty(nb_classes, nb_hidden).normal_(0, epsilon)
b3 = torch.empty(nb_classes).normal_(0, epsilon)

dl_dw1 = torch.empty(w1.size())
dl_db1 = torch.empty(b1.size())

dl_dw2 = torch.empty(w2.size())
dl_db2 = torch.empty(b2.size())

dl_dw3 = torch.empty(w3.size())
dl_db3 = torch.empty(b3.size())



for k in range(1000):

    # Back-prop

    acc_loss = 0
    nb_train_errors = 0

    dl_dw1.zero_()
    dl_db1.zero_()
    
    dl_dw2.zero_()
    dl_db2.zero_()
    
    dl_dw3.zero_()
    dl_db3.zero_()

    for n in range(nb_train_samples):
        x0, s1, x1, s2, x2, s3, x3 = forward_pass(w1, b1, w2, b2, w3, b3, train_input[n])

        pred = x3.max(0)[1].item()
        if train_target[n, pred] < 0.5: nb_train_errors = nb_train_errors + 1
        acc_loss = acc_loss + loss(x3, train_target[n])

        backward_pass(w1, b1, w2, b2,w3,b3,
                      train_target[n],
                      x0, s1, x1, s2, x2,s3,x3,
                      dl_dw1, dl_db1, dl_dw2, dl_db2,dl_dw3,dl_db3)

    # Gradient step

    w1 = w1 - eta * dl_dw1
    b1 = b1 - eta * dl_db1
    
    w2 = w2 - eta * dl_dw2
    b2 = b2 - eta * dl_db2
    
    w3 = w3 - eta * dl_dw3
    b3 = b3 - eta * dl_db3

    # Test error

    nb_test_errors = 0

    for n in range(test_input.size(0)):
        _, _, _, _,_,_, x3 = forward_pass(w1, b1, w2, b2,w3,b3, test_input[n])

        pred = x3.max(0)[1].item()
        if test_target[n, pred] < 0.5: nb_test_errors = nb_test_errors + 1

    print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
          .format(k,
                  acc_loss,
                  (100 * nb_train_errors) / train_input.size(0),
                  (100 * nb_test_errors) / test_input.size(0)))

0 acc_train_loss 510.62 acc_train_error 36.30% test_error 36.90%
1 acc_train_loss 397.53 acc_train_error 36.30% test_error 36.90%
2 acc_train_loss 303.23 acc_train_error 36.30% test_error 36.90%
3 acc_train_loss 296.94 acc_train_error 36.30% test_error 36.90%
4 acc_train_loss 294.20 acc_train_error 36.30% test_error 36.90%
5 acc_train_loss 291.32 acc_train_error 36.30% test_error 36.90%
6 acc_train_loss 287.79 acc_train_error 36.30% test_error 36.90%
7 acc_train_loss 283.82 acc_train_error 36.30% test_error 36.90%
8 acc_train_loss 279.06 acc_train_error 36.30% test_error 36.90%
9 acc_train_loss 273.24 acc_train_error 36.30% test_error 36.90%
10 acc_train_loss 266.05 acc_train_error 36.30% test_error 36.90%
11 acc_train_loss 257.12 acc_train_error 36.30% test_error 36.90%
12 acc_train_loss 246.09 acc_train_error 36.30% test_error 31.10%
13 acc_train_loss 232.58 acc_train_error 29.60% test_error 23.80%
14 acc_train_loss 216.45 acc_train_error 22.90% test_error 17.50%
15 acc_train_loss 19

128 acc_train_loss 38.75 acc_train_error 3.30% test_error 2.80%
129 acc_train_loss 38.41 acc_train_error 2.80% test_error 3.80%
130 acc_train_loss 38.06 acc_train_error 3.30% test_error 2.90%
131 acc_train_loss 37.76 acc_train_error 2.60% test_error 3.60%
132 acc_train_loss 37.41 acc_train_error 3.30% test_error 2.60%
133 acc_train_loss 37.11 acc_train_error 2.60% test_error 3.60%
134 acc_train_loss 36.75 acc_train_error 3.10% test_error 2.60%
135 acc_train_loss 36.49 acc_train_error 2.40% test_error 3.60%
136 acc_train_loss 36.17 acc_train_error 3.10% test_error 2.50%
137 acc_train_loss 35.94 acc_train_error 2.50% test_error 3.60%
138 acc_train_loss 35.63 acc_train_error 3.00% test_error 2.40%
139 acc_train_loss 35.42 acc_train_error 2.50% test_error 3.60%
140 acc_train_loss 35.12 acc_train_error 2.80% test_error 2.40%
141 acc_train_loss 34.90 acc_train_error 2.40% test_error 3.50%
142 acc_train_loss 34.66 acc_train_error 2.70% test_error 2.30%
143 acc_train_loss 34.45 acc_train_error

257 acc_train_loss 21.82 acc_train_error 1.60% test_error 2.40%
258 acc_train_loss 21.77 acc_train_error 1.70% test_error 2.20%
259 acc_train_loss 21.71 acc_train_error 1.60% test_error 2.30%
260 acc_train_loss 21.66 acc_train_error 1.70% test_error 2.20%
261 acc_train_loss 21.60 acc_train_error 1.60% test_error 2.30%
262 acc_train_loss 21.55 acc_train_error 1.70% test_error 2.20%
263 acc_train_loss 21.50 acc_train_error 1.60% test_error 2.30%
264 acc_train_loss 21.45 acc_train_error 1.80% test_error 2.20%
265 acc_train_loss 21.40 acc_train_error 1.60% test_error 2.30%
266 acc_train_loss 21.35 acc_train_error 1.80% test_error 2.20%
267 acc_train_loss 21.30 acc_train_error 1.60% test_error 2.30%
268 acc_train_loss 21.25 acc_train_error 1.70% test_error 2.20%
269 acc_train_loss 21.20 acc_train_error 1.60% test_error 2.30%
270 acc_train_loss 21.15 acc_train_error 1.70% test_error 2.20%
271 acc_train_loss 21.10 acc_train_error 1.50% test_error 2.30%
272 acc_train_loss 21.05 acc_train_error

386 acc_train_loss 17.14 acc_train_error 1.60% test_error 1.90%
387 acc_train_loss 17.12 acc_train_error 1.30% test_error 2.00%
388 acc_train_loss 17.10 acc_train_error 1.60% test_error 1.90%
389 acc_train_loss 17.07 acc_train_error 1.30% test_error 2.00%
390 acc_train_loss 17.04 acc_train_error 1.60% test_error 1.90%
391 acc_train_loss 17.02 acc_train_error 1.30% test_error 2.10%
392 acc_train_loss 16.99 acc_train_error 1.60% test_error 1.90%
393 acc_train_loss 16.96 acc_train_error 1.30% test_error 2.10%
394 acc_train_loss 16.93 acc_train_error 1.60% test_error 1.90%
395 acc_train_loss 16.90 acc_train_error 1.30% test_error 2.10%
396 acc_train_loss 16.87 acc_train_error 1.60% test_error 1.90%
397 acc_train_loss 16.84 acc_train_error 1.30% test_error 2.00%
398 acc_train_loss 16.81 acc_train_error 1.60% test_error 1.90%
399 acc_train_loss 16.78 acc_train_error 1.30% test_error 2.00%
400 acc_train_loss 16.76 acc_train_error 1.60% test_error 1.90%
401 acc_train_loss 16.73 acc_train_error

515 acc_train_loss 14.39 acc_train_error 0.90% test_error 1.60%
516 acc_train_loss 14.37 acc_train_error 1.30% test_error 1.40%
517 acc_train_loss 14.36 acc_train_error 0.90% test_error 1.60%
518 acc_train_loss 14.34 acc_train_error 1.30% test_error 1.40%
519 acc_train_loss 14.33 acc_train_error 0.90% test_error 1.60%
520 acc_train_loss 14.32 acc_train_error 1.30% test_error 1.40%
521 acc_train_loss 14.30 acc_train_error 0.90% test_error 1.60%
522 acc_train_loss 14.29 acc_train_error 1.30% test_error 1.40%
523 acc_train_loss 14.28 acc_train_error 0.90% test_error 1.60%
524 acc_train_loss 14.27 acc_train_error 1.20% test_error 1.40%
525 acc_train_loss 14.26 acc_train_error 0.90% test_error 1.60%
526 acc_train_loss 14.24 acc_train_error 1.20% test_error 1.40%
527 acc_train_loss 14.24 acc_train_error 0.90% test_error 1.60%
528 acc_train_loss 14.22 acc_train_error 1.20% test_error 1.40%
529 acc_train_loss 14.22 acc_train_error 0.90% test_error 1.60%
530 acc_train_loss 14.20 acc_train_error

644 acc_train_loss 14.08 acc_train_error 1.20% test_error 1.60%
645 acc_train_loss 14.04 acc_train_error 1.00% test_error 1.80%
646 acc_train_loss 13.99 acc_train_error 1.20% test_error 1.60%
647 acc_train_loss 13.95 acc_train_error 1.00% test_error 1.60%
648 acc_train_loss 13.90 acc_train_error 1.20% test_error 1.60%
649 acc_train_loss 13.87 acc_train_error 1.00% test_error 1.60%
650 acc_train_loss 13.84 acc_train_error 1.20% test_error 1.60%
651 acc_train_loss 13.81 acc_train_error 1.00% test_error 1.60%
652 acc_train_loss 13.78 acc_train_error 1.20% test_error 1.50%
653 acc_train_loss 13.75 acc_train_error 1.00% test_error 1.60%
654 acc_train_loss 13.73 acc_train_error 1.10% test_error 1.50%
655 acc_train_loss 13.69 acc_train_error 1.00% test_error 1.60%
656 acc_train_loss 13.67 acc_train_error 1.10% test_error 1.40%
657 acc_train_loss 13.65 acc_train_error 1.00% test_error 1.60%
658 acc_train_loss 13.63 acc_train_error 1.10% test_error 1.40%
659 acc_train_loss 13.61 acc_train_error

773 acc_train_loss 12.64 acc_train_error 1.00% test_error 1.30%
774 acc_train_loss 12.60 acc_train_error 0.90% test_error 1.10%
775 acc_train_loss 12.61 acc_train_error 1.00% test_error 1.30%
776 acc_train_loss 12.57 acc_train_error 0.90% test_error 1.10%
777 acc_train_loss 12.59 acc_train_error 1.00% test_error 1.30%
778 acc_train_loss 12.54 acc_train_error 0.90% test_error 1.10%
779 acc_train_loss 12.56 acc_train_error 0.90% test_error 1.30%
780 acc_train_loss 12.51 acc_train_error 0.90% test_error 1.10%
781 acc_train_loss 12.55 acc_train_error 1.00% test_error 1.30%
782 acc_train_loss 12.50 acc_train_error 0.90% test_error 1.10%
783 acc_train_loss 12.51 acc_train_error 0.90% test_error 1.30%
784 acc_train_loss 12.46 acc_train_error 0.90% test_error 1.10%
785 acc_train_loss 12.49 acc_train_error 0.90% test_error 1.30%
786 acc_train_loss 12.44 acc_train_error 0.80% test_error 1.10%
787 acc_train_loss 12.48 acc_train_error 0.90% test_error 1.30%
788 acc_train_loss 12.42 acc_train_error

902 acc_train_loss 11.83 acc_train_error 0.80% test_error 1.00%
903 acc_train_loss 11.88 acc_train_error 0.80% test_error 1.20%
904 acc_train_loss 11.81 acc_train_error 0.80% test_error 1.00%
905 acc_train_loss 11.84 acc_train_error 0.80% test_error 1.20%
906 acc_train_loss 11.77 acc_train_error 0.80% test_error 1.00%
907 acc_train_loss 11.80 acc_train_error 0.80% test_error 1.20%
908 acc_train_loss 11.72 acc_train_error 0.90% test_error 1.00%
909 acc_train_loss 11.77 acc_train_error 0.80% test_error 1.20%
910 acc_train_loss 11.70 acc_train_error 0.80% test_error 1.00%
911 acc_train_loss 11.76 acc_train_error 0.80% test_error 1.20%
912 acc_train_loss 11.70 acc_train_error 0.80% test_error 1.00%
913 acc_train_loss 11.73 acc_train_error 0.80% test_error 1.20%
914 acc_train_loss 11.67 acc_train_error 0.80% test_error 1.00%
915 acc_train_loss 11.75 acc_train_error 0.80% test_error 1.20%
916 acc_train_loss 11.70 acc_train_error 0.80% test_error 1.00%
917 acc_train_loss 11.77 acc_train_error

1.0001
