In [None]:
import numpy as np, pandas as pd, xarray as xr
from pathlib import Path
import datetime, networkx as nx, yaml, warnings, sys, logging
from disjoint_set import DisjointSet
from helper import singleglob, nxrender, Step, dict_merge, PdfWriter, DictWriter, TableWriter, TextWriter
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns
mpl.rcParams['figure.figsize'] = [15, 6]

In [None]:
notebook_name = Path("plot_rt_mt.ipynb")

In [None]:
base_folder = Path(".").resolve().parent
result_folder = Path(".")/notebook_name.stem
result_folder.mkdir(exist_ok=True, parents=True)


In [None]:
tables = TableWriter(result_folder/"tables.xlsx")
figures = PdfWriter(result_folder/"figures.pdf")
dicts = DictWriter(result_folder/"dicts.yaml")
notebook_save_path = result_folder/"notebook.html"
warn =  TextWriter(result_folder/"warnings.txt")

In [None]:
input_data = pd.DataFrame()
input_data["sessions"] = [base_folder] if base_folder.stem.startswith("Session--") else [f for f in base_folder.glob("**/Session--*")]
input_data["metadata_file"] = input_data["sessions"].apply(lambda f: f/"metadata.yaml")
input_data["trial_event_file"] = input_data["sessions"].apply(lambda f: f/"analysis"/"events"/"trial_events.nc")
input_data["run"] = input_data.apply(lambda row: (row["metadata_file"].exists() & row["trial_event_file"].exists()), axis=1)
tables.write(input_data=input_data)


In [None]:
run_data = input_data.loc[input_data["run"]]
run_data = pd.concat([run_data, run_data])
datasets = []
for _, row in run_data.iterrows():
    data = xr.load_dataset(row["trial_event_file"])
    metadata = yaml.safe_load(row["metadata_file"].open("r"))
    data = data.assign(session_date=metadata["date"], subject=metadata["subject"]["name"], opsin=metadata["subject"]["opsin"], handedness=metadata["subject"]["handedness"], task=metadata["task"]["name"])
    data["trial_num"] = data["trial"]
    data = data.drop_vars("trial")
    datasets.append(data)
all = xr.concat(datasets, dim="trial")
all

In [None]:
all["rt"] = all["event_t"].sel(event_name="mouvement_start") - all["event_t"].sel(event_name="cue")
all["mt"] = all["event_t"].sel(event_name="leverpress") - all["event_t"].sel(event_name="mouvement_start")
all["sucess"] = all["event_t"].sel(event_name = "error").isnull() & all["rt"].notnull() & all["mt"].notnull() & (all["rt"] < 2) & (all["mt"] < 2)
all

In [None]:
sucess = all.where(all["sucess"], drop=True)
sucess

In [None]:
motor_df = sucess.drop_dims("event_name").to_dataframe()
print(motor_df.to_string())

In [None]:

motor_df = motor_df.loc[motor_df["go"].fillna(0) > 0.5]
motor_df = motor_df.set_index([c for c in motor_df.columns if not c in ["rt", "mt"]]).stack()
motor_df.index.rename("var", level=-1, inplace=True)
motor_df = motor_df.rename("value").reset_index()
motor_df["trial_type"] = motor_df["cue_type"].astype(str) + ", "+motor_df["stimulation"].astype(str)
motor_df["protocol"] = motor_df["task"].str.extract('(_Cue-RT)', expand=False).fillna("_MT")
print(motor_df.to_string()) 

In [None]:
zscore_df = motor_df.loc[(motor_df["cue_type"]=="LowPitch") & (motor_df["stimulation"]=="None")].groupby(["var", "session_date", "subject"], as_index=False)["value"].agg(["mean", "std"]).reset_index(level=0, drop=True)
zscore_df

In [None]:
merged_df = pd.merge(motor_df, zscore_df, how="left", on=["var", "session_date", "subject"])
merged_df["value_zscored"] = (merged_df["value"] - merged_df["mean"])/merged_df["std"]
merged_df = merged_df[(~(merged_df["std"]==0)) & (~merged_df["std"].isna()) ]
merged_df

In [None]:
final_df = merged_df.reset_index(names=["trial"]).drop(columns=["mean", "std"])
final_df


In [None]:
print(final_df[final_df["task"]=="HF_55_RandomTrial_NoGoGo_left_both_sound_LaserChrim_S1_Cue-RT1_4p15Hz_2rew_NoPadCheck_3000"].to_string())

In [None]:
vars = final_df["var"].drop_duplicates().to_list()
opsins = final_df["opsin"].drop_duplicates().to_list()
trial_types = final_df["trial_type"].drop_duplicates().to_list()
tasks = final_df["task"].drop_duplicates().to_list()
subjects=final_df["subject"].drop_duplicates().to_list()
vars, opsins, trial_types, tasks, subjects

In [None]:
print(final_df.groupby(["var", "opsin","protocol", "task", "trial_type", ])["value"].count().to_string())

In [None]:

if  final_df["subject"].nunique() > 1:
    points_col="session_date"
    avg_col = "subject"
else:
    points_col = "trial"
    avg_col="session_date"
avg_col_order =  sorted(final_df[avg_col].drop_duplicates().to_list())
display_lim = [-3, 5]
for task, plot_df in final_df.groupby("protocol"):
    f, axs = plt.subplots(len(opsins), len(vars), squeeze=False, figsize=(16, 5*len(opsins)))
    for col, var in enumerate(vars):
        for row, opsin in enumerate(opsins):
            ax: plt.Axes=axs[row, col]
            grp = plot_df[(plot_df["var"]==var) & (plot_df["opsin"]==opsin)]
            sns.violinplot(data=grp,  y="value_zscored", x="trial_type", common_norm=False, order=trial_types, ax=ax, cut=0, color="wheat", alpha=0.5, zorder=5, inner=None)
            points = grp.groupby(["trial_type", points_col, avg_col])["value_zscored"].mean().reset_index()
            sns.swarmplot(data=points,  y="value_zscored", x="trial_type", hue=avg_col, hue_order=avg_col_order, order=trial_types, ax=ax, size=1.5, dodge=True, legend=False, zorder=1)
            avgs = grp.groupby(["trial_type", avg_col])["value_zscored"].mean().reset_index()
            sns.swarmplot(data=avgs,  y="value_zscored", x="trial_type", 
                        hue=avg_col, hue_order=avg_col_order, order=trial_types, linewidth=10, size=1, ax=ax, dodge=True, legend=False if col!=0 or row!=0 else "auto"
                        , zorder=0, marker="|")
            total = grp.groupby(["trial_type"])["value_zscored"].agg(["mean", "median"]).reset_index()
            sns.scatterplot(data=total, x="trial_type", y="mean", color="black",  linewidth=10, size=1.2, zorder=6, marker="|", ax=ax, legend=False)
            ax.axhline(y=0, color="gray", alpha=0.5, linewidth=0.7) #linestyle=(0, (1,3))
            ax.set_title(f"{var}, {opsin}")
            ax.set_ylim(display_lim)
            if col==0 and row==0:
                lgd = f.legend(handles=ax.legend().legend_handles, ncols=5, loc='lower center', bbox_to_anchor=(0.5, -0.04*3/(len(opsins))))
                ax.legend().remove()
    suptitle = plt.suptitle(task)
    figures.write(bbox_extra_artists=(lgd,suptitle), bbox_inches='tight')
    plt.tight_layout()
    plt.show()
    


In [None]:
del tables
del dicts
del figures