In [None]:
import numpy as np, pandas as pd, xarray as xr
from pathlib import Path
import datetime, networkx as nx, yaml
from helper import singleglob, nxrender, Step, dict_merge
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns

In [None]:
base_folder = Path(".").resolve().parent.parent
node_metadata_step = Step(Path(".").resolve().parent / "task_graph" / "task_graph.ipynb",  "node_metadata.yaml")
metadata_path = singleglob(base_folder, "metadata.yaml", "metadata --*.yaml", "metadata--*.yaml")
poly_dat_path = singleglob(base_folder, "*.dat")
fiber_events_path = singleglob(base_folder, "**/Events.csv", "**/Events --*.csv", "**/Events--*.csv")
base_folder, node_metadata_step, metadata_path, poly_dat_path, fiber_events_path

In [None]:
event_metadata = yaml.safe_load(metadata_path.open("r"))["task"]["events"]
event_metadata


# Handling Poly events

In [None]:
node_metadata = yaml.safe_load(node_metadata_step.exec_if_necessary().open("r"))
node_metadata

In [None]:
poly_event_df = pd.read_csv(poly_dat_path, sep="\t", names=['time (ms)', 'family', 'nbre', '_P', '_V', '_L', '_R', '_T', '_W', '_X', '_Y', '_Z'], skiprows=13)
poly_event_df.insert(0, "t", poly_event_df["time (ms)"]/1000)
poly_event_df = poly_event_df.sort_values("t")
print(poly_event_df.to_string())

In [None]:
poly_event_df["state_count"] = (poly_event_df["family"]==10).cumsum()
poly_event_df["curr_node"] = poly_event_df["_T"].where(poly_event_df["family"]==10).ffill()
print(poly_event_df.to_string())

In [None]:
def reduce_grp(grp: pd.DataFrame):
    s, n = grp.name
    metadata = node_metadata[n] if n in node_metadata else {}
    t_start = grp["t"].min()
    has_pause = len(grp.loc[grp["family"]==11].index) > 0
    event = metadata["event"] if "event" in metadata else None
    return pd.Series(dict(t_start=t_start, event_name=event, metadata=metadata, has_pause=has_pause))

poly_processed_events_df= poly_event_df.groupby(["state_count", "curr_node"]).apply(reduce_grp).sort_values("t_start").reset_index()
poly_processed_events_df.insert(3, "t_end", poly_processed_events_df["t_start"].shift(-1))
poly_processed_events_df


In [None]:
final_poly_events = poly_processed_events_df.copy().dropna(subset="event_name")
final_poly_events["metadata"] = final_poly_events.apply(lambda row: dict_merge(row["metadata"], dict(poly=dict(pause=row["has_pause"], curr_node=row["curr_node"]))), axis=1)
final_poly_events = final_poly_events.drop(columns=["state_count", "curr_node", "t_end", "has_pause"])
final_poly_events["event_id"] = np.arange(len(final_poly_events.index))
print(final_poly_events.to_string())

# Handling of Fiber Events

In [None]:
raw_fiber_events = pd.read_csv(fiber_events_path).sort_values("TimeStamp")
raw_fiber_events.insert(0, "t", raw_fiber_events["TimeStamp"]/1000)
print(raw_fiber_events.to_string())

In [None]:
raw_fiber_events["ev_num"] = raw_fiber_events.groupby(["Name", "State"]).cumcount()
fiber_events=raw_fiber_events.set_index(["ev_num", "Name", "State"])["t"].unstack("State").reset_index().rename(columns={0:"t_start", 1:"t_end"}).drop(columns="ev_num")
fiber_events.columns.name=None
print(fiber_events.to_string())

In [None]:
def process_fiber_event(name, start, end):
    event = None
    metadata = None
    duration = end-start
    for d in event_metadata["fiber"]:
        detect = d["detection"]
        desc = d["description"]
        selected=True
        if "name" in detect:
            if name != detect["name"]:
                selected=False
        if "duration_min" in detect:
            if detect["duration_min"] > duration:
                selected=False
        if "duration_max" in detect:
            if detect["duration_max"] < duration:
                selected=False
        # print(detect, name, duration)
        if selected:
            if event is not None or metadata is not None:
                raise Exception(f"conflict {event} {metadata}, {detect}, {duration}")
            event = desc["event"]
            metadata = desc
    return (event, metadata)
processed_fiber_events = fiber_events.copy()
processed_fiber_events[["event_name", "metadata"]] = processed_fiber_events.apply(lambda row: process_fiber_event(row["Name"], row["t_start"], row["t_end"]), axis=1, result_type="expand")
processed_fiber_events

In [None]:
final_fiber_events = processed_fiber_events.copy().dropna(subset="event_name").sort_values("t_start")
final_fiber_events["metadata"] = final_fiber_events.apply(lambda row: dict_merge(row["metadata"], dict(fiber=dict(FiberInputNum=row["Name"]))), axis=1)
final_fiber_events = final_fiber_events.drop(columns=["t_end", "Name"])
final_fiber_events["event_id"] = np.arange(len(final_fiber_events.index))
final_fiber_events

