In [1]:
import matplotlib
matplotlib.use('Qt5Agg')
from header import *
from mne.stats import spatio_temporal_cluster_1samp_test, spatio_temporal_cluster_test, permutation_cluster_1samp_test, permutation_cluster_test, summarize_clusters_stc
from scipy.stats.distributions import f,t
from tqdm import tqdm
import xarray as xr
#warnings.filterwarnings("ignore",category=DeprecationWarning)

In [2]:
t0 = time.perf_counter()
task = 'SMEG' #'MIMOSA'
states = ['RS','FA','OM']
subjects = get_subjlist(task)#, include_all=True)

no_blk2 = ['002', '004', '007', '016']
no_mri = ['019', '021']
reject = ['002', '004', '010', '011']
for sub in no_mri + reject:
    if sub in subjects:
        subjects.remove(sub)

subjects.sort()
experts = []
novices = []
experts_i = []
novices_i = []
for s,sub in enumerate(subjects):
    if expertise(sub) == 'N':
        novices.append(sub)
        novices_i.append(s)
    if expertise(sub) == 'E':
        experts.append(sub)
        experts_i.append(s)

In [3]:
PSD = xr.open_dataarray(op.join(Analysis_path, task, 'meg', 'Alpha', 'PSD.nc'))
PSD.load()
PSD = PSD.transpose('state', 'subject', 'freq', 'chan')
PSD_norm = PSD/PSD.sum(['freq', 'chan'])
print(PSD_norm)

<xarray.DataArray (state: 5, subject: 55, freq: 2049, chan: 275)>
array([[[[1.520515e-06, ..., 8.324049e-07],
         ...,
         [7.038912e-11, ..., 5.819002e-11]],

        ...,

        [[1.743167e-06, ..., 5.782697e-07],
         ...,
         [4.646666e-11, ..., 5.195834e-11]]],


       ...,


       [[[         nan, ...,          nan],
         ...,
         [         nan, ...,          nan]],

        ...,

        [[7.007803e-07, ..., 3.977057e-07],
         ...,
         [7.731861e-11, ..., 7.051495e-11]]]])
Coordinates:
  * subject  (subject) object '007' '012' '014' '016' '018' '028' '030' ...
  * state    (state) object 'RS1' 'FA1' 'FA2' 'OM1' 'OM2'
  * chan     (chan) object 'MLC11' 'MLC12' 'MLC13' 'MLC14' 'MLC15' 'MLC16' ...
  * freq     (freq) float64 0.0 0.1465 0.293 0.4395 0.5859 0.7324 0.8789 ...


In [4]:
PSD_ave = np.empty((len(states), *PSD.shape[1:]))
for s,state in enumerate(states):
    PSD_ave[s] = PSD_norm.loc[fnmatch.filter(PSD_norm.state.values, state+'*')].mean('state').values
coords = {dim: PSD_norm.coords[dim].values for dim in PSD_norm.dims}
coords.update({'state': states})
PSD_ave = xr.DataArray(PSD_ave, dims=PSD_norm.dims, coords=coords)
print(PSD_ave)

<xarray.DataArray (state: 3, subject: 55, freq: 2049, chan: 275)>
array([[[[1.520515e-06, ..., 8.324049e-07],
         ...,
         [7.038912e-11, ..., 5.819002e-11]],

        ...,

        [[1.743167e-06, ..., 5.782697e-07],
         ...,
         [4.646666e-11, ..., 5.195834e-11]]],


       ...,


       [[[9.130033e-07, ..., 1.135848e-06],
         ...,
         [1.008259e-10, ..., 7.461664e-11]],

        ...,

        [[7.469942e-07, ..., 4.607345e-07],
         ...,
         [7.579208e-11, ..., 6.598015e-11]]]])
Coordinates:
  * subject  (subject) object '007' '012' '014' '016' '018' '028' '030' ...
  * freq     (freq) float64 0.0 0.1465 0.293 0.4395 0.5859 0.7324 0.8789 ...
  * chan     (chan) object 'MLC11' 'MLC12' 'MLC13' 'MLC14' 'MLC15' 'MLC16' ...
  * state    (state) <U2 'RS' 'FA' 'OM'


