# Inspect prediction results from a CNN

In [None]:
import os
import pandas as pd
import numpy as np
import seaborn as sn
import pickle
import glob
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, balanced_accuracy_score, precision_score, recall_score
import read_settings

Set input and output directories

In [None]:
# Read settings
global_settings = read_settings.check_global()
instrument = global_settings['input_data']['instrument']
cnn_settings = read_settings.check_cnn()
use_weights = cnn_settings['data']['use_weights']

# Directory for input data
data_dir = os.path.join('data', instrument)

# Directory for training outputs
output_dirs = glob.glob(os.path.join('output', '_'.join(['cnn', 'w' if use_weights else 'nw', instrument, '*'])))
output_dirs.sort()
output_dirs

Choose output directory to inspect

In [None]:
output_dir = output_dirs[-1]

Glimpse at settings

In [None]:
with open(os.path.join(output_dir, 'settings.pickle'),'rb') as settings_file:
    settings = pickle.load(settings_file)
settings

## Input data

In [None]:
df_comp = pd.read_csv(os.path.join(output_dir, 'df_comp.csv')).set_index('classif_id')

In [None]:
plt.figure()
df_comp.plot.bar(stacked=True, figsize=(16,8), fontsize = 14)
plt.xlabel("Classes", fontsize = 14)
plt.ylabel("Image number", fontsize = 14)
plt.legend(loc="best")
plt.title("Dataset composition for CNN", fontsize = 16)
plt.show()

## Training

Read training file

In [None]:
with open(os.path.join(output_dir, 'train_results.pickle'),'rb') as results_file:
    train_results = pickle.load(results_file)
    
    train_acc = train_results.get('accuracy')
    val_acc = train_results.get('val_accuracy')
    
    train_loss = train_results.get('loss')
    val_loss = train_results.get('val_loss')
    
    lr = train_results.get('lr')
    
# Compute number of training epochs
epochs = len(train_acc)

Look for best epoch, i.e. where validation loss is smaller

In [None]:
best_epoch = np.argmin(val_loss) + 1 # Add 1 because vector indexing starts from 0 but epochs start from 1
print(f'Best epoch is number {best_epoch}')

Plot training evolution

In [None]:
plt.figure(figsize=(10, 15))
plt.subplot(3, 1, 1)
plt.plot(list(range(1, epochs+1)), train_acc, label='Training Accuracy')
plt.plot(list(range(1, epochs+1)), val_acc, label='Validation Accuracy')
plt.axvline(best_epoch, color = 'k', ls='dotted', label='best epoch')
plt.legend(loc='best')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(3, 1, 2)
plt.plot(list(range(1, epochs+1)), train_loss, label='Training Loss')
plt.plot(list(range(1, epochs+1)), val_loss, label='Validation Loss')
plt.axvline(best_epoch, color = 'k', ls='dotted', label='best epoch')
plt.legend(loc='best')
plt.ylabel('Cross Entropy')
plt.title('Training and Validation Loss')

plt.subplot(3, 1, 3)
plt.plot(list(range(1, epochs+1)), lr)
plt.ylabel('Learning rate')
plt.title('Learning rate evolution')

plt.xlabel('epoch')
plt.suptitle('CNN training', fontsize=20)
plt.show()

## Testing

Read test file

In [None]:
with open(os.path.join(output_dir, 'test_results.pickle'),'rb') as results_file:
    test_results = pickle.load(results_file)

    classes = test_results.get('classes')
    classes_g = test_results.get('classes_g')
    plankton_classes = test_results.get('plankton_classes')
    plankton_classes_g = test_results.get('plankton_classes_g')
    
    true_classes = test_results.get('true_classes')
    predicted_classes = test_results.get('predicted_classes')
    true_classes_g = test_results.get('true_classes_g')
    predicted_classes_g = test_results.get('predicted_classes_g')
    
    accuracy = test_results.get('accuracy')
    balanced_accuracy = test_results.get('balanced_accuracy')
    plankton_precision = test_results.get('plankton_precision')
    plankton_recall = test_results.get('plankton_recall')
    
    accuracy_g = test_results.get('accuracy_g')
    balanced_accuracy_g = test_results.get('balanced_accuracy_g')
    plankton_precision_g = test_results.get('plankton_precision_g')
    plankton_recall_g = test_results.get('plankton_recall_g')

### Accuracy, precision and recall scores

In [None]:
print(f'Accuracy score is {accuracy}')
print(f'Balanced accuracy score is {balanced_accuracy}')
print(f'Weighted plankton precision score is {plankton_precision}')
print(f'Weighted plankton recall is {plankton_recall}')

### Confusion matrix

Plot a confusion matrix

In [None]:
# Create confution matrix
cm = confusion_matrix(true_classes, predicted_classes, normalize='true')

# Plot it
plt.figure(figsize=(20,20))
plt.imshow(cm, cmap='Greys')
plt.colorbar()
plt.clim(0,1)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90, fontsize=14)
plt.yticks(tick_marks, classes, fontsize=14)
plt.ylabel('True label', fontsize=14)
plt.xlabel('Predicted label', fontsize=14)
plt.title("Confusion matrix for CNN", fontsize=20)
plt.show()

### Classification report

Plot a classification report

In [None]:
# Create classification report
report = classification_report(true_classes, predicted_classes, output_dict=True)

# List annotations for figure
annot = [str(x) for x in classes]
annot.extend(("accuracy", "macro avg", "weighted avg"))

# Convert report to dataframe
df_report = pd.DataFrame(report).transpose().drop('support', axis=1)

# Plot figure
plt.figure(figsize = (8,15))
sn.heatmap(df_report, annot=True, vmin=0, vmax=1.0,  yticklabels = annot, cmap="Greys")
plt.title("Classification report for CNN", fontsize=16)
plt.show()

### Accuracy, precision and recall scores after regrouping classes

In [None]:
print(f'Grouped accuracy score is {accuracy_g}')
print(f'Grouped balanced accuracy score is {balanced_accuracy_g}')
print(f'Grouped weighted plankton precision score is {plankton_precision_g}')
print(f'Grouped weighted plankton recall is {plankton_recall_g}')

### Confusion matrix

Plot a confusion matrix

In [None]:
# Create confution matrix
cm_g = confusion_matrix(true_classes_g, predicted_classes_g, normalize='true')

# Plot it
plt.figure(figsize=(20,20))
plt.imshow(cm_g, cmap='Greys')
plt.colorbar()
plt.clim(0,1)
tick_marks = np.arange(len(classes_g))
plt.xticks(tick_marks, classes_g, rotation=90, fontsize=14)
plt.yticks(tick_marks, classes_g, fontsize=14)
plt.ylabel('True label', fontsize=14)
plt.xlabel('Predicted label', fontsize=14)
plt.title("Confusion matrix for CNN after grouping ecological classes", fontsize=20)
plt.show()

### Classification report

Plot a classification report

In [None]:
# Create classification report
report = classification_report(true_classes_g, predicted_classes_g, output_dict=True)

# List annotations for figure
annot = [str(x) for x in classes_g]
annot.extend(("accuracy", "macro avg", "weighted avg"))

# Convert report to dataframe
df_report = pd.DataFrame(report).transpose().drop('support', axis=1)

# Plot figure
plt.figure(figsize = (8,15))
sn.heatmap(df_report, annot=True, vmin=0, vmax=1.0,  yticklabels = annot, cmap="Greys")
plt.title("Classification report for CNN after grouping ecological classes", fontsize=16)
plt.show()