In [None]:
# %% Deep learning - Section 11.107
#    FFN to classify digits

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [None]:
# %% Libraries and modules
import numpy               as np
import matplotlib.pyplot   as plt
import torch
import torch.nn            as nn
import seaborn             as sns
import copy
import torch.nn.functional as F
import pandas              as pd
import scipy.stats         as stats
import time

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from IPython                          import display
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')


In [None]:
# %% Architecture of FFN model

# We will here use a fairly simple model with :
# > Input layer (N=28x28=784)
# > Hidden fully connected layer 1 (N=64)
# > Hidden fully connected layer 2 (N=32)
# > Output layer (N=10)
# > ReLu after each layer, except softmax after the output layer (to get probs)

# Extra note :
# > We will also use the log of the softmax to stretch out small probabilities
#   and increase the penalty for incorrect guesses; this can increase category
#   separability and numerical stability (but standard/linear softmax also works
#   fine sometimes, especially for a relative small number of categories)


In [None]:
# %% Data

# Load data
data = np.loadtxt(open('sample_data/mnist_train_small.csv','rb'),delimiter=',')

# Split labels from data
labels = data[:,0]
data   = data[:,1:]

# Normalise data (original range is (0,255))
data_norm = data / np.max(data)


In [None]:
# %% Plotting

phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(6*phi,6))

ax[0].hist(data.flatten(),50)
ax[0].set_xlabel('Pixel intensity values')
ax[0].set_ylabel('Pixel count')
ax[0].set_title('Original data')
ax[0].set_yscale('log')

ax[1].hist(data_norm.flatten(),50)
ax[1].set_xlabel('Pixel intensity values')
ax[1].set_ylabel('Pixel count')
ax[1].set_title('Normalised data')
ax[1].set_yscale('log')

plt.savefig('figure5_mnist_classify_digits.png')

plt.show()

files.download('figure5_mnist_classify_digits.png')


In [None]:
# %% Create train and test datasets

# Convert to tensor (float and integers)
data_tensor   = torch.tensor(data_norm).float()
labels_tensor = torch.tensor(labels).long()

# Split data with scikitlearn (10% test data)
train_data,test_data,train_labels,test_labels = train_test_split(data_tensor,labels_tensor,test_size=0.1)

# Convert to PyTorch datasets
train_data = TensorDataset(train_data,train_labels)
test_data  = TensorDataset(test_data,test_labels)

# Convert into DataLoader objects
batch_size   = 32
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True)
test_loader  = DataLoader(test_data,batch_size=test_data.tensors[0].shape[0])


In [None]:
# %% Check variables in workspace so far

%whos


In [None]:
# %% Function to generate the model

