In [None]:
# %% Deep learning - Section 7.52
#    Multi-output ANN (iris dataset)

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [None]:
# %% Libraries and modules
import numpy             as np
import matplotlib.pyplot as plt
import torch
import torch.nn          as nn
import seaborn           as sns
import copy

from google.colab                     import files
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')


In [None]:
# %% Import Iris dataset (it comes with seaborn)

iris = sns.load_dataset('iris')
iris.head()


In [None]:
# %% Plotting

sns.pairplot(iris,hue='species')

plt.savefig('figure56_multioutput_ann_iris.png')

plt.show()

files.download('figure56_multioutput_ann_iris.png')


In [None]:
# %% Organise data

# Convert from pandas dataframe to tensor
data = torch.tensor( iris[iris.columns[0:4]].values ).float()

# Convert labels to dummy variables
labels = torch.zeros( len(data),dtype=torch.long )

labels[iris.species=='setosa']     = 0  # technically no need
labels[iris.species=='versicolor'] = 1
labels[iris.species=='virginica']  = 2

labels


In [None]:
# %% Build the model

# Architecture
ANNiris = nn.Sequential(
             nn.Linear(4,64),   # input layer
             nn.ReLU(),         # a.f.
             nn.Linear(64,64),  # hidden layer
             nn.ReLU(),         # a.f.
             nn.Linear(64,3)    # output layer
             )

# Loss function (includes [Log]Softmax)
loss_fun = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.SGD(ANNiris.parameters(),lr=0.01)


In [None]:
# %% Train the model

num_epochs  = 1000
losses      = torch.zeros(num_epochs)
ongoing_acc = []

for epoch_i in range(num_epochs):

    # Forward propagation
    yHat = ANNiris(data)

    # Loss
    loss = loss_fun(yHat,labels)
    losses[epoch_i] = loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Accuracy over iteration (argmax takes the index with the highest val in yHat rows)
    matches     = torch.argmax(yHat,axis=1) == labels  # booleans
    matches_num = matches.float()                      # booleans2numeric
    acc_perc    = 100*torch.mean(matches_num)          # percent
    ongoing_acc.append(acc_perc)                       # append to list

# Final forward pass
predictions = ANNiris(data)

pred_labels = torch.argmax(predictions,axis=1)
tot_acc     = 100*torch.mean((pred_labels == labels).float())

print(f'Final accuracy = {tot_acc.item():.4f}%')


In [None]:
# %% Plotting

fig,ax = plt.subplots(1,2,figsize=(12,4))

ax[0].plot(losses.detach())
ax[0].set_ylabel('Loss')
ax[0].set_xlabel('Epoch')
ax[0].set_title('Losses over epochs')

ax[1].plot(ongoing_acc)
ax[1].set_ylabel('Accuracy')
ax[1].set_xlabel('Epoch')
ax[1].set_title('Accuracy over epochs')

plt.savefig('figure57_multioutput_ann_iris.png')

plt.show()

files.download('figure57_multioutput_ann_iris.png')


In [None]:
# %% Confirm that predictions sum to 1 after softmax

sm = nn.Softmax(1)

print(torch.sum(yHat,axis=1))
print( )
print(torch.sum(sm(yHat),axis=1))


In [None]:
# %% Plotting

fig = plt.figure(figsize=(10,4))

plt.plot(sm(yHat.detach()),'s-',markerfacecolor='w')
plt.xlabel('Stimulus number')
plt.ylabel('Probability')
plt.legend(['Setosa','Versicolor','Virginica'])
plt.title('Classification probabilities')

plt.savefig('figure59_multioutput_ann_iris.png')

plt.show()

files.download('figure59_multioutput_ann_iris.png')


In [None]:
# %% Plotting

fig = plt.figure(figsize=(10,4))

plt.plot(yHat.detach(),'s-',markerfacecolor='w')
plt.xlabel('Stimulus number')
plt.ylabel('Probability')
plt.legend(['Setosa','Versicolor','Virginica'])
plt.title('Classification raw output')

