In [1]:
import torch
import math
from torch.nn import functional as F
from torch import nn

In [2]:
def generate_disc_set(nb):
    # creating the circle in the middle of the points
    axis = torch.FloatTensor(1,2).uniform_(0.5,0.5)
    r = (2/math.pi)**0.5
    
    train_input = torch.FloatTensor(nb,2).uniform_(-1,1)
    train_target = torch.FloatTensor(nb, 2)
    test_input = torch.FloatTensor(nb,2).uniform_(-1,1)
    test_target = torch.FloatTensor(nb, 2)
    
    for i in range(0, len(train_input)):
        a = abs((train_input[i] - axis).pow(2).sum(1).view(-1).pow(0.5))
        b = abs((test_input[i] - axis).pow(2).sum(1).view(-1).pow(0.5))
    
        if a < r:
            train_target[i][0] = 0
            train_target[i][1] = 1
        else:
            train_target[i][0] = 1
            train_target[i][1] = 0
            
        if b < r:
            test_target[i][0] = 0
            test_target[i][1] = 1
        else:
            test_target[i][0] = 1
            test_target[i][1] = 0
        
    return train_input, train_target, test_input, test_target

In [3]:
train_input, train_target, test_input, test_target = generate_disc_set(1000)

In [4]:
def ReLu(x):
    return x.clamp(min=0)

def dReLu(x):
    return (torch.sign(x) + 1)/2

In [5]:
def loss(v, t):
    return (v - t).pow(2).sum()

def dloss(v, t):
    return 2 * (v - t)

In [6]:
def sigma(x):
    return x.tanh()

def dsigma(x):
    return 4 * (x.exp() + x.mul(-1).exp()).pow(-2)

In [8]:
def forward_pass(w1, b1, w2, b2,w3, b3,  x):
    x0 = x
    s1 = w1.mv(x0) + b1
    x1 = ReLu(s1)
    s2 = w2.mv(x1) + b2
    x2 = ReLu(s2)
    s3 = w3.mv(x2) + b3
    x3 = ReLu(s3)

    return x0, s1, x1, s2, x2, s3, x3

