# Defining parameters

In [1]:
import pandas as pd, numpy as np, functools, scipy, xarray as xr
import toolbox, tqdm, pathlib
from typing import List, Tuple, Dict, Any, Literal
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns
from scipy.stats import gaussian_kde
import math, itertools, pickle, logging, beautifullogger
from xarray_helper import apply_file_func, auto_remove_dim, nunique, apply_file_func_decorator, extract_unique, mk_bins
import scipy.signal
from matplotlib.backends.backend_pdf import PdfPages

xr.set_options(use_flox=True, display_expand_coords=True, display_max_rows=100, display_expand_data_vars=True, display_width=150)
logger = logging.getLogger(__name__)
beautifullogger.setup(displayLevel=logging.INFO)
logging.getLogger("flox").setLevel(logging.WARNING)
tqdm.tqdm.pandas(desc="Computing")

MODE: Literal["TEST", "ALL", "SMALL", "BALANCED"]="ALL"
FIGS = ["coherence_phase_figs_structure"] #["pwelch_mean_figs_structure", "pwelch_mean_figs_species", "coherence_mean_figs_structure"]
DISPLAY=True

match MODE:
    case "TEST":
        cache_path = "/media/julien/data1/JulienCache/Test/"
    case "ALL" | "BALANCED":
        cache_path = f"/media/julien/data1/JulienCache/{'All' if MODE=='ALL' else 'Balanced'}/"
    case "SMALL":
        cache_path = "/media/julien/data1/JulienCache/Small/"


group_cols = ["Species", "Structure", "Healthy"]
pair_group_cols = [x+"_1" for x in group_cols] + [x+"_2" for x in group_cols]
species_order = ["Rat", "Monkey", "Human"]
structure_order = ["GPe", "STN", "STR"]
condition_order = [0, 1]
sig_type_order=["bua", "lfp", "spike_times"]

# Loading dataset

In [2]:
signals: xr.Dataset = pickle.load(open(cache_path+"signals_computed.pkl", "rb"))
signal_pairs: xr.Dataset = pickle.load(open(cache_path+"signal_pairs_computed.pkl", "rb"))
signal_pairs = signal_pairs.where((signal_pairs["FullStructure_1"] != "STN_VMNR") & (signal_pairs["FullStructure_2"] != "STN_VMNR"), drop=True)

KeyboardInterrupt: 

# Add group coordinates + diverse stuff to dataset

In [None]:
dataset = xr.merge([signals, signal_pairs])
dataset = dataset[[var for var in dataset.variables if ("pwelch" in var) or "coherence" in var]]
scipy_freq_coords = dataset["f2"].to_numpy()
dataset = dataset.interp(f=np.linspace(3, 50, 94, endpoint=False), f2=np.linspace(3, 50, 94, endpoint=False))
for col in dataset.variables:
    if "f2" in dataset[col].dims:
        dataset[col] = dataset[col].rename(f2="f_interp")
    if "f" in dataset[col].dims:
        dataset[col] = dataset[col].rename(f="f_interp")
dataset = dataset.drop(["f", "f2"])
dataset["is_scipy_freq"] = dataset["f_interp"].isin(np.round(scipy_freq_coords*2)/2)
dataset = dataset.set_coords("is_scipy_freq")
dataset["pwelch"] = dataset[["pwelch_cwt", "pwelch_spectrogram"]].to_array(dim="spectral_analysis_function")
dataset.drop_vars(["pwelch_cwt", "pwelch_spectrogram"])
dataset["spectral_analysis_function"] = dataset["spectral_analysis_function"].str.replace("pwelch_", "").str.replace("spectrogram", "scipy.spectrogram").str.replace("cwt", "pycwt.cwt")
coherence = dataset[["coherence_scipy"]].to_array(dim="spectral_analysis_function", name="coherence")
coherence["spectral_analysis_function"] = coherence["spectral_analysis_function"].str.replace("coherence_", "").str.replace("scipy", "scipy.coherence").str.replace("wct", "pycwt.wct")
dataset=xr.merge([dataset, coherence])
dataset = dataset.drop_vars(["coherence_scipy"])



for dim, groupcols in dict(Contact=group_cols, Contact_pair=pair_group_cols, sig_preprocessing=["sig_type"], sig_preprocessing_pair=["sig_type_1", "sig_type_2"]).items():
    grpname = dim+"_grp"
    dataset[grpname] = xr.DataArray(
        pd.MultiIndex.from_arrays([dataset[a].data for a in groupcols],names=groupcols), 
        dims=[dim], coords=[dataset[dim]]
    )
    dataset = dataset.set_coords(grpname)

dataset["coherence_norm"] = np.abs(dataset["coherence"])

# Creating dataset_contact

In [None]:
dataset_contact = xr.Dataset()
dataset_contact["pwelch"] = dataset["pwelch"].groupby("sig_preprocessing_grp").mean()
dataset_contact["coherence"] = dataset["coherence"].groupby("sig_preprocessing_pair_grp").mean()
dataset_contact["coherence_norm"] = dataset["coherence_norm"].groupby("sig_preprocessing_pair_grp").mean()


# Creating grouped dataset

In [None]:
grouped_dataset = xr.Dataset()

