# Neural Network Multi-class Classification

## **1 Introduction**

This notebook is my learning material to keep track of the notions approached in the [Advanced Learning Algorithms](https://www.coursera.org/learn/advanced-learning-algorithms?specialization=machine-learning-introduction) course from the [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) offered by DeepLearning.AI and Standford University.

Through this notebook, I use the [TEM virus dataset](https://data.mendeley.com/datasets/x4dwwfwtw3/3) created by Damian Matuszewski and
Ida-Maria Sintorn.

### **1.0.1 Imports**

In [None]:
import os
import wget
import shutil
import zipfile

# Data manipulation
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

# Image manipulation
from skimage.io import imread 
from skimage.transform import rescale

# Machine Learning
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Rescaling
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras.optimizers import Adam 
from tensorflow.keras.regularizers import L2

# Options for pandas
pd.options.display.max_columns = 50
pd.options.display.max_rows = 30

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Options for seaborn
sns.set_style('darkgrid')
%matplotlib inline

from IPython import get_ipython
ipython = get_ipython()

# Autoreload extesnions
if 'autoreload' not in ipython.extension_manager.loaded:
    %load_ext autoreload

### **1.1 Data**

#### **1.1.0.1 Download**

In [None]:
url = 'https://md-datasets-cache-zipfiles-prod.s3.eu-west-1.amazonaws.com/x4dwwfwtw3-3.zip'
filename = wget.download(url)

#### **1.1.0.2 Reorganize**

In [None]:
with zipfile.ZipFile('x4dwwfwtw3-3.zip') as zf:
    zf.extractall('TEM_virus')
    
# Keep the 256x256 image data
shutil.move('TEM_virus/context_virus_1nm_256x256/augmented_train/', 'TEM_virus/train')
shutil.move('TEM_virus/context_virus_1nm_256x256/validation/', 'TEM_virus/')
shutil.move('TEM_virus/context_virus_1nm_256x256/test/', 'TEM_virus/')

shutil.rmtree('TEM_virus/context_virus_1nm_256x256/')
shutil.rmtree('TEM_virus/context_virus_RAW/')

# Remove excluded data from test and validation set
shutil.rmtree('TEM_virus/test/_EXCLUDED/')
shutil.rmtree('TEM_virus/validation/_EXCLUDED/')

#### **1.1.0.3 Import**

In [None]:
virus_train = pd.DataFrame(columns=['Image', 'Pixels', 'Type'])
virus_validation = pd.DataFrame(columns=['Image', 'Pixels', 'Type'])
virus_test = pd.DataFrame(columns=['Image', 'Pixels', 'Type'])

root = 'TEM_virus/'
index = 0

for d in os.listdir(root):
    d_path = os.path.join(root, d)
    
    for sd in os.listdir(d_path):
        sd_path = os.path.join(d_path, sd)
        
        for f in os.listdir(sd_path):
            f_path = os.path.join(sd_path, f)
            
            # Delete crop_outlines folder
            if f == 'crop_outlines':
                shutil.rmtree(f_path)
                continue
                
            # Add data to dataframe
            new_data = pd.DataFrame([[f_path, imread(f_path), sd]],
                                    columns=['Image', 'Pixels', 'Type'])
            if d == 'train':
                virus_train = pd.concat([virus_train, new_data],
                                        ignore_index=True)
                continue
                
            if d == 'validation':
                virus_validation = pd.concat([virus_validation, new_data],
                                             ignore_index=True)
                continue
                
            if d == 'test':
                virus_test = pd.concat([virus_test, new_data],
                                        ignore_index=True)
                continue

#### **1.1. Exploratory Data Analysis**

In [None]:
virus_train.info()
virus_validation.info()
virus_test.info()

In [None]:
virus_train.head()

In [None]:
fig, axes = plt.subplots(1, 3,
                         sharey=True,
                         figsize=(15, 5))

sns.countplot(data=virus_train, x='Type',
              ax=axes[0])
sns.countplot(data=virus_validation, x='Type',
              ax=axes[1])
sns.countplot(data=virus_test, x='Type',
              ax=axes[2])

axes[0].set_title('Train')
axes[1].set_title('Validation')
axes[2].set_title('Test')

axes[0].tick_params(axis='x', rotation=90)
axes[1].tick_params(axis='x', rotation=90)
axes[2].tick_params(axis='x', rotation=90)

In [None]:
virus_train_pixels_group = virus_train.groupby('Type')['Pixels'].apply(np.stack)
virus_train_pixels_group

#### **2.2.1 Model**

In [None]:
fig, axes = plt.subplots(2, 7,
                         sharey=True,
                         figsize=(25, 7))

for ax, k in zip(fig.axes, virus_train_pixels_group.keys()):
    sns.heatmap(virus_train_pixels_group[k].mean(axis=0),
                vmin=0, vmax=255,
                cmap='binary',
                square=True,
                xticklabels=False, yticklabels=False,
                cbar=False,
                ax=ax)
    ax.set_title(k)
    
fig.suptitle('Viruses mean pixel values', fontsize=16)

## **2 Classification**

### **2.1 Preprocessing**

#### **2.1.1 Image downscale**

In [None]:
RESIZE_FACTOR = 0.125

def downscale(img):
    return rescale(img, RESIZE_FACTOR, anti_aliasing=True) * 255

In [None]:
virus_train['Pixels downscaled'] = virus_train['Pixels'].apply(lambda x: downscale(x))
virus_validation['Pixels downscaled'] = virus_validation['Pixels'].apply(lambda x: downscale(x))
virus_test['Pixels downscaled'] = virus_test['Pixels'].apply(lambda x: downscale(x))

virus_train[['Pixels', 'Pixels downscaled']]

In [None]:
virus_train_pixels_downscaled_group = virus_train.groupby('Type')['Pixels downscaled'].apply(np.stack)

fig, axes = plt.subplots(2, 7,
                         sharey=True,
                         figsize=(25, 7))

for ax, k in zip(fig.axes, virus_train_pixels_downscaled_group.keys()):
    sns.heatmap(virus_train_pixels_downscaled_group[k].mean(axis=0),
                vmin=0, vmax=255,
                cmap='binary',
                square=True,
                xticklabels=False, yticklabels=False,
                cbar=False,
                ax=ax)
    ax.set_title(k)
    
fig.suptitle('Viruses mean downscaled pixel values', fontsize=16)

#### **2.1.2 Split data**

In [None]:
X_train = np.array(virus_train['Pixels downscaled'].tolist())
y_train = virus_train['Type'].values

X_validation = np.array(virus_validation['Pixels downscaled'].tolist())
y_validation = virus_validation['Type'].values

X_test = np.array(virus_test['Pixels downscaled'].tolist())
y_test = virus_test['Type'].values

print(X_train.shape, y_train.shape)
print(X_validation.shape, y_validation.shape)
print(X_test.shape, y_test.shape)

#### **2.1.3  Feature scaling and label enconding**

In [None]:
le = LabelEncoder()

y_train = le.fit_transform(y_train)
y_validation = le.fit_transform(y_validation)
y_test = le.fit_transform(y_test)

### **2.2 Model**

#### **2.2.1 Building**

In [None]:
model = Sequential([
    Rescaling(1./255, input_shape=(32, 32)),
    Flatten(input_shape=(32, 32)),
    Dense(units=100, activation='relu'), # FIND RANDOMLY
    Dense(units=14, activation='linear'),
], name='virus_classification')

model.summary()

#### **2.2.2 Training and validation**

In [None]:
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True),
              optimizer=Adam(1e-3),
              metrics=[SparseCategoricalAccuracy()])

