Working Memory Decoding
============================================

In [None]:
import mne
from sklearn.decomposition import FastICA, PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

from mne import create_info, EpochsArray
from mne.baseline import rescale
from mne.time_frequency import (tfr_multitaper, tfr_stockwell, tfr_morlet,
                                tfr_array_morlet)

import warnings
from mne.preprocessing import ICA
warnings.filterwarnings('ignore')
from mne import viz
from mne.channels import Layout
from mne.decoding import (SlidingEstimator, GeneralizingEstimator,
                          cross_val_multiscore, LinearModel, get_coef)
import numpy as np
import matplotlib.pyplot as plt

from mne import Epochs, find_events, create_info
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.time_frequency import AverageTFR

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder

from scipy import signal

In [None]:
%matplotlib

1 - IO
------------------------------------

In [None]:
#Reading fif file
input_fname = '/media/analogicalnexus/2568212B752CDB3B/MEG_Data/R2504_WMD-Filtered-raw.fif'
input_ica = '/media/analogicalnexus/2568212B752CDB3B/MEG_Data/R2504_ica_filtered.fif'
raw=mne.io.read_raw_fif(input_fname)
raw.load_data()
# raw.plot()


2 - Filtering - Band Pass filter (1-45 Hz)
------------------------------------

In [None]:
#Band pass filter
raw_filtered = raw.filter(l_freq=1, h_freq=45.0, fir_design='firwin')
# raw_40.plot()

In [None]:
#Noise Cancellation - Already done in the data collection step


3 - ICA
------------------------------------

In [None]:
#Reference - https://martinos.org/mne/dev/auto_tutorials/plot_ica_from_raw.html
#ICA parameters
n_components = 0.95  # if float, select n_components by explained variance of PCA
method = 'fastica'  # for comparison with EEGLAB try "extended-infomax" here
decim = 3  # we need sufficient statistics, not all time points -> saves time

# we will also set state of the random number generator - ICA is a
# non-deterministic algorithm, but we want to have the same decomposition
# and the same order of components each time this tutorial is run
random_state = 23
picks = mne.pick_types(raw_filtered.info, meg=True)

In [None]:
#Apply ICA and check for artifact's components 
ica = ICA(n_components=n_components, method=method, random_state=random_state)
# print(ica)
reject = dict(mag=5e-12, grad=4000e-13)
ica.fit(raw_filtered, picks=picks, decim=decim, reject=reject)
# print(ica)
# ica.plot_components()
# ica.plot_sources(raw_filtered, picks=range(0,ica.n_components_-1))

In [None]:
#Exclude components
ica.exclude += [1,2,9,12] #edit indices
ica.plot_overlay(raw_filtered, exclude=[1,2,9,12])
# ica.save('/media/analogicalnexus/2568212B752CDB3B/MEG_Data/R505_ica_filtered.fif')
ica.apply(raw_filtered)


4 - Epoching (Segmenting )
------------------------------------

In [None]:
# parameter setup for syllable decoding
# event_id_long = dict(nw1=173,nw3=175,w1=163,w3=165)
event_id_long = dict(r=187,nr=188)
# event_id_long = dict(s1=[173,163],s2=[174,164],s3=[175,165])
tmin = -0.2
tmax = 0.5
baseline = (None,None)
picks = mne.pick_types(raw_filtered.info, meg=True)

In [None]:
events = mne.find_events(raw_filtered)

In [None]:
epochs = mne.Epochs(raw_filtered, events, event_id_long, tmin, tmax, proj=False, picks=picks, baseline=baseline, decim = 2)

In [None]:
# epochs = mne.epochs.combine_event_ids(epochs,['w1','nw1'],{'s1':190})
# epochs = mne.epochs.combine_event_ids(epochs,['w3','nw3'],{'s2':191})
# epochs = mne.epochs.combine_event_ids(epochs,['nw3','w3'],{'s3':192})

In [None]:
# epochs.event_id

5 - Sensor space analysis
------------------------------------

