In [2]:
from helper_functions import *
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

In [3]:
N = 1000
train_input, train_target, train_classes, test_input, test_target, test_classes = generate_pair_sets(N)

#### Data preprocessing

In [4]:
train_input, test_input, train_classes, test_classes = split_img_data(train_input, test_input, train_classes, test_classes)

train_input1 = train_input[0]
train_input2 = train_input[1]

test_input1 = test_input[0]
test_input2 = test_input[1]

train_classes1 = train_classes[0]
train_classes2 = train_classes[1]

test_classes1 = test_classes[0]
test_classes2 = test_classes[1]

train_input1 = 0.9*train_input1
train_input2 = 0.9*train_input2

test_input1 = 0.9*test_input1
test_input2 = 0.9*test_input2

train_input1, test_input1 = normalize(train_input1, test_input1)
train_input2, test_input2 = normalize(train_input2, test_input2)

train_input1_reshape = torch.unsqueeze(train_input1, 1)
test_input1_reshape = torch.unsqueeze(test_input1, 1)

train_input2_reshape = torch.unsqueeze(train_input2, 1)
test_input2_reshape = torch.unsqueeze(test_input2, 1)

### Convolutional Neuronal Network

#### Predict digit classes ∈ {0, . . . , 9} and compare¶

#### Define model

In [5]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        nb_hidden = 200
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=2)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=2)
        self.conv3_bn = nn.BatchNorm2d(64)
        self.drop1 = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(64, nb_hidden)
        self.drop2 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(nb_hidden, 10)
        

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)),kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)),kernel_size=2))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = self.drop1(x)
        x = F.relu(self.fc1(x.view(-1, 64)))
        x = self.drop2(x)
        x = self.fc2(x)
        return x
    

def train_model(model, train_input, train_target, mini_batch_size):
    eta = 1e-1
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=eta)

    for e in range(25):
        sum_loss = 0
        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            loss = criterion(output, train_target.narrow(0, b, mini_batch_size))
            model.zero_grad()
            loss.backward()
            sum_loss = sum_loss + loss.item()
            optimizer.step()
        #print(e, sum_loss)
        

def compare_and_predict(output1, output2):
    predict = []
    
    for (a,b) in zip(output1, output2):
        if a <= b:
            predict.append(1)
        else:
            predict.append(0)
    return predict
       
    
def compute_error_(predicted, test):
    error = 0
    
    for (a,b) in zip(predicted, test):
        if a != b:
            error+=1
    return 100*(error/len(predicted))

#### Train and test model

In [6]:
mini_batch_size = 100
model = Net2()

for k in range(15):
    train_model(model.train(), train_input1_reshape, train_classes1, mini_batch_size)
    
    _, output1 = model(test_input1_reshape).max(1)
    _, output2 = model(test_input2_reshape).max(1)
    print('(It. # {}) Final digit comparison error on test set: {:.2f}%'.format(k, 
                                                        compute_error_(compare_and_predict(output1, output2), test_target)))

(It. # 0) Final digit comparison error on test set: 10.20%
(It. # 1) Final digit comparison error on test set: 6.70%
(It. # 2) Final digit comparison error on test set: 6.90%
(It. # 3) Final digit comparison error on test set: 6.70%
(It. # 4) Final digit comparison error on test set: 5.90%
(It. # 5) Final digit comparison error on test set: 5.60%
(It. # 6) Final digit comparison error on test set: 5.60%
(It. # 7) Final digit comparison error on test set: 5.80%
(It. # 8) Final digit comparison error on test set: 5.40%
(It. # 9) Final digit comparison error on test set: 5.30%
(It. # 10) Final digit comparison error on test set: 5.50%
(It. # 11) Final digit comparison error on test set: 5.40%
(It. # 12) Final digit comparison error on test set: 5.70%
(It. # 13) Final digit comparison error on test set: 5.80%
(It. # 14) Final digit comparison error on test set: 5.50%
