##  CNN Model for SSVEP 
classification on EEG dataset for BCI

In this notebook, we will be using SSVEP EEG data to classify an ssvep stimuli of a particular frequency

The dataset used will be in the WATOLINK-data format (refer to https://docs.google.com/document/d/1iVEE2In7eUX1bruULMqP5EWzkU26FuvoLZINJcEDC5U/edit)

In [6]:
#!pip install -q -r requirements.txt

In [1]:
import os
import zipfile
import numpy as np
import numpy.matlib as npm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.io as sio
import pandas as pd
import warnings
import itertools


from CNN_files.Preprocess import Preprocess
from CNN_files import ssvep_utils as su

from scipy.signal import butter, filtfilt

from sklearn.model_selection import KFold 
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras

from keras.layers import Dense, LSTM, Input
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Dropout, Conv2D, BatchNormalization
from keras.layers import Input,Flatten, Dense
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.core import Dropout, Activation
from keras.layers.pooling import GlobalAveragePooling2D
from keras.models import Model
from keras.models import load_model
from keras import optimizers
from keras.losses import categorical_crossentropy
from keras.layers import Dense, Activation, Flatten, Dropout, Conv2D, BatchNormalization
from keras.utils.np_utils import to_categorical
from keras import initializers, regularizers

# Feature Extraction


In [6]:
data_path = os.path.abspath('../data')


# define params
CNN_PARAMS = {
    'batch_size': 64,
    'epochs': 250,
    'droprate': 0.25,
    'learning_rate': 0.001,
    'lr_decay': 0.0,
    'l2_lambda': 0.0001,
    'momentum': 0.9,
    'kernel_f': 10,
    'n_ch': 8, 
    'num_classes': 5} # can be changed

FFT_PARAMS =  {'resolution': 0.2930,
                        'start_frequency': 3.0,
                        'end_frequency': 35.0,
                        'sampling_rate': 250 }

window_len = 1
shift_len = 1
    
all_acc = np.zeros((10, 1))

magnitude_spectrum_features = dict()


NOTE: a fourier transform essentially uses math to identify underlying frequencies within a specific signal

* magnitude_spectrum: the magnitudes of frequency attained from the transform

* window_len: the size of the window we apply the transofrm to, everwhere else is tapered to around 0

* shift_len: the length we want to shift after applying the window function

* flicker_freq: the labels of frequencies we are attempting to classify

In [7]:
data_path = "/content/Brain-computer-interfaces/data"
magnitude_spectrum = dict()
window_len = 1
shift_len = 1
sample_rate = FFT_PARAMS['sampling_rate']
#flicker_freq = np.array([0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 14.0, 16.0, 18.0, 20.0]) # for watolink
#flicker_freq = np.array([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 14.0, 16.0, 18.0, 20.0]) # for watolink
flicker_freq = np.array([0,10.25,11.75,12.75,14.75])


In [10]:
all_segmented_data = dict()
for subject in range(1):
  eeg = d 
  total_trial_len = eeg.shape[2]
  num_trials = eeg.shape[3]
  sample_rate = 250 
  filtered_data = su.get_filtered_eeg(eeg, 6, 80, 4, sample_rate)
  all_segmented_data[f's{subject+1}'] = su.get_segmented_epochs(filtered_data, window_len, 
                                                                  shift_len, sample_rate)



In [11]:
all_segmented_data['s1'].shape

(5, 8, 15, 4, 250)

In [12]:
for subject in all_segmented_data.keys():
    magnitude_spectrum_features[subject] = su.magnitude_spectrum_features(all_segmented_data[subject], 
                                                                          FFT_PARAMS)
   

In [13]:
magnitude_spectrum_features['s1'].shape

(110, 8, 5, 15, 4)

In [14]:
mcnn_training_data = dict()
mcnn_results = dict()


In [15]:

#function to get training data

def get_training_data(features_data):
    features_data = np.reshape(features_data, (features_data.shape[0], features_data.shape[1], 
                                               features_data.shape[2], 
                                               features_data.shape[3]*features_data.shape[4]))
    train_data = features_data[:, :, 0, :].T
    for target in range(1, features_data.shape[2]):
        train_data = np.vstack([train_data, np.squeeze(features_data[:, :, target, :]).T])

    train_data = np.reshape(train_data, (train_data.shape[0], train_data.shape[1], 
                                         train_data.shape[2], 1))
    total_epochs_per_class = features_data.shape[3]
    features_data = []
    class_labels = np.arange(CNN_PARAMS['num_classes'])
    labels = (npm.repmat(class_labels, total_epochs_per_class, 1).T).ravel()
    labels = to_categorical(labels)
    
    return train_data, labels

In [16]:
mcnn_training_data = dict()
mcnn_results = dict()

for subject in all_segmented_data.keys():
    mcnn_training_data[subject] = dict()
   
    train_data, labels = get_training_data(magnitude_spectrum_features[subject])
    mcnn_training_data[subject]['train_data'] = train_data
    mcnn_training_data[subject]['label'] = labels
    
   
    

In [17]:
mcnn_training_data['s1']['train_data'].shape

(300, 8, 110, 1)

In [18]:
mcnn_training_data['s1']['label'].shape

(300, 5)

# Display Data

In [None]:
# Mika you should probably input your script here

Remember data is structured as [Number of targets, Number of channels, Number of sampling points, Number of trials] = size(eeg)

* Number of targets: 12
* Number of channels: 8 
* Number of sampling points: 1114 
* Number of trials: 2 (WATOLINK DATA FORMAT RN)
* Sampling rate [Hz] : 250

The data needs to be reshaped so that the 8 channels data at a particular time is grouped together

In [None]:
magnitude_spectrum_features['s1'].shape

(110, 8, 5, 15, 4)

In [22]:
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): 
  """Makes a labelled confusion matrix comparing predictions and ground truth labels.
  If classes is passed, confusion matrix will be labelled, if not, integer class values
  will be used.
  Args:
    y_true: Array of truth labels (must be same shape as y_pred).
    y_pred: Array of predicted labels (must be same shape as y_true).
    classes: Array of class labels (e.g. string form). If `None`, integer labels are used.
    figsize: Size of output figure (default=(10, 10)).
    text_size: Size of output figure text (default=15).
    norm: normalize values or not (default=False).
    savefig: save confusion matrix to file (default=False).
  
  Returns:
    A labelled confusion matrix plot comparing y_true and y_pred.
  Example usage:
    make_confusion_matrix(y_true=test_labels, # ground truth test labels
                          y_pred=y_preds, # predicted labels
                          classes=class_names, # array of class label names
                          figsize=(15, 15),
                          text_size=10)
  """  
  # Create the confustion matrix
  cm = confusion_matrix(y_true, y_pred)
  cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
  n_classes = cm.shape[0] # find the number of classes we're dealing with

  # Plot the figure and make it pretty
  fig, ax = plt.subplots(figsize=figsize)
  cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
  fig.colorbar(cax)

  # Are there a list of classes?
  if classes:
    labels = classes
  else:
    labels = np.arange(cm.shape[0])
  
  # Label the axes
  ax.set(title="Confusion Matrix",
         xlabel="Predicted label",
         ylabel="True label",
         xticks=np.arange(n_classes), # create enough axis slots for each class
         yticks=np.arange(n_classes), 
         xticklabels=labels, # axes will labeled with class names (if they exist) or ints
         yticklabels=labels)
  
  # Make x-axis labels appear on bottom
  ax.xaxis.set_label_position("bottom")
  ax.xaxis.tick_bottom()

  # Set the threshold for different colors
  threshold = (cm.max() + cm.min()) / 2.

  # Plot the text on each cell
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    if norm:
      plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)
    else:
      plt.text(j, i, f"{cm[i, j]}",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)

  # Save the figure to the current working directory
  if savefig:
    fig.savefig("confusion_matrix.png")

