# Preprocessing

## Defining parameters

In [1]:
import pandas as pd, numpy as np, functools, scipy, xarray as xr
import toolbox, tqdm, pathlib, concurrent, tqdm.notebook
from typing import List, Tuple, Dict, Any, Literal
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns
from scipy.stats import gaussian_kde
import math, itertools, pickle, logging, beautifullogger
from xarray_helper import apply_file_func, auto_remove_dim, nunique, apply_file_func_decorator, extract_unique, mk_bins
import scipy.signal
from matplotlib.backends.backend_pdf import PdfPages
from time import sleep
from  ipydatagrid import DataGrid

xr.set_options(use_flox=True, display_expand_coords=True, display_max_rows=100, display_expand_data_vars=True, display_width=150)
logger = logging.getLogger(__name__)
beautifullogger.setup(displayLevel=logging.INFO)
logging.getLogger("flox").setLevel(logging.WARNING)
tqdm.tqdm.pandas(desc="Computing")

MODE: Literal["TEST", "ALL", "SMALL", "BALANCED"]="ALL"
FIGS = ["coherence_phase_figs_structure"] #["pwelch_mean_figs_structure", "pwelch_mean_figs_species", "coherence_mean_figs_structure"]
DISPLAY=True

match MODE:
    case "TEST":
        cache_path = "/media/julien/data1/JulienCache/Test/"
    case "ALL" | "BALANCED":
        cache_path = f"/media/julien/data1/JulienCache/{'All' if MODE=='ALL' else 'Balanced'}/"
    case "SMALL":
        cache_path = "/media/julien/data1/JulienCache/Small/"


group_cols = ["Species", "Structure", "Healthy"]
pair_group_cols = [x+"_1" for x in group_cols] + [x+"_2" for x in group_cols]
species_order = ["Rat", "Monkey", "Human"]
structure_order = ["GPe", "STN", "STR"]
condition_order = [0, 1]
sig_type_order=["bua", "lfp", "spike_times"]

inotify_add_watch(/home/julien/.config/ibus/bus/99bd62d910774e6fb8ee433829d5a5b8-unix-1) failed: (No space left on device)


## Loading dataset

In [2]:
signals: xr.Dataset = pickle.load(open(cache_path+"signals_computed.pkl", "rb"))
signal_pairs: xr.Dataset = pickle.load(open(cache_path+"signal_pairs_computed.pkl", "rb"))
signal_pairs = signal_pairs.where((signal_pairs["FullStructure_1"] != "STN_VMNR") & (signal_pairs["FullStructure_2"] != "STN_VMNR"), drop=True)
signals

## Basic preprocessing

In [3]:
dataset = xr.merge([signals, signal_pairs])
dataset = dataset[[var for var in dataset.variables if ("pwelch" in var) or "coherence" in var or "duration" in var]]
dataset = dataset.drop_dims("f")
dataset = dataset.rename(f2="f")
# scipy_freq_coords = dataset["f2"].to_numpy()
# dataset = dataset.interp(f=np.linspace(3, 50, 94, endpoint=False), f2=np.linspace(3, 50, 94, endpoint=False))
# for col in dataset.variables:
#     if "f2" in dataset[col].dims:
#         dataset[col] = dataset[col].rename(f2="f_interp")
#     if "f" in dataset[col].dims:
#         dataset[col] = dataset[col].rename(f="f_interp")
# dataset = dataset.drop(["f", "f2"])
dataset

In [4]:
dataset[["pwelch", "coherence"]] = dataset[["pwelch_spectrogram", "coherence_scipy"]]
dataset = dataset.set_coords(["duration", "common_duration"])
dataset = dataset[["pwelch", "coherence"]]
dataset

## Averaging along sig_type, sig_type_pair

In [5]:
from typing import Any

class stupid:
        def __init__(self, grp, dataset):
            self.grp = grp
            self.dataset = dataset

        def __getattr__(self, __name: str) -> Any:
            f = self.grp.__getattribute__(__name)
            def new_f(*args, **kwargs):
                ret = f(*args, **kwargs)
                return xr.merge([ret, self.dataset])
            return new_f

