# EEG Classification - Tensorflow
updated: Sep. 01, 2018

Data: https://www.physionet.org/pn4/eegmmidb/

## 1. Data Downloads

### Warning: Executing these blocks will automatically create directories and download datasets.

In [None]:
from keras import backend as K
K.tensorflow_backend._get_available_gpus()

In [None]:
# Tensorflow Style Guide
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# System
import requests
import re
import os
import pathlib
import urllib

# Modeling & Preprocessing
from keras.layers import Conv2D, BatchNormalization, Activation, Flatten, Dense, Dropout, LSTM, Input, TimeDistributed
from keras import initializers, Model, optimizers, callbacks
from keras.utils.training_utils import multi_gpu_model
from keras import backend as K
from keras.models import load_model
from keras.callbacks import Callback
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

# Essential Data Handling
import numpy as np
import pandas as pd
from math import ceil

# Get Paths
from glob import glob

# EEG package
from mne import pick_types
from mne.io import read_raw_edf

In [None]:
CONTEXT = 'pn4/'
MATERIAL = 'eegmmidb/'
URL = 'https://www.physionet.org/' + CONTEXT + MATERIAL

# Change this directory according to your setting
USERDIR = './data/'

page = requests.get(URL).text
FOLDERS = sorted(list(set(re.findall(r'S[0-9]+', page))))

URLS = [URL+x+'/' for x in FOLDERS]

In [None]:
# Warning: Executing this block will create folders
for folder in FOLDERS:
    pathlib.Path(USERDIR +'/'+ folder).mkdir(parents=True, exist_ok=True)

In [None]:
# Warning: Executing this block will start downloading data
for i, folder in enumerate(FOLDERS):
    page = requests.get(URLS[i]).text
    subs = list(set(re.findall(r'S[0-9]+R[0-9]+', page)))
    
    print('Working on {}, {:.1%} completed'.format(folder, (i+1)/len(FOLDERS)))
    for sub in subs:
        urllib.request.urlretrieve(URLS[i]+sub+'.edf', os.path.join(USERDIR, folder, sub+'.edf'))

## Data Description

