# EEG based Brain-Computer Interface (BCI) using Visual Imagery 

## MSc Project for Computational Cognitive Neuroscience 2020/2021

A brief description of the dataset has been been included in Datasets/Description.txt

## Processing steps

1. Load Datasets
2. Exclude unwanted channels 
3. Apply a bandpass FIR filter using Hamming window to the raw signal (between 1Hz-40Hz) 
4. Create Epochs
5. ICA
6. Remove bad epochs
7. Time-Frequency Analysis
6. Temporal Decoding
7. Temporal Generalisation

### Import Libraries

In [None]:
%%capture libraries   
import sys
import os
!{sys.executable} -m pip install numpy
!{sys.executable} -m pip install mne
!{sys.executable} -m pip install mne-features
import numpy as np
import matplotlib 
import pathlib
import mne
from mne.io import read_raw_edf
from mne.preprocessing import ICA, create_eog_epochs, create_ecg_epochs,corrmap
from mne.time_frequency import tfr_morlet, psd_multitaper, psd_welch, tfr_stockwell
matplotlib.use('Qt5Agg') #allow interactive plots
from sklearn.preprocessing import StandardScaler
from mne.decoding import SlidingEstimator
from sklearn.pipeline import make_pipeline
from mne.decoding import Scaler, Vectorizer, cross_val_multiscore
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
from mne.decoding import GeneralizingEstimator, Scaler,cross_val_multiscore, LinearModel, get_coef, Vectorizer, CSP


### 1. Loading EEG Datasets

The code below defines a function to load all data files with extensions '.edf':

In [None]:
def load_data(path):
  
    '''
    Load the .edf datasets
    
    Parameters
    ----------
    path : the directory path where your data are stored
 
 
    Returns
    -------
    list_load_dataset : list of .edf files
    '''
    
    
    list_files = os.listdir(path=path) #set the directory path
    
    extension = '.edf'
    index = 0
    list_dataset = [] #create an empty list to store our .edf files
    for file in list_files: #for each file in our directory
        if extension in list_files[index]: #if the file's extension is equal to .edf
            list_dataset.append(list_files[index]) #add the file in list_dataset
        index += 1 

    list_load_dataset = []
    for n_file in range(0, len(list_dataset)): #for each .edf file in our list 
        dataset = read_raw_edf(list_dataset[n_file], preload=True) #load the file
        list_load_dataset.append(dataset)
        
    return list_load_dataset

In [None]:
#Apply the function

raw_datasets = load_data(os.getcwd())

In [None]:
print('Overall, we have',(len(raw_datasets)), 'experimental sessions')

To inspect the information available for a given dataset, run the following cell:

In [None]:
print(raw_datasets[0].info)

In [None]:
print('The shape of the following dataset:',(raw_datasets[0].get_data().shape), 'indicates that we have', raw_datasets[0].get_data().shape[0],
'channels and',raw_datasets[0].get_data().shape[1], 'timepoints')

Visually explore the first raw unfiltered dataset:

In [None]:
raw_datasets[0].plot()   

### 2. Exclude unused channels

The function below iterate thorugh each .edf file and remove the unwanted channels (i.e. the channels excluded from the include_channels list):

In [None]:
#Define the channels that you want to include in the analysis:

include_channels = ['AF3','F7','F3','FC5','T7','P7','O1','O2','P8','T8','FC6','F4','F8','AF4']
#reference_channels = ['CQ_CMS', 'CQ_DRL']

In [None]:
def excl_chan(data):

    ''' 
    This function exclude the channels we don't need from further analysis. If you want
    to add or remove some channels, modify the above list "include_channels".
    
    Parameters
    ----------
    data: our raw datasets
    
    Returns
    -------
    list_datasets: the list of datasets with unused channels removed
    
    '''
       
    list_datasets=[]     
    for n_file in range(0, len(data)):    
         for chan_name in data[n_file].ch_names: 
            if chan_name not in include_channels:
                data[n_file].drop_channels([chan_name])   
                list_datasets.append(data)
                
                                                   
    return list_datasets