In [None]:
# rhyme.plot(spatial_colors=True, gfp=True, ylim=dict(mag=[-300,300]))
# non_rhyme.plot(spatial_colors=True, gfp=True, ylim=dict(mag=[-300,300]))
# rhyme.plot_topomap(times=[.0, .17, .4],vmin=-300,vmax=300)
# non_rhyme.plot_topomap(times=[.0, .17, .4],vmin=-300,vmax=300)
# evoked_dict = dict() 
# evoked_dict['rhyme'] = rhyme
# evoked_dict['non_rhyme'] = non_rhyme
# colors=dict(rhyme="Crimson",non_rhyme="CornFlowerBlue") 
# mne.viz.plot_compare_evokeds(evoked_dict, colors=colors,
# picks=picks, gfp=True)

In [None]:

# la=[0,1,2,3,39,41,42,43,44,52,58,67,71,80,82,83,84,85,108,130,131,132,133,134,135,136,151]
# lp=[4,5,6,7,8,9,34,36,37,38,40,45,46,47,48,49,50,75,76,77,79,87,88,90,127,129,137]
# ra=[20,22,23,24,26,59,60,61,62,63,65,89,92,95,99,100,114,115,116,117,118,145,147,148,152,155]
# rp=[14,15,16,17,18,19,25,27,28,30,53,54,56,57,66,68,69,70,94,96,97,119,121,122,143,144]
# lh=[0,1,2,3,39,41,42,43,44,52,58,67,71,80,82,83,84,85,108,130,131,132,133,134,135,136,151,4,5,6,7,8,9,34,36,37,38,40,45,46,47,48,49,50,75,76,77,79,87,88,90,127,129,137]
# rh=[20,22,23,24,26,59,60,61,62,63,65,89,92,95,99,100,114,115,116,117,118,145,147,148,152,155, 14,15,16,17,18,19,25,27,28,30,53,54,56,57,66,68,69,70,94,96,97,119,121,122,143,144]
# mne.viz.plot_compare_evokeds(evoked_dict, colors=colors,
# picks=lh, gfp=True, ylim=dict(mag=[0,100]))

In [None]:
X = epochs.get_data()

In [None]:
X.shape

6 - Frequency domain analysis
------------------------------------

In [None]:
sfreq = 500
freqs = np.arange(8., 12., 1.)
vmin, vmax = -.3e-25, .3e-25  # Define our color limits.
n_cycles = freqs / 2.
time_bandwidth = 8.0  # Same time-smoothing as (1), 7 tapers.

#  signal.spectrogram(epochs.get_data(), sfreq)
X = epochs.get_data()
f,t,Sxx = (signal.spectrogram(X[0,0,:], fs=sfreq, nperseg=250, noverlap=240,nfft=500))
# S = []
S = np.zeros([X.shape[0], X.shape[1], Sxx.shape[0], Sxx.shape[1]], dtype = float)
for e in range(X.shape[0]):
    for c in range(X.shape[1]):
        f,t,S[e,c,:,:] = (signal.spectrogram(X[e,c,:], fs=sfreq, nperseg=250, noverlap=240, nfft=500))
#         S[e][c].append(Sxx)
#         print(S)
    

# power = tfr_multitaper(epochs, freqs=freqs, n_cycles=n_cycles,
#                          time_bandwidth=time_bandwidth, return_itc=False,average=False)
# power = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'],
#                          freqs=freqs, n_cycles=n_cycles)

# psds, freqs = psd_welch(raw, picks=picks, tmin=tmin, tmax=tmax,
#                         fmin=fmin, fmax=fmax)

# Baseline the output
# rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False)
# Plot results. Baseline correct based on first 100 ms.
# power.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax,
#            title='Sim: Less time smoothing, more frequency smoothing')

In [None]:
S.shape

7 - Decoding (MVPA)
------------------------------------

In [None]:
# init scores
trf_scores = np.zeros((S.shape[2], S.shape[3] ))

# Loop through each frequency range of interest
for i in range(S.shape[2]):

    X = S[:,:,i,:]  # MEG signals: n_epochs, freq, n_times
    y = epochs.events[:, 2]  # target: 1 or 3

    clf = make_pipeline(StandardScaler(), LogisticRegression())

    time_decod = SlidingEstimator(clf, n_jobs=1, scoring='roc_auc')

    scores = cross_val_multiscore(time_decod, X, y, cv=5, n_jobs=1)

    # Mean scores across cross-validation splits
    trf_scores[i,:] = np.mean(scores, axis=0)