def gen_model():

    class mnist_FFN(nn.Module):
        def __init__(self):
            super().__init__()

            # Architecture
            self.input  = nn.Linear(784,64)
            self.fc1    = nn.Linear( 64,32)
            self.fc2    = nn.Linear( 32,32)
            self.output = nn.Linear( 32,10)

        # Forward propagation (log-softmax because NLLLoss instead of CrossEntropyLoss)
        def forward(self,x):

            x = F.relu(self.input(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = torch.log_softmax( self.output(x),axis=1 )

            return x

    # Create model instance
    ANN = mnist_FFN()

    # Loss function
    loss_fun = nn.NLLLoss()

    # Optimizer (SGD to slow down learning for illustration purpose)
    optimizer = torch.optim.SGD(ANN.parameters(),lr=0.01)

    return ANN,loss_fun,optimizer


In [None]:
# Test the model on one batch

ANN,loss_fun,optimizer = gen_model()

X,y  = next(iter(train_loader))
yHat = ANN(X)

# Print log-softmax output (size should be batch_size by output nodes)
print(yHat)
print(yHat.shape)
print()

# Print probabilities
print(torch.exp(yHat))
print()

# Compute loss
loss = loss_fun(yHat,y)
print('Loss: ')
print(loss)


In [None]:
# %% Function to train the model

def train_model():

    # Parameters, model instance, inizialise vars
    num_epochs = 60
    ANN,loss_fun,optimizer = gen_model()

    losses    = []
    train_acc = []
    test_acc  = []

    # Loop over epochs
    for epoch_i in range(num_epochs):

        # Loop over training batches
        batch_acc  = []
        batch_loss = []

        for X,y in train_loader:

            # Forward propagation and loss
            yHat = ANN(X)
            loss = loss_fun(yHat,y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Loss and accuracy from this batch
            batch_loss.append(loss.item())

            matches     = torch.argmax(yHat,axis=1) == y
            matches_num = matches.float()
            accuracy    = 100 * torch.mean(matches_num)
            batch_acc.append(accuracy)

        losses.append( np.mean(batch_loss) )
        train_acc.append( np.mean(batch_acc) )

        # Test accuracy
        X,y = next(iter(test_loader))
        yHat = ANN(X)
        test_acc.append( 100*torch.mean((torch.argmax(yHat,axis=1)==y).float()) )

    return train_acc,test_acc,losses,ANN


In [None]:
# %% Run the training

train_acc,test_acc,losses,ANN = train_model()


In [None]:
# %% Plotting

phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(1.5*6*phi,6))

ax[0].plot(losses)
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_ylim([0,3])
ax[0].set_title('Model loss')

ax[1].plot(train_acc,label='Train accuracy')
ax[1].plot(test_acc,label='Test accuracy')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy (%)')
ax[1].set_ylim([10,100])
ax[1].set_title(f'Final model test accuracy: {test_acc[-1]:.2f}%')
ax[1].legend()

plt.savefig('figure6_mnist_classify_digits.png')

plt.show()

files.download('figure6_mnist_classify_digits.png')


In [None]:
# %% Inspect output in more detail

# Run model though for test data
X,y   = next(iter(test_loader))
preds = ANN(X).detach()

print(preds)
print(torch.exp(preds))

# Evidence for all numbers from one sample (log-softmax and softmax)
sample2show = 120

phi = ( 1 + np.sqrt(5) ) / 2
plt.figure(figsize=(6*phi,6))

plt.bar(range(10),preds[sample2show])
plt.xticks(range(10))
plt.xlabel('Number')
plt.ylabel('Evidence for that number (log-softmax)')
plt.title(f'True number was {y[sample2show].item()}')

plt.savefig('figure7_mnist_classify_digits.png')

plt.show()

files.download('figure7_mnist_classify_digits.png')

phi = ( 1 + np.sqrt(5) ) / 2
plt.figure(figsize=(6*phi,6))

plt.bar(range(10),np.exp(preds[sample2show]))
plt.xticks(range(10))
plt.xlabel('Number')
plt.ylabel('Evidence for that number (softmax)')
plt.title(f'True number was {y[sample2show].item()}')

plt.savefig('figure8_mnist_classify_digits.png')

plt.show()

files.download('figure8_mnist_classify_digits.png')

# Find and print the errors
errors = np.where( torch.max(preds,axis=1)[1] != y )[0]
print(errors)

# Evidence for all numbers from one sample
sample2show = 2

phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(6*phi,6))

ax[0].bar(range(10),np.exp(preds[errors[sample2show]]))
ax[0].set_xticks(range(10))
ax[0].set_xlabel('Number')
ax[0].set_ylabel('Evidence for that number')
ax[0].set_title(f'True number: {y[errors[sample2show]].item()}, model guessed {torch.argmax(preds[errors[sample2show]]).item()}')

ax[1].imshow( np.reshape(X[errors[sample2show],:],(28,28)) ,cmap='gray')

plt.savefig('figure9_mnist_classify_digits.png')

plt.show()

files.download('figure9_mnist_classify_digits.png')


In [None]:
# %% Exercise 1
#    Average together the correct 7's and the error 7's, and make images of them (that is, one image
#    of all correct 7's and one image of all incorrectly labeled 7's). How do they look?

# A pretty nice platonic seven for the right ones, and an ugly blob for the incorrect
# classifications

# Get correct and wrong sevens and average
X,y   = next(iter(test_loader))
preds = ANN(X).detach()

pred_y    = torch.argmax(preds,axis=1)
pred_y_np = pred_y.numpy()
true_y_np = y.numpy()

num  = 7
true_indices = np.where(true_y_np == num)[0]
correct = [j for j in true_indices if pred_y_np[j] == num]
errors  = [j for j in true_indices if pred_y_np[j] != num]

avg_errors  = torch.mean(X[errors],axis=0)
avg_correct = torch.mean(X[correct],axis=0)

# Plot
phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(6*phi,6))

ax[0].imshow(np.reshape(avg_errors.numpy(),(28,28)),cmap='gray')
ax[0].set_title('Average of all wrongly\nclassified 7s')
ax[0].axis('off')
ax[1].imshow(np.reshape(avg_correct.numpy(),(28,28)),cmap='gray')
ax[1].set_title('Average of all correctly classified 7s\n("The Platonic 7")')
ax[1].axis('off')

plt.savefig('figure10_mnist_classify_digits_extra1.png')

plt.show()

files.download('figure10_mnist_classify_digits_extra1.png')


In [None]:
# %% Exercise 2
#    Repeat #1 for all numbers to produce a 2x10 matrix of images with corrects on top
#    and errors on the bottom.

# Same as above, but loop over all the numbers; some of the errors are surprising
# but some averages are indeed pretty unrecognisable

# Get correct and wrong numbers and average
X,y   = next(iter(test_loader))
preds = ANN(X).detach()

pred_y    = torch.argmax(preds,axis=1)
pred_y_np = pred_y.numpy()
true_y_np = y.numpy()

nums = np.arange(10)

avg_errors_img  = np.zeros((28, 28, 10))
avg_correct_img = np.zeros((28, 28, 10))

