# NeuroCluster:
<font size= 4> Non-parametric cluster-based permutation testing to identify neurophysiological encoding of continuous variables with time-frequency resolution

Authors: Christina Maher & Alexandra Fink-Skular \
Updated: 06/24/2024 by AFS

In [1]:
import numpy as np
import pandas as pd
import mne
from glob import glob
from scipy.stats import zscore, t, linregress, ttest_ind, ttest_rel, ttest_1samp 
import os 
import re
import h5io
import pickle 
import time 
import datetime 
from joblib import Parallel, delayed
import statsmodels.api as sm 
from scipy.ndimage import label 
import statsmodels.formula.api as smf
import tqdm
import operator
import matplotlib.pyplot as plt


import warnings
warnings.filterwarnings('ignore')

# keep this so we can use our respective paths for testing
current_user = 'alie'

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
date = datetime.date.today().strftime('%m%d%Y')
print(date)

06252024


In [6]:
if current_user == 'christina':
    base_dir = '/Users/christinamaher/Documents/GitHub/NeuroCluster/scripts/'
    data_dir = '/Users/christinamaher/Documents/GitHub/NeuroCluster/'
    tfr_dir  = f'{data_dir}tfr/'
    anat_dir = f'{data_dir}anat/'
elif current_user == 'alie':
    # base_dir = '/hpc/users/finka03/NeuroCluster/NeuroCluster/'
    # swb_dir  = '/sc/arion/projects/guLab/Alie/SWB/'
    # tfr_dir  = f'{swb_dir}ephys_analysis/data/'
    # beh_dir  = f'{swb_dir}swb_behav_models/data/behavior_preprocessed/'
    # anat_dir = f'{swb_dir}ephys_analysis/recon_labels/'
    # save_dir = f'{base_dir}/data/'
    
    base_dir = '/Users/alexandrafink/Documents/GraduateSchool/SaezLab/NeuroCluster/NeuroCluster/NeuroCluster/scripts/'
    data_dir = '/Users/alexandrafink/Documents/GraduateSchool/SaezLab/SWB/'
    tfr_dir  = f'{data_dir}ephys_analysis/data/'
    beh_dir  = f'{data_dir}behavior_analysis/behavior_preprocessed/'
    anat_dir = f'{data_dir}anat_recons/'

In [15]:
# load functions 
import sys
sys.path.append(f'{base_dir}')
# sys.path.append(f'{base_dir}scripts/')

from tfr_cluster_test import *
from helper_utils import *
# from plotting_utils import * 

# Step 1: Format Input Data (Currently within-subject)
- neural input: np.array (n_channels x n_epochs x n_freqs x n_times)
- regressor data: np.array (numpy array: n_epochs x n_features)

In [8]:
# load epoched data for single subj
if current_user == 'alie':
    permute_var = 'decisionCPE'
    subj_id     = 'MS002'   
    power_epochs = mne.time_frequency.read_tfrs(fname=f'{tfr_dir}{subj_id}/{subj_id}_CpeOnset-tfr.h5')[0]
elif current_user == 'christina':
    permute_var = 'ev_zscore'
    subj_id     = 'MS009'   
    power_epochs = mne.time_frequency.read_tfrs(fname=f'{tfr_dir}/{subj_id}_tfr.h5')[0]

Reading /Users/alexandrafink/Documents/GraduateSchool/SaezLab/SWB/ephys_analysis/data/MS002/MS002_CpeOnset-tfr.h5 ...
Adding metadata with 19 columns


In [None]:
# set ROI for single ROI anaylsis 
if current_user == 'alie':
#     roi = 'ains'
    # set all variables included mutliple regression 
    multi_reg_vars = ['GambleChoice','TrialEV','TotalProfit','decisionCPE']
    # set main variable of interest for permutations 
    permute_var = 'decisionCPE'
    # load subj behavior data 
#     beh_df = pd.read_csv(f'{beh_dir}{subj_id}_task_data')
    beh_df = power_epochs.metadata.copy()
    # beh_df['subj_id'] = subj_id
    # add TrialEV to df
    beh_df['TrialEV'] = beh_df.GambleEV - beh_df.SafeBet
    # clean subj dataframe from fail trials/nan values in vars of interest     
    # beh_df = beh_df[(beh_df.GambleChoice=='gamble')|(beh_df.GambleChoice=='safe')]
#     beh_df = beh_df[(beh_df.Outcome=='good')|(beh_df.Outcome=='bad')]
    
    # zscore continuous variables 
    beh_df[multi_reg_vars[1:]] = pd.DataFrame({f'{var}':zscore(beh_df[var])  for var in multi_reg_vars[1:]})
    # format final beh_df
    beh_df = beh_df[multi_reg_vars].reset_index(drop=True) 
    # convert choice to categorical variable
    beh_df['GambleChoice'] = beh_df['GambleChoice'].astype('category')

elif current_user == 'christina':
    beh_df = prepare_regressor_df(power_epochs)
    ## new function for getting elecs in ROI
    roi = ['lpfc','ofc']
    roi_subj_elecs = prepare_anat_dic(roi, f'{anat_dir}master_labels.csv')
    roi_subj_elecs


In [20]:
#### class TFR_Cluster_Test dev + debugging

if current_user == 'alie':

    # subset single electrode tfr data + behav data
    dev_ch_idx = power_epochs.ch_names.index('laims2-laims3')
    ch_name    = 'laims2-laims3'
    tfr_data   = np.squeeze(power_epochs._data[:,dev_ch_idx,:,:].copy())
    predictor_data = beh_df.copy()
    
    # predictor_data = predictor_data.drop(columns='subj_id')

elif current_user == 'christina':
    
        # subset single electrode tfr data + behav data
        # predictor_data = predictor_data.drop(columns=['condition','chosen_shape_current_trial','chosen_color_current_trial','chosen_shape_previous_trial','chosen_color_previous_trial','ev'])
        tfr_data = np.squeeze(power_epochs._data[:,0,:,:].copy())
        ch_name = power_epochs.info['ch_names'][0]

