In [None]:
# %% Deep learning - Section 11.115
#    Code challenge 18: the mystery of the missing 7

#    1) Train the MNIST-FFN on all the numbers except the 7s
#    2) Test the model on the 7s and see how it behaves
#       > Plot final accuracy for 7s
#       > Plot some examples of misclassification
#       > Plot the proportion of 7s labelled as other digits
#    3) Bonus: try dropout regularisation

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [None]:
# %% Libraries and modules
import numpy               as np
import matplotlib.pyplot   as plt
import torch
import torch.nn            as nn
import seaborn             as sns
import copy
import torch.nn.functional as F
import pandas              as pd
import scipy.stats         as stats
import time

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from IPython                          import display
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')


In [None]:
# %% Data

# Load data
data = np.loadtxt(open('sample_data/mnist_train_small.csv','rb'),delimiter=',')

# Split labels from data
labels = data[:,0]
data   = data[:,1:]

# Normalise data (original range is (0,255))
data_norm = data / np.max(data)


In [None]:
# %% Create train and test datasets
#    Here train = all but 7s, and test = 7s

# Convert to tensor (float and integers) the scrambled data
data_tensor   = torch.tensor(data_norm).float()
labels_tensor = torch.tensor(labels).long()

# Split data in 7s and no-7s
train_data,train_labels = data_tensor[labels_tensor!=8], labels_tensor[labels_tensor!=8]
test_data,test_labels   = data_tensor[labels_tensor==8], labels_tensor[labels_tensor==8]

# Convert to PyTorch datasets
train_data = TensorDataset(train_data,train_labels)
test_data  = TensorDataset(test_data,test_labels)

# Convert into DataLoader objects
batch_size   = 32
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True)
test_loader  = DataLoader(test_data,batch_size=test_data.tensors[0].shape[0])


In [None]:
# %% Function to generate the model

def gen_model(drop_rate):

    class mnist_FFN(nn.Module):
        def __init__(self,dropout_rate):
            super().__init__()

            # Architecture
            self.input  = nn.Linear(784,64)
            self.fc1    = nn.Linear( 64,32)
            self.fc2    = nn.Linear( 32,32)
            self.output = nn.Linear( 32,10)

            # Dropout
            self.dr = dropout_rate

        # Forward propagation
        def forward(self,x):

            x = F.relu(self.input(x))
            x = F.dropout(x,p=self.dr,training=self.training)
            x = F.relu(self.fc1(x))
            x = F.dropout(x,p=self.dr,training=self.training)
            x = F.relu(self.fc2(x))
            x = F.dropout(x,p=self.dr,training=self.training)

            x = self.output(x)

            return x

    # Create model instance
    ANN = mnist_FFN(drop_rate)

    # Loss function
    loss_fun = nn.CrossEntropyLoss()

    # Optimizer (SGD to slow down learning for illustration purpose)
    optimizer = torch.optim.SGD(ANN.parameters(),lr=0.01)

    return ANN,loss_fun,optimizer


In [None]:
# %% Function to train the model

def train_model(drop_rate):

    # Parameters, model instance, inizialise vars
    num_epochs = 60
    ANN,loss_fun,optimizer = gen_model(drop_rate)

    losses_trn = []
    losses_tst = []
    train_acc  = []
    test_acc   = []

    # Loop over epochs
    for epoch_i in range(num_epochs):

        # Loop over training batches
        batch_acc  = []
        batch_loss = []

        for X,y in train_loader:

            # Forward propagation and loss
            yHat = ANN(X)
            loss = loss_fun(yHat,y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Loss and accuracy from this batch
            batch_loss.append(loss.item())

            matches     = torch.argmax(yHat,axis=1) == y
            matches_num = matches.float()
            accuracy    = 100 * torch.mean(matches_num)
            batch_acc.append(accuracy)

        losses_trn.append( np.mean(batch_loss) )
        train_acc.append( np.mean(batch_acc) )

        # Test accuracy
        ANN.eval()

        with torch.no_grad():
            X,y = next(iter(test_loader))
            yHat = ANN(X)
        test_acc.append( 100*torch.mean((torch.argmax(yHat,axis=1)==y).float()) )
        loss = loss_fun(yHat,y)
        losses_tst.append(loss.item())

        ANN.train()

    return train_acc,test_acc,losses_trn,losses_tst,ANN


In [None]:
# %% Fit the model

drop_rate = 0
train_acc,test_acc,losses_trn,losses_tst,ANN = train_model(drop_rate)


In [None]:
# %% Plotting

phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(1.5*6*phi,6))

