In [1]:
from collections import OrderedDict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from mne import Epochs, find_events
from mne.decoding import Vectorizer

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import cross_val_score, StratifiedShuffleSplit

from pyriemann.estimation import ERPCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import MDM
from pyriemann.spatialfilters import Xdawn

from experiments import eventRelatedPotential
from dataset import brainflowDataset
from utils import plot_conditions

# P300

The P300 is a positive event-related potential (ERP) that occurs around 300ms after perceiving a novel or unexpected 
stimulus. It is most commonly elicited through 'oddball' experimental paradigms, where a certain subtype of stimulus is 
presented rarely amidst a background of another more common type of stimulus. Interestingly, the P300 is able to be 
elicited by multiple sensory modalities (e.g. visual, odditory, somatosensory). Thus, it is believed that the P300 may 
be a signature of higher level cognitive processing such as conscious attention.

## Set up the experiment

In [2]:
p300_exp = eventRelatedPotential(erp='p300')

## Initialize the EEG signal

In [3]:
# For tresting without connection
p300_exp.initialize_eeg(board_type='synthetic')

# For using the 8-channel Cyton board
#n170_exp.initialize_eeg(board_type='cyton')

# For using the 16-channel Cyton+Daisy combo
#n170_exp.initialize_eeg(board_type='daisy')

## Run Experiment

In [5]:
subject_name = 'test_subject'
duration = 10
trial_num = 0
p300_exp.run_trial(duration=duration,
                   subject=subject_name,
                   run=trial_num)

Beginning EEG Stream; Wait 5 seconds for signal to settle... 



## Load the Dataset

In [None]:
runs = [0, 1, 2]
dataset_p300 = brainflowDataset(erp='p300', subject=subject_name)
raw = dataset_p300.load_subject_to_raw(subject_name, runs, preprocess=False)

## Filter the data
The justification for filtering 1-16 Hz is taken from... [**Find reference from Riemannian Geometric Classifier paper**]

In [None]:
raw.filter(1, 16, method='iir')

## Epoch the data

In [None]:
events = find_events(raw)
event_id = {'Non-Target': 1, 'Target': 2}
epochs = Epochs(raw, events=events, event_id=event_id, 
                tmin=-0.1, tmax=0.8, baseline=None,
                reject={'eeg': 100e-6}, preload=True, 
                verbose=False)
print('sample drop %: ', (1 - len(epochs.events)/len(events)) * 100)

## Analyze data

#### Epoch Averages

In [None]:
%matplotlib inline
conditions = OrderedDict()
conditions['Non-target'] = [1]
conditions['Target'] = [2]

fig, ax = plot_conditions(epochs, conditions=conditions, 
                                ci=97.5, n_boot=1000, title='',
                                diff_waveform=(1, 2))

### Classify

In [None]:
clfs = OrderedDict()

clfs['Vect + LR'] = make_pipeline(Vectorizer(), StandardScaler(), LogisticRegression())
clfs['Vect + RegLDA'] = make_pipeline(Vectorizer(), LDA(shrinkage='auto', solver='eigen'))
clfs['Xdawn + RegLDA'] = make_pipeline(Xdawn(2, classes=[1]), Vectorizer(), LDA(shrinkage='auto', solver='eigen'))
clfs['ERPCov + TS'] = make_pipeline(ERPCovariances(), TangentSpace(), LogisticRegression())
clfs['ERPCov + MDM'] = make_pipeline(ERPCovariances(), MDM())
clfs['ERPCov + RegLDA'] = make_pipeline(ERPCovariances(), LDA(shrinkage='auto', solver='eigen'))

# format data
epochs.pick_types(eeg=True)
X = epochs.get_data() * 1e6
times = epochs.times
y = epochs.events[:, -1]

# define cross validation 
cv = StratifiedShuffleSplit(n_splits=10, test_size=0.25, random_state=42)

# run cross validation for each pipeline
auc = []
methods = []
for m in clfs:
    res = cross_val_score(clfs[m], X, y==2, scoring='roc_auc', cv=cv, n_jobs=-1)
    auc.extend(res)
    methods.extend([m]*len(res))
    
results = pd.DataFrame(data=auc, columns=['AUC'])
results['Method'] = methods

plt.figure(figsize=[8,4])
sns.barplot(data=results, x='AUC', y='Method')
plt.xlim(0.2, 0.85)
sns.despine()