In [21]:
predictor_data.columns

Index(['subj_id', 'GambleChoice', 'TrialEV', 'TotalProfit', 'decisionCPE'], dtype='object')

## Step 2: Find Real Clusters
- Use TFRClusterTest class code to run multivariate regression
- Allows for multiple regression implementation and pixel paralellization, so with more speed improvements will ultimately be worth it.

In [23]:
cluster_test  = TFR_Cluster_Test(tfr_data,predictor_data,permute_var,ch_name,alternative='two-sided')
cluster_test

<tfr_cluster_test.TFR_Cluster_Test at 0x7f9cda4e7ac0>

In [24]:
betas, tstats = cluster_test.tfr_regression()


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    4.0s
[Parallel(n_jobs=-1)]: Done  64 tasks      | elapsed:    4.2s
[Parallel(n_jobs=-1)]: Done 1072 tasks      | elapsed:    5.9s
[Parallel(n_jobs=-1)]: Done 3088 tasks      | elapsed:    8.3s
[Parallel(n_jobs=-1)]: Done 5680 tasks      | elapsed:   11.3s
[Parallel(n_jobs=-1)]: Done 8848 tasks      | elapsed:   14.8s
[Parallel(n_jobs=-1)]: Done 12592 tasks      | elapsed:   21.2s
[Parallel(n_jobs=-1)]: Done 16912 tasks      | elapsed:   26.0s
[Parallel(n_jobs=-1)]: Done 21808 tasks      | elapsed:   31.9s
[Parallel(n_jobs=-1)]: Done 27280 tasks      | elapsed:   41.4s
[Parallel(n_jobs=-1)]: Done 33328 tasks      | elapsed:   53.0s
[Parallel(n_jobs=-1)]: Done 39952 tasks      | elapsed:  1.0min
[Parallel(n_jobs=-1)]: Done 45030 out of 45030 | elapsed:  1.1min finished


In [None]:
max_cluster_data  = cluster_test.max_tfr_cluster(tstats,output='all')


In [None]:
    results = Parallel(n_jobs=n_jobs, verbose=5)(delayed(min_fn)(fit_fn, param_values, (subj_df), param_bounds) for param_values in param_combo_guesses)


In [None]:
### NeuroCluster single electrode workflow: 

# Step 1: Create TFR_Cluster_Test Object
cluster_test  = TFR_Cluster_Test(tfr_data,predictor_data,permute_var,ch_name,alternative='two-sided')

# Step 2: Run TFR regression to extract beta coefficients for predictor of interest (permute_var) & tstats for each pixel in TFR
betas, tstats = cluster_test.tfr_regression()

# Step 3: Find largest cluster(s) and return the max cluster statistic(s) and cluster's  frequencies x times indices
max_cluster_data  = cluster_test.max_tfr_cluster(tstats,output='all')

# Step 4: Create null distribution of maximum cluster statistics from permuted data
null_cluster_distribution = cluster_test.compute_null_cluster_stats(max_cluster_data,num_permutations=10)

# Step 5: Use null cluster statistic distribution from permutations to compute non-parametric p value 
cluster_pvalue = cluster_test.cluster_significance_test(null_cluster_distribution,alpha=0.05) #compute_cluster_pvalue cluster_significance_test