In [14]:
def sensor_perm_test(X1, X2, stat_file, test_key, freqs, sensors, mode='a', p_threshold=0.01, connectivity=None, paired=False, fif_significance=0.05):
    """
    If paired, test X1-X2.
    A summary Evoked of the stats is saved if there is a significant cluster (p-value < fif_significance).
    (Time is replaced by freqs.)
    Saving can be forced by setting fif_significance to 1, or disabled by setting it to 0.
    Input: arrays of shape (subjects, freq, space)
    """
    os.makedirs(op.dirname(stat_file), exist_ok=True)
    evoked_file = op.splitext(stat_file)[0] + '_' + test_key + '_stat-ave.fif'
    
    if not isinstance(X2, (np.ndarray, xr.DataArray, list)):
        X2 = np.zeros(X1.shape)
    
    if paired:
        X = X1 - X2
        t_threshold = -t.ppf(p_threshold / 2, X.shape[0] - 1)
        T_obs, clusters, cluster_pv, H0 = clu_all = spatio_temporal_cluster_1samp_test(X, connectivity=connectivity, threshold=t_threshold, n_jobs=4)
    else:
        f_threshold = f.ppf(1 - p_threshold / 2, X1.shape[0] - 1, X2.shape[0] - 1)
        T_obs, clusters, cluster_pv, H0 = clu_all = spatio_temporal_cluster_test([X1,X2], connectivity=connectivity, threshold=f_threshold, n_jobs=4)
    
    p_val = np.ones_like(T_obs)
    clu_inds = np.zeros_like(T_obs)
    
    info_file = op.join(Analysis_path, 'MEG', 'meta', 'mag-info.fif')
    if op.isfile(info_file):
        info = mne.io.read_info(info_file)
        info['sfreq'] = 1 / (freqs[1] - freqs[0])
    else:
        info = mne.create_info(sensors, 1 / (freqs[1] - freqs[0]))
    
    evokeds = []
    for c,clu in enumerate(clusters):
        p_val[clu] = cluster_pv[c]
        clu_inds[clu] = c+1
        if np.any(cluster_pv[c] <= fif_significance):
            data = np.full_like(T_obs, 0)#np.nan)
            data[clu] = T_obs[clu]
            #mne.write_evokeds(evoked_file, mne.EvokedArray(data.T, info, freqs[0], 'cluster_{}'.format(c+1)))
            evokeds.append(mne.EvokedArray(data.T, info, freqs[0], 'cluster_{}'.format(c+1)))
    
    if np.any(p_val <= fif_significance):
        evokeds.append(mne.EvokedArray(np.where(p_val <= fif_significance, T_obs, 0).T, info, freqs[0], 'all_clusters'))
        mne.write_evokeds(evoked_file, evokeds)
    
    stats = xr.DataArray(np.zeros((3, *T_obs.shape)), dims=['data', 'freq', 'sensor'], coords={'data':['T_stat', 'p_val', 'clu_inds'], 'freq':freqs, 'sensor':sensors})
    stats.loc['T_stat'] = T_obs
    stats.loc['p_val'] = p_val
    stats.loc['clu_inds'] = clu_inds
    
    stats.to_netcdf(path=stat_file, group=test_key, mode=mode if op.isfile(stat_file) else 'w')
    return clu_all

In [15]:
fmin = 3 #PSD_norm.freq.values[0]
fmax = 45 #PSD_norm.freq.values[-1]
stat_path = op.join(Analysis_path, task, 'meg', 'Stats', 'PSD')
os.makedirs(stat_path, exist_ok=True)
stat_file = op.join(stat_path, '{}-{}Hz.nc'.format(fmin, fmax))
paired_tests = {'RS_vs_FA':('RS', 'FA', subjects), 'RS_vs_OM':('RS', 'OM', subjects),
         'OM_vs_FA':('OM', 'FA', subjects), 'RS_vs_FA+E':('RS', 'FA', experts),
         'RS_vs_OM+E':('RS', 'OM', experts), 'OM_vs_FA+E':('OM', 'FA', experts),
         'RS_vs_FA+N':('RS', 'FA', novices), 'RS_vs_OM+N':('RS', 'OM', novices),
         'OM_vs_FA+N':('OM', 'FA', novices)}