#     scores = np.mean(scores, axis=0)
    
#     # Plot
#     fig, ax = plt.subplots()
#     ax.plot(epochs.times, scores, label='score')
#     ax.axhline(.5, color='k', linestyle='--', label='chance')
#     ax.set_xlabel('Times')
#     ax.set_ylabel('AUC')  # Area Under the Curve
#     ax.legend()
#     ax.axvline(.0, color='k', linestyle='-')
#     ax.set_title('Sensor space decoding')
#     plt.show()

#     # You can retrieve the spatial filters and spatial patterns if you explicitly
#     # use a LinearModel
#     clf = make_pipeline(StandardScaler(), LinearModel(LogisticRegression()))
#     time_decod = SlidingEstimator(clf, n_jobs=1, scoring='roc_auc')
#     time_decod.fit(X, y)

#     coef = get_coef(time_decod, 'patterns_', inverse_transform=True)
#     evoked = mne.EvokedArray(coef, epochs.info, tmin=epochs.times[0])
#     evoked.plot_joint(times=np.arange(0., .500, .100), title='patterns')
    

In [None]:
# chance = np.mean(y)  # set chance level to white in the plot
# trf_scores.plot([0], vmin=chance, title="Time-Frequency Decoding Scores",
#             cmap=plt.cm.Reds)


plt.imshow(trf_scores[:,:], cmap='hot',interpolation='nearest', aspect='auto')
plt.show()

In [None]:
# freqs

In [None]:
from sklearn.svm import SVC
from sklearn.model_selection import ShuffleSplit
from mne.decoding import CSP

n_components = 3
svc = SVC(C=1, kernel='linear')
csp = CSP(n_components=n_components, norm_trace=False)

cv=ShuffleSplit(n_splits=10, test_size=0.2, random_state=42)
trf_scores = np.zeros((S.shape[2], S.shape[3] ))

# Loop through each frequency range of interest
for i in range(20):

    scores = []
    X = S[:,:,i,:]  # MEG signals: n_epochs, freq, n_times
    labels = epochs.events[:, 2]  # target: 1 or 3
    
    for train_idx, test_idx in cv.split(labels):
        y_train, y_test = labels[train_idx], labels[test_idx]

        X_train = csp.fit_transform(X[train_idx][:][:][:], y_train)
        X_test = csp.transform(X[test_idx])

        # fit classifier
        svc.fit(X_train, y_train)

        scores.append(svc.score(X_test, y_test))

    # Printing the results
#     print(scores)
    class_balance = np.mean(labels == labels[0])
    class_balance = max(class_balance, 1. - class_balance)
    trf_scores[i,:] = (np.mean(scores))
    print("Classification accuracy: %f / Chance level: %f" % (np.mean(scores),
                                                          class_balance))
        # Plot
    fig, ax = plt.subplots()
    ax.plot(scores, label='score')
    ax.axhline(.5, color='k', linestyle='--', label='chance')
    ax.set_xlabel('Times')
    ax.set_ylabel('AUC')  # Area Under the Curve
    ax.legend()
    ax.axvline(.0, color='k', linestyle='-')
    ax.set_title('Sensor space decoding')
    plt.show()

#     # You can retrieve the spatial filters and spatial patterns if you explicitly
#     # use a LinearModel
#     clf = make_pipeline(StandardScaler(), LinearModel(LogisticRegression()))
#     time_decod = SlidingEstimator(clf, n_jobs=1, scoring='roc_auc')
#     time_decod.fit(X, y)

#     coef = get_coef(time_decod, 'patterns_', inverse_transform=True)
#     evoked = mne.EvokedArray(coef, epochs.info, tmin=epochs.times[0])
#     evoked.plot_joint(times=np.arange(0., .500, .100), title='patterns')
    

In [None]:
plt.imshow(trf_scores[:,:], cmap='hot',interpolation='nearest', aspect='auto')
plt.show()

In [None]:
trf_scores[:,:]

In [None]:
X_train.shape