In [None]:
import copy
import time
import torch
import tensorflow as tf
import torch.nn as nn
import datetime
import time
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib import ticker

EPSILON = np.nextafter(0, 1)
IMAGE_SIDE_LEN = 28
NB_PIXEL = 784
NB_PIXEL_VALUE = 256
NB_CATEGORY = 10

################################## Utilities ###################################
# some utilities dont make sense anymore and should be removed

def get_sets(batch_size, train_size, test_size):
    (xtrain, ytrain), (xtest, ytest) = #db_file
    xtrain = xtrain[0:train_size]
    ytrain = ytrain[0:train_size]
    xtest = xtest[0:test_size]
    ytest = ytest[0:test_size]

    xtrain_tensor = torch.tensor(xtrain, dtype=torch.float, device='cuda:0')[:, :, None]
    ytrain_tensor = torch.tensor(ytrain, dtype=torch.long, device='cuda:0')
    nb_batch = math.floor(xtrain_tensor.shape[0] / batch_size)
    xtrain_batches = np.empty(nb_batch, dtype=object)
    ytrain_batches = np.empty(nb_batch, dtype=object)
    for i in range(0, nb_batch):
        xtrain_batches[i] = xtrain_tensor[i:i+batch_size]
        ytrain_batches[i] = ytrain_tensor[i:i+batch_size]

    xtest_tensor = torch.tensor(xtest, dtype=torch.float, device='cuda:0')[:, :, None]
    ytest_tensor = torch.tensor(ytest, dtype=torch.long, device='cuda:0')
    nb_batch = math.floor(xtest_tensor.shape[0] / batch_size)
    xtest_batches = np.empty(nb_batch, dtype=object)
    ytest_batches = np.empty(nb_batch, dtype=object)
    for i in range(0, nb_batch):
        xtest_batches[i] = xtest_tensor[i:i + batch_size]
        ytest_batches[i] = ytest_tensor[i:i + batch_size]

    xvalid_batches, xtest_batches = np.array_split(xtest_batches, 2)
    yvalid_batches, ytest_batches = np.array_split(ytest_batches, 2)
    
    return (xtrain_batches, ytrain_batches), (xvalid_batches, yvalid_batches), (xtest_batches, ytest_batches)


