In [1]:
import torch
from torch import nn
import torch.optim as optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR

import dlc_practical_prologue as prolog

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def nb_errors(pred, truth):
    
    pred_class = pred.argmax(1)
    return (pred_class - truth != 0).sum().item()
        
    
def train_model(model, train_input, train_target, test_input, test_target,  epochs=500, batch_size=100, lr=0.1):
    
    torch.nn.init.xavier_uniform_(model.conv1.weight)
    torch.nn.init.xavier_uniform_(model.conv2.weight)
    
    optimizer = torch.optim.Adam(model.parameters())
    #scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    train_loss = []
    test_loss = []
    test_accuracy = []
    best_accuracy = 0
    best_epoch = 0
    
    for i in range(epochs):
        for b in range(0, train_input.size(0), batch_size):
            output = model(train_input.narrow(0, b, batch_size))
            criterion = torch.nn.CrossEntropyLoss()
            loss = criterion(output, train_target.narrow(0, b, batch_size))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()
        
        output_train = model(train_input)
        output_test = model(test_input)
        train_loss.append(criterion(output_train, train_target).item())
        test_loss.append(criterion(output_test, test_target).item())
        accuracy = 1 - nb_errors(output_test, test_target) / 1000
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_epoch = i+1
        test_accuracy.append(accuracy)
        
        if i%5 == 0:
            print('Epoch : ',i+1, '\t', 'test loss :', test_loss[-1], '\t', 'train loss', train_loss[-1])
        
    return train_loss, test_loss, test_accuracy, best_accuracy        

In [3]:
class ConvNet3(nn.Module):
    def __init__(self, nb_hidden):
        super(ConvNet3, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 2, out_channels = 4, kernel_size=2, stride = 1)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, stride = 1, padding=2)
        self.conv3 = nn.Conv2d(8, 16, kernel_size = 3, stride = 1, padding=2)
        self.fc1 = nn.Linear(16*3*3, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)
        self.dropout1 = nn.Dropout2d(0.25)
        #self.dropout2 = nn.Dropout2d(0.5)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.max_pool2d(F.relu(self.conv3(x)), 2)
        x = self.dropout1(x)
        x = x.view(-1, 16*3*3)
        x = F.relu(self.fc1(x))
        #x = self.dropout2(x)
        x = F.relu(self.fc2(x))
        return x



In [4]:
class ConvNet2(nn.Module):
    def __init__(self, nb_hidden):
        super(ConvNet2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 2, out_channels = 8, kernel_size=3)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride = 1)
        self.fc1 = nn.Linear(64, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)
        #self.dropout1 = nn.Dropout2d(0.25)
        #self.dropout2 = nn.Dropout2d(0.5)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), (2, 2)))
        #x = self.dropout1(x)
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64)
        #x = self.dropout1(x)
        x = F.relu(self.fc1(x))
        #x = self.dropout2(x)
        x = self.fc2(x)
        return x

In [5]:
train_input, train_target, train_classes, test_input, test_target, test_classes = prolog.generate_pair_sets(1000)

In [6]:
model = ConvNet2(10)

In [7]:
_, _, _, best_accuracy = train_model(model, train_input, train_target, test_input,\
                                             test_target, epochs=51, lr = 0.005)

Epoch :  1 	 test loss : 1.7596607208251953 	 train loss 1.5987602472305298
Epoch :  6 	 test loss : 0.8459348082542419 	 train loss 0.6801041960716248
Epoch :  11 	 test loss : 0.6623099446296692 	 train loss 0.47203686833381653
Epoch :  16 	 test loss : 0.6065219640731812 	 train loss 0.37997451424598694
Epoch :  21 	 test loss : 0.5845654010772705 	 train loss 0.3194357752799988
Epoch :  26 	 test loss : 0.5801838636398315 	 train loss 0.2742043435573578
Epoch :  31 	 test loss : 0.5859231352806091 	 train loss 0.22964411973953247
Epoch :  36 	 test loss : 0.6033982634544373 	 train loss 0.19105957448482513
Epoch :  41 	 test loss : 0.6262567043304443 	 train loss 0.15902751684188843
Epoch :  46 	 test loss : 0.657523512840271 	 train loss 0.13236689567565918
Epoch :  51 	 test loss : 0.6922580003738403 	 train loss 0.10970384627580643


In [8]:
best_accuracy

0.771

In [14]:
model1 = ConvNet3(200)
model2 = ConvNet3(100)
model3 = ConvNet3(350)
model4 = ConvNet3(700)

model5 = ConvNet2(50)
model6 = ConvNet2(100)
model7 = ConvNet2(350)
model8 = ConvNet2(700)

In [10]:
models = [model1, model2, model3, model4, model5, model6, model7, model8]

In [8]:
epochs = 200
accuracies = torch.empty(8, 10, dtype=torch.float)

for i in range(10):
    train_input, train_target, train_classes, test_input, test_target, test_classes = prolog.generate_pair_sets(1000)

    for j in range(8):
        _, _, _, best_accuracy = train_model(models[j], train_input, train_target, test_input,\
                                             test_target, epochs=epochs, lr = 0.01)
        print(best_accuracy)
        accuracies[j][i] = best_accuracy

Epoch :  1 	 test loss : 0.6930968761444092 	 train loss 0.6928606033325195
Epoch :  11 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  21 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  31 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  41 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  51 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  61 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  71 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  81 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  91 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  101 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  111 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  121 	 test loss : 0.6931537985801697 	 train loss 0.69315379858016

Epoch :  51 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  61 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  71 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  81 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  91 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  101 	 test loss : 0.6930237412452698 	 train loss 0.6931537985801697
Epoch :  111 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  121 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  131 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  141 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  151 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  161 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  171 	 test loss : 0.6931537985801697 	 train loss 0.69315379

Epoch :  101 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  111 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  121 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  131 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  141 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  151 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  161 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  171 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  181 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  191 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
0.546
Epoch :  1 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  11 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  21 	 test loss : 0.6931537985801697 	 train loss 0.6

Epoch :  161 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  171 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  181 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  191 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
0.546
Epoch :  1 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  11 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  21 	 test loss : 0.6931537985801697 	 train loss 0.6931537985801697
Epoch :  31 	 test loss : 0.6929537057876587 	 train loss 0.6931537985801697
Epoch :  41 	 test loss : 0.6915331482887268 	 train loss 0.6973340511322021
Epoch :  51 	 test loss : 0.6702325344085693 	 train loss 0.6733124256134033
Epoch :  61 	 test loss : 0.6610875725746155 	 train loss 0.6573403477668762
Epoch :  71 	 test loss : 0.658228874206543 	 train loss 0.6621551513671875
Epoch :  81 	 test loss : 0.6536559462547302 	 train loss 0.65249454

KeyboardInterrupt: 

In [None]:
accuracies.mean(1)