In [None]:
class TFR_Cluster_Test(object):
    """ 
    Single-electrode neurophysiology object class to identify time-frequency resolved neural activity correlates of complex behavioral variables using non-parametric 
    cluster-based permutation testing.   

    Attributes
    ----------
    tfr_data       : (np.array) Single electrode tfr data matrix. Array of floats (n_epochs,n_freqs,n_times). 
    tfr_dims       : (tuple) Frequency and time dimensions of tfr_data. Tuple of integers (n_freq,n_times). 
    ch_name        : (str) Unique electrode identification label. String of characters.
    predictor_data : (pd.DataFrame) Regressors from task behavior with continuous, discreet, or categorical data. DataFrame of (rows=n_epochs,columns=n_regressors). 
    permute_var    : (str) Column label for primary regressor of interest.
      
    Methods
    ----------
    **To-do: fill in methods info
    """

    def __init__(self, tfr_data, predictor_data, permute_var, ch_name, **kwargs):
        """
        Args:
        - tfr_data       : (np.array) Single electrode tfr data matrix. Array of floats (n_epochs,n_freqs,n_times). 
        - predictor_data : (pd.DataFrame) Task-based regressor data with dtypes continuous/discreet(int64/float) or categorical(pd.Categorical). DataFrame of (n_epochs,n_regressors).
        - permute_var    : (str) Column label for primary regressor of interest. Array of 1d integers or floats (n_epochs,).
        - ch_name        : (str) Unique electrode identification label. String of characters.  
        - **kwargs       : (optional) alternative, alpha, cluster_shape
        """

        self.tfr_data       = tfr_data  # single electrode tfr data
        self.predictor_data = predictor_data # single subject behav data
        self.tfr_dims       = self.tfr_data.shape[1:] # time-frequency dims of electrode data (n_freqs x n_times)
        self.permute_var    = permute_var # variable to permute in regression model 
        self.ch_name        = ch_name # channel name for single electrode tfr data

    def tfr_regression(self):
        """
        Performs univariate or multivariate OLS regression across tfr matrix for all pixel-level time-frequency power data and task-based predictor variables. Regressions are parallelized across pixels.

        Returns:
        - tfr_betas  : (np.array) Matrix of beta coefficients for predictor of interest for each pixel regression. Array of (n_freqs,n_times). 
        - tfr_tstats : (np.array) Matrix of t-statistics from coefficient estimates for predictor of interest for each pixel regression. Array of (n_freqs,n_times). 
        """
        
        # Prepare arguments for parallelization`using tfr matrix indices converted to list of tuples (freq x power)
        pixel_args = [self.make_pixel_df(self.tfr_data[:,freq_idx,time_idx]) for freq_idx,time_idx in self.expand_tfr_indices()]
        
        # run pixel permutations in parallel 
        expanded_results = Parallel(n_jobs=-1, verbose=5)(
                        delayed(self.pixel_regression)(args)
                            for args in pixel_args)      
        
        # preallocate np arrays for betas + tstats
        tfr_betas  = np.zeros((self.tfr_dims))
        tfr_tstats = np.zeros((self.tfr_dims))

        # expanded_results is a list of tuples (beta,tstat) for every pixel 
        for count,(freq_idx,time_idx) in enumerate(self.expand_tfr_indices()):
            tfr_betas[freq_idx,time_idx]  = expanded_results[count][0]
            tfr_tstats[freq_idx,time_idx] = expanded_results[count][1]
        
        return tfr_betas, tfr_tstats

    def pixel_regression(self,pixel_df):
        """
        Fit pixel-wise univariate or multivariate OLS regression model and extract beta coefficient and t-statistic for predictor of interest (self.permute_var). 

        Args:
        - pixel_df   : (pd.DataFrame) Pixel-level regression dataframe with power epochs data and behavioral regressors. DataFrame of (n_epochs, n_regressors+1). 
                                      Regressor column data must be continuous(dtype=float), discrete(dtype=int), or categorical(dtype=pd.Categorical). 
        
        Returns:
        - pixel_beta : (np.array) Beta coefficient for predictor of interest from pixel-wise regression. Array of 1d float (1,)
        - pixel_tval : (np.array) Observed t-statistic for predictor of interest from pixel-wise regression. Array of 1d float (1,)
        """

        # formula should be in form 'col_name + col_name' if col is categorical then should be 'C(col_name)'  
        formula    = 'pow ~ 1 + ' + (' + ').join(['C('+col+')' if pd.api.types.is_categorical_dtype(pixel_df[col])
                                    else col for col in pixel_df.columns[~pixel_df.columns.isin(['pow'])].tolist()])
        
        pixel_model = smf.ols(formula,pixel_df,missing='drop').fit()

        return (pixel_model.params[self.permute_var],pixel_model.tvalues[self.permute_var])

    def max_tfr_cluster(self,tfr_tstats,alternative='two-sided',output='all',clust_struct=np.ones(shape=(3,3))):

        """
        Identify time-frequency clusters of neural activity that are significantly correlated with the predictor of interest (self.permute_var). Clusters are identified 
        from neighboring pixel regression t-statistics for the predictor of interest that exceed the tcritical threshold from the alternate hypothesis. 

        Args:
        - tfr_tstats       : (np.array) Pixel regression tstatistic from coefficient estimates for predictor of interest. Array of floats (n_freqs,n_times). 
        - alternative      : (str) Alternate hypothesis for t-test. Must be 'two-sided','greater', or 'less'. Default is 'two-sided'. 
        - output           : (str) Output format for max cluster statistics. Must be 'all', 'cluster_stat', or 'freq_time'. Default is 'all'.
        - clust_struct     : (np.array) Binary matrix to specify cluster structure for scipy.ndimage.label. Array of (3,3). 
                                        Default is np.ones.shape(3,3), to allow diagonal cluster pixels (Not the scipy.ndimage.label default).
                                        https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.label.html

        Returns:
        - max_cluster_data : (list) Beta coefficient for predictor of interest for each pixel regression. List (len=2 if 'two-sided') of dict(s).
                                    If output = 'all', return dictionary of maximum cluster statistic ('cluster_stat' : sum of pixel t-statistics), 
                                    cluster frequency indices ('freq_idx':(freq_x,freq_y)), and cluster time indices ('time_idx':(time_x,time_y)). 
                                    If output = 'cluster_stat', return only [{cluster_stat}]. If output = 'freq_time', return only {freq_idx,time_idx}
                                    ** If no clusters are found, max_cluster_data = {[]}
        """
        
        max_cluster_data = []
        # Create binary matrix from tfr_tstats by thresholding pixel t-statistics by tcritical. (1 = pixel t-statistic exceeded tcritical threshold)
        for binary_mat in self.threshold_tfr_tstat(tfr_tstats,alternative):
            # test whether there are any pixels above tcritical threshold
            if np.sum(binary_mat) != 0: 
                # Find clusters of pixels with t-statistics exceeding tcritical
                cluster_label, num_clusters = label(binary_mat,clust_struct)
                # use argmax to find index of largest absolute value of cluster t statistic sums 
                max_label = np.argmax([np.abs(np.sum(tfr_tstats[cluster_label==i+1])) for i in range(num_clusters)])
                # use max_label index to compute cluster tstat sum (without absolute value)
                max_clust_stat = np.sum(tfr_tstats[cluster_label==max_label+1])
                # find 
                clust_freqs, clust_times = [(np.min(arr),np.max(arr)) for arr in np.where(cluster_label == max_label)]

                if output == 'all':
                    max_cluster_data.append({'cluster_stat':max_clust_stat,'freq_idx':clust_freqs,'time_idx':clust_times})
                elif output == 'cluster_stat':
                    max_cluster_data.append({'cluster_stat':max_clust_stat})
                elif output == 'freq_time':
                    max_cluster_data.append({'freq_idx':clust_freqs,'time_idx':clust_times})

            else: # if there is no cluster, keep max_cluster_data empty list
                continue

        return max_cluster_data

    def compute_tcritical(self,alternative ='two-sided',alpha=0.05):
        """
        Calculate critical t-values for regression model.
        
        Args:
        - alternative : (str) Alternate hypothesis for t-test. Must be 'two-sided','greater', or 'less'. Default is 'two-sided'.
        - alpha       : (float) Significance level. Default is 0.05.

        Returns:
        - tcritical   : (float) Critical t-statistic for hypothesis test. Positive value when alternative = 'two-sided' or 'greater'. Negative when alternative = 'less'. 
        """

        # Set number of tails for t-tests using 'alternative' parameter input string. 
            # tails = 2 if alternative = 'two-sided' (two tailed hypothesis test)
            # tails = 1 if alternative = 'greater' or 'less' (one tailed hypothesis test)
        tails = len(alternative.split('-')) 

        # Calculate degrees of freedom (N-k-1) 
        deg_free = float(len(self.predictor_data)-len(self.predictor_data.columns)-1) #### predictor data must only include regressors in columns

        # Return tcritical from t-distribution. Significance level is alpha/2 for two tailed hypothesis tests (alternative = 'two-sided').
        return (t.ppf(1-(alpha/tails),deg_free) if alternative != 'less' else np.negative(t.ppf(1-(alpha/tails),deg_free)))

    def threshold_tfr_tstat(self,tfr_tstats,alternative='two-sided'):
        """
        Threshold tfr t-statistic matrix using tcritical.

        Args:
        - tfr_tstats  : (np.array) Matrix of t-statistics from pixel-wise regressions. Array of floats (n_freqs, n_times). 
        - alternative : (str) Type of hypothesis test for t-distribution. Must be 'two-sided', 'greater', 'less'. Default is 'two-sided'.

        Returns:
        - binary_mat  : (np.array) Binary matrix results of pixel-wise t-tests. Pixel = 1 when tstatistic > tcritical, else pixel = 0. List of array(s) (n_freqs, n_times).
        """

        if alternative == 'two-sided': 
            return [(tfr_tstats>self.compute_tcritical()).astype(int), (tfr_tstats<np.negative(self.compute_tcritical()).astype(int))]

        elif alternative == 'greater':
            return [(tfr_tstats>self.compute_tcritical(tails=1,alternative='greater')).astype(int)]

        elif alternative == 'less':
            return [(tfr_tstats<self.compute_tcritical(tails=1,alternative='less')).astype(int)] 
        else: 
            raise ValueError('Alternative hypothesis must be two-sided, greater, or less not {alternative}')
    
    def expand_tfr_indices(self):
        """
        Create list of tfr pixel indices for parallelized tfr_regression.

        Returns:
        - iter_tup : (list) Time-frequency indices for all pixels in tfr_data. List of tuples [(freq_x_index,freq_y_index),(time_x_index,time_y_index)]        
        """

        return list(map(tuple,np.unravel_index(np.dstack(([*np.indices(self.tfr_dims)])),np.product(self.tfr_dims)
                            )[0].reshape(np.product(np.dstack(([*np.indices(self.tfr_dims)])).shape[:2]),-1)))

    def make_pixel_df(self,epoch_data):
        """
        Format input data for pixel regression.  input data. Make pixel-level (frequency x timepoint) dataframe. Add tfr power data for single pixel to predictor_df. 

        Args:
        - epoch_data : (str) Alternate hypothesis for t-test. Must be 'two-sided','greater', or'less'. Default is 'two-sided'. Array of 1d integers or floats (n_epochs,).
        
        Returns:
        - pixel_df   : (pd.DataFrame) Pixel regression input dataframe containing power epochs and task-based behavioral regressor data (dtype=int/float/pd.Categorical). 
                                      DataFrame of (n_epochs, n_regressors+1). 
        """
        return self.predictor_data.assign(pow=epoch_data)