def backward_pass(w1, b1, w2, b2, w3, b3,
                  t,
                  x, s1, x1, s2, x2, s3, x3,
                  dl_dw1, dl_db1, dl_dw2, dl_db2,dl_dw3,dl_db3):
    x0 = x
    dl_dx3 = dloss(x3, t)
    dl_ds3 = dReLu(s3) * dl_dx3
    
    dl_dx2 = w3.t().mv(dl_ds3)
    dl_ds2 = dReLu(s2) * dl_dx2
    
    dl_dx1 = w2.t().mv(dl_ds2)
    dl_ds1 = dReLu(s1) * dl_dx1
    
    
    dl_dw3.add_(dl_ds3.view(-1, 1).mm(x2.view(1, -1)))
    dl_db3.add_(dl_ds3) 

    dl_dw2.add_(dl_ds2.view(-1, 1).mm(x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    
    dl_dw1.add_(dl_ds1.view(-1, 1).mm(x0.view(1, -1)))
    dl_db1.add_(dl_ds1)

######################################################################

nb_classes = train_target.size(1)
nb_train_samples = train_input.size(0)

zeta = 1

train_target = train_target * zeta
test_target = test_target * zeta

nb_hidden = 25
eta = 0.0005
epsilon = 1e-1

#SGD momentum
momentum = 0.9


#weights and biases
w1 = torch.empty(nb_hidden, train_input.size(1)).normal_(0, epsilon)
b1 = torch.empty(nb_hidden).normal_(0, epsilon)

w2 = torch.empty(nb_hidden, nb_hidden).normal_(0, epsilon)
b2 = torch.empty(nb_hidden).normal_(0, epsilon)

w3 = torch.empty(nb_classes, nb_hidden).normal_(0, epsilon)
b3 = torch.empty(nb_classes).normal_(0, epsilon)


#Gradient
dl_dw1 = torch.empty(w1.size())
dl_db1 = torch.empty(b1.size())

dl_dw2 = torch.empty(w2.size())
dl_db2 = torch.empty(b2.size())

dl_dw3 = torch.empty(w3.size())
dl_db3 = torch.empty(b3.size())


# SGD velocities
velocity_w1 = torch.empty(w1.size())
velocity_b1 = torch.empty(b1.size())

velocity_w2 = torch.empty(w2.size())
velocity_b2 = torch.empty(b2.size())

velocity_w3 = torch.empty(w3.size())
velocity_b3 = torch.empty(b3.size())







for k in range(1000):

    # Back-prop

    acc_loss = 0
    nb_train_errors = 0

    dl_dw1.zero_()
    dl_db1.zero_()
    
    dl_dw2.zero_()
    dl_db2.zero_()
    
    dl_dw3.zero_()
    dl_db3.zero_()

    for n in range(nb_train_samples):
        x0, s1, x1, s2, x2, s3, x3 = forward_pass(w1, b1, w2, b2, w3, b3, train_input[n])

        
        pred = x3.max(0)[1].item()
        target = 0
        if train_target[n][0] == 0 : target = 1
        if target != pred : nb_train_errors = nb_train_errors + 1
        acc_loss = acc_loss + loss(x3, train_target[n])
        

        backward_pass(w1, b1, w2, b2,w3,b3,
                      train_target[n],
                      x0, s1, x1, s2, x2,s3,x3,
                      dl_dw1, dl_db1, dl_dw2, dl_db2,dl_dw3,dl_db3)

    # SGD
    
    velocity_w1 = momentum * velocity_w1 - eta * dl_dw1
    w1 = w1 + velocity_w1
    
    velocity_b1 = momentum * velocity_b1 - eta * dl_db1
    b1 = b1 + velocity_b1
    
    
    velocity_w2 = momentum * velocity_w2 - eta * dl_dw2
    w2 = w2 + velocity_w2
    
    velocity_b2 = momentum * velocity_b2 - eta * dl_db2
    b2 = b2 + velocity_b2
    
    
    velocity_w3 = momentum * velocity_w3 - eta * dl_dw3
    w3 = w3 + velocity_w3
    
    velocity_b3 = momentum * velocity_b3 - eta * dl_db3
    b3 = b3 + velocity_b3
    

    ### Normal gradient step
    #w1 = w1 - eta * dl_dw1
    #b1 = b1 - eta * dl_db1
    
    #w2 = w2 - eta * dl_dw2
    #b2 = b2 - eta * dl_db2
    
    #w3 = w3 - eta * dl_dw3
    #b3 = b3 - eta * dl_db3
    
    
    
    
    

    # Test error

    nb_test_errors = 0

    for n in range(test_input.size(0)):
        _, _, _, _,_,_, x3 = forward_pass(w1, b1, w2, b2,w3,b3, test_input[n])

        pred = x3.max(0)[1].item()
        target = 0
        if test_target[n][0] == 0 : target = 1
        if target != pred : nb_test_errors = nb_test_errors + 1

    print('{:d} acc_train_loss {:.02f} acc_train_error {:.02f}% test_error {:.02f}%'
          .format(k,
                  acc_loss,
                  (100 * nb_train_errors) / train_input.size(0),
                  (100 * nb_test_errors) / test_input.size(0)))

0 acc_train_loss 1000.00 acc_train_error 35.40% test_error 38.40%
1 acc_train_loss 1000.00 acc_train_error 35.40% test_error 38.40%
2 acc_train_loss 957.87 acc_train_error 35.40% test_error 38.40%
3 acc_train_loss 1125.64 acc_train_error 35.40% test_error 61.60%
4 acc_train_loss 756.31 acc_train_error 64.60% test_error 38.40%
5 acc_train_loss 456.36 acc_train_error 35.40% test_error 38.40%
6 acc_train_loss 522.21 acc_train_error 35.40% test_error 38.40%
7 acc_train_loss 509.26 acc_train_error 35.40% test_error 38.40%
8 acc_train_loss 460.60 acc_train_error 35.40% test_error 38.40%
9 acc_train_loss 483.66 acc_train_error 35.40% test_error 20.70%
10 acc_train_loss 403.18 acc_train_error 17.30% test_error 38.40%
11 acc_train_loss 344.19 acc_train_error 35.40% test_error 30.80%
12 acc_train_loss 386.05 acc_train_error 33.00% test_error 13.60%
13 acc_train_loss 326.50 acc_train_error 13.90% test_error 9.70%
14 acc_train_loss 331.74 acc_train_error 10.00% test_error 12.00%
15 acc_train_loss 

128 acc_train_loss 648.73 acc_train_error 2.10% test_error 2.00%
129 acc_train_loss 648.71 acc_train_error 2.10% test_error 2.00%
130 acc_train_loss 648.69 acc_train_error 2.10% test_error 2.00%
131 acc_train_loss 648.67 acc_train_error 2.10% test_error 2.00%
132 acc_train_loss 648.65 acc_train_error 2.10% test_error 2.00%
133 acc_train_loss 648.63 acc_train_error 2.10% test_error 2.00%
134 acc_train_loss 648.61 acc_train_error 2.10% test_error 2.00%
135 acc_train_loss 648.59 acc_train_error 2.10% test_error 2.00%
136 acc_train_loss 648.58 acc_train_error 2.00% test_error 2.00%
137 acc_train_loss 648.56 acc_train_error 2.00% test_error 2.00%
138 acc_train_loss 648.54 acc_train_error 2.00% test_error 2.00%
139 acc_train_loss 648.52 acc_train_error 2.00% test_error 2.00%
140 acc_train_loss 648.50 acc_train_error 2.00% test_error 2.00%
141 acc_train_loss 648.48 acc_train_error 2.00% test_error 2.00%
142 acc_train_loss 648.46 acc_train_error 2.00% test_error 2.00%
143 acc_train_loss 648.44

255 acc_train_loss 647.09 acc_train_error 1.50% test_error 1.40%
256 acc_train_loss 647.09 acc_train_error 1.50% test_error 1.40%
257 acc_train_loss 647.08 acc_train_error 1.50% test_error 1.40%
258 acc_train_loss 647.07 acc_train_error 1.50% test_error 1.40%
259 acc_train_loss 647.06 acc_train_error 1.50% test_error 1.40%
260 acc_train_loss 647.06 acc_train_error 1.50% test_error 1.40%
261 acc_train_loss 647.05 acc_train_error 1.50% test_error 1.30%
262 acc_train_loss 647.04 acc_train_error 1.50% test_error 1.40%
263 acc_train_loss 647.04 acc_train_error 1.50% test_error 1.30%
264 acc_train_loss 647.03 acc_train_error 1.50% test_error 1.40%
265 acc_train_loss 647.02 acc_train_error 1.50% test_error 1.40%
266 acc_train_loss 647.02 acc_train_error 1.50% test_error 1.40%
267 acc_train_loss 647.01 acc_train_error 1.50% test_error 1.40%
268 acc_train_loss 647.01 acc_train_error 1.50% test_error 1.30%
269 acc_train_loss 647.00 acc_train_error 1.50% test_error 1.30%
270 acc_train_loss 646.99

KeyboardInterrupt: 