In [19]:
import mne
import numpy as np
import matplotlib.pyplot as plt
import scipy
import os
from collections import OrderedDict
import seaborn as sns
import pandas as pd
import gzip
from scipy.signal import filtfilt, butter
import pickle
from pyriemann.estimation import Covariances
from pyriemann.utils.mean import mean_riemann
from pyriemann.utils.distance import distance_riemann
from sklearn.model_selection import KFold


In [3]:
#Parameters : experimental setup
tmin=2
tmax=5
sfreq = 256
freq_band=0.1
frequencies= [13,17,21]
event_code = [33024,33025,33026,33027]
channels = np.array(['Oz','O1','O2','PO3','POz','PO7','PO8','PO4'])
names=['resting','stim13','stim21','stim17']

In [4]:
#loading data
def organize_data_into_dict(data_path= 'ssvep_dataset/'):
    subj_list = os.listdir(data_path)
    records = {k: [] for k in subj_list}
    for subj in subj_list:
        record_all = os.listdir(data_path+subj+'/')
        n = len(record_all)//4#number of records of a given subject
        for i in range(n):
            records[subj].append(record_all[i*4][:28])
    return records

def load_data(chosen_subject,chosen_index,records,data_path= 'ssvep_dataset/'):
    subj_list = list(records.keys())
    assert chosen_subject in subj_list,"The chosen subject doesn't exist in the dataset."
    assert chosen_index in range(len(records[chosen_subject])),"The chosen record doesn't exist for the subject "+chosen_subject

    fname = chosen_subject+'/'+records[chosen_subject][chosen_index]
    with gzip.open(data_path + fname + '.pz', 'rb') as f:
        o = pickle.load(f, encoding='latin1')
    raw_signal = o['raw_signal'].T
    event_pos = o['event_pos'].reshape((o['event_pos'].shape[0]))
    event_type = o['event_type'].reshape((o['event_type'].shape[0]))
    return raw_signal,event_pos,event_type

In [5]:
#The Butterworth filter : band-pass filter, flat in the passband , the passband is concentrated on 
def filter_bandpass(signal, fmin, fmax, fs, order=4, filttype='forward-backward'):
    nyq = 0.5 * fs
    low = fmin / nyq
    high = fmax / nyq
    b, a = butter(order, [low, high], btype='band')
    #filter tpe : forwaard-backward
    filtered = filtfilt(b, a, signal, axis=-1)  
    return filtered

In [6]:
def extended_trials(raw_signal,event_pos,event_type,sfreq=sfreq,frequencies=frequencies,tmin=tmin,tmax=tmax,
                    freq_band=freq_band,channels = channels):
    ext_signal = np.empty_like(raw_signal[0,:])    
    for f in frequencies:
        ext_signal = np.vstack((ext_signal, filter_bandpass(raw_signal, f-freq_band, f+freq_band, fs=sfreq)))
    ext_signal = ext_signal[1:,:]
    ext_trials = list()
    for e, t in zip(event_type, event_pos):
        if e == 32779: # start of a trial
            start = t + tmin*sfreq
            stop  = t + tmax*sfreq
            ext_trials.append(ext_signal[:, start:stop])
    ext_trials = np.array(ext_trials)
    ext_trials = ext_trials - np.tile(ext_trials.mean(axis=2).reshape(ext_trials.shape[0], 
                                ext_trials.shape[1], 1), (1, 1, ext_trials.shape[2]))
    return ext_trials
    

In [7]:
def make_labels(event_type,event_code=event_code):
    labels = []
    n_events = len(event_code)
    for e in event_type:
        for i in range(n_events):
            if e==event_code[i]:
                labels.append(i)
    return labels

In [8]:
#visualisation of the extended signals of a given trial
def visualisation_of_ext_trails(ext_trials,trial,n_seconds,frequencies=frequencies,sfreq=sfreq,
                               channels = channels):
    n_trials = ext_trials.shape[0]
    assert trial in range(n_trials),"The selected trial is out of range."
    assert n_seconds <=3, "The duration of a trial is 3s. Make sure to select less seconds for visualization"
    n_channels = len(channels)
    time = np.linspace(0, n_seconds, n_seconds * sfreq).reshape((1, n_seconds * sfreq))
    fig, axs = plt.subplots(4,2,figsize=(15,15))
    axs =axs.flatten()
    for i in range(n_channels):
        for j in range(len(frequencies)):
            axs[i].plot(time.T, ext_trials[trial, n_channels*j+i, :].T, label=str(frequencies[j])+' Hz')
        if i%2==0:
            axs[i].set_ylabel("$\\mu$V")
        axs[i].set_title(channels[i])
        axs[i].legend(loc='upper left')
    axs[i-1].set_xlabel('Time (s)')
    axs[i].set_xlabel('Time (s)')

In [9]:
def covariances(ext_trials,estimator='scm'):
    cov_ext_trials = Covariances(estimator='scm').transform(ext_trials)
    return cov_ext_trials

