In [104]:
%matplotlib inline
import pandas as pd
import numpy as np
from scipy.signal import correlate
from scipy.stats import zscore, norm, ttest_rel
from statsmodels.sandbox.stats.multicomp import multipletests
import matplotlib.pyplot as plt
import seaborn as sns
#sns.set(style="whitegrid", color_codes=True)
import mne
import itertools
from mne.connectivity import spectral_connectivity

In [2]:
event_id = dict(left_hand=1, right_hand=2, feet=3, tongue=4)  # event ID

In [3]:
fs = 250 # Sampling frequency
info = mne.create_info(["ch"+str(i) for i in range(25)], fs) #  Channel labels and sampling freq.

In [4]:
subjects = ['A01T', 'A02T', 'A03T', 'A04T', 'A05T', 'A06T', 'A07T', 'A08T', 'A09T']
epochs = [0] * len(subjects)

In [5]:
for (i, subject) in enumerate(subjects):
    events = pd.read_csv('./datasets/events_'+subject+'.csv', header=None)
    events = np.array(events.astype(int))
    eegs = pd.read_csv('./datasets/eeg_'+subject+'.csv', header=None)
    raw = mne.io.array.RawArray(eegs, info)
    epochs[i] = mne.Epochs(raw, events, event_id, tmin=0.0, tmax=6.0, baseline=(None,0), preload=True, add_eeg_ref=False)

Creating RawArray with float64 data, n_channels=25, n_times=677145
    Range : 0 ... 677144 =      0.000 ...  2708.576 secs
Ready.
336 matching events found
0 projection items activated
Loading data for 336 events and 1501 original time points ...
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=25, n_times=677145
    Range : 0 ... 677144 =      0.000 ...  2708.576 secs
Ready.
336 matching events found
0 projection items activated
Loading data for 336 events and 1501 original time points ...
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=25, n_times=677145
    Range : 0 ... 677144 =      0.000 ...  2708.576 secs
Ready.
336 matching events found
0 projection items activated
Loading data for 336 events and 1501 original time points ...
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=25, n_times=677145
    Range : 0 ... 677144 =      0.000 ...  2708.576 secs
Ready.
336 matching events found
0 projection items activated
Loading

In [6]:
def calc_mean_coef_z(epochs, event, pre, ch1, ch2):
    if(pre):
        signal = epochs[event].get_data()[:,[ch1,ch2],0:500]
    else:
        signal = epochs[event].get_data()[:,[ch1,ch2],750:1250]
    trial_num = signal.shape[0]
    coefs = [max(correlate(signal[trial,0,:], signal[trial,1,:])) / (np.linalg.norm(signal[trial,0,:])*np.linalg.norm(signal[trial,1,:])) for trial in range(trial_num)]
    mean_r = np.array(coefs).mean()
    # Fisher's z transformation
    z_score = np.arctanh(mean_r)
    return z_score

In [7]:
def make_comparable_data(epochs, event, ch1, ch2):
    data = {"pre":[], "post":[]}
    for i in range(len(epochs)):
        data["pre"].append(calc_mean_coef_z(epochs[i], event, True, ch1, ch2))
        data["post"].append(calc_mean_coef_z(epochs[i], event, False, ch1, ch2))
        title = "ch{0}-ch{1}".format(ch1, ch2)
    df = pd.DataFrame(data)
    return df, event, title

In [8]:
def df2sns(df):
    data = {"z_score":[], "condition":[]}
    for ix, row in df.iterrows():
        data['z_score'].append(row.pre)
        data['z_score'].append(row.post)
        data['condition'].append('pre')
        data['condition'].append('post')
    return pd.DataFrame(data)

In [9]:
def srs2sns(series):
    data = {"z_score":[], "condition":['pre', 'post']}
    data['z_score'].append(series.pre)
    data['z_score'].append(series.post)
    return pd.DataFrame(data)

In [10]:
def test_corr(epochs, event, ch1, ch2):
    df, eventname, title = make_comparable_data(epochs, event, ch1, ch2)
    t, p = ttest_rel(df.pre, df.post)
    return p

In [14]:
results = {}
ch_num = 22
for event in event_id.keys():
    for ch1,ch2 in itertools.combinations(range(ch_num),2):
        p = test_corr(epochs, event, ch1, ch2)
        results["{0}_ch{1}-ch{2}".format(event, ch1, ch2)] = p

In [17]:
reject, q_values,_,_ = multipletests(list(results.values()), alpha=0.05, method="fdr_tsbky")

In [18]:
for i, (condition, pvals) in enumerate(results.items()):
    if(reject[i]):
        print("{0} q={1}".format(condition, q_values[i]))

