In [65]:
import dlc_practical_prologue as prologue
import torch
from torch import nn
from torch.nn import functional as F

N_PAIRS = 1000
MINI_BATCH_SIZE = 100  
EPOCHS = 25
LEFT = 0
RIGHT = 1

#TODO: add comments
#TODO: average over multiple runs
#TODO: normalize?
#TODO: perf?

def train_model(model, train_input, train_target, train_classes):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    for e in range(EPOCHS):
        sum_loss = 0
        for b in range(0, train_input.size(0), MINI_BATCH_SIZE):
            optimizer.zero_grad()
            output = model(train_input.narrow(0, b, MINI_BATCH_SIZE))
            
            if not model.binary:    
                loss = criterion(output[0], train_classes.narrow(0, b, MINI_BATCH_SIZE)[:,0])
                loss += criterion(output[1], train_classes.narrow(0, b, MINI_BATCH_SIZE)[:,1])
                
            else:
                loss = criterion(output, train_target.narrow(0, b, MINI_BATCH_SIZE))
            
            sum_loss = sum_loss + loss.item()
            
            if model.aux:
                left, right = model.get_subnetwork_output()
                classes = train_classes.narrow(0, b, MINI_BATCH_SIZE)
                loss += 0.4*criterion(left, classes[:, LEFT])
                loss += 0.4*criterion(right, classes[:, RIGHT])

            loss.backward()
            optimizer.step()

        print(e, sum_loss)

def compute_nb_errors(model, _input, target):
    nb_erros = 0
    for b in range(0, _input.size(0), MINI_BATCH_SIZE):
        output_res = model(_input.narrow(0, b, MINI_BATCH_SIZE))
        _, predicted = output_res.max(1)
        for i in range(MINI_BATCH_SIZE):
            real = target[b+i]
            if real != predicted[i]:
                nb_erros += 1

    return nb_erros

In [76]:
12*12*32

4608

In [75]:
train_input.shape

torch.Size([1000, 2, 14, 14])

In [77]:
class DigitSubnetwork(nn.Module):

    def __init__(self, aux=False, nb_hidden=200):
        super(DigitSubnetwork, self).__init__()
        self.aux = aux
        self.aux_out = None
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(4 * 4 * 64, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)
        
        if self.aux:
            self.fcaux = nn.Linear(6 * 6 * 32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(F.max_pool2d(x, kernel_size=2))
        
        if self.aux:
            self.aux_out = self.fcaux(x.view(-1, 6*6*32))
        
        x = self.conv2(x)
        x = F.relu(x)

        x = F.relu(self.fc1(x.view(-1, 4 * 4 * 64)))
        x = self.fc2(x)

        return x, self.aux_out

class Model(nn.Module):
    def __init__(self, weight_sharing=False, aux=False, binary=False):
        super(Model, self).__init__()
        self.left = DigitSubnetwork(aux)
        self.weight_sharing = weight_sharing
        self.aux = aux
        self.binary = binary
        
        if not weight_sharing:
            self.right = DigitSubnetwork()
            
        self.fc = nn.Linear(20, 200)
        self.fc2 = nn.Linear(200, 2)
        self.left_out = None
        self.right_out = None

    def get_subnetwork_output(self):
        return (self.left_out, self.right_out)

    def forward(self, x):
        left, left_aux_out = self.left(x[:, LEFT, :, :].unsqueeze(1))
        if not self.weight_sharing:
            right, right_aux_out = self.right(x[:, RIGHT, :, :].unsqueeze(1))
        else:
            right, right_aux_out = self.left(x[:, RIGHT, :, :].unsqueeze(1))

        self.left_out = left
        self.right_out = right
        
        if left_aux_out != None:
            self.left_out = left_aux_out
            self.right_out = right_aux_out
        
        if not self.binary:
            return left, right
        else:
            x = torch.cat((left, right), dim=1)
            x = self.fc(x)
            x = F.relu(x)
            x = self.fc2(x)
            return x

In [67]:
def compute_accuracy(inp, targ, model):
    if model.binary:
        out = model(inp)
        pred = torch.argmax(out, dim=1)
        errors = torch.abs(pred - targ).sum().item()
        
    else:
        out_1, out_2 = model(inp)
        pred_1, pred_2 = torch.argmax(out_1, dim=1), torch.argmax(out_2, dim=1) 

        difference = pred_2 - pred_1
        difference[difference >= 0] = 1
        difference[difference < 0] = 0

        errors = torch.sum(torch.abs(difference - targ)).item()
    
    return 100 * (1 - errors/inp.shape[0])

In [68]:
from IPython.display import clear_output

In [79]:
if __name__ == "__main__":
    rounds = 10
    avg_error = 0
    avg_acc = 0
    
    for i in range(rounds):
        model = Model(weight_sharing=True, binary=False, aux=True)
        
        train_input, train_target, train_classes, test_input, test_target, test_classes = \
            prologue.generate_pair_sets(N_PAIRS)

        train_model(model, train_input, train_target, train_classes)

        acc = compute_accuracy(test_input, test_target, model)        

        avg_acc += acc
        print("Test accuracy: {}".format(acc))
    
        clear_output(wait=True)
        

    avg_acc = avg_acc/rounds
    print("Avg test acc %: {}".format(avg_acc))

Avg test acc %: 96.78