##### start permutation functions here

    # def permuted_tfr_regression(self):
    #     """
    #     Run permuted tfr regression 

    #     """

    #     iter_tup = self.expand_tfr_indices()

    #     # either precompute pixel_args before passing to parallel, or run all together in loop. - check later!! 
    #     perm_args = [self.make_pixel_df(self.tfr_data[:,freq_idx,time_idx]) for freq_idx,time_idx in iter_tup]
        
    #     # run pixel permutations in parallel 
    #     permuted_results = Parallel(n_jobs=-1, verbose=5)(
    #                     delayed(self.pixel_regression)(args)
    #                         for args in perm_args)      
        
    #     # preallocate np arrays for betas + tstats
    #     tfr_betas = np.zeros((self.tfr_dims))
    #     tfr_tstats = np.zeros((self.tfr_dims))

    #     # expanded_results is a list of tuples (beta,tstat) for every pixel 
    #     for count,(freq_idx,time_idx) in enumerate(iter_tup):
    #         tfr_betas[freq_idx,time_idx] = permuted_results[count][0]
    #         tfr_tstats[freq_idx,time_idx] = permuted_results[count][1]
        
    #     # return tfr_betas, tfr_tstats
        

    #     # permute predictor data
    #     self.predictor_data = self.permute_predictor() # Permute predictor variable
        
    #     # Run regression on permuted data + extract tstats only
    #     perm_tstats = self.tfr_regression(permutation=True)  
    #     # extract cluster statistics for permutation
    #     perm_cluster_stat = self.max_tfr_cluster(perm_tstats, output='cluster_stat')  # Get cluster statistics
    #     del perm_tstats  # Delete objects to free up memory
    #     # check if any clusters detected in permuted tfr regression 
    #     if len(perm_cluster_stat) !=0 :
    #         return perm_cluster_stat
    #     else: # return list of empty cluster stats with length of alternative hypothesis ('two_sided'=length 2, less or greater = length 1)
    #         return [{'cluster_stat':0}]*len(alternative.split('_'))

    # def permuted_tfr_regression(self):
    #     """
    #     Run permuted tfr regression 
    #     """

    #     # permute predictor data
    #     self.predictor_data = self.permute_predictor() # Permute predictor variable

    #     # Run regression on permuted data + extract tstats only
    #     perm_tstats = self.tfr_regression(permutation=True)  
    #     # extract cluster statistics for permutation
    #     perm_cluster_stat = self.max_tfr_cluster(perm_tstats, output='cluster_stat')  # Get cluster statistics
    #     del perm_tstats  # Delete objects to free up memory
    #     # check if any clusters detected in permuted tfr regression 
    #     if len(perm_cluster_stat) !=0 :
    #         return perm_cluster_stat
    #     else: # return list of empty cluster stats with length of alternative hypothesis ('two_sided'=length 2, less or greater = length 1)
    #         return [{'cluster_stat':0}]*len(alternative.split('_'))



    # def permute_predictor(self):
    #     """
    #     Permute predictor variable of interest for permutation testing.
    #     """

    #     permuted_predictor_data = self.predictor_data.copy()
    #     permuted_predictor_data[self.permute_var] = np.random.permutation(permuted_predictor_data[self.permute_var].values)

    #     return permuted_predictor_data
    

    # def cluster_significance_test(self, null_distribution,alpha=0.05,alternative='two-sided'):
    #     """
    #     Compute non-parametric pvalue from cluster permutation data 

    #             - alpha (float): Significance level. Default is 0.05.

    #     """
        
    #     return cluster_pvalue