In [None]:
#Apply the function

excl_chan(raw_datasets) 

Double-check the new list of channels:

In [None]:
print(raw_datasets[0].ch_names)

Inspect the Power Spectrum Density (PSD) of the first unfiltered dataset:

In [None]:
raw_datasets[0].plot_psd(average=False)

 ### 3. Filtering datasets

The following function iterates through each file and apply a Hamming windowed FIR band-pass filter (default):

In [None]:
 def filter_data(data):
    
    '''
    This function filter the raw datasets. Because we are interested in low frequencies, 
    in the alpha-beta frequency range, we can band-pass filter between 1Hz-40Hz.
    
    Parameters
    ----------
    data: our raw continuous unfiltered datasets
    
    Returns
    -------
    filtered_data: the datasets containing the filtered data.
    '''


    filtered_data=[]
    for file in range(0, len(data)):
        #data[file].notch_filter(freqs=(50), filter_length='auto', phase='zero') #uncomment to apply notch filter to remove power-noise
        #data[file].filter(0.16, None, fir_design='firwin') #high-pass filter to remove slow drifts
        data[file].filter(1., 40., fir_design='firwin')  #apply band-pass filter between 1 and 40 HZ, our freqs range of interest
        filtered_data.append(data[file])
        
    return filtered_data
        

In [None]:
#Apply the function:

filter_data(raw_datasets)

Inspect now the PSD of a filtered dataset:

In [None]:
raw_datasets[0].plot_psd(average=False)

In [None]:
#Function to plot the PSD for each session to visually inspect them:

#def plot_data(data):   #the input is our filtered dataset
#    for file in range(0, len(data)):
#        data[file].plot_psd(average=True)   


#plot_data(raw_datasets)  #call the function to plot 30 PSD

### 4. Create Epochs

The function below is used to epoch the continuous data to 9.498 second segments. The first and last 250ms have been removed from each epoch to avoid potential overlapping events in the epoched data. 
With 30 sessions we will have 300 epochs in total. 

In [None]:
def make_epochs(list_prep_dataset, duration):
    
    """ 
    This function extract the epochs object from the preprocessed list of datasets. The conditions are:
    Push = 1
    Relax = 0
    --> Note: This function should be called after the preprocessing, but before the ICA <--

    :param list_prep_dataset: a list containing the preprocessed datasets
    :param duration: the duration of each epochs (10 seconds)
    :return: epochs: the mne.Epochs object containing the epoched data.
    """
    
    #event_dict = {'Relax': 0, 'Push': 1}   
    
    list_epochs = []  
    for prep_dataset in list_prep_dataset:
        events = mne.make_fixed_length_events(prep_dataset, id=0, start=65.0, stop=165.0, duration=10.0) # make fixed-length events for each dataset in list_prep_dateset
        
        for n_events in range(0, len(events)):
            if n_events % 2 == 1: 
                events[n_events][2] = 1 
                
          
        # make epochs for each dataset
        epochs = mne.Epochs(prep_dataset, events, tmin=0.0, tmax=9.998, event_id=event_dict, baseline=(0, 0), preload=True)
        list_epochs.append(epochs)
                
        # combine epochs
        epochs = mne.concatenate_epochs(list_epochs) 

        # crop start and end of the epochs based on provided time reference
        epochs.crop(tmin=0.25, tmax=9.998 - 0.25) 
        #epochs.crop(tmin=, tmax=9.998 - 6.25)  #I tried also with: #(tmax=9.998 - 2.75) #(t_max=9.998 - 6.75) # (t_max=9.998 - 7.25)

        # Generate Standard montage (useful for ICA and TimeFrequency analyses)
        biosemi_montage = mne.channels.make_standard_montage('standard_1020')
        epochs.set_montage(biosemi_montage) 
        
    return epochs

In [None]:
#Apply the function

epoched_data=make_epochs(raw_datasets, 10)  
epoched_data_copy=make_epochs(raw_datasets, 10) #create a copy that can be used to compare before and after epochs rejections