def nicegroupby(self: xr.Dataset, val, *args, **kwargs):
    vars = [v for v in self.data_vars if set(self[val].dims).issubset(set(self[v].dims))]
    ret = stupid(self[vars].groupby(val, *args, **kwargs), self.drop_dims(self[val].dims))
    return ret

xr.Dataset.nicegroupby = nicegroupby


In [6]:

tmp = dataset.nicegroupby("sig_type").map(lambda x: x.mean("sig_preprocessing"))
# tmp = tmp.set_xindex(["sig_type_1", "sig_type_2"])
tmp["sig_type_pair"] = xr.DataArray(pd.MultiIndex.from_arrays(
        [dataset["sig_type_1"].data, dataset["sig_type_2"].data],
        names=["sig_type_1", "sig_type_2"]), 
    dims=["sig_preprocessing_pair"])
tmp = tmp.set_coords("sig_type_pair")
tmp = tmp.nicegroupby("sig_type_pair").mean()
tmp["sig_type_1"] =("sig_type_pair", tmp.get_index("sig_type_pair").get_level_values("sig_type_1").to_numpy())
tmp["sig_type_2"] =("sig_type_pair", tmp.get_index("sig_type_pair").get_level_values("sig_type_2").to_numpy())
tmp = tmp.set_coords(["sig_type_1", "sig_type_2"])
tmp = tmp.set_xindex(["sig_type_1", "sig_type_2"])
dataset = tmp
tmp

## Fitting curves

In [7]:
def fit(a: xr.DataArray):
    import sklearn, sklearn.linear_model
    norm = np.abs(a)
    fit_part = xr.concat([norm.sel(f=slice(4, 9)), norm.sel(f=slice(34, 37))], dim="f")

    def fit(arr, f): 
        if np.isnan(arr).all():
            return np.nan
        Y = arr*f
        X = f
        model = sklearn.linear_model.LinearRegression()
        return model.fit(X.reshape(-1, 1), Y)
    
    def predict(model, f):
        if pd.isna(model):
            return np.full_like(f, np.nan)
        X = f
        Y = model.predict(X.reshape(-1, 1))
        res = Y/f
        return res
    
    model = xr.apply_ufunc(fit, fit_part, fit_part["f"], input_core_dims=[["f"]]*2, vectorize=True)
    fit_curve = xr.apply_ufunc(predict, model, norm["f"], input_core_dims=[[], ["f"]], output_core_dims=[["f"]], vectorize=True)

    res = xr.Dataset()
    res["a/f + b"] = a* (norm-fit_curve)/norm
    res["nofit"] = a
    return res.to_array(dim="fit_method")

for v in dataset.data_vars:
    dataset[v] = fit(dataset[v])

dataset

## Creating Contact Group

In [8]:
dataset["Contact_grp"] = xr.DataArray(pd.MultiIndex.from_arrays(
        [dataset["Species"].data, dataset["Structure"].data, dataset["Healthy"].data],
        names=["Species", "Structure", "Healthy"]), 
    dims=["Contact"], coords=[dataset["Contact"]])
dataset = dataset.set_coords("Contact_grp")

if not (dataset["Species_1"] == dataset["Species_2"]).all():
    raise Exception("Strange")
dataset["Contact_pair_grp"] = xr.DataArray(pd.MultiIndex.from_arrays(
        [dataset["Species_1"].data, 
         dataset["Structure_1"].data, dataset["Structure_2"].data, 
         dataset["Healthy_1"].data, dataset["Healthy_2"].data],
        names=["Species", "Structure_1", "Structure_2", "Healthy_1", "Healthy_2"]), 
    dims=["Contact_pair"], coords=[dataset["Contact_pair"]])
dataset = dataset.set_coords("Contact_pair_grp")
dataset

# Functions

In [9]:
functions = {}

def register_func(name: str, on: str, t: Literal["preprocess", "compute"]):
    if name not in functions:
        functions[name] = dict(on=on, funcs={})
    def decorator(f):
        functions[name]["funcs"][t] = f
        return f
    return decorator

In [10]:
@register_func("pwelch", "pwelch", "compute")
def pwelch(a: xr.DataArray):
    return a.mean(dim="Contact")