tongue_ch3-ch12 q=0.01575791878515035
tongue_ch9-ch11 q=0.010141435926614357
tongue_ch11-ch17 q=0.04700949734524107
feet_ch18-ch20 q=0.046884319908835494
tongue_ch10-ch17 q=0.0499466755690663
right_hand_ch8-ch12 q=0.017011884151979073
tongue_ch10-ch16 q=0.018741140610515447
tongue_ch11-ch13 q=0.009283014024122557
right_hand_ch2-ch12 q=0.019092403884888927
tongue_ch6-ch10 q=0.03109633965739195
left_hand_ch6-ch7 q=0.028534597804323825
tongue_ch2-ch6 q=0.043375266062629446
tongue_ch7-ch9 q=0.007311250013262306
left_hand_ch6-ch13 q=0.007311250013262306
tongue_ch5-ch9 q=0.005151785444824716
tongue_ch8-ch10 q=0.003791103534632635
right_hand_ch12-ch15 q=0.013993598584167554
tongue_ch11-ch15 q=0.006905619259277276
tongue_ch5-ch12 q=0.03700834913123362
right_hand_ch2-ch3 q=0.045877624127536105
tongue_ch15-ch16 q=0.006083865907936958
tongue_ch3-ch14 q=0.009283014024122557
tongue_ch17-ch19 q=0.021531580683618795
tongue_ch0-ch7 q=0.012703751423425105
tongue_ch9-ch17 q=0.01974120502896505
tongue_ch

分類器の作成
+ 教師データ　tongueかそうじゃないか
+ 入力データ　pre-postでのチャンネル同士のコネクティビティの変化(22*21 / 2 次元)

In [43]:
from sklearn import svm
from sklearn import cross_validation

In [124]:
# 教師データ
y_train = []
for epoch in epochs:
    classes = np.array(epoch.events[:,2])
    #classes[classes != 4] = 0
    y_train = y_train + list(classes)
len(y_train)

3024

In [None]:
con, freqs, times, n_epochs, n_tapers = spectral_connectivity(epochs[k][i], method='imcoh', mode='multitaper', sfreq=sfreq, fmin=fmin, fmax=fmax,
        faverage=True, tmin=tmin, tmax=tmax, mt_adaptive=False, n_jobs=1)

In [83]:
# 学習データ
X_train = []
for epoch in epochs:
    for i in range(len(epoch)):
        con_vec = []
        pre = epoch[i].get_data()[:,:,0:500]
        post = epoch[i].get_data()[:,:,750:1250]
        for ch1,ch2 in itertools.combinations(range(ch_num),2):
            pre_coef = max(correlate(pre[0,ch1,:], pre[0,ch2,:])) / (np.linalg.norm(pre[0,ch1,:])*np.linalg.norm(pre[0,ch2,:]))
            post_coef = max(correlate(post[0,ch1,:], post[0,ch2,:])) / (np.linalg.norm(post[0,ch1,:])*np.linalg.norm(post[0,ch2,:]))
            con_vec.append(post_coef - pre_coef)
        con_vec = np.asarray(con_vec)
        con_normed = (con_vec - con_vec.mean())/con_vec.std()
        X_train.append([con_normed.mean()])

In [None]:
# 学習データ
fmin, fmax = 14., 30.
sfreq = raw.info['sfreq']  # the sampling frequency
X_train = []
for epoch in epochs:
    for i in range(len(epoch)):
        con_vec = []
        con_pre, freqs, times, n_epochs, n_tapers = spectral_connectivity(epoch[i], method='imcoh', mode='multitaper', sfreq=sfreq, fmin=fmin, fmax=fmax,
        faverage=True, tmin=0., tmax=2., mt_adaptive=False, n_jobs=1)
        con_post, freqs, times, n_epochs, n_tapers = spectral_connectivity(epoch[i], method='imcoh', mode='multitaper', sfreq=sfreq, fmin=fmin, fmax=fmax,
        faverage=True, tmin=3., tmax=5., mt_adaptive=False, n_jobs=1)
        con_vec.append(con_post[:,:,0] - con_pre[:,:,0])
        con_vec = np.asarray(con_vec).reshape(1,-1)
        con_normed = (con_vec - con_vec.mean())/con_vec.std()
        X_train.append(con_normed)

Connectivity computation...
    computing connectivity for 300 connections
    using t=0.000s..2.000s for estimation (501 points)
    frequencies: 14.5Hz..29.9Hz (32 points)
    connectivity scores will be averaged for each band
    using multitaper spectrum estimation with 7 DPSS windows
    the following metrics will be computed: Imaginary Coherence
    computing connectivity for epoch 1
    assembling connectivity matrix
[Connectivity computation done]
Connectivity computation...
    computing connectivity for 300 connections
    using t=3.000s..5.000s for estimation (501 points)
    frequencies: 14.5Hz..29.9Hz (32 points)
    connectivity scores will be averaged for each band
    using multitaper spectrum estimation with 7 DPSS windows
    the following metrics will be computed: Imaginary Coherence
    computing connectivity for epoch 1
    assembling connectivity matrix
[Connectivity computation done]
Connectivity computation...
    computing connectivity for 300 connections
    u

In [117]:
a = []
for x in X_train:
    a.append(x[0])

In [119]:
len(a[0])

625

In [109]:
clf = svm.SVC(C=20, cache_size=200, class_weight="balanced", \
                          gamma="auto", kernel="rbf", \
                          max_iter=-1, probability=False,random_state=None, \
                          shrinking=True, tol=0.001, verbose=False, decision_function_shape='ovr')

In [128]:
Xc_train, Xc_test, yc_train, yc_test = cross_validation.train_test_split(a, y_train, test_size=0.4, random_state=None)
clf.fit(Xc_train, yc_train)
print(clf.score(Xc_test, yc_test))

0.380991735537