In [10]:
def GeometricCenters(x_train,y_train,nb_classes=len(frequencies)+1):
    cov_centers = np.empty((nb_classes, x_train.shape[1], x_train.shape[1]))
    x_trains=[[] for i in range(nb_classes) ]
    for i in range(nb_classes):
        for j in range(x_train.shape[0]):
            if y_train[j]==i:
                x_trains[i].append(x_train[j,:,:])
    for i in range(nb_classes):
        x_trains[i]=np.asarray(x_trains[i])
    for i in range(nb_classes):
        cov_centers[i, :, :] = mean_riemann(x_trains[i])
    return cov_centers


In [11]:
def accuracy(x,y,cov_centers):
    nb_classes= cov_centers.shape[0]
    classes=list(range(cov_centers.shape[0]))
    accuracies = list()
    for sample, true_label in zip(x, y):
        dist = [distance_riemann(sample, cov_centers[m]) for m in range(nb_classes)]
        if classes[np.array(dist).argmin()] == true_label:
            accuracies.append(1)
        else: accuracies.append(0)
    accuracy_ = 100.*np.array(accuracies).sum()/len(y)
    return accuracy_


In [13]:
def describe_label(y,nb_classes=len(frequencies)+1):
    for i in range(nb_classes):
        count=0
        for z in y:
            if z==i:
                count +=1
        print("Occurence of class ",i," = ",count)

In [34]:
def classify_single_recording(chosen_subject,chosen_index,records,m=28,shuffling=True):
    raw_signal,event_pos,event_type = load_data(chosen_subject,chosen_index,records)
    ext_trials = extended_trials(raw_signal,event_pos,event_type)
    labels = make_labels(event_type)
    cov_ext_trials = covariances(ext_trials)
    new_cov_ext_trials = np.empty_like(cov_ext_trials)
    new_labels = []
    indx = list(range(cov_ext_trials.shape[0]))
    if shuffling:
        np.random.shuffle(indx)
            
    for i in range(len(indx)):
        new_cov_ext_trials[i,:,:] = cov_ext_trials[indx[i],:,:] 
        new_labels.append(labels[indx[i]])
    x_train = new_cov_ext_trials[:m,:,:]
    y_train = new_labels[:m]
    x_test= new_cov_ext_trials[m:,:,:]
    y_test= new_labels[m:]
    print("Occurences of classes in the train labels :")
    describe_label(y_train)
    cov_centers = GeometricCenters(x_train,y_train)
    train_accuracy = accuracy(x_train,y_train,cov_centers)
    test_accuracy = accuracy(x_test,y_test,cov_centers)
    print("Train Accuracy  = ",round(train_accuracy,2),"%")
    print("Test  Accuracy  = ",round(test_accuracy,2),"%")
    return train_accuracy,test_accuracy
  

In [55]:
def classify_single_recording_with_cross_val(chosen_subject,chosen_index,records,n_splits=8,shuffle=True):
    raw_signal,event_pos,event_type = load_data(chosen_subject,chosen_index,records)
    ext_trials = extended_trials(raw_signal,event_pos,event_type)
    labels = make_labels(event_type)
    cov_ext_trials = covariances(ext_trials)
    kf = KFold(n_splits=n_splits, shuffle=shuffle)
    train_accuracy, test_accuracy = [], []
    for train_index , test_index in kf.split(labels):
        #print(np.asarray(labels)[test_index])
        
        x_train,x_test,y_train,y_test = [],[],[],[]
        
        for i in train_index:
            x_train.append(cov_ext_trials[i,:,:])
            y_train.append(labels[i])

        for i in test_index:
            x_test.append(cov_ext_trials[i,:,:])
            y_test.append(labels[i])
        
        x_train = np.asarray(x_train)
        x_test  = np.asarray(x_test)
        
        cov_centers = GeometricCenters(x_train,y_train)
        
        train_accuracy.append(accuracy(x_train,y_train,cov_centers))
        test_accuracy.append(accuracy(x_test,y_test,cov_centers))
    
    print(train_accuracy)
    print(test_accuracy)
    
    train_accuracy = np.asarray(train_accuracy)
    test_accuracy = np.asarray(test_accuracy)
    
    print("Train Accuracy  = ",round(np.mean(train_accuracy),2),"%  +/- ", round(np.std(train_accuracy),2),"%")
    print("Test  Accuracy  = ",round(np.mean(test_accuracy),2),"%   +/- ", round(np.std(test_accuracy),2),"%")
    return np.mean(train_accuracy),np.mean(test_accuracy)
    

In [56]:
records = organize_data_into_dict()
key_list =[]
for subj in records.keys():
    for session in range(len(records[subj])):
        key_list.append((subj,records[subj][session]))
accuracies = { x: {'train accuracy':0,'test accuracy':0} for x in key_list}
for subj in records.keys():
    print("______________Subject : ",subj,"______________")
    for session in range(len(records[subj])):
        print("*****Record : ",records[subj][session],"*****")
        train,test = classify_single_recording(subj,session,records,m=24,shuffling=False)
        accuracies[(subj,records[subj][session])]['train accuracy'] = train
        accuracies[(subj,records[subj][session])]['test accuracy'] = test