In [None]:
np.shape(epoched_data)

Plot the PSD for the two conditions: Relax and Push

In [None]:
epoched_data['Relax'].plot_psd()
epoched_data['Push'].plot_psd()

Compute evoked responses by averaging each epochs' conditions and plot it:

In [None]:
epoched_data['Relax'].average().plot(titles='Relax Condition') #it will show a butterfly plot of each channel type
epoched_data['Push'].average().plot(titles='Push Condition')


### 5. Apply Independent Component Analysis (ICA) to remove artifacts

In [None]:
#Apply ICA to all epochs

picks = raw_datasets[0].info['ch_names'] #define the channels we want to include in the analysis
ica=ICA(n_components=14, method='fastica', max_iter=1000, random_state=89) #define the parameters
ica.fit(epoched_data,  picks = picks, reject = dict(eeg = 200e-6)) #apply ICA to epochs

In [None]:
ica.plot_components(picks=range(14), inst=epoched_data)  #plot the components

After inspecting each components individually, explore also their time course:

In [None]:
ica.plot_sources(epoched_data)

Identify the components to remove:

In [None]:
ica.exclude=[0,1,13] #exclude eye movements and heartbeat

In [None]:
ica.apply(epoched_data, exclude=ica.exclude) #exclude the two components

### 6. Reject Bad Epochs

Before rejecting epochs based on the peak-to-peak (PTP) amplitude, let's investigate on which temporal segment the 
PTP is most likely to exceed the threshold. 

In [None]:
#For example, let's explore the ninth epoch:
epo=epoched_data.get_data()[9] #try also epoch n 18,50,51,52,76..


ampl=[] #create an empty list to store all the PTP amplitudes. With 14 channels, we will have 140 values.
idx=[] #create a list to add the epochs' index  

window_size = 243

for n_ch in range(0,len(epo)): #for each channel 
    i = 0 #starting index
    numbers = epo[n_ch][:] #select all the timepoints of that channel
    while i < len(numbers) - window_size +1:       
        window_amplitude = np.max(numbers[i : i + window_size])-np.min(numbers[i : i + window_size]) #select your window of interest and compute PTP
        ampl.append(window_amplitude) #add the PTP value to 'ampl'
        idx.append(i)
        i += 243
        
        if i > len(numbers) - window_size +1:  #move to the next channel 
               n_ch += 1

print(len(ampl)) #check you have 140 PTP values (i.e. 10 PTP values for each channel)

In [None]:
#Now, define your PTP threshold 

thr=200e-6 #200uV
bad_segments=np.argwhere(np.array(ampl)>thr)
print(bad_segments) #check where the amplitude exceed the threshold and use these values to plot them

In [None]:
#Define the t_start and t_end points for plotting:

t_start=idx[26] #use the bad_segments value to retrieve the corresponding idx and plotting the segment
t_end=t_start+243 #define the end of the bad segment

In [None]:
#Plot them
plt.plot(epo[0,t_start:t_end]) #select the first channel
plt.xlabel('Samples') #243 samples correspond to 10 seconds
plt.ylabel('uV')
plt.show()


After having explored which individual segments exceed a given value, we can continue defining the rejection peak-to-peak (PTP) amplitude threshold to 200uV (this threshold will reject 10% of bad epochs) using MNE tools:

In [None]:
reject_criteria = dict(eeg=200e-6) #PTP threshold
flat_criteria = dict(eeg=1e-6) # 1 µV, minimum acceptable peak-to-peak amplitudes
epoched_data.drop_bad(reject=reject_criteria, flat=flat_criteria) 

print(epoched_data.get_data().shape) #print the new data shape
print(epoched_data.drop_log) #print the total number of epochs rejected

In [None]:
epoched_data.plot_drop_log() #plot the percentage of epochs rejected

### 7. Time-Frequencies Analysis (TFR) and Inter-Trial Coherence (ITC)

Compute time-frequency representations (TFRs) from our epoched data:
 