Subjects performed different motor/imagery tasks while 64-channel EEG were recorded using the BCI2000 system (http://www.bci2000.org). Each subject performed 14 experimental runs: two one-minute baseline runs (one with eyes open, one with eyes closed), and three two-minute runs of each of the four following tasks:
A target appears on either the left or the right side of the screen. The subject opens and closes the corresponding fist until the target disappears. Then the subject relaxes.
A target appears on either the left or the right side of the screen. The subject imagines opening and closing the corresponding fist until the target disappears. Then the subject relaxes.
A target appears on either the top or the bottom of the screen. The subject opens and closes either both fists (if the target is on top) or both feet (if the target is on the bottom) until the target disappears. Then the subject relaxes.
A target appears on either the top or the bottom of the screen. The subject imagines opening and closing either both fists (if the target is on top) or both feet (if the target is on the bottom) until the target disappears. Then the subject relaxes.

The data are provided here in EDF+ format (containing 64 EEG signals, each sampled at 160 samples per second, and an annotation channel). For use with PhysioToolkit software, rdedfann generated a separate PhysioBank-compatible annotation file (with the suffix .event) for each recording. The .event files and the annotation channels in the corresponding .edf files contain identical data.

## 2. Raw Data Import

I will use a EEG data handling package named MNE (https://martinos.org/mne/stable/index.html) to import raw data and annotation for events from edf files. This package also provides essential signal analysis features, e.g. band-pass filtering. The raw data were filtered using 1Hz of high-pass filter.

In this research, there are 5 classes for the data: imagined motion of right fist, left fist, both fists, both feet, and rest with eyes closed. A data from one of the 109 subjects was excluded as the record was severely corrupted.

In [None]:
# Get file paths
PATH = './PhysioNet/'
SUBS = glob(PATH + 'S[0-9]*')
FNAMES = sorted([x[-4:] for x in SUBS])

# Remove subject #89 with damaged data
FNAMES.remove('S089')

In [None]:
def get_data(subj_num=FNAMES, epoch_sec=0.0625):
    """ Import from edf files data and targets in the shape of 3D tensor
    
        Output shape: (Trial*Channel*TimeFrames)
        
        Some edf+ files recorded at low sampling rate, 128Hz, are excluded. 
        Majority was sampled at 160Hz.
        
        epoch_sec: time interval for one segment of mashes
        """
    
    # Event codes mean different actions for two groups of runs
    run_type_0 = '02'.split(',')
    run_type_1 = '04,08,12'.split(',')
    run_type_2 = '06,10,14'.split(',')
    
    # Initiate X, y
    X = []
    y = []
    
    # To compute the completion rate
    count = len(subj_num)
    
    # fixed numbers
    nChan = 64 
    sfreq = 160
    sliding = epoch_sec/2 
    timeFromQue = 0.5

    # Sub-function to assign X and X, y
    def append_X(n_segments, data, event=[]):
        # Data should be changed
        '''This function generate a tensor for X and append it to the existing X'''
        
        if len(event):
            event_start = ceil(event[0] * sfreq)
        else:
            event_start = 0
    
        def window(n):
            windowStart = int(timeFromQue*sfreq) + int(sfreq*sliding*n) + event_start
            windowEnd = int(timeFromQue*sfreq) + int(sfreq*sliding*(n+2)) + event_start
            
            while (windowEnd - windowStart) != 10:
                windowEnd += int(sfreq*epoch_sec) - (windowEnd - windowStart)
                
            return [windowStart, windowEnd]
        
        new_x = [data[:, window(n)[0]: window(n)[1]] for n in range(n_segments)\
                 if data[:, window(n)[0]:window(n)[1]].shape==(nChan, int(sfreq*epoch_sec))]
        return new_x
    
    def append_X_Y(run_type, event, old_x, old_y, data):
        '''This function seperate the type of events 
        (refer to the data descriptitons for the list of the types)
        Then assign X and Y according to the event types'''
        # Number of sliding windows
        n_segments = int(event[1]/epoch_sec)
        
        # Rest excluded
        if event[2] == 'T0':
            return old_x, old_y
        
        # y assignment
        if run_type == 1:
            temp_y = [1] if event[2] == 'T1' else [2]
        
        elif run_type == 2:
            temp_y = [3] if event[2] == 'T1' else [4]
                
        new_x = append_X(n_segments, data, event)
        new_y = old_y + temp_y*len(new_x)
        
        return old_x + new_x, new_y
    
    # Iterate over subj_num: S001, S002, S003...
    for i, subj in enumerate(subj_num):
        # Return completion rate
        if i%((len(subj_num)//10)+1) == 0:
            print('working on {}, {:.0%} completed'.format(subj, i/count))

        # Get file names
        fnames = glob(os.path.join(PATH, subj, subj+'R*.edf'))
        fnames = sorted([name for name in fnames if name[-6:-4] in run_type_0+run_type_1+run_type_2])
        
        for i, fname in enumerate(fnames):
            
            # Import data into MNE raw object
            raw = read_raw_edf(fname, preload=True, verbose=False)

            picks = pick_types(raw.info, eeg=True)
            
            if raw.info['sfreq'] != 160:
                print('{} is sampled at 128Hz so will be excluded.'.format(subj))
                break
            
            # High-pass filtering
            raw.filter(l_freq=1, h_freq=None, picks=picks)
            
            # Get annotation
            try:
                events = raw.find_edf_events()
            except:
                continue
                
            # Get data
            data = raw.get_data(picks=picks)
            
            # Number of this run
            which_run = fname[-6:-4]
            
            """ Assignment Starts """ 
            # run 1 - baseline (eye closed)
            if which_run in run_type_0:
                
                # Number of sliding windows
                n_segments = int((raw.n_times/(epoch_sec*sfreq)))

                # Append 0`s based on number of windows
                new_X = append_X(n_segments, data)
                X += new_X
                y.extend([0] * len(new_X))
                    
            # run 4,8,12 - imagine opening and closing left or right fist    
            elif which_run in run_type_1:
                
                for i, event in enumerate(events):
                    X, y = append_X_Y(run_type=1, event=event, old_x=X, old_y=y, data=data)
                        
            # run 6,10,14 - imagine opening and closing both fists or both feet
            elif which_run in run_type_2:
                   
                for i, event in enumerate(events):         
                    X, y = append_X_Y(run_type=2, event=event, old_x=X, old_y=y, data=data)
                        
    X = np.stack(X)
    y = np.array(y).reshape((-1,1))
    return X, y

In [None]:
X,y = get_data(FNAMES, epoch_sec=0.0625)

In [None]:
print(X.shape)
print(y.shape)

## 3. Data Preprocessing

The original goal of applying neural networks is to exclude hand-crafted algorithms & preprocessing as much as possible. I did not use any proprecessing techniques further than standardization to build an end-to-end classifer from the dataset

In [None]:
import numpy as np
from sklearn.preprocessing import OneHotEncoder, scale

#%%
def convert_mesh(X):
    
    mesh = np.zeros((X.shape[0], X.shape[2], 10, 11, 1))
    X = np.swapaxes(X, 1, 2)
    
    # 1st line
    mesh[:, :, 0, 4:7, 0] = X[:,:,21:24]; print('1st finished')
    
    # 2nd line
    mesh[:, :, 1, 3:8, 0] = X[:,:,24:29]; print('2nd finished')
    
    # 3rd line
    mesh[:, :, 2, 1:10, 0] = X[:,:,29:38]; print('3rd finished')
    
    # 4th line
    mesh[:, :, 3, 1:10, 0] = np.concatenate((X[:,:,38].reshape(-1, X.shape[1], 1),\
                                          X[:,:,0:7], X[:,:,39].reshape(-1, X.shape[1], 1)), axis=2)
    print('4th finished')
    
    # 5th line
    mesh[:, :, 4, 0:11, 0] = np.concatenate((X[:,:,(42, 40)],\
                                        X[:,:,7:14], X[:,:,(41, 43)]), axis=2)
    print('5th finished')
    
    # 6th line
    mesh[:, :, 5, 1:10, 0] = np.concatenate((X[:,:,44].reshape(-1, X.shape[1], 1),\
                                        X[:,:,14:21], X[:,:,45].reshape(-1, X.shape[1], 1)), axis=2)
    print('6th finished')
               
    # 7th line
    mesh[:, :, 6, 1:10, 0] = X[:,:,46:55]; print('7th finished')
    
    # 8th line
    mesh[:, :, 7, 3:8, 0] = X[:,:,55:60]; print('8th finished')
    
    # 9th line
    mesh[:, :, 8, 4:7, 0] = X[:,:,60:63]; print('9th finished')
    
    # 10th line
    mesh[:, :, 9, 5, 0] = X[:,:,63]; print('10th finished')
    
    return mesh

#%%
def prepare_data(X, y, test_ratio=0.2, return_mesh=True, set_seed=42):
    
    # y encoding
    oh = OneHotEncoder()
    y = oh.fit_transform(y).toarray()
    
    # Shuffle trials
    np.random.seed(set_seed)
    trials = X.shape[0]
    shuffle_indices = np.random.permutation(trials)
    X = X[shuffle_indices]
    y = y[shuffle_indices]
    
    # Test set seperation
    train_size = int(trials*(1-test_ratio)) 
    X_train, X_test, y_train, y_test = X[:train_size,:,:], X[train_size:,:,:],\
                                    y[:train_size,:], y[train_size:,:]
                                    
    # Z-score Normalization
    def scale_data(X):
        shape = X.shape
        for i in range(shape[0]):
            X[i,:, :] = scale(X[i,:, :])
            if i%int(shape[0]//10) == 0:
                print('{:.0%} done'.format((i+1)/shape[0]))   
        return X
            
    X_train, X_test  = scale_data(X_train), scale_data(X_test)
    if return_mesh:
        X_train, X_test = convert_mesh(X_train), convert_mesh(X_test)
    
    return X_train, y_train, X_test, y_test
    
    

In [None]:
X_train, y_train, X_test, y_test = prepare_data(X, y)

As the EEG recording instrument has 3D locations over the subjects\` scalp, it is essential for the model to learn from the spatial pattern as well as the temporal pattern. I transformed the data into 2D meshes that represents the locations of the electrodes so that stacked convolutional neural networks can grasp the spatial information.

## 4. Modeling - Time-Distributed CNN + RNN

Training Plan:

+ 4 GPU units (Nvidia Tesla P100) were used to train this neural network.
+ Instead of training the whole model at once, I trained the first block (CNN) first. Then using the trained parameters as initial values, I trained the next blocks step-by-step. This approach can greatly reduce the time required for training and help avoiding falling into local minimums.
+ The first blocks (CNN) can be applied for other EEG classification models as a pre-trained base.

+ The initial learning rate is set to be $10^{3}$ with Adam optimization. I used several callbacks such as ReduceLROnPlateau which adjusts the learning rate at local minima. Also, I record the log for tensorboard to monitor the training process.

In [None]:
X_train = X_train.squeeze().reshape(*X_train.squeeze().shape, 1)
X_test = X_test.squeeze().reshape(*X_test.squeeze().shape, 1)

In [None]:
# Make another dimension, 1, to apply CNN for each time frame.
X_train = X_train.reshape(*X_train.shape, 1)
X_test = X_test.reshape(*X_test.shape, 1)

### 4.1 Keras Implementation

The Keras functional API is the way to go for defining complex models, such as multi-output models, directed acyclic graphs, or models with shared layers.

In [None]:
## Complicated Model - the same as Zhang`s
input_shape = (10, 10, 11, 1)
lecun = initializers.lecun_normal(seed=42)

# TimeDistributed Wrapper
def timeDist(layer, prev_layer, name):
    return TimeDistributed(layer, name=name)(prev_layer)
    
# Input layer
inputs = Input(shape=input_shape)

# Convolutional layers block
x = timeDist(Conv2D(32, (3,3), padding='same', 
                    data_format='channels_last', kernel_initializer=lecun), inputs, name='CNN1')
x = BatchNormalization(name='batch1')(x)
x = Activation('elu', name='act1')(x)
x = timeDist(Conv2D(64, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), x, name='CNN2')
x = BatchNormalization(name='batch2')(x)
x = Activation('elu', name='act2')(x)
x = timeDist(Conv2D(128, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), x, name='CNN3')
x = BatchNormalization(name='batch3')(x)
x = Activation('elu', name='act3')(x)
x = timeDist(Flatten(), x, name='flatten')

# Fully connected layer block
y = Dense(1024, kernel_initializer=lecun, name='FC')(x)
y = Dropout(0.5, name='dropout1')(y)
y = BatchNormalization(name='batch4')(y)
y = Activation(activation='elu')(y)

# Recurrent layers block
z = LSTM(64, kernel_initializer=lecun, return_sequences=True, name='LSTM1')(y)
z = LSTM(64, kernel_initializer=lecun, name='LSTM2')(z)

# Fully connected layer block
h = Dense(1024, kernel_initializer=lecun, activation='elu', name='FC2')(z)
h = Dropout(0.5, name='dropout2')(h)

# Output layer
outputs = Dense(5, activation='softmax')(h)

# Model compile
model = Model(inputs=inputs, outputs=outputs)
model.summary()

In [None]:
model = load_model('./model_1230.h5')

In [None]:
'''

# Load a model to transfer pre-trained parameters
trans_model = model.load('CNN_3blocks.h5')

# Transfer learning - parameter copy & paste
which_layer = 'CNN1,CNN2,CNN3,batch1,batch2,batch3'.split(',')
layer_names = [layer.name for layer in model.layers]
trans_layer_names = [layer.name for layer in trans_model.layers]

for layer in which_layer:
    ind = layer_names.index(layer)
    trans_ind = trans_layer_names.index(layer)
    model.layers[ind].set_weights(trans_model.layers[trans_ind].get_weights())
    
for layer in model.layers[:9]: # Freeze the first 9 layers(CNN block)
    layer.trainable = False
    
    
# Turn on multi-GPU mode
model = multi_gpu_model(model, gpus=4)


This metrics calculate sensitivity and specificity batch-wise.
Keras development team removed this feature because
these metrics should be understood as global metrics.

I am not using it this time.

# Metrics - sensitivity, specificity, accuracy
def sens(y_true, y_pred): # Sensitivity
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def prec(y_true, y_pred): # Precision
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())
'''

In [None]:
from keras.callbacks import Callback

class CustomModelCheckPoint(Callback):
    def __init__(self,**kargs):
        super(CustomModelCheckPoint,self).__init__(**kargs)
        self.epoch_accuracy = {} # loss at given epoch
        self.epoch_loss = {} # accuracy at given epoch
        def on_epoch_begin(self,epoch, logs={}):
            # Things done on beginning of epoch. 
            return

        def on_epoch_end(self, epoch, logs={}):
            # things done on end of the epoch
            self.epoch_accuracy[epoch] = logs.get("acc")
            self.epoch_loss[epoch] = logs.get("loss")
            #self.model.save_weights("name-of-model-%d.h5" %epoch) # save the model

In [None]:
callbacks_list = [callbacks.ModelCheckpoint('model_1230.h5', save_best_only=True, monitor='val_loss'),
                 callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
                 callbacks.CSVLogger('model_1230_res.csv', separator=',', append=True)]

# Start training
model.compile(loss='categorical_crossentropy', optimizer=optimizers.adam(lr=1e-4), metrics=['acc'])
history = model.fit(X_train, y_train, batch_size=64, epochs=5000, shuffle=True,
                    validation_split=0.2, callbacks=callbacks_list)

### 4.2 Tensorflow Eager Execution API

TensorFlow's eager execution is an imperative programming environment that evaluates operations immediately, without building graphs: operations return concrete values instead of constructing a computational graph to run later. This makes it easy to get started with TensorFlow and debug models, and it reduces boilerplate as well. To follow along with this guide, run the code samples below in an interactive python interpreter.

In [None]:
# Parameters
learning_rate = 0.001
num_steps = 1000
batch_size = 128
display_step = 100

# Network Parameters
num_input = X.shape[0] # PhysioNet data input (mesh shape: 10*11)
num_classes = 5 # PhysioNet total classes

In [None]:
# Using TF Dataset to split data into batches
dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(batch_size)
dataset_iter = tfe.Iterator(dataset)

In [None]:
class ZhangModel(tf.keras.Model):
    def __init__(self, n_nodes=[3,2,2], 
                 initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG")):
        """
        This is a tensorflow implementation of Zhang`s model (2018) with 
        3 CNN, 2 LSTM, 2 dense layers by default.
        
        n_nodes [list] : specifies the number of layers for each block. 
                        [CNN, LSTM, DENSE] respectively.
        initializer [tf.layers] : defualt initializer set to be He initializer
        """
        super().__init__()
        self.n_CNN, self.n_LSTM, self.n_dense = n_nodes
        self.CNN, self.LSTM, self.dense = [[] for i in range(len(n_nodes))]
        self.initializer = initializer
        
        count = 0
        for n in range(self.n_CNN):
            count += 1
            n_filter = 32*2**count
            
            self.CNN.append(tf.keras.layers.Conv2D(n_filter, (3, 3), padding='same', 
                                                data_format='channels_last',
                                                kernel_initializer=self.initializer))
        
        for n in range(self.n_LSTM):    
            n_hidden = 64
            return_sequences = False if n == self.n_LSTM-1 else True
            
            name = 'LSTM' + str(len(self.LSTM)+1)            
            self.LSTM.append(tf.keras.layers.LSTM(n_hidden, kernel_initializer=self.initializer, 
                                             return_sequences=return_sequences, name=name))
            
        for n in range(self.n_dense):
            n_node=1024
            
            name = 'Dense' + str(len(self.dense)+1)
            self.dense.append(tf.keras.layers.Dense(n_node, 
                                                    kernel_initializer=self.initializer, name=name))

    def call(self, input_tensor):
        "Run the model."
        
        assert self.CNN, 'No CNN blocks defined!'
        assert self.dense, 'No Dense Blocks defined!'
        assert self.LSTM, 'No LSTM blocks defined!'
        
        def timeDist(layer, prev_layer, name):
            return tf.keras.layers.TimeDistributed(layer, name=name)(prev_layer)
        
        for i, layer in enumerate(self.CNN):
            name = 'CNN' + str(i+1)
            nameBatch = 'batch' + str(i+1)
            nameAct = 'act' + str(i+1)
            
            prev_layer = input_tensor if i==0 else x
            
            x = timeDist(layer, prev_layer, name=name)
            x = tf.keras.layers.BatchNormalization(name=nameBatch)(x)
            x = tf.nn.elu(x, name=nameAct)
            
        x = timeDist(tf.keras.layers.Flatten(), x, name='flatten')
        
        for i, layer in enumerate(self.dense):            
            
            nameDrop = 'drop' + str(i+1)
            nameBatch = 'batch' + str(i+1)
            nameAct = 'act' + str(i+1)
            
            if i == len(self.dense)-1:
                break
            
            x = layer(x)
            x = tf.keras.layers.Dropout(0.5, name=nameDrop)(x)
            x = tf.keras.layers.BatchNormalization(name=nameBatch)(x)
            x = tf.nn.elu(x, name=nameAct)
            
        for i, layer in enumerate(self.LSTM):
            x = layer(x)
        
        x = self.dense[-1](x)
        x = tf.keras.layers.Dropout(0.5, name=nameDrop)(x)
        x = tf.keras.layers.BatchNormalization(name=nameBatch)(x)
        x = tf.nn.elu(x, name=nameAct)
        
        output = tf.keras.layers.Dense(5, activation='softmax')(x)
        
        return output    

In [None]:
model = ZhangModel([3, 2, 2])
print(model(tf.random_normal([1, 10, 10, 11, 1])))

In [None]:
# Cross-Entropy loss function
def loss_fn(inference_fn, inputs, labels):
    # Using sparse_softmax cross entropy
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=inference_fn(inputs), labels=labels))

# Calculate accuracy
def accuracy_fn(inference_fn, inputs, labels):
    prediction = inference_fn(inputs)
    correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
    return tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# SGD Optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

# Compute gradients
grad = tfe.implicit_gradients(loss_fn)

In [None]:
# Training
average_loss = 0.
average_acc = 0.
for step in range(num_steps):

    # Iterate through the dataset
    try:
        d = dataset_iter.next()
    except StopIteration:
        # Refill queue
        dataset_iter = tfe.Iterator(dataset)
        d = dataset_iter.next()

    # EEGs
    x_batch = d[0]
    # Labels
    y_batch = tf.cast(d[1], dtype=tf.int64)

    # Compute the batch loss
    batch_loss = loss_fn(model, x_batch, y_batch)
    average_loss += batch_loss
    
    # Compute the batch accuracy
    batch_accuracy = accuracy_fn(model, x_batch, y_batch)
    average_acc += batch_accuracy

    if step == 0:
        # Display the initial cost, before optimizing
        print("Initial loss= {:.9f}".format(average_loss))

    # Update the variables following gradients info
    optimizer.apply_gradients(grad(model, x_batch, y_batch))

    # Display info
    if (step + 1) % display_step == 0 or step == 0:
        if step > 0:
            average_loss /= display_step
            average_acc /= display_step
        print("Step:", '%04d' % (step + 1), " loss=",
              "{:.9f}".format(average_loss), " accuracy=",
              "{:.4f}".format(average_acc))
        average_loss = 0.
        average_acc = 0.

### 5. Evaluation

In [None]:
# load in libraries
import pickle
import itertools
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

In [None]:
# make directories
if not os.path.exists('./metrics/'):
    os.makedirs('./metrics/')

In [None]:
def plot_history(history):
    loss_list = [s for s in history.keys() if 'loss' in s and 'val' not in s]
    val_loss_list = [s for s in history.keys() if 'loss' in s and 'val' in s]
    acc_list = [s for s in history.keys() if 'acc' in s and 'val' not in s]
    val_acc_list = [s for s in history.keys() if 'acc' in s and 'val' in s]
    
    if len(loss_list) == 0:
        print('Loss is missing in history')
        return 
    
    ## As loss always exists
    epochs = range(1,len(history[loss_list[0]]) + 1)
    
   ## Loss
    plt.figure(1)
    for l in loss_list:
        plt.plot(epochs, history[l], 'b', label='Training loss (' + str(str(format(history[l][-1],'.5f'))+')'))
    for l in val_loss_list:
        plt.plot(epochs, history[l], 'g', label='Validation loss (' + str(str(format(history[l][-1],'.5f'))+')'))
    
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("./metrics/loss.png")
    
    ## Accuracy
    plt.figure(2)
    for l in acc_list:
        plt.plot(epochs, history[l], 'b', label='Training accuracy (' + str(format(history[l][-1],'.5f'))+')')
    for l in val_acc_list:    
        plt.plot(epochs, history[l], 'g', label='Validation accuracy (' + str(format(history[l][-1],'.5f'))+')')

    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()
    plt.savefig("./metrics/acc.png")
    
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        title='Normalized confusion matrix'
    else:
        title='Confusion matrix'

    plt.figure(3)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig("./metrics/confuMat.png")
    plt.show()
    
def full_multiclass_report(model,
                           x,
                           y_true,
                           classes):
    
    # 2. Predict classes and stores in y_pred
    y_pred = model.predict(x).argmax(axis=1)
    
    # 3. Print accuracy score
    print("Accuracy : "+ str(accuracy_score(y_true,y_pred)))
    
    print("")
    
    # 4. Print classification report
    print("Classification Report")
    print(classification_report(y_true,y_pred,digits=4))    
    
    # 5. Plot confusion matrix
    cnf_matrix = confusion_matrix(y_true,y_pred)
    print(cnf_matrix)
    plot_confusion_matrix(cnf_matrix,classes=classes)    

In [None]:
# Load in the data
howManyTest = 0.2

thisInd = np.random.randint(0, len(X_test), size=(len(X_test)//howManyTest))
X_conf, y_conf = X_test[[i for i in thisInd], :], y_test[[i for i in thisInd],:] 

'''
## Only if you have a previous model + history
# Get the model
model = models.load_model('./model/model0.h5')

# Get the history
with open('./history/history0.pkl', 'rb') as hist:
    history = pickle.load(hist)
'''

# Get the graphics
plot_history(history)
X_test = X_test.reshape(X_test.shape[0], X_train.shape[1], X_train.shape[2], X_train.shape[3], 1)
full_multiclass_report(model,
                       X_test,
                       y_test.argmax(axis=1),
                       [1,2,3,4,5])