bands = xr.Dataset()
bands["band_start"] = xr.DataArray([4, 8, 15, 30], dims="band")
bands["band_end"] = xr.DataArray([8, 15, 30, 49], dims="band")
bands["band"] = ("band", ["before_beta", "low_beta", "high_beta", "after_beta"])
bands = bands.set_coords(["band_start", "band_end"])

@register_func("pwelch_band", "pwelch", "preprocess")
def pwelch_band_preprocess(a: xr.DataArray):
    a = a.where((a["f"] >= bands["band_start"]) & (a["f"] <= bands["band_end"])).mean("f")
    return a

@register_func("pwelch_band", "pwelch", "compute")
def pwelch_band(a: xr.DataArray):
    return a.mean(dim="Contact")

@register_func("pwelch_max_f", "pwelch", "compute")
def pwelch_max_f(a: xr.DataArray):
    return a.mean(dim="Contact").sel(f=slice(7, 34)).idxmax("f")


In [11]:
@register_func("coherence_norm", "coherence", "preprocess")
def coherence_norm_preprocess(a: xr.DataArray):
    return np.abs(a)

@register_func("coherence_norm", "coherence", "compute")
def coherence_norm(a: xr.DataArray):
    return a.mean(dim="Contact_pair")

@register_func("coherence_norm_band", "coherence", "preprocess")
def coherence_norm_band_preprocess(a: xr.DataArray):
    a = np.abs(a.where((a["f"] >= bands["band_start"]) & (a["f"] <= bands["band_end"]))).mean("f")
    return a

@register_func("coherence_norm_band", "coherence", "compute")
def coherence_norm_band(a: xr.DataArray):
    return a.mean(dim="Contact_pair")
    

@register_func("coherence_norm_max_f", "coherence", "compute")
def coherence_norm_max_f_preprocess(a: xr.DataArray):
    return np.abs(a)

@register_func("coherence_norm_max_f", "coherence", "compute")
def coherence_norm_max_f(a: xr.DataArray):
    return a.mean(dim="Contact_pair").sel(f=slice(7, 34)).idxmax("f")

# Computations

## Bootstrapping library

In [12]:
def bootstrap(f, arr: xr.DataArray, sample_dim, n_resamples=10000,  bt_dist="bt_dist", vectorize=False, executor_ctr=None, progress=[]):
    import tqdm, tqdm.notebook, concurrent
    if arr.isnull().all():
        return xr.apply_ufunc(lambda a: np.full(a.shape[:-1] + (n_resamples,), np.nan), arr, input_core_dims=[[sample_dim]], output_core_dims=[["bt_dist"]])
    if vectorize is True:
        vectorize=n_resamples
    if vectorize is False:
        vectorize=1
    n = arr.sizes[sample_dim]

    def mk_sample(start, end):
        samples = np.arange(start, min(end, n_resamples))
        choices = np.random.choice(np.arange(n), len(samples) * n, replace=True).reshape((n, len(samples)))
        choices = xr.DataArray(data=choices, dims=["sample", bt_dist], coords={"sample":np.arange(n), bt_dist:samples})
        sample = arr.isel({sample_dim:choices}).drop([sample_dim]).rename(sample=sample_dim)
        return sample

    all_res = []
    if not executor_ctr is None:
        with executor_ctr() as executor:
            futures = {}
            for i in tqdm.notebook.tqdm(range(0, n_resamples, vectorize), desc="Submitting", disable="Submit" not in progress):
                sample = mk_sample(i, i+vectorize)
                if vectorize == 1:
                    sample = sample.isel(sample_dim=0)
                futures[i] = executor.submit(f, sample)

            for f in tqdm.notebook.tqdm(concurrent.futures.as_completed(futures.values()), desc="computing", disable = "Compute" not in progress):
                all_res.append(f.result())
    else:
        for i in tqdm.notebook.tqdm(range(0, n_resamples, vectorize), desc="Computing", disable = "Compute" not in progress):
                sample = mk_sample(i, i+vectorize)
                if vectorize == 1:
                    sample = sample.isel(sample_dim=0)
                all_res.append(f(sample))
    final= xr.concat(all_res, dim=bt_dist)
    return final
    