plt.savefig('figure60_multioutput_ann_iris.png')

plt.show()

files.download('figure60_multioutput_ann_iris.png')


In [None]:
# %% Exercise 1
#    When the loss does not reach an asymptote, it's a good idea to train the model for more epochs. Increase the number of
#    epochs until the plot of the losses seems to hit a "floor" (that's a statistical term for being as small as possible).

# Hard to say exactly when the losses hit the floor, but even with 2500 epochs, the
# model does not seem to really improve the accuracy


In [None]:
# %% Exercise 2
#    We used a model with 64 hidden units. Modify the code to have 16 hidden units. How does this model perform? If there
#    is a decrease in accuracy, is that decrease distributed across all three iris types, or does the model learn some
#    iris types and not others?

# Epochs reset to 1000. The model seems to perform equally good with 16 hidden
# nodes or with 64 hidden nodes, at least from a visual inspection of accuracy
# and plots; the iris setosa is well classified while the other two species are
# a bit more difficult to disentangle. If anything, maybe it takes a bit more
# epochs to reach a near-ceil accuracy

# Architecture
ANNiris = nn.Sequential(
             nn.Linear(4,16),   # input layer
             nn.ReLU(),         # a.f.
             nn.Linear(16,16),  # hidden layer
             nn.ReLU(),         # a.f.
             nn.Linear(16,3)    # output layer
             )

# Loss function (includes [Log]Softmax)
loss_fun = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.SGD(ANNiris.parameters(),lr=0.01)


In [None]:
# %% Exercise 3
#    Write code to compute three accuracy scores, one for each iris type. In real DL projects, category-specific accuracies
#    are often more informative than the aggregated accuracy.

# Model reset to 64 hidden nodes and 1000 epochs. In the end compute accuracy
# specifically for each category to show whether there is variability; as clued
# by the plots, the iris versicolor is not always perfectly categorised as the
# labels provided in the dataset

# Train the model
num_epochs  = 1000
losses      = torch.zeros(num_epochs)
ongoing_acc = []

for epoch_i in range(num_epochs):

    # Forward propagation
    yHat = ANNiris(data)

    # Loss
    loss = loss_fun(yHat,labels)
    losses[epoch_i] = loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Accuracy over iteration (argmax takes the index with the highest val in yHat rows)
    matches     = torch.argmax(yHat,axis=1) == labels  # booleans
    matches_num = matches.float()                      # booleans2numeric
    acc_perc    = 100*torch.mean(matches_num)          # percent
    ongoing_acc.append(acc_perc)                       # append to list

# Final forward pass
predictions = ANNiris(data)

# Overall accuracy
pred_labels = torch.argmax(predictions,axis=1)
tot_acc     = 100 * torch.mean((pred_labels == labels).float())

# Class accuracy (labels = [0,1,2], use boolean mask for category)
class_acc = torch.zeros(3)
for c in range(3):

    class_mask    = ( labels == c )
    correct_preds = ( pred_labels[class_mask] == labels[class_mask] ).float()

    if correct_preds.numel() > 0:
        class_acc[c]  = 100 * torch.mean(correct_preds)
    else:
        torch.tensor(0)

print(f'Final overall accuracy = {tot_acc:.2f}%')
print(f'Final accuracy for iris setosa = {class_acc[0].item():.2f}%')
print(f'Final accuracy for iris versicolor = {class_acc[1].item():.2f}%')
print(f'Final accuracy for iris virginica = {class_acc[2].item():.2f}%')


In [None]:
# %% Plotting

plt.figure(figsize=(6,4))

plt.bar(['Setosa','Versicolor','Virginica'],class_acc,color=['blue','green','red'])
plt.xlabel('Class')
plt.ylabel('Accuracy (%)')
plt.title('Classification accuracy per category')
plt.ylim(75,105)

plt.savefig('figure66_multioutput_ann_iris_extra3.png')

plt.show()

files.download('figure66_multioutput_ann_iris_extra3.png')