# Synchronizing and merging

In [None]:
single_shift=False
dfs = pd.Series({"fiber": final_fiber_events, "poly": final_poly_events}, name="df").to_frame()
dfs.index.name="source"
merge_data = dfs.to_xarray()
merge_data["event_list"] = xr.apply_ufunc(lambda df: set(df["event_name"].dropna().to_list()), merge_data["df"], vectorize=True)
merge_data["event"] = xr.DataArray(list(set.union(*list(merge_data["event_list"].to_numpy()))), dims="event")
merge_data["event_df"] = xr.apply_ufunc(lambda df, ev: df.loc[df["event_name"] == ev], merge_data["df"], merge_data["event"], vectorize=True)
merge_data["num_values"] = xr.apply_ufunc(lambda df: len(df.index), merge_data["event_df"], vectorize=True)
merge_ok = ((merge_data["num_values"] == merge_data["num_values"].isel(source=0)) | (merge_data["num_values"]==0))
if not merge_ok.all():
    print(f'Possible problem for merge, problem matching number of events...')
    display(merge_ok)
merge_data["avg_value"] = xr.apply_ufunc(lambda df: df["t_start"].mean(), merge_data["event_df"], vectorize=True)
merge_data["var_value"] = xr.apply_ufunc(lambda df: df["t_start"].var(), merge_data["event_df"], vectorize=True)
merge_data["shift_value"] = (merge_data["avg_value"] - merge_data["avg_value"].isel(source=0)).where(merge_ok)
if single_shift:
    merge_data["applied_shift_value"] = merge_data["shift_value"].median("event")
else:
    merge_data["applied_shift_value"] = xr.where(merge_data["shift_value"].notnull(), merge_data["shift_value"], merge_data["shift_value"].median("event"))
merge_data["shifted_event_df"] = xr.apply_ufunc(lambda df, shift: df.copy().assign(t_shifted=df["t_start"]-shift).sort_values("t_shifted"), merge_data["event_df"], merge_data["applied_shift_value"], vectorize=True)
merge_data


In [None]:

def match_event_df(df: pd.DataFrame, target_df: pd.DataFrame):
    res = df.copy()
    if len(target_df.index) == 0: 
        res["match"] = pd.NA
    elif len(target_df.index) == len(df.index):
        res["match"] = target_df["event_id"].to_numpy()
    elif len(target_df.index) > len(df.index):
        target_ar = xr.DataArray(target_df["event_id"].to_numpy(), dims="t", coords=dict(t=target_df["t_shifted"].to_numpy()))
        res["match"] = target_ar.sel(t=res["t_shifted"].to_numpy(), method="nearest").to_numpy()
    else:
        ar = xr.DataArray(res["event_id"].to_numpy(), dims="t", coords=dict(t=res["t_shifted"].to_numpy()))
        target_ar = xr.DataArray(target_df["t_shifted"].to_numpy(), dims="event_id", coords=dict(event_id=target_df["event_id"].to_numpy()))
        match= ar.sel(t=target_ar, method="nearest").to_dataframe(name="event_id").reset_index(names="match").drop(columns="t")
        res = pd.merge(res, match, how="left", on="event_id")
    return res

merge_data["match_df"] = xr.apply_ufunc(match_event_df, merge_data["shifted_event_df"], merge_data["shifted_event_df"].isel(source=0), vectorize=True)
merge_data["reconstructed_df"] = xr.apply_ufunc(lambda dfs: pd.concat(dfs).sort_values("t_shifted"), merge_data["match_df"], input_core_dims=[["event"]], vectorize=True)
merge_data["order_changed"] = xr.apply_ufunc(lambda init, rec: (init["event_id"].to_numpy() != rec["event_id"].to_numpy()).any(), merge_data["df"], merge_data["reconstructed_df"], vectorize=True)
if merge_data["order_changed"].any():
    print(f'Possible problem for merge... order is not preserved')
    display(merge_data["order_changed"])
merge_data


In [None]:
debug_df = pd.concat([merge_data["reconstructed_df"].isel(source=i).item().assign(source=merge_data["source"].isel(source=i).item()) for i in range(merge_data.sizes["source"])]).sort_values("t_shifted")
debug_df["event_id"] = debug_df["event_id"].where(debug_df["source"] == merge_data["source"].isel(source=0).item())
print(debug_df.to_string())

