# Defining parameters

In [5]:
import pandas as pd, numpy as np, functools, scipy, xarray as xr
import toolbox, tqdm, pathlib
from typing import List, Tuple, Dict, Any, Literal
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns
from scipy.stats import gaussian_kde
import math, itertools, pickle, logging, beautifullogger
from xarray_helper import apply_file_func, auto_remove_dim, nunique, apply_file_func_decorator, extract_unique, mk_bins
import scipy.signal
from matplotlib.backends.backend_pdf import PdfPages

xr.set_options(use_flox=True, display_expand_coords=True, display_max_rows=100, display_expand_data_vars=True, display_width=150)
logger = logging.getLogger(__name__)
beautifullogger.setup(displayLevel=logging.INFO)
logging.getLogger("flox").setLevel(logging.WARNING)
tqdm.tqdm.pandas(desc="Computing")

MODE: Literal["TEST", "ALL", "SMALL", "BALANCED"]="ALL"
FIGS = ["coherence_phase_figs_structure"] #["pwelch_mean_figs_structure", "pwelch_mean_figs_species", "coherence_mean_figs_structure"]
DISPLAY=True

match MODE:
    case "TEST":
        cache_path = "/media/julien/data1/JulienCache/Test/"
    case "ALL" | "BALANCED":
        cache_path = f"/media/julien/data1/JulienCache/{'All' if MODE=='ALL' else 'Balanced'}/"
    case "SMALL":
        cache_path = "/media/julien/data1/JulienCache/Small/"


group_cols = ["Species", "Structure", "Healthy"]
pair_group_cols = [x+"_1" for x in group_cols] + [x+"_2" for x in group_cols]
species_order = ["Rat", "Monkey", "Human"]
structure_order = ["GPe", "STN", "STR"]
condition_order = [0, 1]
sig_type_order=["bua", "lfp", "spike_times"]

# Loading dataset

In [6]:
signals: xr.Dataset = pickle.load(open(cache_path+"signals_computed.pkl", "rb"))
signal_pairs: xr.Dataset = pickle.load(open(cache_path+"signal_pairs_computed.pkl", "rb"))
signal_pairs = signal_pairs.where((signal_pairs["FullStructure_1"] != "STN_VMNR") & (signal_pairs["FullStructure_2"] != "STN_VMNR"), drop=True)
signals

# Add group coordinates + diverse stuff to dataset

In [7]:
dataset = xr.merge([signals, signal_pairs])
dataset = dataset[[var for var in dataset.variables if ("pwelch" in var) or "coherence" in var or "duration" in var]]
scipy_freq_coords = dataset["f2"].to_numpy()
dataset = dataset.interp(f=np.linspace(3, 50, 94, endpoint=False), f2=np.linspace(3, 50, 94, endpoint=False))
for col in dataset.variables:
    if "f2" in dataset[col].dims:
        dataset[col] = dataset[col].rename(f2="f_interp")
    if "f" in dataset[col].dims:
        dataset[col] = dataset[col].rename(f="f_interp")
dataset = dataset.drop(["f", "f2"])
dataset["is_scipy_freq"] = dataset["f_interp"].isin(np.round(scipy_freq_coords*2)/2)
dataset = dataset.set_coords("is_scipy_freq")
dataset["pwelch"] = dataset[["pwelch_cwt", "pwelch_spectrogram"]].to_array(dim="spectral_analysis_function")
dataset.drop_vars(["pwelch_cwt", "pwelch_spectrogram"])
dataset["spectral_analysis_function"] = dataset["spectral_analysis_function"].str.replace("pwelch_", "").str.replace("spectrogram", "scipy.spectrogram").str.replace("cwt", "pycwt.cwt")
coherence = dataset[["coherence_scipy"]].to_array(dim="spectral_analysis_function", name="coherence")
coherence["spectral_analysis_function"] = coherence["spectral_analysis_function"].str.replace("coherence_", "").str.replace("scipy", "scipy.coherence").str.replace("wct", "pycwt.wct")
dataset=xr.merge([dataset, coherence])
dataset = dataset.drop_vars(["coherence_scipy"])



for dim, groupcols in dict(Contact=group_cols, Contact_pair=pair_group_cols, sig_preprocessing=["sig_type"], sig_preprocessing_pair=["sig_type_1", "sig_type_2"]).items():
    grpname = dim+"_grp"
    dataset[grpname] = xr.DataArray(
        pd.MultiIndex.from_arrays([dataset[a].data for a in groupcols],names=groupcols), 
        dims=[dim], coords=[dataset[dim]]
    )
    dataset = dataset.set_coords(grpname)

dataset["coherence_norm"] = np.abs(dataset["coherence"])
dataset

# Creating dataset_contact

In [8]:
dataset_contact = xr.Dataset()
dataset_contact["pwelch"] = dataset["pwelch"].groupby("sig_preprocessing_grp").mean()
dataset_contact["coherence"] = dataset["coherence"].groupby("sig_preprocessing_pair_grp").mean()
dataset_contact["coherence_norm"] = dataset["coherence_norm"].groupby("sig_preprocessing_pair_grp").mean()
dataset_contact["duration"] = dataset["duration"].groupby("sig_preprocessing_grp").mean()
dataset_contact["common_duration"] = dataset["common_duration"].groupby("sig_preprocessing_pair_grp").mean()
dataset_contact