In [None]:
freqs = np.logspace(*np.log10([4, 30]), num=40) # define frequencies of interest (log-spaced) 
n_cycles = freqs / 2.  # different number of cycle per frequency


#Compute power and ITC for RELAX condition
power_r, itc_r = mne.time_frequency.tfr_morlet(epoched_data['Relax'], freqs=freqs, n_cycles=n_cycles, 
                                           use_fft=True, average=True,
                                           return_itc=True, decim=3, n_jobs=1)

#Compute power and ITC for PUSH condition
power_p, itc_p = mne.time_frequency.tfr_morlet(epoched_data['Push'], freqs=freqs, n_cycles=n_cycles, 
                                           use_fft=True, average=True,
                                           return_itc=True, decim=3, n_jobs=1)

Compute the Power average across epochs in the Alpha and Beta band frequencies for each condition:

In [None]:
########################## RELAX #############

#Alpha band

relax_pow_a = [] #store the power average for each channel  

for a_file in range(0, len(raw_datasets[0].ch_names)): #for each file in the range 0-14
    pow_ar = power_r.data[a_file][(power_r.freqs>=8) & (power_r.freqs<=12)] #select the alpha band freqs range
    pow_avg_ar = np.mean(pow_ar, axis=0) #compute the average
    relax_pow_a.append(pow_avg_ar)  
    
    
#Beta band

relax_pow_b = [] #store the power average for each channel  

for b_file in range(0, len(raw_datasets[0].ch_names)): #for each file in the range 0-14
    pow_br = power_r.data[b_file][(power_r.freqs>12) & (power_r.freqs<=30)] #select the beta band freqs range
    pow_avg_br = np.mean(pow_br, axis=0) #compute the average
    relax_pow_b.append(pow_avg_br)  
    
    
#Theta band

relax_pow_t = [] #store the power average for each channel  

for c_file in range(0, len(raw_datasets[0].ch_names)): #for each file in the range 0-14
    pow_tr = power_r.data[c_file][(power_r.freqs>=4) & (power_r.freqs<8)] #select the beta band freqs range
    pow_avg_tr = np.mean(pow_tr, axis=0) #compute the average
    relax_pow_t.append(pow_avg_tr)  
    

    
######################### PUSH ###########

#Alpha band

push_pow_a = []

for d_file in range(0, len(raw_datasets[0].ch_names)):
    pow_ap = power_p.data[d_file][(power_p.freqs>=8) & (power_p.freqs<=12)] 
    pow_avg_ap = np.mean(pow_ap, axis=0)
    push_pow_a.append(pow_avg_ap)
    
    
#Beta band

push_pow_b = []

for e_file in range(0, len(raw_datasets[0].ch_names)):
    pow_bp = power_p.data[e_file][(power_p.freqs>12) & (power_p.freqs<=30)] 
    pow_avg_bp = np.mean(pow_bp, axis=0)
    push_pow_b.append(pow_avg_bp)
    
    
#Theta band

push_pow_t = []

for f_file in range(0, len(raw_datasets[0].ch_names)):
    pow_tp = power_p.data[f_file][(power_p.freqs>=4) & (power_p.freqs<8)] 
    pow_avg_tp = np.mean(pow_tp, axis=0)
    push_pow_t.append(pow_avg_tp)

In [None]:
# Uncomment and run the lines below to double-check if the right frequency range is being selected:

#print(power_r.freqs[np.argwhere((power_r.freqs>=8) & (power_r.freqs<=12)).flatten()]) #alpha band
#print(power_r.freqs[np.argwhere((power_r.freqs>12) & (power_r.freqs<=30)).flatten()]) #beta band

Plot the power difference between Relax vs Push condition:

In [None]:
#['AF3','F7','F3','FC5','T7','P7','O1','O2','P8','T8','FC6','F4','F8','AF4'] #reminder of our channels idx

#Alpha Band

plt.plot(relax_pow_a[2]) 
plt.plot(push_pow_a[2]) 
plt.legend(["Relax", "Push"])
plt.xlabel('Time[s]')
plt.ylabel('Power[db]')
plt.title('Channel F3 - Alpha Band')
plt.show()