In [None]:
all_events = merge_data["reconstructed_df"].isel(source=0).item().assign(source=merge_data["source"].isel(source=0).item()).drop(columns=["match", "t_start"])
for i in range(1, merge_data.sizes["source"]):
    d = merge_data["reconstructed_df"].isel(source=i).item()
    s = merge_data["source"].isel(source=i).item()
    d_matched = d.loc[~pd.isna(d["match"])]
    all_events=pd.merge(all_events, d_matched[["match", "metadata", "t_shifted"]].rename(columns=dict(metadata="additional_metadata", match="event_id", t_shifted="other_t")).assign(other_src=s), how="left", on="event_id")
    try:
        all_events["metadata"] = all_events.apply(lambda row: 
                                          dict_merge(
                                              row["metadata"], 
                                              row["additional_metadata"] if not pd.isna(row["additional_metadata"]) else {},
                                              {s:dict(t=row["other_t"])}
                                         ), axis=1)
    except:
        print(all_events.to_string())
        raise
    all_events["source"] = np.where(~pd.isna(all_events["other_src"]), all_events["source"] + "+" +all_events["other_src"], all_events["source"])
    all_events = all_events.drop(columns=["additional_metadata", "other_t", "other_src"])
    d_unmatched = d.loc[pd.isna(d["match"])].copy().assign(source=s).drop(columns=["event_id", "match", "t_start"])
    all_events = pd.concat([all_events, d_unmatched])
all_events = all_events.sort_values("t_shifted").drop(columns=["event_id"]).reset_index(drop=True).rename(columns=dict(t_shifted="t"))
all_events["metadata"] = all_events.pop("metadata")
print(all_events.to_string())



In [None]:
compare_to = merge_data["source"].isel(source=0).item()
for i in range(1, merge_data.sizes["source"]):
    s = merge_data["source"].isel(source=i).item()
    data = all_events.copy()
    data["other_t"] = data["metadata"].apply(lambda d: d[s]["t"] if s in d and "t" in d[s] else np.nan)
    data["diff"] = np.log10(1/(data["other_t"] - data["t"]).abs())
    data = data.loc[~pd.isna(data["other_t"])]
    sns.displot(data, x="diff", hue="event_name", bins=5, multiple="dodge", stat="probability", common_norm=False, shrink=.8)
    plt.suptitle(f'log10(1/abs(t_{s} - t_{compare_to})) after alignment distribution for each event type')
    plt.show()
    

In [None]:
%matplotlib qt
nsources = len(merge_data["source"])
for i, s in enumerate(merge_data["source"]):
    for j, ev in enumerate(merge_data["event"]):
        df= merge_data["shifted_event_df"].sel(source=s, event=ev).item()
        other_drawargs = dict(label=ev.item()) if i==0 else {}
        plt.vlines(df["t_shifted"], [(i)]*len(df.index), [(i+1-0.02)]*len(df.index), color= f"C{j}", **other_drawargs)
plt.yticks([0.5+i for i in range(merge_data.sizes["source"])], merge_data["source"].to_numpy())
plt.legend()
plt.suptitle("Events after alignment")
plt.show()

# Reducing to trial

In [None]:
all_evs = set(all_events["event_name"].dropna().to_list())
trial_type_cols = set.union(*[set(d["trial_type"].keys()) for d in all_events["metadata"].values if "trial_type" in d])
def reduce_trial(grp: pd.DataFrame): 
    trial = grp.name
    prev_trial_end = grp["t"].min()
    real_start = grp.loc[grp["event_name"]=="trial_start", "t"].max()
    first_start = grp.loc[grp["event_name"]=="trial_start", "t"].min()
    if pd.isna(real_start):
        real_grp = grp
    else:
        real_grp = grp.loc[grp["t"] >=real_start]
    events = real_grp[["event_name", "t"]].dropna(subset="event_name").set_index("event_name")["t"]
    event_dict = {k:events[k] if k in events else pd.NA for k in all_evs}
    if events.index.duplicated().any():
        raise Exception("Duplicated events")
    
    metadata = dict_merge(*real_grp["metadata"].to_list(), incompatible="remove")
    trial_type = {k:metadata["trial_type"][k] if "trial_type"  in metadata and k in metadata["trial_type"] else pd.NA for k in trial_type_cols }
    res = {
           "events": pd.DataFrame([event_dict | dict(first_start=first_start, prev_trial_end = prev_trial_end)]),
           "trial_type": pd.DataFrame([trial_type]),
    }
    # print(res)
    r = pd.concat(res, axis=1)
    r["trial"] = trial
    # print(r)
    # raise Exception("stop")
    return r
trial_event_df = all_events.groupby((all_events["event_name"]=="trial_end").cumsum()-1).apply(reduce_trial).reset_index(drop=True)
# test = trial_event_df["events"].stack().to_xarray()
trial_event_df[("events", "trial_end")] = trial_event_df.pop(("events", "prev_trial_end")).shift(-1)
trial_event_df=trial_event_df.set_index("trial")

print(trial_event_df.to_string())
# test

In [None]:
trial_dataset = xr.merge([trial_event_df["events"].rename_axis('event_name', axis=1).stack().rename("event_t").to_xarray(), trial_event_df["trial_type"].to_xarray()])
trial_dataset