Current to-dos: 

- reformat max_cluster_data freq/time indices returns
- permutation class functions 
- speed:
    https://joblib.readthedocs.io/en/latest/parallel.html

In [None]:
# from joblib import Parallel, delayed, parallel_config

# with parallel_config(backend="loky", inner_max_num_threads=2):
#     results = Parallel(n_jobs=4)(delayed(func)(x, y) for x, y in data)

# with parallel_config(backend='custom', endpoint='http://compute',
#                      api_key='42'):
#     Parallel()(delayed(some_function)(i) for i in range(10))

In [None]:
# import multiprocessing

# import numpy as np

# def parallel_apply_along_axis(func1d, axis, arr, *args, **kwargs):
#     """
#     Like numpy.apply_along_axis(), but takes advantage of multiple
#     cores.
#     """        
#     # Effective axis where apply_along_axis() will be applied by each
#     # worker (any non-zero axis number would work, so as to allow the use
#     # of `np.array_split()`, which is only done on axis 0):
#     effective_axis = 1 if axis == 0 else axis
#     if effective_axis != axis:
#         arr = arr.swapaxes(axis, effective_axis)

#     # Chunks for the mapping (only a few chunks):
#     chunks = [(func1d, effective_axis, sub_arr, args, kwargs)
#               for sub_arr in np.array_split(arr, multiprocessing.cpu_count())]

#     pool = multiprocessing.Pool()
#     individual_results = pool.map(unpacking_apply_along_axis, chunks)
#     # Freeing the workers:
#     pool.close()
#     pool.join()

#     return np.concatenate(individual_results)

# def unpacking_apply_along_axis(all_args):
#     """…"""
#     (func1d, axis, arr, args, kwargs) = all_args
#     …

# https://stackoverflow.com/questions/45526700/easy-parallelization-of-numpy-apply-along-axis

In [None]:
# https://stackoverflow.com/questions/76339092/convert-joblib-parallel-to-multiprocessing-pool

In [None]:
# from multiprocessing import Pool
# import numpy as np

# def my_function(x):
#     pass     # do something and return something

# if __name__ == '__main__':    
#     X = np.arange(6).reshape((3,2))
#     pool = Pool(processes = 4)
#     results = pool.map(my_function, map(lambda x: x, X))
#     pool.close()
#     pool.join()

# https://stackoverflow.com/questions/16468717/iterating-over-numpy-matrix-rows-to-apply-a-function-each

# https://stackoverflow.com/questions/8079061/function-application-over-numpys-matrix-row-column

# https://docs.python.org/3/library/itertools.html#itertools.filterfalse

# https://medium.com/pythoneers/vectorization-in-python-an-alternative-to-python-loops-2728d6d7cd3e

# https://realpython.com/python-map-function/

In [None]:
# https://medium.com/pythoneers/vectorization-in-python-an-alternative-to-python-loops-2728d6d7cd3e
# https://realpython.com/python-map-function/

# https://joblib.readthedocs.io/en/latest/auto_examples/parallel_memmap.html#sphx-glr-auto-examples-parallel-memmap-py

In [None]:
# # https://medium.com/@nirmalya.ghosh/13-ways-to-speedup-python-loops-e3ee56cd6b73

# def test_08_v1(n):
#   # Improved version
#   # (Efficiently calculates the nth Fibonacci
#   # number using a generator)
#   a, b = 0, 1
#   for _ in range(n):
#     yield a
#     a, b = b, a + b


#     def some_function_X(x):
#   # This would normally be a function containing application logic
#   # which required it to be made into a separate function
#   # (for the purpose of this test, just calculate and return the square)
#   return x**2

# def test_09_v0(numbers):
#   # Baseline version (Inefficient way)
#   output = []
#   for i in numbers:
#     output.append(some_function_X(i))

#   return output

# def test_09_v1(numbers):
#   # Improved version
#   # (Using Python's built-in map() function)
#   output = map(some_function_X, numbers)
#   return output


# def test_12_v0(numbers):
#   # Baseline version (Inefficient way)
#   filtered_data = []
#   for i in numbers:
#     filtered_data.extend(list(
#         filter(lambda x: x % 5 == 0,
#                 range(1, i**2))))
  
#   return filtered_data

# from itertools import filterfalse



# def test_12_v1(numbers):
#   # Improved version
#   # (using filterfalse)
#   filtered_data = []
#   for i in numbers:
#     filtered_data.extend(list(
#         filterfalse(lambda x: x % 5 != 0,
#                     range(1, i**2))))
    
#     return filtered_data

In [None]:
# Parallel(n_jobs=2, prefer="threads")(
#     delayed(sqrt)(i ** 2) for i in range(10))


# from joblib import parallel_config
# with parallel_config(backend='threading', n_jobs=2):
#    Parallel()(delayed(sqrt)(i ** 2) for i in range(10))

#    shared_set = set()
# def collect(x):
#    shared_set.add(x)

# Parallel(n_jobs=2, require='sharedmem')(
#     delayed(collect)(i) for i in range(5))
# [None, None, None, None, None]


# with Parallel(n_jobs=2) as parallel:
#    accumulator = 0.
#    n_iter = 0
#    while accumulator < 1000:
#        results = parallel(delayed(sqrt)(accumulator + i ** 2)
#                           for i in range(5))
#        accumulator += sum(results)  # synchronization barrier
#        n_iter += 1

# (accumulator, n_iter) 
# # https://joblib.readthedocs.io/en/latest/parallel.html