plt.plot(relax_pow_a[6]) 
plt.plot(push_pow_a[6]) 
plt.legend(["Relax", "Push"])
plt.xlabel('Time[s]')
plt.ylabel('Power[db]')
plt.title('Channel O1 - Alpha Band')
plt.show()


#Beta Band

plt.plot(relax_pow_b[2])  
plt.plot(push_pow_b[2]) 
plt.legend(["Relax", "Push"])
plt.xlabel('Time[s]')
plt.ylabel('Power[db]')
plt.title('Channel F3 - Beta Band')
plt.show()

plt.plot(relax_pow_b[6]) 
plt.plot(push_pow_b[6]) 
plt.legend(["Relax", "Push"])
plt.xlabel('Time[s]')
plt.ylabel('Power[db]')
plt.title('Channel O1 - Beta Band')
plt.show()

#Theta Band
plt.plot(relax_pow_t[2]) 
plt.plot(push_pow_t[2]) 
plt.legend(["Relax", "Push"])
plt.xlabel('Time[s]')
plt.ylabel('Power[db]')
plt.title('Channel F3 - Theta Band')
plt.show()



#### ITC

Compute ITC average for Alpha and Beta band for each channel and both conditions:

In [None]:
########################## RELAX #############

#Alpha band

relax_itc_a  = []

for a_file in range(0, len(raw_datasets[0].ch_names)):
    itc_alphar = itc_r.data[a_file][(itc_r.freqs>=8) & (itc_r.freqs<=12)] 
    itc_alphar_avg = np.mean(itc_alphar, axis=0)
    relax_itc_a.append(itc_alphar_avg)
    
    
#Beta band

relax_itc_b  = []

for b_file in range(0, len(raw_datasets[0].ch_names)):
    itc_betar = itc_r.data[b_file][(itc_r.freqs>12) & (itc_r.freqs<=30)] 
    itc_betar_avg = np.mean(itc_betar, axis=0)
    relax_itc_b.append(itc_betar_avg)
    

#Theta band

relax_itc_t  = []

for c_file in range(0, len(raw_datasets[0].ch_names)):
    itc_thetar = itc_r.data[c_file][(itc_r.freqs>=4) & (itc_r.freqs<8)] 
    itc_thetar_avg = np.mean(itc_thetar, axis=0)
    relax_itc_t.append(itc_thetar_avg)
    
    
########################## PUSH #############
    
#Alpha band


push_itc_a = []

for d_file in range(0, len(raw_datasets[0].ch_names)):
    itc_alphap=itc_p.data[d_file][(itc_p.freqs>=8) & (itc_p.freqs<=12)] 
    itc_alphap_avg=np.mean(itc_alphap, axis=0)
    push_itc_a.append(itc_alphap_avg)
    
    
    
#Beta band


push_itc_b = []

for e_file in range(0, len(raw_datasets[0].ch_names)):
    itc_betap=itc_p.data[e_file][(itc_p.freqs>12) & (itc_p.freqs<=30)] 
    itc_betap_avg=np.mean(itc_betap, axis=0)
    push_itc_b.append(itc_betap_avg)
    
    
    
    
#Theta band


push_itc_t = []

for f_file in range(0, len(raw_datasets[0].ch_names)):
    itc_thetap=itc_p.data[f_file][(itc_p.freqs>=4) & (itc_p.freqs<8)] 
    itc_thetap_avg=np.mean(itc_thetap, axis=0)
    push_itc_t.append(itc_thetap_avg)

Plot the ITC for a given channel, relax vs push:

In [None]:
#Alpha band

plt.plot(relax_itc_a[2]) 
plt.plot(push_itc_a[2]) 
plt.legend(["Relax", "Push"])
plt.title('Channel F3 - Alpha Band')
plt.show()


#Beta band