def time_since(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def do_confusion(model, number_batches, category_batches):
    confusion = torch.zeros(NB_CATEGORY, NB_CATEGORY)
    # Go through a bunch of examples and record which are correctly guessed
    with torch.no_grad():
        for i in range(0, number_batches.shape[0]):
            batch_ouput = model.forward(number_batches[i], category_batches[i])
            for j in range(0, batch_ouput.shape[0]):
                guess_index = batch_ouput[j].topk(1)[1].item()
                confusion[category_batches[i][j].item()][guess_index] += 1
    # Normalize by dividing every row by its sum
    for i in range(NB_CATEGORY):
        confusion[i] = confusion[i] / confusion[i].sum()
    plot_confusion(confusion)


def plot_confusion(confusion):
    # Set up plot
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(confusion.numpy())
    fig.colorbar(cax)
    # Set up axes
    ax.set_xticklabels([''] + list(range(NB_CATEGORY)), rotation=90)
    ax.set_yticklabels([''] + list(range(NB_CATEGORY)))
    # Force label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion matrix")
    plt.savefig("confusion.png")
    plt.show()


def plot_failed(model, number_batches, category_batches):
    failed, guessed, actual = find_failed(model, number_batches, category_batches)
    fig = plt.figure(figsize=(11, 4))
    columns = 5
    rows = 2
    for i in range(columns * rows):
        fig.add_subplot(rows, columns, i + 1)
        plt.imshow(failed[i].cpu().reshape(IMAGE_SIDE_LEN, IMAGE_SIDE_LEN))
        plt.title("P: " + str(guessed[i]) + "  A: " + str(actual[i]))
        plt.axis('off')
    plt.savefig("failed.png")
    plt.show()


def find_failed(model, number_batches, category_batches):
    failed = []
    guessed = []
    actual = []
    with torch.no_grad():
        for i in range(0, number_batches.shape[0]):
            batch_ouput = model.forward(number_batches[i], category_batches[i])
            for j in range(0, batch_ouput.shape[0]):
                guess_index = batch_ouput[j].topk(1)[1].item()
                if category_batches[i][j].item() != guess_index:
                    failed.append(number_batches[i][j])
                    guessed.append(guess_index)
                    actual.append(category_batches[i][j].item())
                    if len(failed) == 10:
                        return failed, guessed, actual
    return failed, guessed, actual


def get_inference_time(model, number_batches, category_batches):
    with torch.no_grad():
        n1 = datetime.datetime.now()
        for i in range(0, number_batches.shape[0]):
            _ = model.forward(number_batches[i], category_batches[i])
    n2 = datetime.datetime.now()
    start = n1.second * 1000000 + n1.microsecond
    now = n2.second * 1000000 + n2.microsecond
    us = (now - start) / (number_batches.shape[0] * number_batches[0].size()[0])
    return str(round(us)) + " microseconds (μs)"


def plot_accuracy(train, test, valid):
    plt.figure()
    plt.plot(train, label="Train")
    plt.plot(test, label="Test")
    plt.plot(valid, label="Validation")
    plt.legend(loc='best')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy")
    plt.savefig("accuracy.png")
    plt.show()


def plot_loss(train, test, valid):
    plt.figure()
    plt.plot(train, label="Train")
    plt.plot(test, label="Test")
    plt.plot(valid, label="Validation")
    plt.legend(loc='best')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss")
    plt.savefig("loss.png")
    plt.show()

In [None]:
################################ Neural Network ################################


class NN(nn.Module):

    def __init__(self, model, hidden_size, output_size, lr):
        super(NN, self).__init__()
        self.base_model = model
        self.h2o = nn.Linear(hidden_size, output_size).cuda()
        self.softmax = nn.LogSoftmax(dim=1).cuda()
        self.loss = nn.NLLLoss().cuda()
        self.optimizer = torch.optim.Adam(self.parameters())


    def forward(self, input, _): # 3rd param is to fit with Transformer.forward
        output, _ = self.base_model.forward(input)
        output = self.h2o(output[:, -1, :])
        output = self.softmax(output)
        return output


class Transformer(nn.Module):

    def __init__(self, model, batch_size):
        super(Transformer, self).__init__()
        self.base_model = model
        self.loss = nn.NLLLoss().cuda()
        self.optimizer = torch.optim.Adam(self.parameters())
        self.class_tokens = torch.tensor(np.zeros((batch_size, NB_PIXEL, NB_CATEGORY)), device='cuda:0')
        self.input_token = torch.tensor(np.zeros((batch_size, 1)), device='cuda:0')

    def forward(self, input, target):
        input = torch.cat((input, self.class_tokens), 2).transpose(0, 1).float()
        target = torch.cat((self.input_token, nn.functional.one_hot(target, NB_CATEGORY)), 1).float()[None, :]
        output = self.base_model.forward(input, target)
        output = torch.squeeze(output, 0)[:, 1:]
        return output

In [None]:
################################ Training ######################################

def train_batch(model, number_batch, category_batch):
    model.zero_grad()
    model.optimizer.zero_grad()
    output = model.forward(number_batch, category_batch)
    loss = model.loss(output, category_batch)
    loss.backward()
    model.optimizer.step()
    return output, loss.item()


def get_accuracy_loss(model, number_batches, category_batches):
    nb_correct = 0
    with torch.no_grad():
        for i in range(0, number_batches.shape[0]):
            batch_ouput = model.forward(number_batches[i], category_batches[i])
            loss = model.loss(batch_ouput, category_batches[i])
            for j in range(0, batch_ouput.shape[0]):
                if batch_ouput[j].topk(1)[1].item() == category_batches[i][j]:
                    nb_correct += 1
    return nb_correct / (number_batches.shape[0] * number_batches[0].size()[0]), loss


def log_accuracy_loss(model, number_batches, category_batches, accuracy_array, loss_array):
    accuracy, loss = get_accuracy_loss(model, number_batches, category_batches)
    accuracy_array.append(accuracy)
    loss_array.append(loss)
    return accuracy, loss


def log_sets(model, epoch, nb_epochs):
    train_accuracy, train_loss = log_accuracy_loss(model, xtrain, ytrain, train_accuracies, train_losses)
    test_accuracy, _ = log_accuracy_loss(model, xtest, ytest, test_accuracies, test_losses)
    validation_accuracy, _ = log_accuracy_loss(model, xvalid, yvalid, validation_accuracies, validation_losses)
    print('%d [%d%%][%s] Loss: %.4f Train:%.2f%% Test:%.2f%%' % (epoch, int(epoch / nb_epochs * 100), time_since(start),
                                                      train_loss, train_accuracy * 100, test_accuracy * 100))
    return validation_accuracy


def train(model):
    best_model = model
    best_validation_accuracy = log_sets(model, 0, nb_epochs)
    nb_batch = xtrain.shape[0]
    for epoch in range(0, nb_epochs):
        for i in range(0, nb_batch):
            _, _ = train_batch(model, xtrain[i], ytrain[i])
        validation_accuracy = log_sets(model, epoch+1, nb_epochs)
        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            best_model = copy.deepcopy(model)
    return best_model

In [None]:
################################## Setup #######################################

cuda0 = torch.device('cuda:0')

# Parameters
batch_size = 100
nb_epochs = 30
train_size = 60000
test_size = 10000
hidden_size = 6
num_layers = 1
lr = 0.001

# !! uncomment the model you want to run, training is generic and models are interchangeable !!

# LSTM, GRU
#base_model = nn.LSTM(1, hidden_size, num_layers, batch_first=True).cuda()
#base_model = nn.GRU(1, hidden_size, num_layers, batch_first=True).cuda()
#model = NN(base_model, hidden_size, NB_CATEGORY, lr)
#(xtrain, ytrain), (xvalid, yvalid), (xtest, ytest) = get_sets(batch_size, train_size, test_size)

# Transformer
tranformer_model = nn.Transformer(d_model=1+NB_CATEGORY, nhead=1, dim_feedforward=hidden_size).cuda()
model = Transformer(tranformer_model, batch_size)
(xtrain, ytrain), (xvalid, yvalid), (xtest, ytest) = get_sets(batch_size, train_size, test_size)

# Logs
train_accuracies = []
train_losses = []
validation_accuracies = []
validation_losses = []
test_accuracies = []
test_losses = []

# Train
start = time.time()
best_model = train(model)
final_time = time_since(start)

# Results
print('Final train accuracy: %.2f%%' % (get_accuracy_loss(best_model, xtrain, ytrain)[0]*100))
print('Final test accuracy: %.2f%%' % (get_accuracy_loss(best_model, xtest, ytest)[0]*100))
print('Overall training time: %s' % (final_time))
print('Average inference time: %s' % (get_inference_time(best_model, xtest, ytest)))
print('Number of parameters: %d' % (sum(p.numel() for p in best_model.parameters())))
do_confusion(best_model, xtest, ytest)
plot_loss(train_losses, test_losses, validation_losses)
plot_accuracy(train_accuracies, test_accuracies, validation_accuracies)
plot_failed(best_model, xtest, ytest)