In [None]:
from header import *
import psutil
from mne.stats import spatio_temporal_cluster_1samp_test,spatio_temporal_cluster_test
from scipy.stats.distributions import f,t
from tqdm import tqdm
import visbrain
from visbrain.objects import BrainObj, ColorbarObj, SceneObj, SourceObj
import xarray as xr
warnings.filterwarnings("ignore",category=DeprecationWarning)

In [None]:
t0 = time.perf_counter()
task = 'SMEG' #'MIMOSA'
states = ['RS','FA','OM']
subjects = get_subjlist(task)

reject = ['004', '010', '072', '109']
for sub in reject:
    if sub in subjects:
        subjects.remove(sub)

subjects.sort()

experts = []
novices = []
experts_i = []
novices_i = []
for s,sub in enumerate(subjects):
    if expertise(sub) == 'N':
        novices.append(sub)
        novices_i.append(s)
    if expertise(sub) == 'E':
        experts.append(sub)
        experts_i.append(s)

In [None]:
names = ['R_ECG_included', 'R_ECG_excluded', 'T_ECG_included', 'T_ECG_excluded']
name = names[0]
noise_cov = 'baseline_cov'
fsaverage = '-fsaverage'#or False
stc_ext = '-lh.stc'
sfreq = 200
start = None
stop = None
stc_path = op.join(Analysis_path, task, 'meg', 'SourceEstimate')

In [None]:
#for n,name in enumerate(names):
for st,state in enumerate(states):
    print(state)
    for su,sub in enumerate(tqdm(subjects)):
        stc_file = op.join(stc_path, sub, state+'*'+name+'*'+noise_cov+'*'+(fsaverage if fsaverage else '')+stc_ext)
        data = mne.read_source_estimate(glob.glob(stc_file)[0].strip(stc_ext))
        if sfreq:
            data.resample(sfreq)
        if not st and not su:#and not n 
            times = data.times
            if not start:
                start = times[0]
            if not stop:
                stop = times[-1]
            vertices = np.concatenate([['lh_' + str(x) for x in data.lh_vertno],['rh_' + str(x) for x in data.rh_vertno]])
            try:
                del stc
            except:
                pass
            stc = np.zeros((len(states), len(subjects), data.data.shape[1], data.data.shape[0]))
            stc = xr.DataArray(stc, dims=['state', 'subject', 'time', 'src'], coords={'state':states, 'subject':subjects, 'time':times, 'src':vertices})
        stc[st,su] = data.data.T
        del data

In [None]:
def cluster_perm_test(X1, X2, stat_file, p_threshold=0.01, connectivity='ico5', paired=False):
    """
    If paired, test X2-X1.
    Input: arrays of shape (subjects, time, space)
    """
    if type(connectivity) is str:
        connectivity = mne.spatial_tris_connectivity(mne.grade_to_tris(int(connectivity[-1])))
    
    if paired:
        t_threshold = -t.ppf(p_threshold / 2, (X2-X1).shape[0] - 1)
        T_obs, clusters, cluster_pv, H0 = spatio_temporal_cluster_1samp_test((X2-X1), connectivity=connectivity, threshold=t_threshold, n_jobs=4)
    else:
        f_threshold = f.ppf(1 - p_threshold / 2, X1.shape[0] - 1, X2.shape[0] - 1)
        T_obs, clusters, cluster_pv, H0 = spatio_temporal_cluster_test([X1,X2], connectivity=connectivity, threshold=f_threshold, n_jobs=4)
    
    mask = np.ones(T_obs.shape)
    
    for c,clu in enumerate(clusters):
        mask[clu] = cluster_pv[c]
    
    if not op.isdir(op.split(stat_file)[0]):
        os.makedirs(op.split(stat_file)[0])
    for h,hemi in enumerate(['lh', 'rh']):
        i = h + 1
        h *= T_obs.shape[1]/2
        i *= T_obs.shape[1]/2
        
        np.savetxt(stat_file+'-'+hemi+'-T_stat.tsv', T_obs[:,int(h):int(i)], delimiter='\t')
        np.savetxt(stat_file+'-'+hemi+'-p_val.tsv', mask[:,int(h):int(i)], delimiter='\t')
    
    return T_obs, clusters, cluster_pv, H0, mask

In [None]:
spacing='ico5'
connectivity = mne.spatial_tris_connectivity(mne.grade_to_tris(int(spacing[-1])))

In [None]:
stat_path = op.join(Analysis_path, task, 'meg', 'Stats')
T_obs=dict(); clusters=dict(); cluster_pv=dict(); H0=dict(); mask=dict()

In [None]:
t0 = time.perf_counter()
ti = t0
start = .3
stop = .5

test_key = 'RS_vs_FA'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',:,start:stop].values, stc.loc['FA',:,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'RS_vs_OM'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',:,start:stop].values, stc.loc['OM',:,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'OM_vs_FA'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['OM',:,start:stop].values, stc.loc['FA',:,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'RS_vs_FA+E'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',experts,start:stop].values, stc.loc['FA',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'RS_vs_OM+E'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',experts,start:stop].values, stc.loc['OM',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'OM_vs_FA+E'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['OM',experts,start:stop].values, stc.loc['FA',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'RS_vs_FA+N'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',novices,start:stop].values, stc.loc['FA',novices,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'RS_vs_OM+N'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',novices,start:stop].values, stc.loc['OM',novices,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'OM_vs_FA+N'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['OM',novices,start:stop].values, stc.loc['FA',novices,start:stop].values, stat_file=stat_file, connectivity=connectivity, paired=True)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()


test_key = 'N_vs_E+RS'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['RS',novices,start:stop].values, stc.loc['RS',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'N_vs_E+FA'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['FA',novices,start:stop].values, stc.loc['FA',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'N_vs_E+OM'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['OM',novices,start:stop].values, stc.loc['OM',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'N_vs_E+FA-RS'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['FA',novices,start:stop].values-stc.loc['RS',novices,start:stop].values, stc.loc['FA',experts,start:stop].values-stc.loc['RS',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'N_vs_E+OM-RS'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['OM',novices,start:stop].values-stc.loc['RS',novices,start:stop].values, stc.loc['OM',experts,start:stop].values-stc.loc['RS',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()

test_key = 'N_vs_E+FA-OM'
stat_file = op.join(stat_path, name+'-'+noise_cov+'-'+test_key+'-'+str(round(start,4))+'_'+str(round(stop,4))+'_t_window'+(fsaverage if fsaverage else ''))
T_obs[test_key], clusters[test_key], cluster_pv[test_key], H0[test_key], mask[test_key] = cluster_perm_test(stc.loc['FA',novices,start:stop].values-stc.loc['OM',novices,start:stop].values, stc.loc['FA',experts,start:stop].values-stc.loc['OM',experts,start:stop].values, stat_file=stat_file, connectivity=connectivity)
print('\t',test_key, time.perf_counter()-ti,'\n\t*****\n')
ti = time.perf_counter()


T = ti - t0
print('\n*****\nTotal running time:', T)