# Stats

## Creating bands

In [9]:
bands = xr.Dataset()
bands["band_start"] = xr.DataArray([4, 8, 15, 30], dims="band")
bands["band_end"] = xr.DataArray([8, 15, 30, 49], dims="band")
bands["band"] = ("band", ["before_beta", "low_beta", "high_beta", "after_beta"])
bands = bands.set_coords(["band_start", "band_end"])
dataset_contact = xr.merge([dataset_contact, bands])
dataset_contact

In [10]:
dataset_contact["pwelch_band"] = dataset_contact["pwelch"].where(
    (dataset_contact["f_interp"] >= dataset_contact["band_start"]) & (dataset_contact["f_interp"] <= dataset_contact["band_end"])
).mean("f_interp")
dataset_contact

## Bootstrapping

In [88]:
import concurrent
bootstrap_dataset = xr.Dataset()
n_boostrap = 100
pd.set_option('display.float_format', lambda x: f'{x:.2e}')
# bootstrap_dataset = bootstrap_dataset.assign_coords(significance=[0.05, 10**(-2), 10**(-3), 10**(-6)])

In [106]:
import concurrent
def bootstrap(f, arr: xr.DataArray, sample_dim, n_resamples=10000,  bt_dist="bt_dist", vectorize=False, executor_ctr=None):
    n = arr.sizes[sample_dim]
    if not executor_ctr is None:
        import concurrent
        futures = {}
        executor = executor_ctr().__enter__()
        

    if not vectorize:
        res=xr.Dataset()
        for i in tqdm.tqdm(np.arange(n_resamples), disable=False):
            choices = np.random.choice(np.arange(n), size=n, replace=True)
            sample = arr.isel({sample_dim:choices})
            if executor_ctr is None:
                res[i] = f(sample)
            else:
                futures[i] = executor.submit(f, sample)
        if not executor_ctr is None:
            for i, f in enumerate(concurrent.futures.as_completed(futures.values())):
                res[i] = f.result()
            executor.__exit__()
        return res.to_array(dim=bt_dist)
    else:
        if vectorize is True:
            vectorize=n_resamples
        all_res = []
        for i in range(int(np.ceil(n_resamples/vectorize))):
            samples = np.arange(i*vectorize, min((i+1)* vectorize, n_resamples))
            choices = np.random.choice(np.arange(n), size=len(samples)*arr.size, replace=True).reshape(arr.shape+(len(samples),))
            choices = xr.DataArray(data=choices, dims=[d if d != sample_dim else "bt_sample" for d in arr.dims] + [bt_dist], coords={bt_dist: samples, "bt_sample":np.arange(n)})
            sample: xr.DataArray = arr.isel({sample_dim:choices}, drop=True)
            sample = sample.drop([sample_dim])
            sample = sample.rename(bt_sample = sample_dim)
            if executor_ctr is None:
                all_res.append(f(sample))
            else:
                futures[i] = executor.submit(f, sample)
        if not executor_ctr is None:
            for f in concurrent.futures.as_completed(futures.values()):
                all_res.append(f.result())
            executor.__exit__()
        final= xr.concat(all_res, dim=bt_dist)
        # print(final)
        return final
    
def progress_map(self, f, *args, **kwargs):
    import tqdm
    bar = tqdm.notebook.tqdm(total=len(self.groups))
    def new_f(*a, **kw):
        bar.update(1)
        return f(*a, **kw)
    res = self.map(new_f, *args, **kwargs)
    bar.close()
    return res

xr.core.groupby.DataArrayGroupBy.progress_map = progress_map

def get_p_value(a):
    def pvalue(x):
        if np.isnan(x).any():
            return np.nan, np.nan
        nover = (x > 0).sum()
        nbelow = (x < 0).sum()
        if nover > nbelow:
            return "+", 1 - nover/x.size
        else:
            return "-", 1 - nbelow/x.size
    ret = xr.Dataset()
    ret["dir"], ret["pvalue"] = xr.apply_ufunc(pvalue,a,
        input_core_dims=[["bt_dist"]], output_core_dims=[[], []], output_dtypes=[object, float],  vectorize=True)
    return ret
    
def get_confidence_interval(a, q=[0.05, 0.5, 0.95]):
    return a.quantile(q, dim="bt_dist")



### Pwelch

In [108]:

bootstrap_dataset["pwelch_dist"] = dataset_contact["pwelch_band"].sel(spectral_analysis_function="scipy.spectrogram", sig_type="bua").groupby("Contact_grp").progress_map(
    lambda a: bootstrap(lambda ar: ar.mean(dim="Contact"), a, sample_dim = "Contact", n_resamples=n_boostrap, vectorize=1000, 
                        # executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
    )).unstack()