______________Subject :  subject01 ______________
*****Record :  record-[2012.07.06-19.02.16] *****
Occurences of classes in the train labels :
Occurence of class  0  =  8
Occurence of class  1  =  5
Occurence of class  2  =  5
Occurence of class  3  =  6
Train Accuracy  =  100.0 %
Test  Accuracy  =  12.5 %
*****Record :  record-[2012.07.06-19.06.14] *****
Occurences of classes in the train labels :
Occurence of class  0  =  8
Occurence of class  1  =  5
Occurence of class  2  =  5
Occurence of class  3  =  6
Train Accuracy  =  100.0 %
Test  Accuracy  =  37.5 %
______________Subject :  subject02 ______________
*****Record :  record-[2012.07.19-17.36.23] *****
Occurences of classes in the train labels :
Occurence of class  0  =  8
Occurence of class  1  =  5
Occurence of class  2  =  5
Occurence of class  3  =  6
Train Accuracy  =  95.83 %
Test  Accuracy  =  25.0 %
*****Record :  record-[2012.07.19-17.41.14] *****
Occurences of classes in the train labels :
Occurence of class  0  =  8
O

In [57]:
pd.DataFrame(accuracies).T

Unnamed: 0,Unnamed: 1,train accuracy,test accuracy
subject01,record-[2012.07.06-19.02.16],100.0,12.5
subject01,record-[2012.07.06-19.06.14],100.0,37.5
subject02,record-[2012.07.19-17.36.23],95.833333,25.0
subject02,record-[2012.07.19-17.41.14],100.0,37.5
subject03,record-[2012.07.11-15.25.23],100.0,62.5
subject03,record-[2012.07.11-15.33.08],100.0,75.0
subject04,record-[2012.07.18-17.52.30],100.0,50.0
subject04,record-[2012.07.18-17.56.53],100.0,37.5
subject05,record-[2012.07.19-11.24.02],95.833333,75.0
subject05,record-[2012.07.19-11.28.18],91.666667,37.5


In [51]:
records = organize_data_into_dict()
key_list =[]
for subj in records.keys():
    for session in range(len(records[subj])):
        key_list.append((subj,records[subj][session]))
accuracies_ = { x: {'train accuracy':0,'test accuracy':0} for x in key_list}
for subj in records.keys():
    print("______________Subject : ",subj,"______________")
    for session in range(len(records[subj])):
        print("*****Record : ",records[subj][session],"*****")
        train,test = classify_single_recording_with_cross_val(subj,session,records,n_splits=8,shuffle=True)
        accuracies_[(subj,records[subj][session])]['train accuracy'] = train
        accuracies_[(subj,records[subj][session])]['test accuracy'] = test

______________Subject :  subject01 ______________
*****Record :  record-[2012.07.06-19.02.16] *****
[100.          92.85714286  92.85714286  89.28571429 100.
  89.28571429  92.85714286 100.        ]
[25. 50. 25. 50. 50. 50.  0. 25.]
Train Accuracy  =  94.64 %  +/-  4.37 %
Test  Accuracy  =  34.38 %   +/-  17.4 %
*****Record :  record-[2012.07.06-19.06.14] *****
[ 92.85714286  96.42857143  92.85714286 100.          89.28571429
  92.85714286 100.          92.85714286]
[ 0. 25. 50. 50. 50. 25. 50.  0.]
Train Accuracy  =  94.64 %  +/-  3.57 %
Test  Accuracy  =  31.25 %   +/-  20.73 %
______________Subject :  subject02 ______________
*****Record :  record-[2012.07.19-17.36.23] *****
[ 96.42857143  96.42857143 100.         100.          96.42857143
  92.85714286 100.         100.        ]
[ 0. 25.  0. 25. 25. 25. 25. 25.]
Train Accuracy  =  97.77 %  +/-  2.49 %
Test  Accuracy  =  18.75 %   +/-  10.83 %
*****Record :  record-[2012.07.19-17.41.14] *****
[ 92.85714286  89.28571429  92.85714286 

In [52]:
pd.DataFrame(accuracies_).T

Unnamed: 0,Unnamed: 1,train accuracy,test accuracy
subject01,record-[2012.07.06-19.02.16],94.642857,34.375
subject01,record-[2012.07.06-19.06.14],94.642857,31.25
subject02,record-[2012.07.19-17.36.23],97.767857,18.75
subject02,record-[2012.07.19-17.41.14],92.410714,18.75
subject03,record-[2012.07.11-15.25.23],100.0,43.75
subject03,record-[2012.07.11-15.33.08],99.553571,37.5
subject04,record-[2012.07.18-17.52.30],100.0,31.25
subject04,record-[2012.07.18-17.56.53],97.321429,34.375
subject05,record-[2012.07.19-11.24.02],97.767857,28.125
subject05,record-[2012.07.19-11.28.18],92.410714,40.625