def progress_map(self, f, *args, desc=None, **kwargs):
    import tqdm, tqdm.notebook
    bar = tqdm.notebook.tqdm(total=len(self.groups), desc=desc)
    def new_f(*a, **kw):
        bar.update(1)
        return f(*a, **kw)
    res = self.map(new_f, *args, **kwargs)
    bar.close()
    return res

xr.core.groupby.DataArrayGroupBy.progress_map = progress_map




## Dataset selection

In [13]:

analysis_dataset = dataset #.sel(sig_type=["bua"])
tmp = analysis_dataset["coherence"]
# .where(
#     analysis_dataset["sig_type_1"].isin(["bua"]) & analysis_dataset["sig_type_2"].isin(["bua"]), drop=True)
analysis_dataset = analysis_dataset.drop_dims("sig_type_pair")
analysis_dataset["coherence"] = tmp
analysis_dataset


## Computing averages

In [14]:
averages_dataset = xr.Dataset()
for fname, f in functions.items():
    a = analysis_dataset[f["on"]]
    if "Contact" in a.dims:
        grp = "Contact_grp"
    elif "Contact_pair" in a.dims:
        grp = "Contact_pair_grp"
    else:
        raise Exception("Strange")
    if "preprocess" in f["funcs"]:
        a = f["funcs"]["preprocess"](a)
    try:
        averages_dataset[f"{fname}_avg"] = a.groupby(grp).map(f["funcs"]["compute"]).unstack()
    except Exception as e:
        e.add_note(f"During computation of {fname}")
        raise e

averages_dataset

## Computing bootstrap

In [15]:

bootstrap_dataset = xr.Dataset()
n_bootstrap = 10**4

In [16]:



for fname, f in tqdm.notebook.tqdm(functions.items(), desc="Computing functions"):
    a = analysis_dataset[f["on"]]
    if "Contact" in a.dims:
        grp = "Contact_grp"
        dim="Contact"
        sel="Contact_bootstrap_sel"
    elif "Contact_pair" in a.dims:
        grp = "Contact_pair_grp"
        dim="Contact_pair"
        sel="Contact_pair_bootstrap_sel"
    else:
        raise Exception("Strange")
    if "preprocess" in f["funcs"]:
        a = f["funcs"]["preprocess"](a)

    bootstrap_dataset[f"{fname}_dist"] = a.groupby(grp).progress_map(
        lambda a: bootstrap(f["funcs"]["compute"], a, sample_dim = dim, 
                            n_resamples=n_bootstrap, vectorize=200, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
    ), desc=f"Computing {fname}").unstack()

bootstrap_dataset

Computing functions:   0%|          | 0/6 [00:00<?, ?it/s]

Computing pwelch:   0%|          | 0/16 [00:00<?, ?it/s]

  return np.nanmean(a, axis=axis, dtype=dtype)



Computing pwelch_band:   0%|          | 0/16 [00:00<?, ?it/s]

Computing pwelch_max_f:   0%|          | 0/16 [00:00<?, ?it/s]

Computing coherence_norm:   0%|          | 0/35 [00:00<?, ?it/s]

: 

## Computing confidence intervals

In [None]:
confidence_intervals = bootstrap_dataset.quantile([0.05, 0.95], dim=["bt_dist"], skipna=True)
confidence_intervals= confidence_intervals.rename({k:k.replace("dist", "quantile") for k in confidence_intervals.data_vars})
confidence_intervals

## Merging all results

In [None]:
all_results = xr.merge([averages_dataset, bootstrap_dataset, confidence_intervals])
pickle.dump(all_results, open(cache_path + "results_for_plotting.pkl", "wb"))

# Plotting

# PValues

## Pvalue function

In [None]:
def get_p_value(a):
    def pvalue(x):
        if np.isnan(x).any():
            return np.nan, np.nan
        nover = (x > 0).sum()
        nbelow = (x < 0).sum()
        if nover > nbelow:
            return "+", 1 - nover/x.size
        else:
            return "-", 1 - nbelow/x.size
    ret = xr.Dataset()
    ret["dir"], ret["pvalue"] = xr.apply_ufunc(pvalue,a,
        input_core_dims=[["bt_dist"]], output_core_dims=[[], []], output_dtypes=[object, float],  vectorize=True)
    return ret