Now, finally train and evauate the model!

# Displaying Metrics

In [None]:
def calculate_results(y_true, y_pred):
  """
  Calculates model accuracy, precision, recall and f1 score of a binary classification model.
  Args:
      y_true: true labels in the form of a 1D array
      y_pred: predicted labels in the form of a 1D array
  Returns a dictionary of accuracy, precision, recall, f1-score.
  """
  # Calculate model accuracy
  model_accuracy = accuracy_score(y_true, y_pred) * 100
  # Calculate model precision, recall and f1 score using "weighted average
  model_precision, model_recall, model_f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted")
  model_results = {"accuracy": model_accuracy,
                  "precision": model_precision,
                  "recall": model_recall,
                  "f1": model_f1}
  return model_results

# Pretrained Model and Fine Tuning

*make sure h5 model is named "model.h5" and in working directory

In [4]:
model = load_model('./CNN_files/model.h5')




In [7]:
model.summary()




In [None]:
# Build new model
new_model = Sequential()

for layer in model.layers[:-1]: # go through until last layer
    new_model.add(layer)
#new_model.add(Dense(13, activation='softmax'))





new_model.add(Dense(5, activation='softmax'))
new_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_200 (Conv2D)          (None, 1, 110, 16)        144       
_________________________________________________________________
batch_normalization_200 (Bat (None, 1, 110, 16)        64        
_________________________________________________________________
activation_200 (Activation)  (None, 1, 110, 16)        0         
_________________________________________________________________
dropout_200 (Dropout)        (None, 1, 110, 16)        0         
_________________________________________________________________
conv2d_201 (Conv2D)          (None, 1, 101, 16)        2576      
_________________________________________________________________
batch_normalization_201 (Bat (None, 1, 101, 16)        64        
_________________________________________________________________
activation_201 (Activation)  (None, 1, 101, 16)       

In [None]:
for layer in new_model.layers[:-1]:
  layer.trainable = False 


## Compile The New Model

In [None]:
new_model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

In [None]:
mcnn_training_data['s1']['train_data'][8].shape

(8, 110, 1)

# Split Data for Train and Val

In [None]:


def split_train_test(data, labels):
   
   """Splits data into training and validation set"""
   n_data, n_labels = shuffle(data,labels, random_state = 0)

   split_index = int(np.round(len(n_data)*0.8) )
   print(split_index)

   train_data = n_data[:split_index]
   test_data = n_data[split_index:]

   train_labels = n_labels[:split_index]
   test_labels = n_labels[split_index:]

   return (train_data, test_data, train_labels, test_labels) 

   


In [None]:
train_data, test_data, train_labels, test_labels = split_train_test(mcnn_training_data['s1']['train_data'], mcnn_training_data['s1']['label'])

240


In [None]:
train_labels.shape

(240, 5)

In [None]:
history = new_model.fit(train_data, train_labels, epochs = 100, validation_data = (test_data, test_labels))

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [None]:
test_data.shape

(60, 8, 110, 1)

# Making a prediction

In [None]:
preds = new_model.predict(test_data)



In [None]:
preds.shape

(60, 5)

In [None]:
new_preds = []
for i in range(0, len(preds)):
  new_preds.append(np.argmax(preds[i]))

In [None]:
new_preds

[0,
 3,
 2,
 0,
 1,
 4,
 3,
 4,
 1,
 1,
 2,
 2,
 4,
 3,
 0,
 2,
 3,
 4,
 0,
 0,
 3,
 2,
 2,
 4,
 3,
 0,
 3,
 2,
 3,
 4,
 4,
 1,
 4,
 2,
 2,
 3,
 4,
 1,
 0,
 1,
 4,
 3,
 2,
 2,
 4,
 0,
 3,
 3,
 1,
 1,
 4,
 2,
 1,
 1,
 3,
 4,
 3,
 1,
 2,
 2]

In [None]:
g = np.argmax(test_labels, axis = -1)
g

array([2, 3, 1, 0, 2, 4, 3, 0, 1, 2, 1, 0, 4, 3, 0, 2, 0, 4, 2, 0, 3, 2,
       2, 4, 3, 0, 0, 2, 3, 4, 4, 2, 4, 2, 1, 3, 4, 1, 4, 1, 0, 2, 4, 2,
       4, 0, 3, 1, 1, 1, 4, 4, 3, 0, 3, 4, 3, 1, 0, 2])

In [None]:
calculate_results(g, new_preds)

{'accuracy': 65.0,
 'f1': 0.6443738977072311,
 'precision': 0.6526307026307026,
 'recall': 0.65}

In [None]:
new_model.save('4_freq_model.h5')

In [None]:
new_new_model = tf.keras.models.load_model('4_freq_model.h5')

In [None]:
new_new_model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_200 (Conv2D)         (None, 1, 110, 16)        144       
                                                                 
 batch_normalization_200 (Ba  (None, 1, 110, 16)       64        
 tchNormalization)                                               
                                                                 
 activation_200 (Activation)  (None, 1, 110, 16)       0         
                                                                 
 dropout_200 (Dropout)       (None, 1, 110, 16)        0         
                                                                 
 conv2d_201 (Conv2D)         (None, 1, 101, 16)        2576      
                                                                 
 batch_normalization_201 (Ba  (None, 1, 101, 16)       64        
 tchNormalization)                                    

In [None]:
len(mcnn_training_data['s1']['label'])

100

In [None]:
labels = []
for i in range(0, len(mcnn_training_data['s1']['label']), 8):
  labels.append(np.argmax(mcnn_training_data['s1']['label'][i]))

labels

[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4]

In [None]:
results = calculate_results(labels, newer_preds)

In [None]:
results

{'accuracy': 100.0, 'f1': 1.0, 'precision': 1.0, 'recall': 1.0}

In [None]:
new_preds 

array([[2.5976470e-01, 3.4223524e-01, 6.8242834e-06, ..., 4.9288565e-04,
        5.0280016e-04, 4.3352794e-02],
       [3.9175245e-01, 4.5865629e-02, 8.3481958e-03, ..., 5.1062245e-02,
        4.8718777e-02, 1.3213304e-02],
       [2.5130564e-01, 3.8929675e-02, 3.7654120e-02, ..., 4.5531169e-02,
        3.8480416e-02, 4.2067297e-02],
       ...,
       [4.9853826e-01, 7.8318447e-02, 6.9522765e-04, ..., 3.9624986e-03,
        5.9457425e-02, 7.8338981e-03],
       [2.9208639e-01, 5.3282484e-02, 8.8120513e-03, ..., 1.1730568e-02,
        4.1607544e-02, 4.1292764e-02],
       [2.8679876e-02, 7.9060203e-01, 2.1538297e-07, ..., 2.7583601e-05,
        5.4401833e-05, 3.7048198e-02]], dtype=float32)