# Stat Analysis 

## Repeated measures one-way ANOVA - Session Types

Repeated measures ANOVA on source data with spatio-temporal clustering

## Import

In [None]:
import os.path as op
import numpy as np
import matplotlib.pyplot as plt

import mne
import os
import vtk

from mne.datasets import fetch_fsaverage

from mne import spatial_src_adjacency

from mne_bids import make_report

from mne.minimum_norm import read_inverse_operator, apply_inverse

from mne.stats import (spatio_temporal_cluster_test,
                      summarize_clusters_stc, f_mway_rm, f_threshold_mway_rm)

## Main Directory

In [None]:
subjects = ['01', '12', '13', '14', '15', '20', '22', '01', '01', '01']
sessions = ['WM1', 'WM3', 'WM5']
#tasks = ['isi0', 'isi50', 'isi250', 'isi500', 'isi750', 'isi1000', 'isi1250', 'isi1500']
tasks = ['isi1000', 'isi1250', 'isi1500']

root_path = os.path.join("C:/Users/trevo/OneDrive/Desktop/Lab_Files/VWM_LAB")
bids_root = os.path.join(root_path, 'EGI_BIDS')
output_path = os.path.join(root_path, 'EGI_OUTPUTS')

In [None]:
stats = 'stats'
preprocessed_reports = "03_preprocessed"

#Makes EEG_OUTPUTS -> stats
if not os.path.exists(os.path.join("C:/Users/trevo/OneDrive/Desktop/Lab_Files/VWM_LAB/"+stats)):
    os.makedirs(os.path.join("C:/Users/trevo/OneDrive/Desktop/Lab_Files/VWM_LAB/"+stats))
    
stats_path = os.path.join("C:/Users/trevo/OneDrive/Desktop/Lab_Files/VWM_LAB/"+stats)

## Global Variables

In [None]:
n_subjects = 0
factor_levels=[3]

n_permutations = 50

In [None]:
for subject in subjects:
    for session in sessions:
        for task in tasks:
            n_subjects+=1

## Set Parameters

In [None]:
# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)

src_fname = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')

# Read the source space we are morphing to
src = mne.read_source_spaces(src_fname)
adjacency = spatial_src_adjacency(src)
fsave_vertices = [s['vertno'] for s in src]

## Statistic Pipeline

### Initialize

In [None]:
def initialize():
    printname = 'sub-'+subject+'_ses-'+session+'_task-'+task+'_preprocessed_epo.fif'
    print('######################## ['+printname+'] ########################')
    preprocessed_report_path = os.path.join(root_path, "EGI_OUTPUTS\\sub-"+subject+"\\ses-"+session+"\\eeg\\03_preprocessed")
    epochs_source_path = os.path.join(preprocessed_report_path, printname)
   
    epochs_load(epochs_source_path, printname)

### Load Epochs 

In [None]:
def epochs_load(epochs_source_path, printname):
    print('######################## Loading Epochs ########################')
    epochs = mne.read_epochs(epochs_source_path)

    epochs = epochs['WMD+', 'TsD-']
    
    print(epochs.event_id)
    
    inv_operator(epochs)

### Load Inverse Operator

In [None]:
def inv_operator(epochs):
    inv_title = 'sub-'+subject+'_ses-'+session+'_task-'+task+'-inv.fif'
    inverse_operator_path = os.path.join(root_path, "EGI_OUTPUTS\\sub-"+subject+"\\ses-"+session+"\\eeg\\06_inv")
    inv_path = os.path.join(inverse_operator_path, inv_title)

    snr = 3.0
    lambda2 = 1.0 / snr ** 2
    method = "dSPM"  # use dSPM method (could also be MNE, sLORETA, or eLORETA)
    inverse_operator = read_inverse_operator(inv_path)
    
    compute_conditions(epochs, inverse_operator, lambda2, method)

### Compute Conditions

In [None]:
 def compute_conditions(epochs, inverse_operator, lambda2, method):
    print('########################  Computing Condition ########################')
    
    evoked = epochs.average()
    evoked.resample(30).crop(0., None) #for troubleshooting, remove during analysis
    stc = apply_inverse(evoked, inverse_operator, lambda2, method)
     
    stcs.append(stc)
    
    if session == 'WM1':
        A1.append(stc.data)
    elif session == 'WM3':
        A2.append(stc.data)
    elif session == 'WM5':
        A3.append(stc.data)
        

### Compute Dataset

In [None]:
def compute_dataset():
    
    global A1
    global A2
    global A3
    global stcs
    
    A1 = np.stack(A1, axis=0) # observations, space, times
    A2 = np.stack(A2, axis=0) # observations, space, times
    A3 = np.stack(A3, axis=0) # observations, space, times
    print(A1.shape)
    
    A1 = np.transpose(A1, [0, 2, 1]) # observations, times, space
    A2 = np.transpose(A2, [0, 2, 1]) # observations, times, space
    A3 = np.transpose(A3, [0, 2, 1]) # observations, times, space
    print(A1.shape)
    
    X = [A1, A2, A3]
    
    a_times = X[0].shape[0]
    b_times = X[0].shape[1]
    c_times = X[0].shape[2]
    a_times = str(a_times)
    b_times = str(b_times)
    c_times = str(c_times)
    print('X shape 1 = '+a_times+'. X shape 2 = '+b_times+'. X shape 3 = '+c_times+'.')
    
    tstep = stcs[0].tstep
    tstep_str = str(tstep)
    print(task+' task type tstep = '+tstep_str)
    
    times = np.arange(X[0].shape[1]) * tstep * 1e3
    
    compute_cluster(X, tstep, times)