grouped_dataset["pwelch"] = dataset["pwelch"].groupby("sig_preprocessing_grp").mean().groupby("Contact_grp").mean()
grouped_dataset["n_contacts"] = dataset["pwelch"].groupby("sig_preprocessing_grp").mean().groupby("Contact_grp").count("Contact").sel(f_interp=4)
grouped_dataset["coherence"] = dataset["coherence"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").mean()
grouped_dataset["coherence_norm"] = dataset["coherence_norm"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").mean()
grouped_dataset["n_contact_pairs"] = dataset["coherence"].groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp").count("Contact_pair").sel(f_interp=4)
grouped_dataset = grouped_dataset.set_coords(["n_contact_pairs", "n_contacts"])

grouped_dataset["coherence_phase"] = xr.apply_ufunc(np.angle, grouped_dataset["coherence"])
grouped_dataset["coherence_validity"] = np.abs(grouped_dataset["coherence"])/grouped_dataset["coherence_norm"]
grouped_dataset["f_max"] = grouped_dataset["coherence_norm"].sel(f_interp=slice(7, 40)).idxmax("f_interp")

# Adding the column of interest

In [None]:
print(dataset)
print(grouped_dataset)


In [None]:
def f(x: np.ndarray, coord, dim):
    other = dataset[dim].to_numpy()
    print(x.shape, other.shape)
    
    positions = np.where(coord[:, None] == other[None, :])[0]
    res = np.apply_along_axis(lambda x: x[positions], -1, x)
    return res

f_max_tmp = xr.apply_ufunc(lambda x, y: f(x, y, "Contact_pair_grp"), grouped_dataset["f_max"], grouped_dataset["Contact_pair_grp"], input_core_dims=[["Contact_pair_grp"], ["Contact_pair_grp"]], output_core_dims=[["Contact_pair"]])
f_max_tmp = xr.apply_ufunc(lambda x, y: f(x, y, "sig_preprocessing_pair_grp"), f_max_tmp, grouped_dataset["sig_preprocessing_pair_grp"], input_core_dims=[["sig_preprocessing_pair_grp"], ["sig_preprocessing_pair_grp"]], output_core_dims=[["sig_preprocessing_pair"]])
dataset["f_max_coherence"] = f_max_tmp
print(dataset)

In [None]:

tmp_group = dataset["coherence"].sel(f_interp=dataset["f_max_coherence"].fillna(4)).groupby("sig_preprocessing_pair_grp").mean().groupby("Contact_pair_grp")


In [None]:

phase_bins = np.linspace(-np.pi, np.pi, 180)

def mk_phase_info(x: xr.DataArray):
    m = x*np.exp(-1j * xr.DataArray(phase_bins, dims=["phase_bins"], coords=dict(phase_bins=phase_bins)))
    angle = xr.apply_ufunc(np.angle, m)
    dot = np.real(m)
    res = dot.where(abs(angle) <np.pi/5).sum("Contact_pair")
    return res

grouped_dataset["coherence_phase_bins"] = tmp_group.map(mk_phase_info)
print(grouped_dataset["coherence_phase_bins"])

In [None]:
from xarray_helper import normalize


# Post handling

In [None]:


# print(grouped_dataset["coherence_phase_bins"].sel(f_interp=grouped_dataset["f_max"].fillna(grouped_dataset["f_interp"].min())))
# exit()

selected = (grouped_dataset["f_interp"] > grouped_dataset["f_max"]-2) & (grouped_dataset["f_interp"] < grouped_dataset["f_max"]+2) & grouped_dataset["is_scipy_freq"]
grouped_dataset["coherence_phase"] = grouped_dataset["coherence_phase"].where(selected)

for col in group_cols:
    if grouped_dataset[f"{col}_1"].equals(grouped_dataset[f"{col}_2"]):
        grouped_dataset[col+"(common)"] = grouped_dataset[f"{col}_1"]
        grouped_dataset=grouped_dataset.set_coords(col+"(common)")
    if dataset_contact[f"{col}_1"].equals(dataset_contact[f"{col}_2"]):
        dataset_contact[col+"(common)"] = dataset_contact[f"{col}_1"]
        dataset_contact=dataset_contact.set_coords(col+"(common)")

print(grouped_dataset)

In [None]:
print(grouped_dataset)

# Drawing coherence phase

In [None]:
grouped_dataset["coherence_phase_bins"] = grouped_dataset["coherence_phase_bins"]/grouped_dataset["coherence_phase_bins"].max("phase_bins")
data = (
    grouped_dataset["coherence_phase_bins"]
    .where(grouped_dataset["sig_type_1"]=="bua")
    .where(grouped_dataset["n_contact_pairs"] > 10)
)
print(data)
data = data.to_dataframe().dropna(axis="index", subset="coherence_phase_bins")
print(data)
(toolbox.FigurePlot(
    data = data,
    figures=["spectral_analysis_function", "sig_type_1", "sig_type_2"], col="Structure_1", row="Structure_2", 
    sharey=False, margin_titles=True, fig_title="coherence_phase_bins_structure_{spectral_analysis_function}, {sig_type_1}, {sig_type_2}", 
    row_order=structure_order, col_order=structure_order,
    subplot_kws=dict(projection='polar'), despine=False,
).map(sns.lineplot, x="phase_bins", y="coherence_phase_bins", hue="Species(common)", style="Healthy(common)", hue_order=species_order, style_order=condition_order)
.add_legend().set(yticks=[])).maximize().save_pdf(f"{cache_path}Figures/coherence_phase_bins_structure.pdf")
plt.show()