In [None]:
# #import Pool
# from multiprocessing import Pool
# #Define a worker — a function which will be executed in parallel
# def worker(x):
#  return x*x
# #Assuming you want to use 3 processors
# num_processors = 3
# #Create a pool of processors
# p=Pool(processes = num_processors)
# #get them to work in parallel
# output = p.map(worker,[i for i in range(0,3)])
# print(output)

# from multiprocessing import Pool
# import workers
# if __name__ ==  '__main__': 
#  num_processors = 3
#  p=Pool(processes = num_processors)
#  output = p.map(workers.worker,[i for i in range(0,3)])
#  print(output)
# https://medium.com/@grvsinghal/speed-up-your-code-using-multiprocessing-in-python-36e4e703213e
# https://medium.com/@grvsinghal/speed-up-your-python-code-using-multiprocessing-on-windows-and-jupyter-or-ipython-2714b49d6fac

In [None]:
# # https://stackoverflow.com/questions/42220458/what-does-the-delayed-function-do-when-used-with-joblib-in-python

# import joblib

# @joblib.delayed
# def getHog(image):
#     """Some time-consuming function on an image"""
#     ...

# # Running this in parallel
# with joblib.Parallel(backend="loky", n_jobs=8) as parallel:
#     result = parallel(getHog(img) for img in allImages)

In [None]:
# def my_function(x):
#     """The function you want to compute in parallel."""
#     x += 1
#     return x

# import multiprocessing
# import workers

# pool = multiprocessing.Pool()
# results = pool.map(workers.my_function, [1,2,3,4,5,6])
# # print(results)
# # 
# # https://stackoverflow.com/questions/23641475/multiprocessing-working-in-python-but-not-in-ipython/23641560#23641560



In [None]:
power_epochs.ch_names

In [None]:
# args = [(np.squeeze(power_epochs._data[:,ch_ix,:,:].copy()),
#                             predictor_data,permute_var,ch_name) for ch_ix, ch_name
#                             in enumerate(power_epochs.ch_names)]

In [None]:
# https://joblib.readthedocs.io/en/latest/parallel.html

In [None]:
### TEST PERMUTATIONS 
start = time.time() # start timer


args = [(np.squeeze(power_epochs._data[:,ch_ix,:,:].copy()),
         predictor_data,permute_var,ch_name) for ch_ix, ch_name
         in enumerate(power_epochs.ch_names)]

# Perform permutations in parallel
elec_Cluster_objs = Parallel(n_jobs=-1, verbose=5)(
delayed(TFR_Cluster_Test)(*a)
for a in args)

#         pickle.dump(elec_permuted_data, open(f'{results_dir}{subj_id}_{c}_perm_clusters{date}.pkl', "wb")) 

end = time.time()    
print('{:.4f} s'.format(end-start)) # print time elapsed for computation (approx 4 seconds per permutation)

In [None]:
start = time.time() # start timer

## run simple linear regression on electrodes in parallel to speed up computation - I did this for just the subset OFC electrodes.
subj_all_elec_data = Parallel(n_jobs=-1,verbose=5)(
    delayed(elec_obj.tfr_cluster_results())(num_permutations=None
        ) for elec_obj in elec_Cluster_objs)

end = time.time()    
print('{:.4f} s'.format(end-start)) # print time elapsed for computation (approx 20 seconds per channel)

# save subj cluster data for all electrodes
# pickle.dump(subj_all_elec_data, open(f'{save_dir}{subj_id}_all_elec_real_clusters.pkl', "wb")) 



In [None]:
# start = time.time() # start timer

# ## run simple linear regression on electrodes in parallel to speed up computation - I did this for just the subset OFC electrodes.
# subj_all_elec_data = Parallel(n_jobs=-1,verbose=5)(
#     delayed(TFR_Cluster_Test.tfr_cluster_results())(
#         np.squeeze(power_epochs._data[:,ch_ix,:,:].copy()),predictor_data,permute_var,ch_name
#         ) for ch_ix, ch_name in enumerate(power_epochs.ch_names))

# end = time.time()    
# print('{:.4f} s'.format(end-start)) # print time elapsed for computation (approx 20 seconds per channel)

# # save subj cluster data for all electrodes
# pickle.dump(subj_all_elec_data, open(f'{save_dir}{subj_id}_all_elec_real_clusters.pkl', "wb")) 



In [None]:
cluster_test  = TFR_Cluster_Test(tfr_data,beh_df,permute_var,'laims2-laims3')
betas, tstats = cluster_test.tfr_regression()
cluster_data  = cluster_test.max_tfr_cluster(tstats)

cluster_data

In [None]:
test_perm = TFR_Cluster_Test(tfr_data,beh_df,permute_var,'laims2-laims3').tfr_cluster_results(num_permutations=10)
test_perm

In [None]:
permute_var

In [None]:
tfr_data.shape

In [None]:
# https://docs.python.org/3/tutorial/classes.html#generators

# Step 3: Extract Surrogate Clusters from Pixel-wise Permutation
- For loop for each electrode- 
- Run each permutation (1000x) in parallel within electrode loop
- Calculate max cluster p value for each +/- cluster for each electrode
- Save permuted cluster statistics for each electrode 

DEPENDENCIES: permuted_tfr_cluster_test, tfr_cluster_test

In [None]:
test_perm = TFR_Cluster_Test(tfr_data,beh_df,permute_var,'laims2-laims3').tfr_cluster_results(num_permutations=10)
test_perm

In [None]:
test_perm

In [None]:
perm_cluster_results = TFR_Cluster_Test(tfr_data,beh_df,permute_var,'laims2-laims3').tfr_cluster_results(num_permutations=200)
perm_cluster_results

In [None]:
perm_cluster_results

In [None]:
# test_full_perm = TFR_Cluster_Test(tfr_data,beh_df,permute_var,ch_name).tfr_cluster_results(num_permutations=200)
# test_full_perm

In [None]:

