In [1]:
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import time
from utils import load_data, split_data, shuffle, weight_initialization
from metrics import compute_accuracy
from torchsummary import summary

In [2]:
# Load train and test data
train_data, test_data = load_data()

# Split data in train-validation-test
train_loader, valid_loader, test_loader = split_data(train_data, test_data)

In [3]:
def trial(net, train_loader, test_loader, n_epochs=25, n_trials=20, alpha=1, alpha_decay=1, verbose=True):
    all_losses = torch.zeros((n_trials, n_epochs))
    tr_accuracies = torch.zeros(n_trials)
    te_accuracies = torch.zeros(n_trials)
    
    for trial in range(n_trials):
        # Shuffle data
        train_loader, valid_loader, test_loader = split_data(train_data, test_data)
        
        # Reset weights
        net.train()
        net.apply(weight_initialization)
        
        # Train
        start = time.time()
        train_loss = train(net, train_loader, alpha=alpha, alpha_decay=alpha_decay, verbose=False)
        print('Trial %d/%d... Training time: %.2f s' % (i+1, n_trials, time.time()-start))
        
        # Collect data
        all_losses[trial] = train_loss
        
        net.eval()
        with torch.no_grad():
            tr_accuracies[i] = compute_accuracy(net, train_loader)
            te_accuracies[i] = compute_accuracy(net, test_loader)
        
        if verbose:
            print('Loss: %.3f, Train acc: %.3f, Test acc: %.3f' % 
                  (train_loss[-1], tr_accuracies[i], te_accuracies[i]))
    
    return all_losses, tr_accuracies, te_accuracies

In [4]:
def train(net, train_loader, eta=1e-3, decay=1e-5, n_epochs=25, alpha=1, alpha_decay=1, verbose=True):
    aux_crit = nn.CrossEntropyLoss()
    binary_crit = nn.BCELoss()
    optimizer = optim.Adam(net.parameters(), lr=eta, weight_decay=decay)
    #optimizer = optim.RMSprop(net.parameters(), lr=eta, weight_decay=decay)
    #optimizer = optim.SGD(net.parameters(), lr=eta, weight_decay=decay, momentum=0.9, nesterov=True)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

    tr_losses, val_losses = torch.zeros(n_epochs), torch.zeros(n_epochs)
    
    for e in range(n_epochs):
        tr_loss, val_loss, aux_loss = 0, 0, 0
        net.train()
        
        for (trainX, trainY, trainC) in train_loader:
            out, aux = net(trainX)
            
            aux1, aux2 = aux.unbind(1)
            c1, c2 = trainC.unbind(1)
            
            aux_loss = aux_crit(aux1, c1) + aux_crit(aux2, c2)

            binary_loss = binary_crit(out, trainY.float())
            
            total_loss = binary_loss + alpha*aux_loss
            
            tr_loss += total_loss.item()
           
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
        
        with torch.no_grad():
            for (valX, valY, trainC) in valid_loader:
                out, _ = net(valX)
                loss = binary_crit(out, valY.float())
                val_loss += loss.item()
        
        tr_losses[e] = tr_loss
        val_losses[e] = val_loss
                
        scheduler.step(val_loss)
        alpha *= alpha_decay
        
        if verbose:
            print('Epoch %d/%d, Train loss: %.3f, Val loss: %.3f, Aux loss: %.3f' % 
                  (e+1, n_epochs, tr_loss, val_loss, aux_loss))
    
    return tr_losses

In [5]:
class CNN(nn.Module):
    def __init__(self, verbose=True):
        super(CNN, self).__init__()
        
        ## Siamese block
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        
        # Decision block
        self.fc3 = nn.Linear(20, 10)
        self.fc4 = nn.Linear(10, 1)
        
        # Regularizers
        self.drop = nn.Dropout(0.2)
        self.pool = nn.MaxPool2d(2,2)
        
        # Activation fcns
        self.relu = nn.ReLU()
        self.selu = nn.SELU()
        self.sigmoid = nn.Sigmoid()
        
        if verbose:
            print(f'Parameters: {self.count_params()}')

    def count_params(self):
        return sum(p.numel() for p in self.parameters())
    
    def forward(self, x):
        x1, x2 = x.unbind(1)
        
        x1 = self.pool(self.relu(self.conv1(x1.unsqueeze(1))))
        x1 = self.pool(self.relu(self.conv2(x1)))
        x1 = self.relu(self.fc1(x1.flatten(start_dim=1)))
        x1 = self.drop(x1)
        x1 = self.relu(self.fc2(x1))
        x1 = self.drop(x1)
        
        x2 = self.pool(self.relu(self.conv1(x2.unsqueeze(1))))
        x2 = self.pool(self.relu(self.conv2(x2)))
        x2 = self.relu(self.fc1(x2.flatten(start_dim=1)))
        x2 = self.drop(x2)
        x2 = self.relu(self.fc2(x2))
        x2 = self.drop(x2)
        
        # Dim x1: Nx1x10
        
        aux = torch.stack([x1, x2], dim=1)
        
        # Dim aux: Nx2x10
        
        x = torch.cat([x1, x2], dim=1)
        
        # Dim x: Nx20
        
        x = self.relu(self.fc3(x.flatten(start_dim=1)))
        x = self.drop(x)
        x = self.sigmoid(self.fc4(x))
        return x.squeeze(), aux

In [7]:
net = CNN()

Parameters: 71975


In [None]:
train(net, train_loader, alpha=1, alpha_decay=0.9)

Epoch 1/25, Train loss: 79.515, Val loss: 2.794, Aux loss: 4.098
Epoch 2/25, Train loss: 60.516, Val loss: 2.615, Aux loss: 3.231


In [None]:
compute_accuracy(net, test_loader)

In [None]:
_, tr_accuracies, te_accuracies, trial(net, train_loader, test_loader, alpha_decay=.9)

In [None]:
print(f'Training Acc - Mean: {tr_accuracies.mean()}')
print(f'Training Acc - SD: {tr_accuracies.std()}')
print(f'Test Acc - Mean: {te_accuracies.mean()}')
print(f'Test Acc - SD: {te_accuracies.std()}')
print(f'Test Acc - Median: {tr_accuracies.median()}')

In [None]:
plt.hist(te_accuracies)