In [1]:
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline

from mne.decoding import CSP
from mne.channels import read_layout
from loaddatamat import *

print(__doc__)

Automatically created module for IPython interactive environment


In [2]:
fp = {
    'al': '../DATA/Competencia BCI/III/Dataset_IV/mat/data_set_IVa_al.mat',
}

pick_chan = {
    'al': ['C3', 'Cz', 'C5'],
}

low_freq, high_freq = 7., 30.
tmin, tmax = 0., 3.5

# event_id
event_id = {'right': 1, 'foot': 2}

In [3]:
#Se carga set de datos train
raw = creatRawArray(fp['al'])
events, labels = creatEventsArray(fp['al'])

#raw.plot(block=True, scalings='auto', n_channels=1)

# Se aplica filtros band-pass
raw.filter(low_freq, high_freq, fir_design='firwin', skip_by_annotation='edge')

#raw.plot(block=True, scalings='auto', n_channels=1)
print(events)

Creating RawArray with float64 data, n_channels=118, n_times=283574
    Range : 0 ... 283573 =      0.000 ...  2835.730 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 165 samples (1.650 sec)

[[  5171      0      1]
 [  5748      0      1]
 [  7385      0      2]
 [  9083      0      1]
 [  9623      0      1]
 [ 10183      0      1]
 [ 10763      0      2]
 [ 12448      0      2]
 [ 13006      0      1]
 [ 13547      0      2]
 [ 14649      0      2]
 [ 15221      0      1]
 [ 15774      0   

In [4]:
# event_train = eventsTrain(fp[f])
epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True,
                    verbose=False)

print(events)
epochs_train = epochs.copy().crop(tmin=1., tmax=2.)
labels = epochs.events[:, -1] - 2
#print(epochs)
#print(epochs_train)
#print(len(epochs_train.events))


#print(len(epochs.events))


[[  5171      0      1]
 [  5748      0      1]
 [  7385      0      2]
 [  9083      0      1]
 [  9623      0      1]
 [ 10183      0      1]
 [ 10763      0      2]
 [ 12448      0      2]
 [ 13006      0      1]
 [ 13547      0      2]
 [ 14649      0      2]
 [ 15221      0      1]
 [ 15774      0      2]
 [ 16873      0      2]
 [ 17450      0      1]
 [ 18561      0      2]
 [ 21816      0      1]
 [ 22364      0      2]
 [ 25122      0      2]
 [ 25662      0      1]
 [ 27320      0      1]
 [ 27855      0      1]
 [ 28425      0      1]
 [ 28987      0      2]
 [ 29534      0      2]
 [ 30087      0      1]
 [ 30621      0      2]
 [ 31201      0      2]
 [ 31779      0      1]
 [ 32327      0      2]
 [ 32889      0      2]
 [ 33995      0      2]
 [ 34544      0      1]
 [ 37220      0      1]
 [ 37796      0      1]
 [ 38358      0      2]
 [ 38941      0      1]
 [ 40055      0      2]
 [ 40632      0      1]
 [ 42884      0      1]
 [ 43993      0      1]
 [ 44552      0 

In [5]:
# Define a monte-carlo cross-validation generator (reduce variance):
scores = []
epochs_data = epochs.get_data()
epochs_data_train = epochs_train.get_data()

#divide el set en train 
cv = ShuffleSplit(5, test_size=0.2, random_state=42)
cv_split = cv.split(epochs_data_train)

# Assemble a classifier
lda = LinearDiscriminantAnalysis()
csp = CSP(n_components=len(epochs.ch_names), reg=None, log=True, norm_trace=False)

# Use scikit-learn Pipeline with cross_val_score function
clf = Pipeline([('CSP', csp), ('LDA', lda)])
scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=1, verbose=False)

Computing rank from data with rank=None
    Using tolerance 1.8e+03 (2.2e-16 eps * 118 dim * 6.9e+16  max singular value)
    Estimated rank (mag): 118
    MAG: rank 118 computed from 118 data channels with 0 projectors
Reducing data rank from 118 -> 118
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.8e+03 (2.2e-16 eps * 118 dim * 6.9e+16  max singular value)
    Estimated rank (mag): 118
    MAG: rank 118 computed from 118 data channels with 0 projectors
Reducing data rank from 118 -> 118
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.8e+03 (2.2e-16 eps * 118 dim * 7.1e+16  max singular value)
    Estimated rank (mag): 118
    MAG: rank 118 computed from 118 data channels with 0 projectors
Reducing data rank from 118 -> 118
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.8e+03 (2.2e-16 eps * 118 dim * 6.9e+16  m

In [6]:
# Printing the results
print(scores)
class_balance = np.mean(labels == labels[0])
print(class_balance)
class_balance = max(class_balance, 1. - class_balance)
print("Classification accuracy: %f / Chance level: %f" % (np.mean(scores), class_balance))

# plot CSP patterns estimated on full data for visualization
csp.fit_transform(epochs_data, labels)

[0.86666667 0.86666667 0.88888889 0.82222222 0.95555556]
0.5
Classification accuracy: 0.880000 / Chance level: 0.500000
Computing rank from data with rank=None
    Using tolerance 3.6e+03 (2.2e-16 eps * 118 dim * 1.4e+17  max singular value)
    Estimated rank (mag): 118
    MAG: rank 118 computed from 118 data channels with 0 projectors
Reducing data rank from 118 -> 118
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 3.5e+03 (2.2e-16 eps * 118 dim * 1.3e+17  max singular value)
    Estimated rank (mag): 118
    MAG: rank 118 computed from 118 data channels with 0 projectors
Reducing data rank from 118 -> 118
Estimating covariance using EMPIRICAL
Done.


array([[-2.65046877, -2.13129238, -1.90507502, ..., -0.75593205,
        -0.70342905, -0.92328317],
       [-1.74802089, -2.07887   , -1.73949991, ..., -0.63707919,
        -0.49402031, -0.8406706 ],
       [-0.26542195, -0.60702867, -0.00924283, ..., -0.75505953,
        -0.49000489, -0.83445092],
       ...,
       [-2.47295022, -1.77263681, -1.78428844, ..., -0.65181098,
        -0.84794432, -0.81582713],
       [-2.21593186, -1.98774539, -1.54457096, ..., -0.49382871,
        -0.58332832, -0.59852983],
       [-0.21388883, -0.10332413, -0.36824982, ..., -0.70319295,
        -0.6511371 , -0.43434828]])