for i, num in enumerate(nums):

    true_indices = np.where(true_y_np == num)[0]

    correct = [j for j in true_indices if pred_y_np[j] == num]
    errors  = [j for j in true_indices if pred_y_np[j] != num]

    avg_errors  = torch.mean(X[errors],axis=0)
    avg_correct = torch.mean(X[correct],axis=0)

    avg_errors_img[:,:,i]  = np.reshape(avg_errors.numpy(),(28,28))
    avg_correct_img[:,:,i] = np.reshape(avg_correct.numpy(),(28,28))

# Plot
phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(2,10,figsize=(2*8*phi,8))

for i in range(10):

    ax[0,i].imshow(avg_errors_img[:,:,i],cmap='gray')
    ax[0,i].set_title(f'Wrong\n{i}s')
    ax[0,i].axis('off')

    ax[1,i].imshow(avg_correct_img[:,:,i],cmap='gray')
    ax[1,i].set_title(f'Correct\n{i}s')
    ax[1,i].axis('off')

plt.savefig('figure11_mnist_classify_digits_extra2.png')

plt.show()

files.download('figure11_mnist_classify_digits_extra2.png')


In [None]:
# %% Exercise 3
#    Identify "almost errors," which we can define as correct categorizations that had a probability of
#    e.g., >.1 for any other number. Make images of some of these numbers. Can you understand why the model
#    was confused?

# To an human eye it's quite obvious which one is the correct classification, but
# sometimes one can see why the model couldn't get it right

# Get all misclassified labels, turn back log-softmax to probs, get misclassified
# digits, compute margin between top-2 predictions, and then get nearly correctly
# classified digits
X,y    = next(iter(test_loader))
preds  = ANN(X).detach()
probs  = torch.exp(preds)

pred_labels = torch.argmax(probs,dim=1)
true_labels = y

misclassified = (pred_labels != true_labels)
top2          = torch.topk(probs,2,dim=1)
margin        = top2.values[:,0] - top2.values[:,1]

margin_threshold    = 0.1
almost_right_margin = misclassified & (margin <= margin_threshold)

print(f"Total misclassified digits: {misclassified.sum().item()}")
print(f"Diguts misclassified with a low-margin errors (≤ {margin_threshold}): {almost_right_margin.sum().item()}")

# Plot
X_almost     = X[almost_right_margin]
y_almost     = y[almost_right_margin]
preds_almost = pred_labels[almost_right_margin]

indices = torch.where(almost_right_margin)[0]

if len(indices) >= 2:

    selected = np.random.choice(indices.tolist(),2)

    phi = ( 1 + np.sqrt(5) ) / 2
    fig,ax = plt.subplots(1,2,figsize=(6*phi,6))

    for i,img2show in enumerate(selected):

        img      = X[img2show].reshape(28,28).squeeze().numpy()
        true_lbl = y[img2show].item()
        pred_lbl = pred_labels[img2show].item()
        m        = margin[img2show].item()

        ax[i].imshow(img, cmap='gray')
        ax[i].set_title(f"True: {true_lbl}, Pred: {pred_lbl}\nMargin: {m:.3f}")
        ax[i].axis('off')

    plt.savefig('figure12_mnist_classify_digits_extra3.png')

    plt.show()

    files.download('figure12_mnist_classify_digits_extra3.png')

else:

    print("Not enough almost-right misclassified samples in this batch.")


In [None]:
# %% Exercise 4
#    I didn't use .train(), .eval(), or no_grad() here. Is that a problem? Can you add those in without checking
#    other notebooks?

# It shouldn't be a problem here because we are not using regularisation methods
# such as nodes dropout or batch normalisation, but one could place the .train()
# and .eval() switches before training starts and before evaluation starts,
# respectively; no_grad() could also be used to kill the gradient tracking
# during the test phase and save up memory


# Modified function to train the model
def train_model():

    # Parameters, model instance, inizialise vars
    num_epochs = 60
    ANN,loss_fun,optimizer = gen_model()

    losses    = []
    train_acc = []
    test_acc  = []

    # Loop over epochs
    for epoch_i in range(num_epochs):

        # Loop over training batches
        batch_acc  = []
        batch_loss = []

        for X,y in train_loader:

            # Forward propagation and loss
            yHat = ANN(X)
            loss = loss_fun(yHat,y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Loss and accuracy from this batch
            batch_loss.append(loss.item())

            matches     = torch.argmax(yHat,axis=1) == y
            matches_num = matches.float()
            accuracy    = 100 * torch.mean(matches_num)
            batch_acc.append(accuracy)

        losses.append( np.mean(batch_loss) )
        train_acc.append( np.mean(batch_acc) )

        # Test accuracy
        ANN.eval()

        with torch.no_grad():
            X,y = next(iter(test_loader))
            yHat = ANN(X)
            test_acc.append( 100*torch.mean((torch.argmax(yHat,axis=1)==y).float()) )

        ANN.train()

    return train_acc,test_acc,losses,ANN