get_confidence_interval(bootstrap_dataset["pwelch_dist"].sel(Healthy=0) - bootstrap_dataset["pwelch_dist"].sel(Healthy=1)).to_dataframe().unstack("quantile")

  0%|          | 0/16 [00:00<?, ?it/s]

  return function_base._ureduce(a,



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,band_start,band_start,band_end,band_end,pwelch_dist,pwelch_dist
Unnamed: 0_level_1,Unnamed: 1_level_1,quantile,5.00e-02,9.50e-01,5.00e-02,9.50e-01,5.00e-02,9.50e-01
band,Species,Structure,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
before_beta,Human,GPe,4,4,8,8,,
before_beta,Human,STN,4,4,8,8,0.00297,0.00368
before_beta,Human,STR,4,4,8,8,,
before_beta,Monkey,GPe,4,4,8,8,-0.000583,0.00151
before_beta,Monkey,STN,4,4,8,8,0.00215,0.00361
before_beta,Monkey,STR,4,4,8,8,0.0006,0.00222
before_beta,Rat,GPe,4,4,8,8,0.000927,0.00141
before_beta,Rat,STN,4,4,8,8,-0.00854,-0.00344
before_beta,Rat,STR,4,4,8,8,6.69e-05,0.000534
low_beta,Human,GPe,8,8,15,15,,


In [105]:
def compute_max_f(a):
    # print(a)
    # from time import sleep
    ret = a.sel(f_interp=slice(8, 34)).mean(dim="Contact").idxmax("f_interp")
    # print(ret)
    # sleep(1)
    return ret

bootstrap_dataset["pwelch_max_f_dist"] = dataset_contact["pwelch"].sel(spectral_analysis_function="scipy.spectrogram", sig_type="bua").groupby("Contact_grp").progress_map(
    lambda a: bootstrap(compute_max_f, a, sample_dim = "Contact", n_resamples=n_boostrap, vectorize=1000, 
                        # executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
    )).unstack()

bootstrap_dataset["pwelch_max_f_dist"].mean("bt_dist").to_dataframe()

  0%|          | 0/16 [00:00<?, ?it/s]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,spectral_analysis_function,sig_preprocessing_grp,sig_type,pwelch_max_f_dist
Species,Structure,Healthy,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Human,GPe,0,scipy.spectrogram,"(bua,)",bua,16.2
Human,GPe,1,scipy.spectrogram,"(bua,)",bua,
Human,STN,0,scipy.spectrogram,"(bua,)",bua,11.3
Human,STN,1,scipy.spectrogram,"(bua,)",bua,25.4
Human,STR,0,scipy.spectrogram,"(bua,)",bua,18.6
Human,STR,1,scipy.spectrogram,"(bua,)",bua,
Monkey,GPe,0,scipy.spectrogram,"(bua,)",bua,10.1
Monkey,GPe,1,scipy.spectrogram,"(bua,)",bua,8.26
Monkey,STN,0,scipy.spectrogram,"(bua,)",bua,10.7
Monkey,STN,1,scipy.spectrogram,"(bua,)",bua,8.74


In [43]:
import scipy.stats




def mbootstrap(h, p):
    if np.count_nonzero(~np.isnan(h)) <1 or np.count_nonzero(~np.isnan(p))<1:
        return np.full(n_boostrap , np.nan)
    h = h[~np.isnan(h)]
    p = p[~np.isnan(p)]
    r = scipy.stats.bootstrap([p, h], lambda x, y, axis: np.nanmean(x, axis=axis)-np.nanmean(y, axis=axis), n_resamples=n_boostrap, vectorized=True)
    return r.bootstrap_distribution 



def test(x):
    healthy =  x.where(dataset_contact["Healthy"], drop=True)
    park = x.where(~dataset_contact["Healthy"], drop=True)
    r = xr.apply_ufunc(mbootstrap, healthy, park, input_core_dims=[["Contact"]]*2, exclude_dims={"Contact"}, output_core_dims=[["bt_dist"]], vectorize=True)
    return r



data = dataset_contact[["pwelch_band"]].sel(spectral_analysis_function="scipy.spectrogram", sig_type="bua")
data["Structure_grp"] = xr.apply_ufunc(lambda x: str(x[0:-1]), data["Contact_grp"], output_core_dims=[[]], vectorize=True)
data = data.set_coords("Structure_grp")
bootstrap_dataset["pwelch_val"]  = data["pwelch_band"].groupby("Structure_grp").map(test)
bootstrap_dataset

In [148]:
import scipy.stats

class stupid:
    def __init__(self, a):
        self.a = a

def mbootstrap(x, y, f):
    if np.count_nonzero(~np.isnan(x)) <1 or np.count_nonzero(~np.isnan(y))<1:
        return np.full(n_boostrap , np.nan)
    x = x[~np.isnan(x).all(axis=1)]
    y = y[~np.isnan(y).all(axis=1)]

    res = np.empty(n_boostrap)
    for i in tqdm.tqdm(np.arange(n_boostrap), disable=True):
        sample1_index = np.random.choice(np.arange(x.shape[0]), size=x.shape[0])
        sample1 = x[sample1_index]
        sample2_index = np.random.choice(np.arange(y.shape[0]), size=y.shape[0])
        sample2 = y[sample2_index]
        freqs1 = sample1.mean(axis=0)
        freqs2 = sample2.mean(axis=0)
        f1 =f[np.argmax(freqs1)]
        f2 = f[np.argmax(freqs2)]
        res[i] = f1 - f2
    return res



data = dataset_contact[["pwelch"]].sel(spectral_analysis_function="scipy.spectrogram", sig_type="bua", f_interp=slice(8, 34)).where(data["Healthy"]<1, drop=True)
data["Structure_grp"] = xr.apply_ufunc(lambda x: str(x[0:-1]), data["Contact_grp"], output_core_dims=[[]], vectorize=True)
data = data.set_coords("Structure_grp")
res = {}
for grpname1, grp1 in tqdm.tqdm(data["pwelch"].groupby("Structure_grp")):
    
    res[grpname1] = xr.Dataset()
    for grpname2, grp2 in data["pwelch"].groupby("Structure_grp"):
        # tmp1 = xr.apply_ufunc(lambda a: stupid(a) if not np.isnan(a).any() else np.nan, grp1, input_core_dims=[["f_interp"]], output_core_dims=[[]], vectorize=True)
        # tmp2 = xr.apply_ufunc(lambda a: stupid(a) if not np.isnan(a).any() else np.nan, grp2, input_core_dims=[["f_interp"]], output_core_dims=[[]], vectorize=True)
        r = xr.apply_ufunc(mbootstrap, grp1, grp2, data["f_interp"], input_core_dims=[["Contact", "f_interp"]]*2 +[["f_interp"]], exclude_dims={"Contact"}, output_core_dims=[["bt_dist"]], vectorize=True)
        res[grpname1][grpname2] = r

tmp=xr.Dataset()
for n1, arr in tqdm.tqdm(res.items()):
    tmp[n1] = arr.to_array(dim="grp2")

res = tmp.to_array(dim="grp1")
bootstrap_dataset["pwelch_f"]  =  res
bootstrap_dataset

100%|██████████| 9/9 [01:49<00:00, 12.15s/it]
100%|██████████| 9/9 [00:00<00:00, 916.21it/s]


### Computing significance levels

In [149]:
def pvalue(x):
    if np.isnan(x).any():
        return np.nan, np.nan
    nover = (x > 0).sum()
    nbelow = (x < 0).sum()
    if nover > nbelow:
        return "+", 1 - nover/x.size
    else:
        return "-", 1 - nbelow/x.size
    
significance_dataset = xr.Dataset()
for v in bootstrap_dataset.data_vars:
    significance_dataset[[f"{v}_direction", f"{v}_pvalue"]] = xr.apply_ufunc(pvalue,bootstrap_dataset[v],
        input_core_dims=[["bt_dist"]], output_core_dims=[[], []], output_dtypes=[object, float],  vectorize=True)

In [150]:
pd.set_option('display.float_format', lambda x: f'{x:.2e}')
print(significance_dataset[["pwelch_f_direction", "pwelch_f_pvalue"]].to_dataframe().to_string())

                                    pwelch_f_direction  pwelch_f_pvalue
grp1              grp2                                                 
('Human', 'GPe')  ('Human', 'GPe')                   +         6.19e-01
                  ('Human', 'STN')                   +         4.69e-01
                  ('Human', 'STR')                   -         2.85e-01
                  ('Monkey', 'GPe')                  +         4.17e-01
                  ('Monkey', 'STN')                  +         4.57e-01
                  ('Monkey', 'STR')                  +         3.86e-01
                  ('Rat', 'GPe')                     -         2.86e-01
                  ('Rat', 'STN')                     +         4.99e-01
                  ('Rat', 'STR')                     +         3.08e-01
('Human', 'STN')  ('Human', 'GPe')                   -         4.71e-01
                  ('Human', 'STN')                   +         7.20e-01
                  ('Human', 'STR')                   -         1

In [57]:
import scipy.stats
def mbootstrap(x, y):
    print(x.shape, y.shape)

    # x = x[~np.isnan(x)]
    # y = y[~np.isnan(y)]
    def get_max(x, y, axis):
        print(x.shape, y.shape, axis)
        raise Exception("Stop")
    r = scipy.stats.bootstrap((x,y), get_max, axis=[0, 1])
    
    # raise Exception("Stop")
    if np.isnan(r.confidence_interval[1]):
        res=  np.nan
    else:
        if r.confidence_interval[1] < 0:
            res = 1.0
        elif r.confidence_interval[0] > 0:
            res = -1.0
        else: 
            res = 0.0
    # print(res)
    return res

data = dataset_contact[["pwelch"]].sel(spectral_analysis_function="scipy.spectrogram", sig_type="bua").sel(f_interp = slice(8, 34))
data = data.where(data["Healthy"]<1, drop=True)

data["Structure_grp"] = xr.apply_ufunc(lambda x: str(x[1:]), dataset_contact["Contact_grp"], output_core_dims=[[]], vectorize=True)
data = data.set_coords("Structure_grp")

def test(x):
    # print(x)
    vals = ["Monkey", "Human", "Rat"]
    r = []
    for i in range(3):
        for j in range(i+1, 3):
            r.append(xr.apply_ufunc(
                mbootstrap, x.where(x["Species"] == vals[i], drop=True), x.where(x["Species"] == vals[j], drop=True), 
                input_core_dims=[["Contact", "f_interp"]]*2, 
                exclude_dims={"Contact"}, vectorize=True).expand_dims(Sp1=[vals[i]], Sp2 = vals[j]))
    # raise Exception("Stop")
    return xr.merge(r)

res = data["pwelch"].groupby("Structure_grp").map(test)
print(res)
res.sel(spectral_analysis_function="scipy.spectrogram", sig_type="bua").to_dataframe(name="boostrap significative")

(104, 53) (30, 53)


TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'

# Creating grouped dataset

In [None]:
grouped_dataset = xr.Dataset()

grouped_dataset["pwelch"] = dataset["pwelch"].groupby("sig_preprocessing_grp").mean().groupby("Contact_grp").mean()
grouped_dataset["n_contacts"] = dataset["pwelch"].groupby("sig_preprocessing_grp").mean().groupby("Contact_grp").count("Contact").sel(f_interp=4)
grouped_dataset["duration"] = dataset["duration"].groupby("sig_preprocessing_grp").mean().groupby("Contact_grp").mean("Contact")
grouped_dataset["coherence"] = dataset["coherence"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").mean()
grouped_dataset["coherence_norm"] = dataset["coherence_norm"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").mean()
grouped_dataset["n_contact_pairs"] = dataset["coherence"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").count("Contact_pair").sel(f_interp=4)
grouped_dataset["common_duration"] = dataset["common_duration"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").mean("Contact_pair")
grouped_dataset = grouped_dataset.set_coords(["n_contact_pairs", "n_contacts"])

grouped_dataset["coherence_phase"] = xr.apply_ufunc(np.angle, grouped_dataset["coherence"])
grouped_dataset["coherence_validity"] = np.abs(grouped_dataset["coherence"])/grouped_dataset["coherence_norm"]
grouped_dataset["f_max"] = grouped_dataset["coherence_norm"].sel(f_interp=slice(7, 40)).idxmax("f_interp")
grouped_dataset

# Adding the column of interest

In [None]:
print(dataset)
print(grouped_dataset)


<xarray.Dataset>
Dimensions:                     (Contact: 5456, sig_preprocessing: 5, sig_preprocessing_pair: 15, Contact_pair: 27559, f_interp: 94,
                                 spectral_analysis_function: 3)
Coordinates:
  * Contact                     (Contact) int32 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ... 5455 5456 5457 5458 5459 5460 5461 5462 5463 5464 5465 5466 5467
  * sig_preprocessing           (sig_preprocessing) object 'bua' 'lfp' 'neuron_0' 'neuron_1' 'neuron_2'
  * sig_preprocessing_pair      (sig_preprocessing_pair) object MultiIndex
  * sig_preprocessing_1         (sig_preprocessing_pair) object 'bua' 'bua' 'bua' 'bua' 'bua' ... 'neuron_0' 'neuron_1' 'neuron_1' 'neuron_2'
  * sig_preprocessing_2         (sig_preprocessing_pair) object 'bua' 'lfp' 'neuron_0' 'neuron_1' ... 'neuron_2' 'neuron_1' 'neuron_2' 'neuron_2'
  * Contact_pair                (Contact_pair) object MultiIndex
  * Contact_1                   (Contact_pair) int32 139 140 144 145 147 148 150 151 157 158

In [None]:
def f(x: np.ndarray, coord, dim):
    other = dataset[dim].to_numpy()
    print(x.shape, other.shape)
    
    positions = np.where(coord[:, None] == other[None, :])[0]
    res = np.apply_along_axis(lambda x: x[positions], -1, x)
    return res

f_max_tmp = xr.apply_ufunc(lambda x, y: f(x, y, "Contact_pair_grp"), grouped_dataset["f_max"], grouped_dataset["Contact_pair_grp"], input_core_dims=[["Contact_pair_grp"], ["Contact_pair_grp"]], output_core_dims=[["Contact_pair"]])
f_max_tmp = xr.apply_ufunc(lambda x, y: f(x, y, "sig_preprocessing_pair_grp"), f_max_tmp, grouped_dataset["sig_preprocessing_pair_grp"], input_core_dims=[["sig_preprocessing_pair_grp"], ["sig_preprocessing_pair_grp"]], output_core_dims=[["sig_preprocessing_pair"]])
dataset["f_max_coherence"] = f_max_tmp
print(dataset)

(3, 6, 35) (27559,)
(3, 27559, 6) (15,)
<xarray.Dataset>
Dimensions:                     (Contact: 5456, sig_preprocessing: 5, sig_preprocessing_pair: 15, Contact_pair: 27559, f_interp: 94,
                                 spectral_analysis_function: 3)
Coordinates:
  * Contact                     (Contact) int32 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ... 5455 5456 5457 5458 5459 5460 5461 5462 5463 5464 5465 5466 5467
  * sig_preprocessing           (sig_preprocessing) object 'bua' 'lfp' 'neuron_0' 'neuron_1' 'neuron_2'
  * sig_preprocessing_pair      (sig_preprocessing_pair) object MultiIndex
  * sig_preprocessing_1         (sig_preprocessing_pair) object 'bua' 'bua' 'bua' 'bua' 'bua' ... 'neuron_0' 'neuron_1' 'neuron_1' 'neuron_2'
  * sig_preprocessing_2         (sig_preprocessing_pair) object 'bua' 'lfp' 'neuron_0' 'neuron_1' ... 'neuron_2' 'neuron_1' 'neuron_2' 'neuron_2'
  * Contact_pair                (Contact_pair) object MultiIndex
  * Contact_1                   (Contact_pair) int32

In [None]:

tmp_group = dataset["coherence"].sel(f_interp=dataset["f_max_coherence"].fillna(4)).groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp")


In [None]:

phase_bins = np.linspace(-np.pi, np.pi, 180)

def mk_phase_info(x: xr.DataArray):
    m = x*np.exp(-1j * xr.DataArray(phase_bins, dims=["phase_bins"], coords=dict(phase_bins=phase_bins)))
    angle = xr.apply_ufunc(np.angle, m)
    dot = np.real(m)
    res = dot.where(abs(angle) <np.pi/5).sum("Contact_pair")
    return res

grouped_dataset["coherence_phase_bins"] = tmp_group.map(mk_phase_info)
print(grouped_dataset["coherence_phase_bins"])

<xarray.DataArray 'coherence_phase_bins' (spectral_analysis_function: 3, sig_preprocessing_pair_grp: 6, Contact_pair_grp: 35, phase_bins: 180)>
array([[[[ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],

        [[ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
        

In [None]:
from xarray_helper import normalize


# Post handling

In [None]:


# print(grouped_dataset["coherence_phase_bins"].sel(f_interp=grouped_dataset["f_max"].fillna(grouped_dataset["f_interp"].min())))
# exit()

selected = (grouped_dataset["f_interp"] > grouped_dataset["f_max"]-2) & (grouped_dataset["f_interp"] < grouped_dataset["f_max"]+2) & grouped_dataset["is_scipy_freq"]
grouped_dataset["coherence_phase"] = grouped_dataset["coherence_phase"].where(selected)

for col in group_cols:
    if grouped_dataset[f"{col}_1"].equals(grouped_dataset[f"{col}_2"]):
        grouped_dataset[col+"(common)"] = grouped_dataset[f"{col}_1"]
        grouped_dataset=grouped_dataset.set_coords(col+"(common)")
    if dataset_contact[f"{col}_1"].equals(dataset_contact[f"{col}_2"]):
        dataset_contact[col+"(common)"] = dataset_contact[f"{col}_1"]
        dataset_contact=dataset_contact.set_coords(col+"(common)")

print(grouped_dataset)

<xarray.Dataset>
Dimensions:                     (f_interp: 94, spectral_analysis_function: 3, sig_preprocessing_grp: 3, Contact_grp: 16,
                                 sig_preprocessing_pair_grp: 6, Contact_pair_grp: 35, phase_bins: 180)
Coordinates:
  * f_interp                    (f_interp) float64 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 ... 45.0 45.5 46.0 46.5 47.0 47.5 48.0 48.5 49.0 49.5
  * spectral_analysis_function  (spectral_analysis_function) object 'pycwt.cwt' 'scipy.coherence' 'scipy.spectrogram'
    is_scipy_freq               (f_interp) bool True False True False True False True False True ... False True False True False True False True False
  * sig_preprocessing_grp       (sig_preprocessing_grp) object MultiIndex
  * sig_type                    (sig_preprocessing_grp) object 'bua' 'lfp' 'spike_times'
  * Contact_grp                 (Contact_grp) object MultiIndex
  * Species                     (Contact_grp) object 'Human' 'Human' 'Human' 'Human' 'Monkey' 'Monkey

In [None]:
print(grouped_dataset)

<xarray.Dataset>
Dimensions:                     (f_interp: 94, spectral_analysis_function: 3, sig_preprocessing_grp: 3, Contact_grp: 16,
                                 sig_preprocessing_pair_grp: 6, Contact_pair_grp: 35, phase_bins: 180)
Coordinates:
  * f_interp                    (f_interp) float64 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 ... 45.0 45.5 46.0 46.5 47.0 47.5 48.0 48.5 49.0 49.5
  * spectral_analysis_function  (spectral_analysis_function) object 'pycwt.cwt' 'scipy.coherence' 'scipy.spectrogram'
    is_scipy_freq               (f_interp) bool True False True False True False True False True ... False True False True False True False True False
  * sig_preprocessing_grp       (sig_preprocessing_grp) object MultiIndex
  * sig_type                    (sig_preprocessing_grp) object 'bua' 'lfp' 'spike_times'
  * Contact_grp                 (Contact_grp) object MultiIndex
  * Species                     (Contact_grp) object 'Human' 'Human' 'Human' 'Human' 'Monkey' 'Monkey

# Drawing coherence phase

In [None]:
dataset

In [None]:

n_contacts: xr.DataArray = grouped_dataset[["n_contacts", "duration"]]
n_contacts = n_contacts.sel(sig_preprocessing_grp="bua", drop=True).isel(spectral_analysis_function=0).unstack().squeeze()
res = n_contacts.to_dataframe()[["n_contacts", "duration"]]
res["n_data"] = res["n_contacts"] * res["duration"]/1000
res = res.apply(np.round).dropna().astype(int)
res = res.sort_values("n_data")
res["n_data"]="~"+ res["n_data"].astype(str) + " 000"
res

# ["n_contacts"].iloc[:, 0:1].sort_values("n_contacts")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,n_contacts,duration,n_data
Species,Structure,Healthy,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Human,STR,0,57,17,~1 000
Human,GPe,0,30,51,~2 000
Rat,STN,0,18,96,~2 000
Rat,STN,1,21,96,~2 000
Rat,STR,1,149,99,~15 000
Rat,GPe,1,199,99,~20 000
Rat,STR,0,314,100,~31 000
Human,STN,0,1538,21,~32 000
Human,STN,1,1958,17,~33 000
Rat,GPe,0,402,100,~40 000


In [None]:
n_contacts_pairs: xr.DataArray = grouped_dataset[["n_contact_pairs", "common_duration"]].unstack("sig_preprocessing_pair_grp")

n_contacts_pairs = n_contacts_pairs.sel(sig_type_1="bua", sig_type_2="bua", drop=True).squeeze()
n_contacts_pairs = n_contacts_pairs.isel(spectral_analysis_function=1)
res = n_contacts_pairs.rename(Species_1="Species", Healthy_1="Healthy").to_dataframe()[["n_contact_pairs", "common_duration"]]
res = res.reset_index(["Species_2", "Healthy_2"], drop=True).reset_index()
res = res[res["Structure_1"] == res["Structure_2"]]
res = res.set_index(["Species", "Healthy", "Structure_1", "Structure_2"])
res = res[res["n_contact_pairs"] >10]

res["n_data"] = res["n_contact_pairs"] * res["common_duration"]/1000
res = res.apply(np.round).dropna().astype(int)
res = res.sort_values("n_data")
res["n_data"]="~"+ res["n_data"].astype(str) + " 000"
res
# n_contacts_pairs = n_contacts_pairs[n_contacts_pairs["n_contact_pairs"]>10].reset_index()
# n_contacts_pairs["Species"] = n_contacts_pairs["Species_1"]
# n_contacts_pairs["Healthy"] = n_contacts_pairs["Healthy_1"]
# n_contacts_pairs = n_contacts_pairs.drop(columns=["Species_1", "Species_2", "Healthy_1", "Healthy_2"])
# n_contacts_pairs = n_contacts_pairs.set_index(["Species",	"Healthy", "Structure_1", "Structure_2"]).sort_values("n_contact_pairs")
# n_contacts_pairs.to_dataframe()["n_contact_pairs"].iloc[:, 0:1]


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,n_contact_pairs,common_duration,n_data
Species,Healthy,Structure_1,Structure_2,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Human,0.0,STN,STN,72,26,~2 000
Rat,0.0,STN,STN,26,95,~2 000
Rat,1.0,STN,STN,35,97,~3 000
Monkey,0.0,STN,STN,14,209,~3 000
Monkey,1.0,STN,STN,18,640,~12 000
Monkey,0.0,STR,STR,52,337,~18 000
Monkey,0.0,GPe,GPe,78,407,~32 000
Monkey,1.0,STR,STR,53,611,~32 000
Rat,1.0,STR,STR,698,100,~69 000
Rat,1.0,GPe,GPe,1286,99,~128 000


In [None]:
data = grouped_dataset["pwelch"].where(grouped_dataset["n_contacts"] > 0).sel(spectral_analysis_function="scipy.spectrogram").unstack().sel(sig_type="bua").rename(f_interp="f")
toolbox.FigurePlot(
    data = data.to_dataframe(name="pwelch").dropna(axis="index", subset="pwelch"),
    figures="spectral_analysis_function", col="Structure", row="sig_type", sharey=False, margin_titles=True, fig_title="mean_pwelch_structure_{spectral_analysis_function}",
).map(sns.lineplot, x="f", y="pwelch", hue="Species", style="Healthy", hue_order=species_order, style_order=condition_order).add_legend()
plt.show()

In [None]:
data = grouped_dataset["pwelch"].where(grouped_dataset["n_contacts"] > 0).sel(spectral_analysis_function="scipy.spectrogram").unstack().sel(sig_type="bua")
toolbox.FigurePlot(
    data = data.to_dataframe(name="pwelch").dropna(axis="index", subset="pwelch"),
    figures="spectral_analysis_function", col="Species", row="sig_type", sharey=False, margin_titles=True, fig_title="mean_pwelch_species_{spectral_analysis_function}",
).map(sns.lineplot, x="f_interp", y="pwelch", hue="Structure", style="Healthy", 
      hue_order=structure_order, style_order=condition_order, palette=sns.color_palette("dark", n_colors=len(structure_order))).add_legend()
plt.show()

In [None]:
data = grouped_dataset[["coherence_norm", "coherence_phase", "coherence_validity", "f_max"]].where(grouped_dataset["n_contact_pairs"] > 10)
data = data.sel(spectral_analysis_function="scipy.coherence").sel(sig_type_1="bua").sel(sig_type_2="bua").rename(f_interp="f")
data = data.where(data["Structure_1"]==data["Structure_2"]).rename(Structure_1="Structure").rename({"Species(common)":"Species", "Healthy(common)":"Healthy"})
data = data.where(data["Healthy"]==0)
data["f_max_amp"] = data["coherence_norm"].sel(f= data["f_max"].fillna(4))
(toolbox.FigurePlot(
    data = data.to_dataframe().dropna(axis="index", subset="coherence_norm"),
    figures=["spectral_analysis_function", "sig_type_1", "sig_type_2"], col="Structure",
    sharey=False, margin_titles=True, fig_title="mean_coherence_structure_{spectral_analysis_function}, {sig_type_1}, {sig_type_2}", 
    col_order=structure_order,
).map(sns.scatterplot, x="f_max", y="f_max_amp", c="red").map(sns.lineplot, x="f", y="coherence_norm", hue="Species", hue_order=species_order)

).add_legend()
plt.show()

In [None]:
data = grouped_dataset[["coherence_norm", "coherence_phase", "coherence_validity", "f_max"]].where(grouped_dataset["n_contact_pairs"] > 10)
data = data.sel(spectral_analysis_function="scipy.coherence").sel(sig_type_1="bua").sel(sig_type_2="bua").rename(f_interp="f")
data = data.where(data["Structure_1"]!=data["Structure_2"]).rename({"Species(common)":"Species", "Healthy(common)":"Healthy"})
data = data.where(data["Healthy"]==0)
data["f_max_amp"] = data["coherence_norm"].sel(f= data["f_max"].fillna(4))
data["Structure"] = data["Structure_1"].astype(str).astype(object) +", " +  data["Structure_2"].astype(str).astype(object)
(toolbox.FigurePlot(
    data = data.to_dataframe().dropna(axis="index", subset="coherence_norm"),
    figures=["spectral_analysis_function", "sig_type_1", "sig_type_2"],
    sharey=False, margin_titles=True, fig_title="mean_coherence_structure_{spectral_analysis_function}, {sig_type_1}, {sig_type_2}", 
).map(sns.scatterplot, x="f_max", y="f_max_amp", c="red").map(sns.lineplot, x="f", y="coherence_norm", hue="Structure")

).add_legend()
plt.show()

In [None]:
grouped_dataset["coherence_phase_bins"] = grouped_dataset["coherence_phase_bins"]/grouped_dataset["coherence_phase_bins"].max("phase_bins")
data = (
    grouped_dataset["coherence_phase_bins"]
# .where(grouped_dataset["sig_type_1"]=="bua")

    .where(grouped_dataset["n_contact_pairs"] > 10)
    .where(grouped_dataset["Healthy(common)"] <0.5)
)
data = data.sel(spectral_analysis_function="scipy.coherence").sel(sig_type_1="bua").sel(sig_type_2="bua")
data = data.where(data["Structure_1"]!=data["Structure_2"]).rename({"Species(common)":"Species", "Healthy(common)":"Healthy"})
data["Structure"] = data["Structure_1"].astype(str).astype(object) +", " +  data["Structure_2"].astype(str).astype(object)
data = data.to_dataframe().dropna(axis="index", subset="coherence_phase_bins")
(toolbox.FigurePlot(
    data = data,
    figures=["spectral_analysis_function", "sig_type_1", "sig_type_2"],
    sharey=False, margin_titles=True, fig_title="coherence_phase_bins_structure_{spectral_analysis_function}, {sig_type_1}, {sig_type_2}", 
    subplot_kws=dict(projection='polar'), despine=False,
).map(sns.lineplot, x="phase_bins", y="coherence_phase_bins", hue="Structure")
.add_legend().set(yticks=[]))
# .maximize().save_pdf(f"{cache_path}Figures/coherence_phase_bins_structure.pdf")
plt.show()

In [None]:
grouped_dataset["coherence_phase_bins"] = grouped_dataset["coherence_phase_bins"]/grouped_dataset["coherence_phase_bins"].max("phase_bins")
data = (
    grouped_dataset["coherence_phase_bins"]
# .where(grouped_dataset["sig_type_1"]=="bua")

    .where(grouped_dataset["n_contact_pairs"] > 10)
    .where(grouped_dataset["Healthy(common)"] <0.5)
)
data = data.sel(spectral_analysis_function="scipy.coherence").sel(sig_type_1="bua").sel(sig_type_2="bua")
data = data.where(data["Structure_1"]==data["Structure_2"]).rename(Structure_1="Structure").rename({"Species(common)":"Species", "Healthy(common)":"Healthy"})
print(data)
data = data.to_dataframe().dropna(axis="index", subset="coherence_phase_bins")
print(data)
(toolbox.FigurePlot(
    data = data,
    figures=["spectral_analysis_function", "sig_type_1", "sig_type_2"], col="Structure",
    sharey=False, margin_titles=True, fig_title="coherence_phase_bins_structure_{spectral_analysis_function}, {sig_type_1}, {sig_type_2}", 
    col_order=structure_order,
    subplot_kws=dict(projection='polar'), despine=False,
).map(sns.lineplot, x="phase_bins", y="coherence_phase_bins", hue="Species", style="Healthy", hue_order=species_order, style_order=condition_order)
.add_legend().set(yticks=[]))
# .maximize().save_pdf(f"{cache_path}Figures/coherence_phase_bins_structure.pdf")
plt.show()

<xarray.DataArray 'coherence_phase_bins' (Contact_pair_grp: 35, phase_bins: 180)>
array([[0.03800728, 0.03784132, 0.03740026, ..., 0.03819858, 0.03812641,
        0.03800728],
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       ...,
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan]])
Coordinates:
    spectral_analysis_function  <U15 'scipy.coherence'
    sig_type_2                  <U3 'bua'
  * Contact_pair_grp            (Contact_pair_grp) object MultiIndex
  * Species_1                   (Contact_pair_grp) object 'Human' 'Human' 'Human' 'Human' 'Human' 'Human' ... 'Rat' 'Rat' 'Rat' 'Rat' 'Rat' 'Rat'
  * Structure   