### Spatio-temporal Cluster Test

In [None]:
def stat_fun(*args):
    # get f-values only.
    return f_mway_rm(np.swapaxes(args, 1, 0), factor_levels=factor_levels,
                     effects='A', return_pvals=False)[0]

In [None]:
def compute_cluster(X, tstep, times):
    f_thresh = mne.stats.f_threshold_mway_rm(n_subjects, factor_levels = factor_levels, effects='A', pvalue=.05)
    f_thresh_str = str(f_thresh)
    print('F Threshold for A = '+f_thresh_str)
    
    print('Clustering.')
    F_obs, clusters, cluster_p_values, H0 = clu = \
        spatio_temporal_cluster_test(X, adjacency=adjacency, n_jobs=1,
                                 threshold=f_thresh, stat_fun=stat_fun,
                                 n_permutations=n_permutations,
                                 buffer_size=None)
    # Now select the clusters that are sig. at p < 0.05 (note that this value
    # is multiple-comparisons corrected).
    good_cluster_inds = np.where(cluster_p_values < 0.05)[0]
    
    visualize_clusters(X, clu, tstep, clusters, good_cluster_inds, times)

### Visualize Clusters

In [None]:
def visualize_clusters(X, clu, tstep, clusters, good_cluster_inds, times):
    print('Visualizing clusters.')

    #    Now let's build a convenient representation of each cluster, where each
    #    cluster becomes a "time point" in the SourceEstimate
    stc_all_cluster_vis = summarize_clusters_stc(clu, tstep=tstep, p_thresh=0.05,
                                                 vertices=fsave_vertices,
                                                 subject='fsaverage')

    #    Let's actually plot the first "time point" in the SourceEstimate, which
    #    shows all the clusters, weighted by duration
    lh_brain = stc_all_cluster_vis.plot(subjects_dir=subjects_dir, views='lat', hemi='lh',
                                     time_label='temporal extent (ms)',
                                     clim=dict(kind='value', lims=[0, 1, 40]))
    lh_brain.save_image(stats_path+'/lh_clusters_'+task+'_statistic.png')
    
    rh_brain = stc_all_cluster_vis.plot(subjects_dir=subjects_dir, views='lat', hemi='rh',
                                     time_label='temporal extent (ms)',
                                     clim=dict(kind='value', lims=[0, 1, 40]))
    rh_brain.save_image(stats_path+'/rh_clusters_'+task+'_statistic.png')
    #brain.show_view('medial')
    
    interaction_effect(X, tstep, clusters, good_cluster_inds, times)

### Visualize the Interaction Effect

In [None]:
def interaction_effect(X, tstep, clusters, good_cluster_inds, times):
    inds_t, inds_v = [(clusters[cluster_ind]) for ii, cluster_ind in
                  enumerate(good_cluster_inds)][0]  # first cluster

    fig = plt.figure()
    colors = ['y', 'b', 'g']
    work_loads = ['WM1', 'WM3', 'WM5']
    
    for ii, (condition, color, work_load) in enumerate(zip(X, colors, work_loads)):
        # extract time course at cluster vertices
        condition = condition[:, :, inds_v]
        # normally we would normalize values across subjects but
        # here we use data from the same subject so we're good to just
        # create average time series across subjects and vertices.
        mean_tc = condition.mean(axis=2).mean(axis=0)
        std_tc = condition.std(axis=2).std(axis=0)
        plt.plot(times, mean_tc.T, color=color, label=work_load)
        plt.fill_between(times, mean_tc + std_tc, mean_tc - std_tc, color='gray',
                         alpha=0.5, label='')

    ymin, ymax = mean_tc.min() - .5, mean_tc.max() + .5
    plt.xlabel('Time (ms)')
    plt.ylabel('Activation (F-values)')
    plt.xlim(times[[0, -1]])
    plt.ylim(ymin, ymax)
    plt.fill_betweenx((ymin, ymax), times[inds_t[0]],
                      times[inds_t[-1]], color='orange', alpha=0.3)
    plt.legend()
    plt.title('Interaction between working memory load sets, '+task+' task type.')
    plt.show()
    
    title = 'Interaction between working memory load sets, '+task+' task type.'
    report = mne.Report(title=title)
    report.add_figure(
        fig=fig, title=title,
        image_format='PNG'
    )
    report.save(stats_path+'/interaction_effect_'+task+'.html', overwrite=True)

## Compute Data - isi0

In [None]:
for task in tasks:
    stcs = []
    X = []
    A1 = []
    A2 = []
    A3 = []
    for subject in subjects:
        for session in sessions:
            initialize()
    
    compute_dataset()