# NeuroCluster:
<font size= 4> Non-parametric cluster-based permutation testing to identify neurophysiological encoding of continuous variables with time-frequency resolution

Authors: Christina Maher & Alexandra Fink-Skular \
Updated: 05/23/2024 by AFS

In [1]:
import numpy as np
import pandas as pd
import mne
from glob import glob
from scipy.stats import zscore, t, linregress, ttest_ind, ttest_rel, ttest_1samp 
import os 
import re
import h5io
import pickle 
import time 
import datetime 
from joblib import Parallel, delayed
import statsmodels.api as sm 
from scipy.ndimage import label
import statsmodels.formula.api as smf


import warnings
warnings.filterwarnings('ignore')


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
base_dir = '/Users/alexandrafink/Documents/GraduateSchool/SaezLab/NeuroCluster/'

In [4]:
# load helper functions 
import sys
sys.path.append(base_dir)
from cluster_utils import *
from helper_utils import *

In [5]:
date = datetime.date.today().strftime('%m%d%Y')
print(date)

05262024


In [6]:
# Set directory path for epoched data + set subj_id(s) for analysis
epochs_path = f'{base_dir}test_data/'
subj_id     = 'MS002' #if loading multiple subjects make list & load epochs into dict
# set directory for results 
results_dir = f'{base_dir}results/'

# Step 1: Format Input Data (Currently within-subject)
- neural input: np.array (n_channels x n_epochs x n_freqs x n_times)
- regressor data: np.array (numpy array: n_epochs x n_features)

In [7]:
roi_list =['acc', 'dmpfc', 'amy', 'ains', 'vlpfc', 'hpc', 'smg','mtg', 'ofc','pins', 'dlpfc', 'motor','vmpfc','mcc']

In [7]:
# load epoched data for single subj
power_epochs = mne.time_frequency.read_tfrs(fname=f'{epochs_path}{subj_id}_CpeOnset-tfr.h5')[0]


Reading /Users/alexandrafink/Documents/GraduateSchool/SaezLab/NeuroCluster/test_data/MS002_CpeOnset-tfr.h5 ...
Adding metadata with 19 columns


In [8]:
# explore the TFR data structure 
power_epochs._data.shape #should be shape: trial x channel x freq x time(ms)

(150, 94, 30, 1501)

In [9]:
# store single subj data dim info (num_epochs, num_channels, num_freq, num_time)
num_epochs   = power_epochs._data.shape[0]  # number of epochs
num_channels = power_epochs._data.shape[1] # number of channels
num_freq     = power_epochs._data.shape[2] # number of freq bins
num_time     = power_epochs._data.shape[3] # number of time points

#### need to update below to work for multivariate regressor matrices

In [10]:
# extract behavioral regressors from epochs metadata
regressors   = [col for col in power_epochs.metadata if col not in power_epochs.ch_names] # extract behavioral regressors, we are on the first dimension because there is an option to store different epoch types 
regressors 

['Round',
 'RT',
 'CpeOnset',
 'GambleChoice',
 'TrialType',
 'SafeBet',
 'LowBet',
 'HighBet',
 'GambleEV',
 'Outcome',
 'Profit',
 'TotalProfit',
 'CR',
 'choiceEV',
 'RPE',
 'decisionCPE',
 'decisionRegret',
 'decisionRelief',
 'decisionCF']

In [11]:
# set regressor of interest for analysis 
reg_name    = 'decisionCPE' # make list if multiple 
# reg_data = np.array(power_epochs[0].metadata[reg_name]).reshape(-1, 1)
reg_data = np.array(power_epochs.metadata[reg_name])

In [12]:
reg_data.shape # check shape of regressor data (should be equal to the number of epochs)

(150,)

In [13]:
# predictor_dims = reg_data.shape #### NEED TO FIX FOR MULTIVAR REGS!
predictor_dims = (len(reg_data),1)
tcritical = compute_tcritical(predictor_dims,tails=2, alternative = 'two-sided',alpha=0.05) # this is the critical t value for the given alpha level and degrees of freedom, we will use this to threshold the t values corresponding to the regression coefficients.
tcritical 

1.9761224936033632