## Pvalue for healthy vs park

In [None]:
res = get_p_value(
    all_results["pwelch_band_dist"].sel(Healthy=0) - all_results["pwelch_band_dist"].sel(Healthy=1)
).to_dataframe().join(
    all_results["pwelch_band_avg"].sel(Healthy=0, drop=True).drop(["band_start", "band_end"]).to_dataframe()
)
widget = DataGrid(res.reset_index())
widget

DataGrid(auto_fit_params={'area': 'all', 'padding': 30, 'numCols': None}, corner_renderer=None, default_render…

## Pvalue for comparing f_max

In [None]:
res = get_p_value(
    all_results["pwelch_max_f_dist"].sel(Species="Rat") - all_results["pwelch_max_f_dist"].sel(Species="Monkey")
).to_dataframe().join(
    all_results["pwelch_max_f_avg"].sel(Species="Rat", drop=True).to_dataframe()
)
widget = DataGrid(res.reset_index())
widget

DataGrid(auto_fit_params={'area': 'all', 'padding': 30, 'numCols': None}, corner_renderer=None, default_render…

In [None]:
all_results["coherence_norm_dist"].size

311040000

# Old

### Pwelch

In [None]:
# def compute_mean(a):
#     return a.mean(dim="Contact")

# grp = analysis_dataset["pwelch"]

# bootstrap_dataset["pwelch_dist"] = grp.groupby("Contact_grp", squeeze=False).progress_map(
#     lambda a: bootstrap(compute_mean, a, sample_dim = "Contact", n_resamples=n_boostrap, vectorize=1000, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
#     )).unstack()

# bootstrap_dataset["pwelch"] = grp.groupby("Contact_grp", squeeze=False).map(
#     compute_mean).unstack()

# bootstrap_dataset

### Pwelch max f

In [None]:
# def compute_max_f(a):
#     return a.sel(f_interp=slice(8, 34)).mean(dim="Contact").idxmax("f_interp")

# grp = analysis_dataset["pwelch"]

# bootstrap_dataset["pwelch_max_f_dist"] = grp.groupby("Contact_grp", squeeze=False).progress_map(
#     lambda a: bootstrap(compute_max_f, a, sample_dim = "Contact", n_resamples=n_boostrap, vectorize=1000, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
#     )).unstack()

# bootstrap_dataset["pwelch_max_f"] = grp.groupby("Contact_grp", squeeze=False).map(
#     compute_max_f).unstack()

# bootstrap_dataset

In [None]:
# def compute_max_f(a):
#     averaged = a.mean(dim="Contact")
#     fit_part = xr.concat([averaged.sel(f=slice(4, 9)), averaged.sel(f=slice(34, 37))], dim="f")

#     def fit(arr, f): 
#         Y = arr*f
#         X = f
#         import sklearn, sklearn.linear_model, sklearn.compose
#         model = sklearn.linear_model.LinearRegression()
#         return model.fit(X.reshape(-1, 1), Y)
    
#     def predict(model, f):
#         X = f
#         Y = model.predict(X.reshape(-1, 1))
#         res = Y/f
#         return res
    
#     model = xr.apply_ufunc(fit, fit_part, fit_part["f"], input_core_dims=[["f"]]*2 + [[]], vectorize=True)
#     fitted = xr.apply_ufunc(predict, model, averaged["f"], input_core_dims=[[], ["f"]], output_core_dims=[["f"]], vectorize=True)

#     res = xr.Dataset()
#     res["fitted"] = (averaged-fitted).sel(f=slice(7, 34)).idxmax("f")
#     res["none"] = averaged.sel(f=slice(7, 34)).idxmax("f")

#     return res.to_array(dim="fit_method")

# grp = analysis_dataset["pwelch"]

# bootstrap_dataset["pwelch_max_f_dist"] = grp.groupby("Contact_grp", squeeze=False).progress_map(
#     lambda a: bootstrap(compute_max_f, a, sample_dim = "Contact", n_resamples=n_boostrap, vectorize=1000, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
#     )).unstack()

# bootstrap_dataset["pwelch_max_f"] = grp.groupby("Contact_grp", squeeze=False).map(
#     compute_max_f).unstack()

# bootstrap_dataset

### Pwelch band val

In [None]:
# def compute_mean_f(a):
#     return a.mean(dim="Contact")

# grp = analysis_dataset["pwelch_band"]

# bootstrap_dataset["pwelch_band_dist"] = grp.groupby("Contact_grp", squeeze=False).progress_map(
#     lambda a: bootstrap(compute_mean_f, a, sample_dim = "Contact", n_resamples=n_boostrap, vectorize=1000, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
#     )).unstack()

# bootstrap_dataset["pwelch_band"] = grp.groupby("Contact_grp", squeeze=False).map(
#     compute_mean_f).unstack()

# bootstrap_dataset

### Coherence norm

In [None]:
# def compute_mean(a):
#     return np.abs(a).mean(dim="Contact_pair")

# grp = analysis_dataset["coherence"]

# bootstrap_dataset["coherence_norm_dist"] = grp.groupby("Contact_pair_grp", squeeze=False).progress_map(
#     lambda a: bootstrap(compute_mean, a, sample_dim = "Contact_pair", n_resamples=n_boostrap, vectorize=200, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
#     )).unstack()


# bootstrap_dataset["coherence_norm"] = grp.groupby("Contact_pair_grp", squeeze=False).map(
#     compute_mean).unstack()

# bootstrap_dataset

### Coherence_phase

In [None]:
# def compute_mean(a):
#     return xr.apply_ufunc(np.angle, a).mean(dim="Contact_pair")

# grp = analysis_dataset["coherence"]

# bootstrap_dataset["coherence_phase_dist"] = grp.groupby("Contact_pair_grp", squeeze=False).progress_map(
#     lambda a: bootstrap(compute_mean, a, sample_dim = "Contact_pair", n_resamples=n_boostrap, vectorize=500, executor_ctr=lambda: concurrent.futures.ThreadPoolExecutor(10)
#     )).unstack()


# bootstrap_dataset["coherence_phase"] = grp.groupby("Contact_pair_grp", squeeze=False).map(
#     compute_mean).unstack()

# bootstrap_dataset

## Generating confidence intervals

In [None]:
# to_compute = bootstrap_dataset[[v for v in bootstrap_dataset.data_vars if not v.endswith("quantile")]]
# tmp = to_compute.quantile([0.05, 0.95], dim=["bt_dist"], skipna=True)
# tmp = tmp[[v for v in tmp.data_vars if "quantile" in tmp[v].dims and not v.endswith("quantile")]]
# tmp=tmp.rename({k:f"{k}_quantile" for k in tmp.data_vars})
# tmp
# bootstrap_dataset = xr.merge([tmp, to_compute])
# bootstrap_dataset

In [None]:
# bootstrap_dataset["pwelch_max_f2_dist_quantile"].to_dataset("quantile").to_dataframe().join(
#     bootstrap_dataset["pwelch_max_f2"].to_dataframe(), 
# ).unstack("method")

## Getting p_values

In [None]:
# get_p_value(
#     bootstrap_dataset["pwelch_band_dist"].sel(Healthy=0) - bootstrap_dataset["pwelch_band_dist"].sel(Healthy=1)
# ).to_dataframe().merge(
#     bootstrap_dataset["pwelch_band"].sel(Healthy=0, drop=True).drop(["band_start", "band_end"]).to_dataframe(), left_index=True, right_index=True,
# )

In [None]:
# get_p_value(
#     bootstrap_dataset["pwelch_max_f2_dist"].sel(Species="Rat", Healthy=0, method="fitted") - bootstrap_dataset["pwelch_max_f2_dist"].sel(Species="Monkey", Healthy=0, method="fitted")
# ).to_dataframe().merge(
#     bootstrap_dataset["pwelch_max_f2"].sel(Species="Rat", Healthy=0, method="fitted", drop=True).to_dataframe(), left_index=True, right_index=True,
# )