plt.plot(relax_itc_b[2]) 
plt.plot(push_itc_b[2]) 
plt.legend(["Relax", "Push"])
plt.title('Channel F3 - Beta Band')
plt.show()

Compare the ITC for the same condition across channels:

In [None]:
####### RELAX ##########   

#ALPHA band
    
plt.plot(relax_itc_a[3]) 
plt.plot(relax_itc_a[7]) 
plt.legend(["Relax FC5", "Relax O2"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Alpha Band')
plt.show()

plt.plot(relax_itc_a[2]) 
plt.plot(relax_itc_a[6]) 
plt.legend(["Relax F3", "Relax O1"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Alpha Band')
plt.show()


#BETA band
    
plt.plot(relax_itc_b[3]) 
plt.plot(relax_itc_b[7]) 
plt.legend(["Relax FC5", "Relax O2"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Beta Band')
plt.show()

plt.plot(relax_itc_b[2]) 
plt.plot(relax_itc_b[6]) 
plt.legend(["Relax F3", "Relax O1"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Beta Band')
plt.show()

#THETA band 

plt.plot(relax_itc_t[3]) 
plt.plot(relax_itc_t[5]) 
plt.legend(["Relax FC5", "Relax P7"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Theta Band')
plt.show()

plt.plot(relax_itc_t[3]) 
plt.plot(relax_itc_t[5]) 
plt.legend(["Relax F3", "Relax P7"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Theta Band')
plt.show()


####### PUSH ##########
#ALPHA band 
    
plt.plot(push_itc_a[3]) 
plt.plot(push_itc_a[7]) 
plt.legend(["Push FC5", "Push O2"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Alpha Band')
plt.show()

plt.plot(push_itc_a[2]) 
plt.plot(push_itc_a[6]) 
plt.legend(["Push F3", "Push O1"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Alpha Band')
plt.show()


#BETA band
    
plt.plot(push_itc_b[3]) 
plt.plot(push_itc_b[7]) 
plt.legend(["Relax FC5", "Relax O2"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Beta Band')
plt.show()

plt.plot(push_itc_b[2]) 
plt.plot(push_itc_b[6]) 
plt.legend(["Relax F3", "Relax O1"])
plt.xlabel('Time[s]')
plt.ylabel('ITC')
plt.title('ITC  - Beta Band')
plt.show()


#THETA band 

plt.plot(push_itc_t[2]) 
plt.plot(push_itc_t[6]) 
plt.legend(["Push F3", "Push O1"])
plt.title('ITC  - Theta Band')
plt.show()



## Multivariate Pattern Analysis (MVPA)

### 8. Decoding over time 
#### Fit the classifier at every single time point to see at which time points  it can discriminate between the two conditions (through Logistic Regression).

In [None]:
#Create X and y.

X = epoched_data.get_data()
y = epoched_data.events[:, 2]

# Classifier pipeline. 
clf = make_pipeline(StandardScaler(),
                    LogisticRegression())

scoring = 'roc_auc'
time_decoder = SlidingEstimator(clf, scoring=scoring, n_jobs=1, verbose=True) # The "sliding estimator" will train the classifier at each time point.

# Run cross-validation.
n_splits = 5
scores = cross_val_multiscore(time_decoder, X, y, cv=5, n_jobs=1)

# Mean scores across cross-validation splits, for each time point.
mean_scores = np.mean(scores, axis=0)

# Mean score across all time points.
mean_across_all_times = round(np.mean(scores), 3)
print(f'\n=> Mean CV score across all time points: {mean_across_all_times:.3f}')

Plot the result:

In [None]:
fig, ax = plt.subplots()

ax.axhline(0.5, color='k', linestyle='--', label='chance')  # AUC = 0.5
ax.axvline(0, color='k', linestyle='-')  # Mark time point zero.
ax.plot(epoched_data.times, mean_scores, label='score')

ax.set_xlabel('Time (s)')
ax.set_ylabel('Mean ROC AUC')
ax.legend()
ax.set_title('Relax vs Push')
fig.suptitle('Sensor Space Decoding')
plt.show()

### 9. Temporal generalization¶


Evaluate whether the model estimated at a particular time instant accurately predicts any other time instant. 


In [None]:

X = epoched_data.get_data()
y = epoched_data.events[:, 2]

# define the Temporal generalization object
time_gen = GeneralizingEstimator(clf, n_jobs=1, scoring='roc_auc',
                                 verbose=True)

scores = cross_val_multiscore(time_gen, X, y, cv=5, n_jobs=1)

# Mean scores across cross-validation splits
scores = np.mean(scores, axis=0)

In [None]:
# Plot the diagonal (it's exactly the same as the time-by-time decoding above)
fig, ax = plt.subplots()
ax.plot(epoched_data.times, np.diag(scores), label='score')
ax.axhline(.5, color='k', linestyle='--', label='chance')
ax.set_xlabel('Times')
ax.set_ylabel('AUC')
ax.legend()
ax.axvline(.0, color='k', linestyle='-')
ax.set_title('Decoding EEG sensors over time')
plt.show()

In [None]:
#Plot the full (generalization) matrix:


fig, ax = plt.subplots(1, 1)
im = ax.imshow(scores, interpolation='lanczos', origin='lower', cmap='RdBu_r',
               extent=epoched_data.times[[0, -1, 0, -1]], vmin=0., vmax=4.)
ax.set_xlabel('Testing Time (s)')
ax.set_ylabel('Training Time (s)')
ax.set_title('Temporal generalization')
ax.axvline(0, color='k')
ax.axhline(0, color='k')
plt.colorbar(im, ax=ax)
plt.show()

#### Part 2.

In [None]:

## Create a continuos epoch:

x=len(raw_datasets[0])
dur=x/256
epoch= mne.make_fixed_length_epochs(raw_datasets[0], duration=dur,  preload=True)
epoch.crop(tmin=65.0, tmax=165.0)
print(np.shape(epoch))

biosemi_montage = mne.channels.make_standard_montage('standard_1020')
epoch.set_montage(biosemi_montage)


#Apply ICA to the epoch

picks = raw_datasets[0].info['ch_names'] #define the channels we want to include in the analysis
ica=ICA(n_components=14, method='fastica', max_iter=1000, random_state=89) #define the parameters
ica.fit(epoch,  picks = picks, reject = dict(eeg = 200e-6)) #apply ICA to epochs

#ica.plot_components(picks=range(14), inst=epoch)  #plot the components

ica.exclude=[0,1,13] #exclude eye movements and heartbeat

ica.apply(epoch, exclude=ica.exclude) #exclude the two components

#Reject bad epochs:

reject_criteria = dict(eeg=200e-6) #PTP threshold
flat_criteria = dict(eeg=1e-6) # 1 µV, minimum acceptable peak-to-peak amplitudes
epoch.drop_bad(reject=reject_criteria, flat=flat_criteria) 

print(epoch.get_data().shape) #print the new data shape
print(epoch.drop_log) #print the total number of epochs rejected

#TFR

freqs = np.logspace(*np.log10([4, 30]), num=40) # define frequencies of interest (log-spaced) 
n_cycles = freqs / 2.  # different number of cycle per frequency


#Compute power  
power = mne.time_frequency.tfr_morlet(epoch, freqs=freqs, n_cycles=n_cycles, 
                                           use_fft=True, average=False,
                                           return_itc=False, decim=3, n_jobs=1)



#average
epo_pow = [] #store the power average for each channel  

for chan in range(0, len(raw_datasets[0].ch_names)):
    epo_p = (power.data[0][chan][(power.freqs>=8) & (power.freqs<=12)])
    epo_avg = np.mean(epo_p, axis=0) #compute the average
    epo_pow.append(epo_avg) 

    
#Plot

plt.plot(epo_pow[6]) 
#plt.xlabel('Time[s]')
#plt.ylabel('Power[db]')
#plt.title('Channel O1  - Alpha Band')
plt.show()

#np.shape(power.data[0][0][(power.freqs>=8) & (power.freqs<=12)])