np.random.seed(100)

history = model.fit(X_train, y_train,
                    epochs=100,
                    validation_data=(X_validation, y_validation),
                    verbose=2)

#### **2.2.3 Test**

In [None]:
score = model.evaluate(X_test, y_test)

print('Test loss, Test acc: ', score)

## **3 Results**

### **3.1 History**

In [None]:
fig, axes = plt.subplots(2, 2,
                         sharex=True,
                         figsize=(10, 5))
palette = iter(sns.color_palette())

for ax, m in zip(fig.axes, history.history):
    sns.lineplot(x=history.epoch, y=history.history[m],
                 color=next(palette),
                 ax=ax)
    ax.set_title(m)

fig.suptitle('Training history', fontsize=16)
fig.supxlabel('epoch')

### **3.2 Accuracy**

In [None]:
fig, axes = plt.subplots(8, 8,
                         figsize=(8, 8))

for i, ax in enumerate(axes.flat):
    random_index = np.random.randint(virus_test.shape[0])
    img = X_test[random_index]
    y = y_test[random_index]
    
    prediction = model.predict(img.reshape(1, 32, 32))
    y_hat = prediction.argmax()
    
    ax.imshow(img, cmap='Greens' if y_hat == y else 'Reds')
    
    ax.set_title(f'{y},{y_hat}')
    ax.set_axis_off()

fig.suptitle('Label, Prediction', fontsize=16)
plt.tight_layout()

In [None]:
predictions = pd.DataFrame()

predictions['y_test'] = le.inverse_transform(y_test)
predictions['prediction'] = le.inverse_transform([v.argmax() for v in model.predict(X_test)])

accuracy = predictions.groupby(['y_test', 'prediction']).value_counts().reset_index(name='count')
accuracy['percentage'] = accuracy['count'] / accuracy.groupby('y_test')['count'].transform('sum')

plt.figure(figsize=(8, 8))

sns.heatmap(data=pd.pivot_table(data=accuracy,
                                index='y_test', columns='prediction', values='percentage',
                                fill_value=0),
            square=True,
            cmap='binary')

plt.suptitle('Predicition distribution', fontsize=16)