# Exploration on spatial filtering

This notebook is an exploration on spatial filtering. The goal is to understand how spatial filtering works and how it can be used to enhance images.

In [1]:
!pip install pyriemann

Note: you may need to restart the kernel to use updated packages.


**Exploration using techniques from BE_MI_Filled adapted to our data**

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from src.preprocessing.train import process_train_data


PATH_DATA_FOLDER = './data'
PATH_DATA_RAW_FOLDER = f'{PATH_DATA_FOLDER}/raw'
PATH_DATA_001_TRIAL1_FILE = f'{PATH_DATA_RAW_FOLDER}/DATA_001_Trial1.npy'


array_001_t1 = np.load(PATH_DATA_001_TRIAL1_FILE, allow_pickle=True)
side = 'G'
fs = 1024
low_freq = 10
high_freq = 50
epoch_start_time = -1
epoch_end_time = 1
window_size = 0.5
step_size = 0.1
pre_duration = 0.5
post_duration = 0

# TBD : Documentation MNE à lire : ICA, préprocessing, ... class MNE epochs, MNE
# regarder et utiliser train.py de Julien : ICA pour denoising + BSS à faire pour travailler sur les sources (pas les électrodes)
# xDawn en MNE : à checker


In [None]:
data_labels, meta_data = process_train_data(array_mple_t1, side, fs, low_freq, high_freq, window_size, 
                                   step_size, epoch_start_time, epoch_end_time, pre_duration, post_duration, verbose=True)

In [3]:
from mne import Epochs, find_events, pick_types

def load_epoch(raws, subject_nb, event_id, fmin = 7., fmax = 35.):
    """Function to load epoched data for a specified subject"""
    
    raw = raws[subject_nb]['0']['0']

    # Apply band-pass filter
    raw.filter(fmin, fmax, fir_design='firwin', skip_by_annotation='edge')

    # Get the event (left / right hand) by looking at the "stim" channel.
    events = find_events(raw, shortest_event=0, verbose=True)

    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')
    tmin, tmax = -1., 4.
    # Read epochs (train will be done only between 1 and 2s)
    # Testing will be done with a running classifier
    epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None, preload=True)
    labels = epochs.events[:, -1] - 1
    return epochs, labels

In [None]:
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler

pochs, labels = load_epoch(raws, event_id, 1, fmin=5., fmax = 50)
epoch_train = epochs.crop(tmin=1., tmax=2.)

# Convert from MNE object to numpy Nd-array
epochs_data_train = epochs.get_data()

# Assemble feature extractor 
cov = Covariances(estimator='scm')                                                      # HERE
ts = TangentSpace()                                                                     # HERE
ss = StandardScaler()                                                                   # HERE

# Assemble a classifier
rf = RandomForestClassifier()                                                           # HERE

# Use scikit-learn Pipeline
clf = Pipeline([('cov', cov), ('ts', ts), ('ss', ss), ('rf', rf)])                      # HERE

# Evaluate the resulting classifier using cross-validation
scores = cross_val_score(clf, epochs_data_train, labels, cv=10, n_jobs=1,verbose=False) # HERE
print('Mean score:', np.mean(scores))