ax[0].plot(losses_trn,label='Non-7s loss')
ax[0].plot(losses_tst,label='7s loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_ylim([-2.5,20])
ax[0].set_title('Model loss')
ax[0].legend()

ax[1].plot(train_acc,label='Non-7s accuracy')
ax[1].plot(test_acc,label='7s accuracy')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy (%)')
ax[1].set_ylim([-10,100])
ax[1].set_title(f'Final model test accuracy on 7s: {test_acc[-1]:.2f}%')
ax[1].legend()

plt.savefig('figure72_code_challenge_18.png')

plt.show()

files.download('figure72_code_challenge_18.png')


In [None]:
# %% Plotting

phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(3,4,figsize=(6*phi,6))

# Put model in evaluation mode, get the 7s, and the model prediction
ANN.eval()
X7s, y7s = test_data.tensors
with torch.no_grad():
    predictions = torch.argmax(ANN(X7s),dim=1)

# Choose random indices from test set
rand_indices = np.random.choice(X7s.shape[0],size=12,replace=False)

for ax, idx in zip(axs.flatten(),rand_indices):
    img = X7s[idx].reshape(28,28)
    pred_label = predictions[idx].item()

    ax.imshow(img,cmap='gray')
    ax.set_title(f'Predicted digit: {pred_label}')
    ax.axis('off')

plt.suptitle('Model predictions on 7s',fontsize=14)
plt.tight_layout(rect=[0,0,1,0.95])

plt.savefig('figure73_code_challenge_18.png')

plt.show()

files.download('figure73_code_challenge_18.png')


In [None]:
# %% Plotting

# Count occurrences of each predicted label, and get proportion
pred_counts = torch.bincount(predictions,minlength=10).numpy()
pred_proportions = pred_counts / len(predictions)

# Plot
phi = (1 + np.sqrt(5)) / 2
plt.figure(figsize=(6*phi,6))

plt.bar(np.arange(10),pred_proportions,zorder=3)
plt.xticks(np.arange(10))
plt.xlabel('Predicted label')
plt.ylabel('Proportion of 7s classified as other digits')
plt.title('Misclassification distribution of digit 7')
plt.grid(axis='y',linestyle='--',alpha=0.7)

plt.savefig('figure74_code_challenge_18.png')

plt.show()

files.download('figure74_code_challenge_18.png')


In [None]:
# %% Exercise 1
#    In the image matrix we created above, we picked random 7's and showed their labels. Create another image
#    matrix to show all of the times that a 7 was labeled "0". How do those 7's look? You can re-run this for
#    the other numbers.

# Sometimes the misclassification is nearly "understandable", in some other
# cases it just feels quite random

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(3,4,figsize=(6*phi,6))

misclassified_as_0 = torch.where(predictions == 0)[0]
n_examples         = min(12, len(misclassified_as_0))
chosen_indices     = np.random.choice(misclassified_as_0.numpy(),size=n_examples,replace=False)

for ax, idx in zip(axs.flatten(),chosen_indices):
    img = X7s[idx].reshape(28,28)
    pred_label = predictions[idx].item()

    ax.imshow(img,cmap='gray')
    ax.set_title(f'Predicted: {pred_label}')
    ax.axis('off')

# If fewer than 12 images, hide unused subplots
for ax in axs.flatten()[n_examples:]:
    ax.axis('off')

plt.suptitle('7s Misclassified as 0',fontsize=14)
plt.tight_layout(rect=[0,0,1,0.95])

plt.savefig('figure75_code_challenge_18_extra1.png')

plt.show()

files.download('figure75_code_challenge_18_extra1.png')


In [None]:
# %% Exercise 2
#    It's not surprising that most of the 7's were labeled as "9". You can now repeat this code file with other numbers
#    left out. What other pair of numbers do you expect to be commonly misclassified based on typographical similarity?

# Let's try to leave out 8, it might be similar to 6. But surprisingly no, most
# of the 8s are classified as either 3s or 5s.

# Plot
phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(1.5*6*phi,6))

ax[0].plot(losses_trn,label='Non-8s loss')
ax[0].plot(losses_tst,label='8s loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_ylim([-2.5,20])
ax[0].set_title('Model loss')
ax[0].legend()

ax[1].plot(train_acc,label='Non-8s accuracy')
ax[1].plot(test_acc,label='8s accuracy')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy (%)')
ax[1].set_ylim([-10,100])
ax[1].set_title(f'Final model test accuracy on 8s: {test_acc[-1]:.2f}%')
ax[1].legend()

plt.savefig('figure76_code_challenge_18_extra2.png')

plt.show()

files.download('figure76_code_challenge_18_extra2.png')

# Plot
phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(3,4,figsize=(6*phi,6))

ANN.eval()
X7s, y7s = test_data.tensors
with torch.no_grad():
    predictions = torch.argmax(ANN(X7s),dim=1)

rand_indices = np.random.choice(X7s.shape[0],size=12,replace=False)

for ax, idx in zip(axs.flatten(),rand_indices):
    img = X7s[idx].reshape(28,28)
    pred_label = predictions[idx].item()

    ax.imshow(img,cmap='gray')
    ax.set_title(f'Predicted digit: {pred_label}')
    ax.axis('off')

plt.suptitle('Model predictions on 8s',fontsize=14)
plt.tight_layout(rect=[0,0,1,0.95])

plt.savefig('figure77_code_challenge_18_extra2.png')

plt.show()

files.download('figure77_code_challenge_18_extra2.png')

# Plot
pred_counts = torch.bincount(predictions,minlength=10).numpy()
pred_proportions = pred_counts / len(predictions)

phi = (1 + np.sqrt(5)) / 2
plt.figure(figsize=(6*phi,6))

plt.bar(np.arange(10),pred_proportions,zorder=3)
plt.xticks(np.arange(10))
plt.xlabel('Predicted label')
plt.ylabel('Proportion of 8s classified as other digits')
plt.title('Misclassification distribution of digit 8')
plt.grid(axis='y',linestyle='--',alpha=0.7)

plt.savefig('figure78_code_challenge_18_extra2.png')

plt.show()

files.download('figure78_code_challenge_18_extra2.png')


In [None]:
# %% Exercise 3
#    Add dropout regularization to fc1 and fc2 (what else do you need to modify in the code to make sure the dropout
#    is applied only during training?). Does that affect how the model categorizes 7's?

# Easy fix adding the appropriate lines as in previous code challenges, and of
# course .eval() and .train() modes. That said, even with a massive 0.5 dropout
# rate the model is still totally failing