class TFR_Cluster_Test(object):
    """Class for time-frequency resolved cluster permutation testing.

    Parameters
    ----------
    tfr_data : array, float, shape(freqs,times,n_epochs)
        Single electrode tfr data 
    predictor_data : DataFrame, shape(n_epochs,n_regressors)
        Single subject behav data 

    """

    def __init__(self, tfr_data, predictor_data, permute_var, ch_name, **kwargs):

        """Constructor for the Environment class
        This function runs every time we create an instance of the class Environment
        To learn more about how constructors work: https://www.udacity.com/blog/2021/11/__init__-in-python-an-overview.html"""

        # "self" is just a convention that binds the attributes and methods of a class with the arguments of a given instance

        self.tfr_data       = tfr_data  # single electrode tfr data
        self.predictor_data = predictor_data # single subject behav data
        self.tfr_dims       = self.tfr_data.shape[1:] # dims of single electrode tfr data (n_freqs x n_times)
        self.permute_var    = permute_var # variable to permute in regression model
        self.ch_name        = ch_name # channel name for single electrode tfr data

    def tfr_cluster_results(self,num_permutations=1000):

        ### run tfr regression + get tstats/betas
        tfr_betas, tfr_tstats = self.tfr_regression()
        max_cluster_data      = self.max_tfr_cluster(tfr_tstats)

        if len(max_cluster_data) == 0:
            print(f'{self.ch_name} has no clusters')

        else: 
            if not num_permutations: # if num_permutations == None (do not run permutations)     
                return max_cluster_data
            else:
                print(f'starting tfr cluster permutation tests for {self.ch_name}')

                perm_start = time.time()
                # run permutations - run regression, get max data 
                permutation_cluster_stats = Parallel(n_jobs=-1, verbose=12)(
                    delayed(self.permuted_tfr_regression)(cluster_test) for _ in range(num_permutations))

                print(f'{self.ch_name} permutation time: ', '{:.2f}'.format(time.time()-perm_start))

                # compute pvalue for real tfr clusters in max_cluster_data from permutation stats 
                for ix,cluster in enumerate(max_cluster_data): 
                    cluster['elec_id'] = self.ch_name
                    perm_counter = 0
                    for perm_result in permutation_cluster_stats:
                        if np.abs(perm_result[ix]['cluster_stat']) > np.abs(cluster['cluster_stat']):
                            perm_counter += 1
                    if perm_counter != 0: 
                        cluster['pvalue'] = perm_counter/num_permutations 
                    else:
                        cluster['pvalue'] = 1/num_permutations # minimum possible p value for number of permutations    
                        max_cluster_data[ix] = cluster

                return max_cluster_data   

    def tfr_regression(self,permutation=False):

        iter_tup = self.expand_tfr_indices()

        # Prepare arguments for the permutation function`
        start = time.time()

        # either precompute pixel_args before passing to parallel, or run all together in loop. - check later!!
        pixel_args = [self.make_pixel_df(self.tfr_data[:,freq_idx,time_idx]) for freq_idx,time_idx in iter_tup]

        # run pixel permutations in parallel 
        expanded_results = Parallel(n_jobs=-1, verbose=12)(
                            delayed(self.pixel_regression)(args)
                            for args in pixel_args)      
        if not permutation:
            # preallocate np arrays for betas + tstats
            tfr_betas = np.zeros((self.tfr_dims))
            tfr_tstats = np.zeros((self.tfr_dims))

            # expanded_results is a list of tuples (beta,tstat) for every pixel 
            for count,(freq_idx,time_idx) in enumerate(iter_tup):
                tfr_betas[freq_idx,time_idx] = expanded_results[count][0]
                tfr_tstats[freq_idx,time_idx] = expanded_results[count][1]

            print(f'tfr regression time: ', '{:.2f}'.format(time.time()-start))

            return tfr_betas, tfr_tstats

        else:
            # preallocate np arrays for tstats
            perm_tstats = np.zeros((self.tfr_dims))

            # expanded_results is a list of tuples (betas,tstat) for every pixel 
            for count,(freq_idx,time_idx) in enumerate(iter_tup):
                perm_tstats[freq_idx,time_idx] = expanded_results[count][1]

            # return only tstats for permutation regressions 
            return perm_tstats

    def max_tfr_cluster(self,tfr_tstats,alternative='two-sided',clust_struct=np.ones(shape=(3,3)),output='all'):
        
        max_cluster_data = []

        for binary_mat in self.threshold_tfr_tstat(tfr_tstats,alternative = alternative):
            if np.sum(binary_mat) != 0: # test whether there are any pixels > t critical
                cluster_label, num_clusters = label(binary_mat,clust_struct)
                # use argmax to find index of largest absolute value of cluster t statistic sums 
                max_label = np.argmax([np.abs(np.sum(tfr_tstats[cluster_label==i+1])) for i in range(num_clusters)])
                # use max_label index to compute cluster tstat sum (without absolute value)
                max_clust_stat = np.sum(tfr_tstats[cluster_label==max_label+1])
                clust_freqs, clust_times = [(np.min(arr),np.max(arr)) for arr in np.where(cluster_label == max_label)]

                if output == 'all':
                    max_cluster_data.append({'cluster_stat':max_clust_stat,'freq_idx':clust_freqs,'time_idx':clust_times})
                elif output == 'cluster_stat':
                    max_cluster_data.append({'cluster_stat':max_clust_stat})
                elif output == 'freq_time':
                    max_cluster_data.append({'freq_idx':clust_freqs,'time_idx':clust_times})
                else: 
                    continue

        return max_cluster_data

    def permuted_tfr_regression(self,alternative ='two-sided'):
        """
        Run permuted tfr regression 
        """

        # permute predictor data
        self.predictor_data = self.permute_predictor() # Permute predictor variable
        # Run regression on permuted data + extract tstats only
        perm_tstats = self.tfr_regression(permutation=True)  
        # extract cluster statistics for permutation
        perm_cluster_stat = self.max_tfr_cluster(perm_tstats, output='cluster_stat')  # Get cluster statistics
        del perm_tstats  # Delete objects to free up memory
        # check if any clusters detected in permuted tfr regression 
        if len(perm_cluster_stat) !=0 :
            return perm_cluster_stat
        else: # return list of empty cluster stats with length of alternative hypothesis ('two_sided'=length 2, less or greater = length 1)
            return [{'cluster_stat':0}]*len(alternative.split('_'))

    def pixel_regression(self,pixel_df):

        """
        Run pixel-wise OLS regression model to extraxct beta coefficient and t-statistic. 

        Args:
        - pixel_df (pandas df): regression dataframe (insert details here)

        Returns:
        - beta_coeff (numpy array): Beta coefficient(s) from pixel-wise regression.
        - tstat_pixel (numpy array): Observed t-statistic(s) from pixel-wise regression.
        """

        # formula should be in form 'col_name + col_name' if col is categorical then should be 'C(col_name)'  
        formula    = 'pow ~ 1 + ' + (' + ').join(['C('+col+')' if pd.api.types.is_categorical_dtype(pixel_df[col])
                                    else col for col in pixel_df.columns[~pixel_df.columns.isin(['pow'])].tolist()])

        pixel_model = smf.ols(formula,pixel_df,missing='drop').fit() # fit regression model

        return (pixel_model.params[self.permute_var],pixel_model.tvalues[self.permute_var])

    def expand_tfr_indices(self):
        iter_tup = list(map(tuple,np.unravel_index(np.dstack(([*np.indices(self.tfr_dims)])),np.product(self.tfr_dims))[0].
                            reshape(np.product(np.dstack(([*np.indices(self.tfr_dims)])).shape[:2]),-1)))

        return iter_tup

    def make_pixel_df(self,epoch_data):
        """
        Make pixel-level (frequency x timepoint) dataframe. 
        """
        return self.predictor_data.assign(pow=epoch_data)

    def compute_tcritical(self,tails=2, alternative ='two-sided',alpha=0.05):
        """
        Calculate critical t-values for regression model.

        Args:
        - predictor_dims (tuple): Dimensions of data matrix. Tuple of (n_samples, n_predictors). 
        - tails (int): Number of tails for t-distribution. Default is 2. Options are 1 or 2.
        - alternative (str): Type of test. Default is 'two-sided'. Options are 'two-sided', 'greater', 'less'.
        - alpha (float): Significance level. Default is 0.05.

        Returns:
        - tcritical (float): Critical t-value.
        """

        # Calculate degrees of freedom
        deg_free = float(len(self.predictor_data)-len(self.predictor_data.columns)-tails)

        return (t.ppf(1-alpha/tails,deg_free) if alternative != 'less' else np.negative(t.ppf(1-alpha/tails,deg_free)))

    def threshold_tfr_tstat(self,tfr_tstats,alternative='two-sided'):
        if alternative == 'two-sided':
            return [(tfr_tstats>self.compute_tcritical()).astype(int), (tfr_tstats<np.negative(self.compute_tcritical()).astype(int))]

        elif alternative == 'greater':
            return [(tfr_tstats>self.compute_tcritical(tails=1,alternative='greater')).astype(int)]

        else: #alternative = less
            return [(tfr_tstats<self.compute_tcritical(tails=1,alternative='less')).astype(int)]

    def permute_predictor(self):
        """
        Permute predictor variable for permutation test.
        """

        self.predictor_data[self.permute_var] = np.random.permutation(self.predictor_data[self.permute_var].values)

        return self.predictor_data