clu = dict()

In [16]:
for key,val in paired_tests.items():
    logger.info(key)
    clu[key] = sensor_perm_test(PSD_ave.loc[val[0],val[2],fmin:fmax].values, PSD_ave.loc[val[1],val[2],fmin:fmax].values, stat_file=stat_file, test_key=key, freqs=PSD_ave.loc[:,:,fmin:fmax].freq.values, sensors=PSD_ave.chan.values.tolist(), paired=True)

OM_vs_FA
stat_fun(H1): min=-4.039363 max=4.240101
Running initial clustering
Found 370 clusters
Permuting 1023 times...
[........................................................... ] 99.22%  |   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   22.8s remaining:   22.8s


Computing cluster p-values
Done.
Isotrak not found


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   23.8s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   23.8s finished


RS_vs_FA+N
stat_fun(H1): min=-5.949462 max=6.240852
Running initial clustering
Found 1968 clusters
Permuting 1023 times...
[....................................................        ] 86.82%  \   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   16.9s remaining:   16.9s


[........................................................... ] 99.22%  |   Computing cluster p-values


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   17.3s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   17.3s finished


Done.
Isotrak not found
RS_vs_FA
stat_fun(H1): min=-6.317864 max=6.428724
Running initial clustering
Found 1751 clusters
Permuting 1023 times...
[....................................................        ] 86.82%  \   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   21.2s remaining:   21.2s


[........................................................... ] 99.22%  |   Computing cluster p-values


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   22.2s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   22.2s finished


Done.
Isotrak not found
RS_vs_OM+E
stat_fun(H1): min=-6.102058 max=6.604395
Running initial clustering
Found 1733 clusters
Permuting 1023 times...
[....................................................        ] 86.82%  \   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   15.5s remaining:   15.5s


[........................................................... ] 99.22%  |   

[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   16.0s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   16.0s finished


Computing cluster p-values
Done.
Isotrak not found
OM_vs_FA+E
stat_fun(H1): min=-5.098516 max=4.330124
Running initial clustering
Found 247 clusters
Permuting 1023 times...
[....................................................        ] 86.82%  \   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   16.0s remaining:   16.0s


[........................................................... ] 99.22%  |   Computing cluster p-values
Done.
Isotrak not found


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   16.7s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   16.7s finished


OM_vs_FA+N
stat_fun(H1): min=-4.454789 max=3.820508
Running initial clustering
Found 192 clusters
Permuting 1023 times...
[........................................................... ] 99.22%  |   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   17.5s remaining:   17.5s


Computing cluster p-values
Done.


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   18.6s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   18.6s finished


Isotrak not found
RS_vs_OM
stat_fun(H1): min=-6.762470 max=7.542230
Running initial clustering
Found 1484 clusters
Permuting 1023 times...
[....................................................        ] 87.84%  \   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   20.2s remaining:   20.2s


[........................................................... ] 99.22%  |   

[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   21.8s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   21.8s finished


Computing cluster p-values
Done.
Isotrak not found
RS_vs_OM+N
stat_fun(H1): min=-6.539066 max=7.260325
Running initial clustering
Found 2022 clusters
Permuting 1023 times...
[....................................................        ] 87.84%  \   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   16.6s remaining:   16.6s


[........................................................... ] 99.22%  |   Computing cluster p-values


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   17.7s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   17.7s finished


Done.
Isotrak not found
RS_vs_FA+E
stat_fun(H1): min=-6.254659 max=7.159777
Running initial clustering
Found 1846 clusters
Permuting 1023 times...
[........................................................... ] 99.22%  |   

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   15.9s remaining:   15.9s


Computing cluster p-values


[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   16.8s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   16.8s finished


Done.
Isotrak not found