# Step 2: Extract Clusters from Univariate Pixel-Wise Regression
- Calculate critical t-statistic (both positive and negative)
- Perform univariate pixel-wise regression with real data for each freq & timepoint 
- Compute t-statistic for pixel-wise regression and compare to critical t-statistic
- Extract largest positive + negative cluster from single-electrode TFR

DEPENDENCIES - tfr_cluster_test, pixel_regression, get_max_cluster

In [63]:
# # get channels that contain 'OFC' in the name
# ofc_channels = [ch for ch in ch_names if 'olf' in ch]
# ofc_data = subset_channels(power_epochs[0], ofc_channels)
# num_ofc_channels = len(ofc_channels)

# # check the amount of memory being used for power_epochs[0] versus ofc_data. 
# print(f'Power epochs: {power_epochs[0]._data.nbytes/1e6} MB')
# print(f'OFC data: {ofc_data.nbytes/1e6} MB')

# del power_epochs # free up memory once you have the data you need 

In [18]:
# Define channels of interest - for now we will use all channels
# save ch names as list 
ch_names     = power_epochs.ch_names # all channels, or make list of desired channels
num_channels = len(ch_names)



In [19]:
ch_names

['lacas1-lacas2',
 'lacas2-lacas3',
 'lacas3-lacas4',
 'lacas4-lacas5',
 'lacas5-lacas6',
 'lacas6-lacas7',
 'lacas7-lacas8',
 'lacas8-lacas9',
 'lacas9-lacas10',
 'lagit1-lagit2',
 'lagit2-lagit3',
 'lagit3-lagit4',
 'lagit6-lagit7',
 'lagit7-lagit8',
 'lagit8-lagit9',
 'laims1-laims2',
 'laims2-laims3',
 'laims3-laims4',
 'laims4-laims5',
 'laims5-laims6',
 'laims6-laims7',
 'laims10-laims11',
 'laims11-laims12',
 'laims12-laims13',
 'laims13-laims14',
 'lhplt1-lhplt2',
 'lhplt2-lhplt3',
 'lhplt3-lhplt4',
 'lhplt4-lhplt5',
 'lhplt8-lhplt9',
 'lhplt9-lhplt10',
 'lhplt10-lhplt11',
 'lhplt11-lhplt12',
 'lloif1-lloif2',
 'lloif2-lloif3',
 'lloif3-lloif4',
 'lloif6-lloif7',
 'lloif7-lloif8',
 'lloif8-lloif9',
 'lmoif2-lmoif3',
 'lmoif8-lmoif9',
 'lmoif9-lmoif10',
 'lmoif10-lmoif11',
 'lmoif11-lmoif12',
 'lmoif12-lmoif13',
 'lpips1-lpips2',
 'lpips2-lpips3',
 'lpips3-lpips4',
 'lpips4-lpips5',
 'lpips9-lpips10',
 'lpips10-lpips11',
 'lpips11-lpips12',
 'lpips12-lpips13',
 'lsif3-lsif4',
 '

In [29]:
start = time.time() # start timer

## run simple linear regression on electrodes in parallel to speed up computation - I did this for just the subset OFC electrodes.
subj_all_elec_data = Parallel(n_jobs=-1,verbose=12)(
    delayed(tfr_cluster_test)(
        np.squeeze(power_epochs._data[:,c,:,:]),reg_data,tcritical
        ) for c in range(num_channels))

end = time.time()    
print('{:.4f} s'.format(end-start)) # print time elapsed for computation (approx 20 seconds per channel)

# save subj cluster data for all electrodes
pickle.dump(subj_all_elec_data, open(f'{results_dir}{subj_id}_all_elec_clusters.pkl', "wb")) 

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   17.7s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:   17.9s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:   18.0s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:   18.1s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:   18.1s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:   18.1s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:   18.3s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:   18.3s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:   35.3s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:   35.9s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:   36.1s
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:   36.2s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:   36.4s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:   36.4s
[Parallel(n_jobs=-1)]: Done  15 tasks      | elapsed:   

219.6341 s


In [30]:
subj_all_elec_data

[{'results_betas': array([[-0.16831299, -0.16751641, -0.16669723, ..., -0.1008478 ,
          -0.09928747, -0.09771148],
         [-0.21702832, -0.21621908, -0.21537974, ..., -0.10172934,
          -0.10025334, -0.09876587],
         [-0.21850114, -0.21786115, -0.21718907, ..., -0.04504799,
          -0.04361677, -0.04218972],
         ...,
         [ 0.15554734,  0.16282363,  0.17470995, ...,  0.08015444,
           0.06853781,  0.04299185],
         [ 0.11016591,  0.13260834,  0.15825853, ...,  0.01136862,
          -0.01409919, -0.04441674],
         [-0.01235756,  0.00242384,  0.01988116, ..., -0.05393414,
          -0.08598995, -0.10535038]]),
  'tstat_observed': array([[-1.67115235, -1.6623342 , -1.65331339, ..., -1.04429586,
          -1.02820362, -1.0119635 ],
         [-2.03933363, -2.02944701, -2.01933345, ..., -0.97699585,
          -0.96123758, -0.94544172],
         [-2.07553664, -2.06741306, -2.05898581, ..., -0.38059662,
          -0.36741414, -0.35436289],
         ...,

In [31]:
subj_all_elec_data[0].keys() # check keys for data structure - this is what we save for each elec. results_betas', 'tstat_observed', 'sig_tstat_observed_pos', 'sig_tstat_observed_neg', 'pos_clust_data', 'neg_clust_data'



dict_keys(['results_betas', 'tstat_observed', 'sig_tstat_observed_pos', 'sig_tstat_observed_neg', 'pos_clust_data', 'neg_clust_data'])

# Step 3: Extract Surrogate Clusters from Pixel-wise Permutation
- For loop for each electrode- 
- Run each permutation (1000x) in parallel within electrode loop
- Calculate max cluster p value for each +/- cluster for each electrode
- Save permuted cluster statistics for each electrode 

DEPENDENCIES: permuted_tfr_cluster_test, tfr_cluster_test

In [14]:
# If you need to load the data back in, you can use the following code:
subj_all_elec_data = pickle.load(open(f'{results_dir}{subj_id}_all_elec_clusters.pkl','rb'))
# reload pow_epochs if necessary 
# power_epochs = mne.time_frequency.read_tfrs(fname=f'{epochs_path}{subj_id}_CpeOnset-tfr.h5')[0]

In [22]:
len(subj_all_elec_data)

94

In [76]:
### TEST PERMUTATIONS 
num_permutations = 1000
start = time.time() # start timer

all_ch_perm = {}

for c in range(num_channels):
        ch_start = time.time() # start timer

        # Prepare arguments for the permutation function
        permutation_args = [
        (np.squeeze(power_epochs._data[:,c,:,:]), reg_data, tcritical)
        for _ in range(num_permutations)]
    
        # Perform permutations in parallel
        elec_permuted_data = Parallel(n_jobs=-1, verbose=12)(
        delayed(permuted_tfr_cluster_test)(*args)
        for args in permutation_args)
        
        # save in all elec dict 
        all_ch_perm[ch_names[c]] = elec_permuted_data
        pickle.dump(elec_permuted_data, open(f'{results_dir}{subj_id}_{ch_names[c]}_perm_clusters.pkl', "wb")) 

        ch_end = time.time() 
        print(f'{ch_names[c]} permute time: ', '{:.2f}'.format(ch_end-ch_start))
        
        

end = time.time()    
print('{:.2f} s'.format(end-start)) # print time elapsed for computation (approx 4 seconds per permutation)


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   18.2s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:   18.6s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:   18.6s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:   18.6s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:   18.6s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:   18.8s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:   19.1s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:   19.1s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:   35.9s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:   36.3s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:   36.3s
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:   36.6s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:   36.6s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:   36.7s
[Parallel(n_jobs=-1)]: Done  15 tasks      | elapsed:   

lacas1-lacas2 permute time:  2222.30


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done  15 tasks      | elapsed:   

KeyboardInterrupt: 

In [23]:
c=16
ch_names[c]

'laims2-laims3'

In [24]:
num_permutations = 1000
ch_start = time.time() # start timer

# Prepare arguments for the permutation function
permutation_args = [
(np.squeeze(power_epochs._data[:,c,:,:]), reg_data, tcritical)
for _ in range(num_permutations)]

# Perform permutations in parallel
elec_permuted_data_reduc = Parallel(n_jobs=-1, verbose=12)(
delayed(permuted_tfr_cluster_test)(*args)
for args in permutation_args)

# save in all elec dict 
# all_ch_perm[ch_names[c]] = elec_permuted_data
pickle.dump(elec_permuted_data_reduc, open(f'{results_dir}{subj_id}_{ch_names[c]}_reduced_output_perm_clusters.pkl', "wb")) 

ch_end = time.time() 
print(f'{ch_names[c]} permute time: ', '{:.2f}'.format(ch_end-ch_start))

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   17.0s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:   17.1s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:   17.1s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:   17.3s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:   17.4s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:   17.5s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:   17.7s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:   18.0s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:   31.9s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:   32.0s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:   32.0s
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:   32.1s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:   32.6s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:   32.6s
[Parallel(n_jobs=-1)]: Done  15 tasks      | elapsed:   

laims2-laims3 permute time:  2178.48


In [25]:
elec_permuted_data_reduc

[(298.09901001049246, 136.65549636093147),
 (416.62770479299354, 294.02828639583845),
 (213.59620000754634, 161.3735007138127),
 (776.4162480069654, 346.0772862745674),
 (699.5711094275869, 291.25891852679166),
 (263.43614340489853, 482.1334299376421),
 (321.7094325851797, 814.2876139906707),
 (1046.84402514575, 2333.8754487420783),
 (503.92539785009524, 318.03396552392513),
 (3267.6333142172816, 306.45038172574954),
 (1071.8413092866156, 450.88592825587494),
 (342.4957075736979, 541.4328474638614),
 (884.0766553697839, 692.284183355802),
 (1137.2591985985878, 433.34819297661335),
 (660.0653655560752, 619.1936027582708),
 (656.5377522142544, 301.4506391071268),
 (707.5103430220606, 158.99476310451865),
 (532.8078319548499, 148.69888641633457),
 (342.15703060688145, 459.0734808127743),
 (334.6597627149725, 424.63972204863467),
 (103.05975358418993, 1339.9864464398865),
 (446.6581833826948, 242.67833451315664),
 (538.7768915190258, 557.3786981854928),
 (352.86551793200977, 1739.855157561

In [78]:
pickle.dump(elec_permuted_data, open(f'{results_dir}{subj_id}_{ch_names[c]}_perm_clusters.pkl', "wb")) 

In [89]:
elec_permuted_data_reduc

{'results_betas': array([[ 0.07496082,  0.07568205,  0.07640095, ..., -0.0100412 ,
         -0.00988225, -0.00972245],
        [ 0.11049633,  0.11199644,  0.11349419, ...,  0.01570515,
          0.01619038,  0.01668231],
        [ 0.15378925,  0.15533521,  0.15687356, ...,  0.04369581,
          0.0441821 ,  0.04468103],
        ...,
        [ 0.13442341,  0.14565859,  0.15230832, ..., -0.02047693,
          0.02107548,  0.06610168],
        [ 0.07638172,  0.10577456,  0.12961218, ..., -0.08342112,
         -0.05628506, -0.04392064],
        [-0.01453783, -0.02366936, -0.03528101, ..., -0.1266133 ,
         -0.10745436, -0.09402729]]),
 'tstat_observed': array([[ 0.77350381,  0.78096577,  0.78837237, ..., -0.10668448,
         -0.1048461 , -0.1030016 ],
        [ 1.04259185,  1.05652562,  1.07038852, ...,  0.17455808,
          0.17978833,  0.1850744 ],
        [ 1.52371944,  1.53667189,  1.54944371, ...,  0.46651023,
          0.4716797 ,  0.47696856],
        ...,
        [ 1.4442022

In [26]:
# https://github.com/mne-tools/mne-python/blob/maint/1.7/mne/stats/cluster_level.py

# https://numpy.org/doc/stable/reference/generated/numpy.linalg.lstsq.html

# def _get_components(x_in, adjacency, return_list=True):
#     """Get connected components from a mask and a adjacency matrix."""
#     if adjacency is False:
#         components = np.arange(len(x_in))
#     else:
#         mask = np.logical_and(x_in[adjacency.row], x_in[adjacency.col])
#         data = adjacency.data[mask]
#         row = adjacency.row[mask]
#         col = adjacency.col[mask]
#         shape = adjacency.shape
#         idx = np.where(x_in)[0]
#         row = np.concatenate((row, idx))
#         col = np.concatenate((col, idx))
#         data = np.concatenate((data, np.ones(len(idx), dtype=data.dtype)))
#         adjacency = sparse.coo_matrix((data, (row, col)), shape=shape)
#         _, components = connected_components(adjacency)
#     if return_list:
#         start = np.min(components)
#         stop = np.max(components)
#         comp_list = [list() for i in range(start, stop + 1, 1)]
#         mask = np.zeros(len(comp_list), dtype=bool)
#         for ii, comp in enumerate(components):
#             comp_list[comp].append(ii)
#             mask[comp] += x_in[ii]
#         clusters = [np.array(k) for k, m in zip(comp_list, mask) if m]
#         return clusters
#     else:
#         return components


# def _find_clusters(
#     x,
#     threshold,
#     tail=0,
#     adjacency=None,
#     max_step=1,
#     include=None,
#     partitions=None,
#     t_power=1,
#     show_info=False,
# ):
#     """Find all clusters which are above/below a certain threshold.

#     When doing a two-tailed test (tail == 0), only points with the same
#     sign will be clustered together.

#     Parameters
#     ----------
#     x : 1D array
#         Data
#     threshold : float | dict
#         Where to threshold the statistic. Should be negative for tail == -1,
#         and positive for tail == 0 or 1. Can also be an dict for
#         threshold-free cluster enhancement.
#     tail : -1 | 0 | 1
#         Type of comparison
#     adjacency : scipy.sparse.coo_matrix, None, or list
#         Defines adjacency between features. The matrix is assumed to
#         be symmetric and only the upper triangular half is used.
#         If adjacency is a list, it is assumed that each entry stores the
#         indices of the spatial neighbors in a spatio-temporal dataset x.
#         Default is None, i.e, a regular lattice adjacency.
#         False means no adjacency.
#     max_step : int
#         If adjacency is a list, this defines the maximal number of steps
#         between vertices along the second dimension (typically time) to be
#         considered adjacent.
#     include : 1D bool array or None
#         Mask to apply to the data of points to cluster. If None, all points
#         are used.
#     partitions : array of int or None
#         An array (same size as X) of integers indicating which points belong
#         to each partition.
#     t_power : float
#         Power to raise the statistical values (usually t-values) by before
#         summing (sign will be retained). Note that t_power == 0 will give a
#         count of nodes in each cluster, t_power == 1 will weight each node by
#         its statistical score.
#     show_info : bool
#         If True, display information about thresholds used (for TFCE). Should
#         only be done for the standard permutation.

#     Returns
#     -------
#     clusters : list of slices or list of arrays (boolean masks)
#         We use slices for 1D signals and mask to multidimensional
#         arrays.
#     sums : array
#         Sum of x values in clusters.
#     """
#     _check_option("tail", tail, [-1, 0, 1])

#     x = np.asanyarray(x)

#     if not np.isscalar(threshold):
#         if not isinstance(threshold, dict):
#             raise TypeError(
#                 "threshold must be a number, or a dict for "
#                 "threshold-free cluster enhancement"
#             )
#         if not all(key in threshold for key in ["start", "step"]):
#             raise KeyError(
#                 "threshold, if dict, must have at least " '"start" and "step"'
#             )
#         tfce = True
#         use_x = x[np.isfinite(x)]
#         if use_x.size == 0:
#             raise RuntimeError(
#                 "No finite values found in the observed statistic values"
#             )
#         if tail == -1:
#             if threshold["start"] > 0:
#                 raise ValueError('threshold["start"] must be <= 0 for ' "tail == -1")
#             if threshold["step"] >= 0:
#                 raise ValueError('threshold["step"] must be < 0 for ' "tail == -1")
#             stop = np.min(use_x)
#         elif tail == 1:
#             stop = np.max(use_x)
#         else:  # tail == 0
#             stop = max(np.max(use_x), -np.min(use_x))
#         del use_x
#         thresholds = np.arange(threshold["start"], stop, threshold["step"], float)
#         h_power = threshold.get("h_power", 2)
#         e_power = threshold.get("e_power", 0.5)
#         if show_info is True:
#             if len(thresholds) == 0:
#                 warn(
#                     f'threshold["start"] ({threshold["start"]}) is more extreme '
#                     f"than data statistics with most extreme value {stop}"
#                 )
#             else:
#                 logger.info(
#                     "Using %d thresholds from %0.2f to %0.2f for TFCE "
#                     "computation (h_power=%0.2f, e_power=%0.2f)"
#                     % (len(thresholds), thresholds[0], thresholds[-1], h_power, e_power)
#                 )
#         scores = np.zeros(x.size)
#     else:
#         thresholds = [threshold]
#         tfce = False

#     # include all points by default
#     if include is None:
#         include = np.ones(x.shape, dtype=bool)

#     if tail in [0, 1] and not np.all(np.diff(thresholds) > 0):
#         raise ValueError("Thresholds must be monotonically increasing")
#     if tail == -1 and not np.all(np.diff(thresholds) < 0):
#         raise ValueError("Thresholds must be monotonically decreasing")

#     # set these here just in case thresholds == []
#     clusters = list()
#     sums = list()
#     for ti, thresh in enumerate(thresholds):
#         # these need to be reset on each run
#         clusters = list()
#         if tail == 0:
#             x_ins = [
#                 np.logical_and(x > thresh, include),
#                 np.logical_and(x < -thresh, include),
#             ]
#         elif tail == -1:
#             x_ins = [np.logical_and(x < thresh, include)]
#         else:  # tail == 1
#             x_ins = [np.logical_and(x > thresh, include)]
#         # loop over tails
#         for x_in in x_ins:
#             if np.any(x_in):
#                 out = _find_clusters_1dir_parts(
#                     x, x_in, adjacency, max_step, partitions, t_power, ndimage
#                 )
#                 clusters += out[0]
#                 sums.append(out[1])
#         if tfce:
#             # the score of each point is the sum of the h^H * e^E for each
#             # supporting section "rectangle" h x e.
#             if ti == 0:
#                 h = abs(thresh)
#             else:
#                 h = abs(thresh - thresholds[ti - 1])
#             h = h**h_power
#             for c in clusters:
#                 # triage based on cluster storage type
#                 if isinstance(c, slice):
#                     len_c = c.stop - c.start
#                 elif isinstance(c, tuple):
#                     len_c = len(c)
#                 elif c.dtype == np.dtype(bool):
#                     len_c = np.sum(c)
#                 else:
#                     len_c = len(c)
#                 scores[c] += h * (len_c**e_power)
#     # turn sums into array
#     sums = np.concatenate(sums) if sums else np.array([])
#     if tfce:
#         # each point gets treated independently
#         clusters = np.arange(x.size)
#         if adjacency is None or adjacency is False:
#             if x.ndim == 1:
#                 # slices
#                 clusters = [slice(c, c + 1) for c in clusters]
#             else:
#                 # boolean masks (raveled)
#                 clusters = [(clusters == ii).ravel() for ii in range(len(clusters))]
#         else:
#             clusters = [np.array([c]) for c in clusters]
#         sums = scores
#     return clusters, sums

In [87]:
elec_permuted_data[0]

{'results_betas': array([[-0.08016932, -0.07956401, -0.07894769, ...,  0.15106719,
          0.15045894,  0.14985837],
        [-0.0703888 , -0.06882809, -0.06725523, ...,  0.06779217,
          0.06671937,  0.06565613],
        [-0.04358736, -0.04087346, -0.03814483, ...,  0.00626315,
          0.00505182,  0.00385288],
        ...,
        [ 0.09267326,  0.07968578,  0.06926128, ..., -0.08249287,
         -0.07606   , -0.06480062],
        [ 0.00348834,  0.04125545,  0.07354837, ..., -0.15519246,
         -0.17495711, -0.18154361],
        [-0.00987438,  0.02467814,  0.05528575, ..., -0.09622962,
         -0.07429381, -0.07025493]]),
 'tstat_observed': array([[-0.79024462, -0.78389939, -0.77745869, ...,  1.5682283 ,
          1.56208862,  1.55605349],
        [-0.65325513, -0.6380954 , -0.6228677 , ...,  0.64852259,
          0.63724497,  0.62610791],
        [-0.40836779, -0.3825791 , -0.35670152, ...,  0.05277846,
          0.0424461 ,  0.03227953],
        ...,
        [ 0.8036269

# ALIE STOPPED HERE 
To-dos: 
1. Format output dataframes
2. FDR correct cluster p values within-subject across electrodes
3. Visualize within-subj significant clusters


## Format Outputs and Filter clusters based on permutation p value
check later - 
https://mne.tools/stable/auto_examples/stats/sensor_regression.html
https://github.com/mne-tools/mne-python/issues/8832

https://github.com/mne-tools/mne-python/blob/maint/1.7/mne/stats/cluster_level.py#L1167-L1262

In [None]:
# store permutation cluster statistics - one cluster statistic (pos or neg) for each permutation 
cluster_perm_matrix_pos = np.zeros((num_channels, num_permutations))
cluster_perm_matrix_neg = np.zeros((num_channels, num_permutations))
# store real cluster max (pos or neg) values as list (1 value/elec)
real_cluster_max_pos    = []
real_cluster_max_neg    = []
# store p value of real cluster maxes (1 value/elec)
real_cluster_pos_pvals  = []
real_cluster_neg_pvals  = []

In [None]:
# start = time.time()
# num_permutations = 1000

# # store permutation cluster statistics - one cluster statistic (pos or neg) for each permutation 
# cluster_perm_matrix_pos = np.zeros((num_channels, num_permutations))
# cluster_perm_matrix_neg = np.zeros((num_channels, num_permutations))
# # store real cluster max (pos or neg) values as list (1 value/elec)
# real_cluster_max_pos    = []
# real_cluster_max_neg    = []
# # store p value of real cluster maxes (1 value/elec)
# real_cluster_pos_pvals  = []
# real_cluster_neg_pvals  = []

# # set number of permutations 


# for c in range(num_channels):
#     # subset single electrode data
#     elec_data = np.squeeze(power_epochs[0]._data[:,c,:,:]) # matrix is num_epochs x num_freqs x num_time
        
#     #### run permutations in parallel, compute permuted matrix within parallel fn 
#     # elec_permuted_data is list of dictionaries of regression output for permutations 
#     elec_permuted_data = Parallel(n_jobs=-1,verbose=12)(
#         delayed(cont_clust_perm_utils.parallel_permutation_tstat_simple_regression)(
#             elec_data,reward_data,tcritical) for p in range(num_permutations)) 
    
#     # Compute p values for positive/negative real clusters:
    
#     # get real data results 
#     elec_results = subj_all_elec_data[c] # list indices of subj_data correspond to c in num channels
#     # extract maximum positive cluster 
#     real_max_pos = elec_results['pos_clust_data']['max_cluster_tstat']
#     # extract maximum negative cluster 
#     real_max_neg = elec_results['neg_clust_data']['max_cluster_tstat']
    
#     #list of surrogate positive cluster stats only 
#     perm_max_pos_clust = [elec_permuted_data[p]['perm_max_pos_clust_stat'] for p in range(num_permutations)]
#     #list of surrogate negative cluster stats only
#     perm_max_neg_clust = [elec_permuted_data[p]['perm_max_neg_clust_stat'] for p in range(num_permutations)]
    
#     # compute p value for positive cluster max 
#     pos_clust_pval = (np.sum(perm_max_pos_clust>real_max_pos)/num_permutations)
#     # compute p value for negative clutster max 
#     neg_clust_pval = (np.sum(perm_max_neg_clust<real_max_neg)/num_permutations)
    
#     # save real max cluster stat and p value for clusters 
#     real_cluster_max_pos.append(real_max_pos)
#     real_cluster_max_neg.append(real_max_neg)
#     real_cluster_pos_pvals.append(pos_clust_pval)
#     real_cluster_neg_pvals.append(neg_clust_pval)
    
#     # save permutation max pos/min cluster statistics 
#     cluster_perm_matrix_pos[c,:] = pos_clust_pval # list of size n perm
#     cluster_perm_matrix_neg[c,:] = neg_clust_pval


# end = time.time()    
# print('{:.4f} s'.format(end-start))

In [None]:
# from skimage.measure import label, regionprops

# l = label(sig_perm_matrix[0,0,:,:], connectivity=2)
# clusters = [i.coords for i in regionprops(l)]

# # visualize the clusters by plotting sig_perm_matrix[0,0,:,:] and drawing a line around the coordinates in clusters 
# plt.imshow(sig_perm_matrix[0,0,:,:], interpolation = 'Bicubic',cmap='Spectral_r', aspect='auto',origin='lower')
# for i in range(len(clusters)):
#     plt.plot(clusters[i][:,1], clusters[i][:,0], 'r')
# plt.show()

In [None]:
# # plot the the t statistic "TFR" and include the channel name for one example PERMUTATION REP
# num_permutations = 1
# for p in range(num_permutations):
#     tstat_data = tstat_perm_matrix[0,p,:,:] # extract beta coefficients 
#     plt.imshow(tstat_data, interpolation = 'Bicubic',cmap='Spectral_r', aspect='auto',origin='lower',vmin=-.5,vmax=.5) 
#     plt.colorbar()
#     plt.xlabel('Time')
#     plt.ylabel('Freq')
#     plt.title('Permuted Data')
#     plt.show()


In [None]:
# # plot a histogram of the cluster_perm_matrix
# # Flatten the array (if it's a 2D array) to get a 1D array
# cluster_perm_matrix = cluster_perm_matrix.flatten()
# # Plotting the histogram
# plt.hist(cluster_perm_matrix, bins=50, color='blue', edgecolor='black')
# # Adding labels and title
# plt.xlabel('Cluster statistic')
# plt.ylabel('Frequency')
# plt.title('Null distribution')
# plt.axvline(x=cluster_observed_matrix[0], color='r', linestyle='--')
# # Display the plot
# plt.show()


In [None]:
# # plot the the t statistic "TFR" and include the channel name for the OBSERVED DATA 
# ch_names = power_epochs[0].ch_names
# ch_names = ch_names[16:89] # clean this up (temp fix - stop saving these for future reference)

# num_channels = 1
# for c in range(num_channels):
#     tstat_data = tstat_observed_matrix[c,:,:] # extract beta coefficients 
#     plt.imshow(tstat_data, interpolation = 'Bicubic',cmap='Spectral_r', aspect='auto',origin='lower',vmin=-.5,vmax=.5) 
#     plt.colorbar()
#     plt.xlabel('Time')
#     plt.ylabel('Freq')
#     plt.title(f'channel name:{ch_names[c]}')
#     plt.show()

In [None]:
# # Plot "TFR" of beta coefficients for regressor of interest
# # for each channel plot the beta coefficients and include channel name 
# # ch_names = power_epochs[0].ch_names
# # ch_names = ch_names[16:89] # clean this up (temp fix - stop saving these for future reference)

# num_channels = 1
# for c in range(num_channels):
#     beta_data = results_matrix[c,:,:] # extract beta coefficients 
#     plt.imshow(beta_data, interpolation = 'Bicubic',cmap='Spectral_r', aspect='auto',origin='lower',vmin=-.5,vmax=.5) 
#     plt.colorbar()
#     plt.xlabel('Time')
#     plt.ylabel('Freq')
#     plt.title(f'channel name:{ch_names[c]}')
#     plt.show()

In [None]:
# tstat_data = sig_tstat_observed_pos[c,:,:] # extract beta coefficients 
# plt.imshow(tstat_data, interpolation = 'Bicubic',cmap='Spectral_r', aspect='auto',origin='lower',vmin=-.5,vmax=.5) 
# plt.colorbar()
# plt.xlabel('Time')
# plt.ylabel('Freq')
# plt.title(f'channel name:{ch_names[c]}')
# plt.show()

In [None]:
# beta_data = results_betas[c,:,:] # extract beta coefficients 
# plt.imshow(beta_data, interpolation = 'Bicubic',cmap='Spectral_r', aspect='auto',origin='lower',vmin=-.5,vmax=.5) 
# plt.colorbar()
# plt.xlabel('Time')
# plt.ylabel('Freq')
# plt.title(f'channel name:{ch_names[c]}')
# plt.show()

## Step 4: Group-level Significance Testing
- make distribution of surrogate clusters across electrodes 
- determine significance threshold (95% cluster statistic)
- identify significant clusters across electrodes 

*** To decide: SET CLUSTER THRESHOLD WITHIN SUBJ/ WITHIN ROI/ACROSS SUBJ??
