In [18]:
import torch
from torch import tensor, nn
from loading_datas import  generate_pair_sets
import torch.nn.functional as F



In [19]:
train_pairs, train_target, train_classes, test_pairs, test_target, test_classes = generate_pair_sets(1000)

In [20]:
class Net1(nn.Module):
  def __init__(self):
    super(Net1,self).__init__()

    self.conv11 = nn.Conv2d(1,16,3)
    self.conv12 = nn.Conv2d(16,32,3)

    self.conv21 = nn.Conv2d(1,16,3)
    self.conv22 = nn.Conv2d(16,32,3)

    self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=2)

    self.fc1 = nn.Linear(64*4*4,64)
    self.fc2 = nn.Linear(64,32)
    self.fc3 = nn.Linear(32,2)

  def forward(self,x): 
    # spliting the channels
    c1 = torch.narrow(x,1,0,1)
    c2 = torch.narrow(x,1,1,1)

    # Channel 1
    c1 = F.relu(self.conv11(c1))
    c1 = self.pool(c1)
    c1 = F.relu(self.conv12(c1))
    

    # Channel 2
    c2 = F.relu(self.conv21(c2))
    c2 = self.pool(c2)
    c2 = F.relu(self.conv22(c2))
    


    output = torch.cat((c1,c2),1)
    output = output.view(-1,64*4*4)
   
    output = F.relu(self.fc1(output))
    output = F.relu(self.fc2(output))
    output = self.fc3(output)

    return output

class Net2(nn.Module):
  def __init__(self):
    super(Net2,self).__init__()
    
    self.conv1 = nn.Conv2d(1,16,3)
    self.conv2 = nn.Conv2d(16,32,3)
    self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=2)

    self.fc1 = nn.Linear(64*4*4,64)
    self.fc2 = nn.Linear(64,32)
    self.fc3 = nn.Linear(32,2)

  def forward(self,x): 
    # spliting the channels
    c1 = torch.narrow(x,1,0,1)
    c2 = torch.narrow(x,1,1,1)

    # Channel 1
    c1 = F.relu(self.conv1(c1))
    c1 = self.pool(c1)
    c1 = F.relu(self.conv2(c1))

    # Channel 2
    c2 = F.relu(self.conv1(c2))
    c2 = self.pool(c2)
    c2 = F.relu(self.conv2(c2))


    output = torch.cat((c1,c2),1)
    output = output.view(-1,64*4*4)
   
    output = F.relu(self.fc1(output))
    output = F.relu(self.fc2(output))
    output = self.fc3(output)

    return output


In [21]:
def compute_nb_errors(model, data_input, data_target,batch_size):

    nb_data_errors = 0

    for inputs, targets in zip(data_input.split(batch_size), data_target.split(batch_size)):
        output = model(inputs)
        for k in range(len(targets)):
            
            if torch.argmax(output[k]) != torch.argmax(targets[k]):
                nb_data_errors = nb_data_errors + 1
                

    return nb_data_errors

In [22]:
lr, nb_epochs, batch_size = 1e-2, 100, 1000
model = Net2()
optimizer = torch.optim.SGD(model.parameters(), lr = lr)
criterion = nn.CrossEntropyLoss()
for e in range(nb_epochs):
    for input, targets in zip(train_pairs.split(batch_size), train_target.split(batch_size)):
        output = model(input)
        loss = criterion(output, targets)
        # print(loss)
        # print(loss)
        optimizer.zero_grad()
        loss.backward()

    nb_error = compute_nb_errors(model, train_pairs, train_target, batch_size)
    # print(nb_error)

    train_errors = 100 * (1 - compute_nb_errors(model, train_pairs, train_target,batch_size)/train_pairs.size(0))
    test_errors = 100 * (1 - compute_nb_errors(model, test_pairs, test_target,batch_size)/test_pairs.size(0))
    print(f"Epoch # {e+1} / Train accuracy (%): {train_errors:.2f} / Test accuracy (%): {test_errors:.2f}")

        
    optimizer.step()

525
Epoch # 1 / Train accuracy (%): 47.50 / Test accuracy (%): 47.90
502
Epoch # 2 / Train accuracy (%): 49.80 / Test accuracy (%): 49.80
498
Epoch # 3 / Train accuracy (%): 50.20 / Test accuracy (%): 50.40
500
Epoch # 4 / Train accuracy (%): 50.00 / Test accuracy (%): 50.80
490
Epoch # 5 / Train accuracy (%): 51.00 / Test accuracy (%): 53.00
465
Epoch # 6 / Train accuracy (%): 53.50 / Test accuracy (%): 55.40
443
Epoch # 7 / Train accuracy (%): 55.70 / Test accuracy (%): 57.30
426
Epoch # 8 / Train accuracy (%): 57.40 / Test accuracy (%): 60.60
410
Epoch # 9 / Train accuracy (%): 59.00 / Test accuracy (%): 61.50
392
Epoch # 10 / Train accuracy (%): 60.80 / Test accuracy (%): 63.50
365
Epoch # 11 / Train accuracy (%): 63.50 / Test accuracy (%): 65.20
344
Epoch # 12 / Train accuracy (%): 65.60 / Test accuracy (%): 66.10
330
Epoch # 13 / Train accuracy (%): 67.00 / Test accuracy (%): 67.50
317
Epoch # 14 / Train accuracy (%): 68.30 / Test accuracy (%): 69.10
295
Epoch # 15 / Train accura

In [23]:
output = model(test_pairs)


In [24]:
error = 0

for i in range(1000):
    if torch.argmax(output[i]) != torch.argmax(test_target[i]):
        error += 1

print(error)

238
