In [None]:
import numpy as np, pandas as pd, xarray as xr
from pathlib import Path
import datetime, networkx as nx, yaml, warnings, sys, logging
from disjoint_set import DisjointSet
from helper import singleglob, nxrender, Step, dict_merge, PdfWriter, DictWriter, TableWriter, TextWriter
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns
mpl.rcParams['figure.figsize'] = [15, 6]

In [None]:
notebook_name = Path("plot_heatmap.ipynb")

In [None]:
base_folder = Path(".").resolve().parent
result_folder = Path(".")/notebook_name.stem
result_folder.mkdir(exist_ok=True, parents=True)


In [None]:
tables = TableWriter(result_folder/"tables.xlsx")
figures = PdfWriter(result_folder/"figures.pdf")
dicts = DictWriter(result_folder/"dicts.yaml")
notebook_save_path = result_folder/"notebook.html"
warn =  TextWriter(result_folder/"warnings.txt")

In [None]:
input_data = pd.DataFrame()
input_data["sessions"] = [base_folder] if base_folder.stem.startswith("Session--") else [f for f in base_folder.glob("**/Session--*")]
input_data["metadata_file"] = input_data["sessions"].apply(lambda f: f/"metadata.yaml")
input_data["trial_data_file"] = input_data["sessions"].apply(lambda f: f/"analysis"/"trial_data"/"trial_data.nc")
input_data["run"] = input_data.apply(lambda row: (row["metadata_file"].exists() & row["trial_data_file"].exists()), axis=1)
tables.write(input_data=input_data)


In [None]:
run_data = input_data.loc[input_data["run"]]
run_data = pd.concat([run_data, run_data])
datasets = []
for _, row in run_data.iterrows():
    data = xr.load_dataset(row["trial_data_file"])
    metadata = yaml.safe_load(row["metadata_file"].open("r"))
    data = data.assign(session_date=metadata["date"], subject=metadata["subject"]["name"], handedness=metadata["subject"]["handedness"], opsin=metadata["subject"]["opsin"], task=metadata["task"]["name"])
    data["trial_num"] = data["trial"]
    data = data.drop_vars("trial")
    datasets.append(data)
all = xr.concat(datasets, dim="trial")
all

In [None]:
kept = all.where(all["sucess"] & (all["go"].fillna(0) > 0.5) & (all["cue_type"]=="LowPitch"), drop=True)
kept = kept.where(kept["event_name"].isin(["cue", "leverpress", "mouvement_start"]))
kept["protocol"] = "Protocol"+ kept["task"].str.extract('(_Cue-RT)', dim=None).fillna("_MT")
kept["trial_type"] = kept["cue_type"] + ", " + kept["stimulation"]
kept["side"] = xr.where(kept["channel_hemi"].str.lower() == kept["handedness"].str.lower(), "ipsi", "contra")

In [None]:
sides = kept["side"].drop_duplicates(...).to_list()
opsins = kept["opsin"].drop_duplicates(...).to_list()
trial_types = kept["trial_type"].drop_duplicates(...).to_list()
subjects=kept["subject"].drop_duplicates(...).to_list()
events = kept["event_name"].drop_duplicates(...).to_list()
protocols = kept["protocol"].drop_duplicates(...).to_list()
sides, opsins, trial_types, subjects, events, protocols

In [None]:
# print(tables.write(counts=final_df.groupby(["opsin","protocol", "trial_type"]).size(name="count").reset_index()).to_string())

In [None]:

# if  final_df["subject"].nunique() > 1:
#     points_col="session_date"
#     avg_col = "subject"
# else:
#     points_col = "trial"
#     avg_col="session_date"
# avg_col_order =  sorted(final_df[avg_col].drop_duplicates().to_list())
# display_lim = [-3, 5]
from itertools import product
for side, opsin, protocol in product(sides, opsins, protocols):
    f, axs = plt.subplots(len(events), len(trial_types)+1, squeeze=False, figsize=(16, 5*len(events)))
    for col, trial_type in enumerate(trial_types):
        for row, event in enumerate(events):
            arr = kept.where((kept["side"] == side) & (kept["opsin"] == opsin) & (kept["protocol"] == protocol) & (kept["trial_type"] == trial_type)).sel(event_name=event)
            print(arr)
            raise Exception("Stop")