In [None]:
# # initialize list to store cluster data
# cluster_list = []

# for p in range(1000):
#     uni_test = TFR_Cluster_Test(tfr_data,pd.DataFrame(test_univar),permute_var,1000)
#     _, uni_tstats = uni_test.tfr_regression()
#     cluster_data = uni_test.max_tfr_cluster(uni_tstats,output='cluster_stat') 
#     # add permutation number to cluster data
#     cluster_data['perm_num'] = p
#     del uni_test, uni_tstats # clear memory
#     cluster_list.append(cluster_data) 



In [None]:
# ### TEST PERMUTATIONS 
# num_permutations = 1000
# start = time.time() # start timer

# all_ch_perm = {}

# for c in range(num_channels):
#         ch_start = time.time() # start timer

#         # Prepare arguments for the permutation function
#         permutation_args = [
#         (np.squeeze(power_epochs._data[:,c,:,:]), reg_data, tcritical)
#         for _ in range(num_permutations)]
    
#         # Perform permutations in parallel
#         elec_permuted_data = Parallel(n_jobs=-1, verbose=12)(
#         delayed(permuted_tfr_cluster_test)(*args)
#         for args in permutation_args)
        
#         # save in all elec dict 
#         all_ch_perm[ch_names[c]] = elec_permuted_data
#         pickle.dump(elec_permuted_data, open(f'{results_dir}{subj_id}_{ch_names[c]}_perm_clusters.pkl', "wb")) 

#         ch_end = time.time() 
#         print(f'{ch_names[c]} permute time: ', '{:.2f}'.format(ch_end-ch_start))
        
        

# end = time.time()    
# print('{:.2f} s'.format(end-start)) # print time elapsed for computation (approx 4 seconds per permutation)


In [None]:
# num_permutations = 1000
# ch_start = time.time() # start timer

# # Prepare arguments for the permutation function
# permutation_args = [
# (np.squeeze(power_epochs._data[:,c,:,:]), reg_data, tcritical)
# for _ in range(num_permutations)]

# # Perform permutations in parallel
# elec_permuted_data_reduc = Parallel(n_jobs=-1, verbose=12)(
# delayed(permuted_tfr_cluster_test)(*args)
# for args in permutation_args)

# # save in all elec dict 
# # all_ch_perm[ch_names[c]] = elec_permuted_data
# pickle.dump(elec_permuted_data_reduc, open(f'{results_dir}{subj_id}_{ch_names[c]}_reduced_output_perm_clusters.pkl', "wb")) 

# ch_end = time.time() 
# print(f'{ch_names[c]} permute time: ', '{:.2f}'.format(ch_end-ch_start))

In [None]:
# To implement FDR correction: 
# https://www.statsmodels.org/dev/generated/statsmodels.stats.multitest.multipletests.html
# multitest.multipletests(p_upper, method='fdr_bh')