# 12-Class SSVEP EEG Dataset - Classification Using Convolutional Neural Network
User-Dependent Training using Magnitude Spectrum Features and Complex Spectrum Features
(10-Fold Cross-validation)

Following implementation is an asynchronous SSVEP BCI using Convolutional Neural Network classification for 1 second data length.

Reference Paper: [Comparing user-dependent and user-independent training of CNN for SSVEP BCI](https://iopscience.iop.org/article/10.1088/1741-2552/ab6a67)


In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
%config Completer.use_jedi = False

In [3]:
%%capture
import warnings
import numpy as np
import numpy.matlib as npm
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio
import pickle
from sklearn.model_selection import KFold

from keras.utils.np_utils import to_categorical
from keras import optimizers
from keras.losses import categorical_crossentropy

from bcilib import ssvep_utils as su
warnings.filterwarnings('ignore')

In [4]:
def get_training_data(features_data):
    features_data = np.reshape(features_data, (features_data.shape[0], features_data.shape[1], 
                                               features_data.shape[2], 
                                               features_data.shape[3]*features_data.shape[4]))
    train_data = features_data[:, :, 0, :].T
    for target in range(1, features_data.shape[2]):
        train_data = np.vstack([train_data, np.squeeze(features_data[:, :, target, :]).T])

    train_data = np.reshape(train_data, (train_data.shape[0], train_data.shape[1], 
                                         train_data.shape[2], 1))
    total_epochs_per_class = features_data.shape[3]
    features_data = []
    class_labels = np.arange(CNN_PARAMS['num_classes'])
    labels = (npm.repmat(class_labels, total_epochs_per_class, 1).T).ravel()
    labels = to_categorical(labels)
    
    return train_data, labels

In [5]:
def train_CNN_cross_val_predict(train_data, labels, num_folds=10):
    train_data_file = open("./train_data.pickle", "wb")
    labels_file = open("./labels.pickle", "wb")
    
    pickle.dump(train_data, train_data_file)
    pickle.dump(labels, labels_file)
    
    train_data_file.close()
    labels_file.close()
    
    exit(0)
    
    kf = KFold(n_splits=num_folds, shuffle=True)
    kf.get_n_splits(train_data)
    cv_acc = np.zeros((num_folds, 1))
    fold = -1

    for train_index, test_index in kf.split(train_data):
        x_tr, x_ts = train_data[train_index], train_data[test_index]
        y_tr, y_ts = labels[train_index], labels[test_index]
        input_shape = np.array([x_tr.shape[1], x_tr.shape[2], x_tr.shape[3]])
        
        fold = fold + 1
        
        model = su.CNN_model(input_shape, CNN_PARAMS)
        
        sgd = optimizers.SGD(lr=CNN_PARAMS['learning_rate'], decay=CNN_PARAMS['lr_decay'], 
                             momentum=CNN_PARAMS['momentum'], nesterov=False)
        model.compile(loss=categorical_crossentropy, optimizer=sgd, metrics=["accuracy"])
        history = model.fit(x_tr, y_tr, batch_size=CNN_PARAMS['batch_size'], 
                            epochs=CNN_PARAMS['epochs'], verbose=0)

        score = model.evaluate(x_ts, y_ts, verbose=0) 
        cv_acc[fold, :] = score[1]*100
        print(f'cv{fold+1}:{score[1]*100:.2f}%', end=" ")
    
    return cv_acc

In [13]:
data_path = os.path.abspath('data')

CNN_PARAMS = {
    'batch_size': 64,
    'epochs': 50,
    'droprate': 0.25,
    'learning_rate': 0.001,
    'lr_decay': 0.0,
    'l2_lambda': 0.0001,
    'momentum': 0.9,
    'kernel_f': 10,
    'n_ch': 8,
    'num_classes': 12}

FFT_PARAMS = {
    'resolution': 0.2930,
    'start_frequency': 3.0,
    'end_frequency': 35.0,
    'sampling_rate': 256
}

window_len = 1
shift_len = 1
    
all_acc = np.zeros((10, 1))

magnitude_spectrum_features = dict()
complex_spectrum_features = dict()

In [14]:
mcnn_training_data = dict()
ccnn_training_data = dict()

In [15]:
mcnn_results = dict()
ccnn_results = dict()

# Load Dataset and Segment

In [16]:
all_segmented_data = dict()
for subject in range(0, 10):
    dataset = sio.loadmat(f'{data_path}/s{subject+1}.mat')
    eeg = np.array(dataset['eeg'], dtype='float32')
    
    CNN_PARAMS['num_classes'] = eeg.shape[0]
    CNN_PARAMS['n_ch'] = eeg.shape[1]
    total_trial_len = eeg.shape[2]
    num_trials = eeg.shape[3]
    sample_rate = 256

    filtered_data = su.get_filtered_eeg(eeg, 6, 80, 4, sample_rate)
    all_segmented_data[f's{subject+1}'] = su.get_segmented_epochs(filtered_data, window_len, 
                                                                  shift_len, sample_rate)

# Feature Extraction

In [10]:
for subject in all_segmented_data.keys():
    magnitude_spectrum_features[subject] = su.magnitude_spectrum_features(all_segmented_data[subject], 
                                                                          FFT_PARAMS)
    complex_spectrum_features[subject] = su.complex_spectrum_features(all_segmented_data[subject], 
                                                                      FFT_PARAMS)

In [11]:
for subject in all_segmented_data.keys():
    mcnn_training_data[subject] = dict()
    ccnn_training_data[subject] = dict()
    train_data, labels = get_training_data(magnitude_spectrum_features[subject])
    mcnn_training_data[subject]['train_data'] = train_data
    mcnn_training_data[subject]['label'] = labels
    
    train_data, labels = get_training_data(complex_spectrum_features[subject])
    ccnn_training_data[subject]['train_data'] = train_data
    ccnn_training_data[subject]['label'] = labels

# M-CNN Training and Results

In [12]:
for subject in mcnn_training_data.keys():
    print(f'\nMCNN - Subject: {subject}')
    train_data = mcnn_training_data[subject]['train_data']
    labels = mcnn_training_data[subject]['label']
    
    cv_acc = train_CNN_cross_val_predict(train_data, labels, 10)
    mcnn_results[subject] = np.mean(cv_acc)
    print(f'\nAccuracy: {mcnn_results[subject]:2f}%')

mcnn_overall_accuracy = np.mean(np.fromiter(mcnn_results.values(), dtype=float))    
print(f'Overall Accuracy MCNN - {mcnn_overall_accuracy:.2f}%')


MCNN - Subject: s1
cv1:61.11% cv2:73.61% cv3:70.83% cv4:72.22% cv5:55.56% cv6:66.67% cv7:65.28% cv8:66.67% cv9:63.89% cv10:59.72% 
Accuracy: 65.555556%

MCNN - Subject: s2
cv1:30.56% cv2:40.28% cv3:40.28% cv4:41.67% cv5:30.56% cv6:33.33% cv7:48.61% cv8:29.17% cv9:37.50% cv10:31.94% 
Accuracy: 36.388889%

MCNN - Subject: s3
cv1:83.33% cv2:81.94% cv3:87.50% cv4:73.61% cv5:80.56% cv6:79.17% cv7:77.78% cv8:80.56% cv9:80.56% cv10:81.94% 
Accuracy: 80.694445%

MCNN - Subject: s4
cv1:91.67% cv2:93.06% cv3:87.50% cv4:87.50% cv5:88.89% cv6:94.44% cv7:90.28% cv8:93.06% cv9:91.67% cv10:93.06% 
Accuracy: 91.111112%

MCNN - Subject: s5
cv1:91.67% cv2:97.22% cv3:95.83% cv4:97.22% cv5:94.44% cv6:93.06% cv7:91.67% cv8:98.61% cv9:95.83% cv10:98.61% 
Accuracy: 95.416666%

MCNN - Subject: s6
cv1:95.83% cv2:98.61% cv3:97.22% cv4:95.83% cv5:94.44% cv6:94.44% cv7:97.22% cv8:91.67% cv9:98.61% cv10:98.61% 
Accuracy: 96.249999%

MCNN - Subject: s7
cv1:84.72% cv2:88.89% cv3:86.11% cv4:88.89% cv5:91.67% cv6:91.

# C-CNN Training and Results

In [1]:
for subject in ccnn_training_data.keys():
    print(f'\nCCNN - Subject: {subject}')
    train_data = ccnn_training_data[subject]['train_data']
    labels = ccnn_training_data[subject]['label']
    
    cv_acc = train_CNN_cross_val_predict(train_data, labels, 10)
    ccnn_results[subject] = np.mean(cv_acc)
    print(f'\nAccuracy: {ccnn_results[subject]:.2f}%')

ccnn_overall_accuracy = np.mean(np.fromiter(ccnn_results.values(), dtype=float))    
print(f'Overall Accuracy CCNN - {ccnn_overall_accuracy:.2f}%')

NameError: name 'ccnn_training_data' is not defined

# Summary

In [None]:
results = pd.DataFrame({'mcnn': np.fromiter(mcnn_results.values(), dtype=float), 
              'ccnn': np.fromiter(ccnn_results.values(), dtype=float)})

In [None]:
print(results)

In [None]:
results.boxplot(figsize=(12, 4), column=['mcnn', 'ccnn'])
plt.ylabel('Accuracy')
plt.show()

In [None]:
results.plot.bar(figsize=(12, 4), title='Comparing User-Dependent Training of MCNN and CCCN')
plt.xlabel('Subject')
plt.ylabel('Accuracy')
plt.show()