# for task, plot_df in final_df.groupby("protocol"):
#     f, axs = plt.subplots(len(opsins), len(vars), squeeze=False, figsize=(16, 5*len(opsins)))
#     for col, var in enumerate(vars):
#         for row, opsin in enumerate(opsins):
#             ax: plt.Axes=axs[row, col]
#             grp = plot_df[(plot_df["var"]==var) & (plot_df["opsin"]==opsin)]
#             sns.violinplot(data=grp,  y="value_zscored", x="trial_type", common_norm=False, order=trial_types, ax=ax, cut=0, color="wheat", alpha=0.5, zorder=5, inner=None)
#             points = grp.groupby(["trial_type", points_col, avg_col])["value_zscored"].mean().reset_index()
#             sns.swarmplot(data=points,  y="value_zscored", x="trial_type", hue=avg_col, hue_order=avg_col_order, order=trial_types, ax=ax, size=1.5, dodge=True, legend=False, zorder=1)
#             avgs = grp.groupby(["trial_type", avg_col])["value_zscored"].mean().reset_index()
#             sns.swarmplot(data=avgs,  y="value_zscored", x="trial_type", 
#                         hue=avg_col, hue_order=avg_col_order, order=trial_types, linewidth=10, size=1, ax=ax, dodge=True, legend=False if col!=0 or row!=0 else "auto"
#                         , zorder=0, marker="|")
#             total = grp.groupby(["trial_type"])["value_zscored"].agg(["mean", "median"]).reset_index()
#             sns.scatterplot(data=total, x="trial_type", y="mean", color="black",  linewidth=10, size=1.2, zorder=6, marker="|", ax=ax, legend=False)
#             counts = grp.groupby("trial_type")["value_zscored"].count().reset_index()
#             for _, r in counts.iterrows():
#                 ax.text(x=trial_types.index(r["trial_type"]), y=display_lim[0]*0.9, s= f'n_trials={r["value_zscored"]}', horizontalalignment="center")
#             ax.axhline(y=0, color="gray", alpha=0.5, linewidth=0.7) #linestyle=(0, (1,3))
#             ax.set_title(f"{var}, {opsin}")
#             ax.set_ylim(display_lim)
#             if col==0 and row==0:
#                 lgd = f.legend(handles=ax.legend().legend_handles, ncols=5, loc='lower center', bbox_to_anchor=(0.5, -0.04*3/(len(opsins))))
#                 ax.legend().remove()
#     suptitle = plt.suptitle(task)
#     figures.write(bbox_extra_artists=(lgd,suptitle), bbox_inches='tight')
#     plt.tight_layout()
#     plt.show()
    


In [None]:
# for task, plot_df in final_df.groupby("protocol"):
#     f, axs = plt.subplots(len(opsins), 1, squeeze=False, figsize=(16, 5*len(opsins)))
#     for row, opsin in enumerate(opsins):
#         ax: plt.Axes=axs[row, 0]
#         grp = plot_df[(plot_df["opsin"]==opsin)][["trial", "var", "value_zscored", "trial_type"]].set_index(["trial", "var", "trial_type"])["value_zscored"].unstack("var").reset_index()
#         grp=grp.rename(columns=dict(rt="rt_zscored", mt="mt_zscored"))
#         def get_regression_lines(d):
#             a, b = np.polyfit(d["rt_zscored"], d["mt_zscored"], 1)
#             return pd.DataFrame([dict(rt_zscored=x, mt_zscored=a*x +b) for x in [d["rt_zscored"].min(), d["rt_zscored"].max()]])
#         regressions = grp.groupby("trial_type").apply(get_regression_lines, include_groups=False).reset_index()
#         # print(regressions)
#         sns.scatterplot(data=grp, x="rt_zscored", y="mt_zscored", hue="trial_type", hue_order=trial_types, legend="auto" if row==0 else False, ax=ax)
#         sns.lineplot(data=regressions, x="rt_zscored", y="mt_zscored", hue="trial_type", hue_order=trial_types, legend=False, ax=ax)
#         if row==0:
#             lgd = f.legend(handles=ax.legend().legend_handles, ncols=5, loc='lower center', bbox_to_anchor=(0.5, -0.04*3/(len(opsins))))
#             ax.legend().remove()
#         ax.set_title(f"{opsin}")
#         ax.set_xlim(display_lim)
#         ax.set_ylim(display_lim)
#     suptitle = plt.suptitle(f"{task}")
#     figures.write(bbox_extra_artists=(lgd,suptitle), bbox_inches='tight')
#     plt.tight_layout()
#     plt.show()

In [None]:
del tables
del dicts
del figures