In [None]:
import numpy as np, pandas as pd, xarray as xr
from pathlib import Path
import datetime, networkx as nx, yaml, warnings, sys, logging
from disjoint_set import DisjointSet
from helper import singleglob, nxrender, Step, dict_merge, PdfWriter, DictWriter, TableWriter, TextWriter
import matplotlib.pyplot as plt
import matplotlib as mpl, seaborn as sns
mpl.rcParams['figure.figsize'] = [15, 6]

In [None]:
notebook_name = Path("events.ipynb")

In [None]:
base_folder = Path(".").resolve().parent
result_folder = Path(".")/notebook_name.stem
result_folder.mkdir(exist_ok=True, parents=True)
node_metadata_step = Step(Path(".").resolve() / "task_graph.ipynb")
metadata_path = singleglob(base_folder, "metadata.yaml", "metadata --*.yaml", "metadata--*.yaml")
poly_dat_path = singleglob(base_folder, "poly/events*.dat")
fiber_events_path = singleglob(base_folder, "**/Events.csv", "**/Events --*.csv", "**/Events--*.csv")
trial_dataset_path =  result_folder / "trial_events.nc"
base_folder, node_metadata_step, metadata_path, poly_dat_path, fiber_events_path, trial_dataset_path

In [None]:
tables = TableWriter(result_folder/"tables.xlsx")
figures = PdfWriter(result_folder/"figures.pdf")
dicts = DictWriter(result_folder/"dicts.yaml")
notebook_save_path = result_folder/"notebook.html"
warn =  TextWriter(result_folder/"warnings.txt")


In [None]:
all_metadata = yaml.safe_load(metadata_path.open("r"))
event_metadata = all_metadata["task"]["events"]
dicts.write(event_metadata = event_metadata)


# Handling Poly events

In [None]:
node_metadata = yaml.safe_load((node_metadata_step.exec_if_necessary()/"node_metadata.yaml").open("r"))
dicts.write(node_metadata = node_metadata)

In [None]:
poly_event_df = pd.read_csv(poly_dat_path, sep="\t", names=['time (ms)', 'family', 'nbre', '_P', '_V', '_L', '_R', '_T', '_W', '_X', '_Y', '_Z'], skiprows=13)
poly_event_df.insert(0, "t", poly_event_df["time (ms)"]/1000)
if "valid_recording_intervals" in all_metadata["recordings"]["poly"]:
    l = all_metadata["recordings"]["poly"]["valid_recording_intervals"]
    if len(l) > 1:
        raise Exception(f"For now, only one valid recording interval allowed")
    start = l[0]["start"]
    end = l[0]["end"]
    if not start:
        start = -np.inf
    if not end:
        end=np.inf
    poly_event_df = poly_event_df.loc[(poly_event_df["t"] >= start)&(poly_event_df["t"] <= end)]
        
poly_event_df = poly_event_df.sort_values("t")
poly_event_df["state_count"] = (poly_event_df["family"]==10).cumsum()
poly_event_df["curr_node"] = poly_event_df["_T"].where(poly_event_df["family"]==10).ffill()
print(tables.write(poly_event_df = poly_event_df))

In [None]:
def reduce_grp(grp: pd.DataFrame):
    s, n = grp.name
    metadata = node_metadata[n] if n in node_metadata else {}
    t_start = grp["t"].min()
    has_pause = len(grp.loc[grp["family"]==11].index) > 0
    event = metadata["event"] if "event" in metadata else None
    return pd.Series(dict(t_start=t_start, event_name=event, metadata=metadata, has_pause=has_pause))

poly_processed_events_df= poly_event_df.groupby(["state_count", "curr_node"]).apply(reduce_grp, include_groups=False).sort_values("t_start").reset_index()
poly_processed_events_df.insert(3, "t_end", poly_processed_events_df["t_start"].shift(-1))
print(tables.write(poly_processed_events_df = poly_processed_events_df))


In [None]:
final_poly_events = poly_processed_events_df.copy().dropna(subset="event_name")
final_poly_events["metadata"] = final_poly_events.apply(lambda row: dict_merge(row["metadata"], dict(poly=dict(pause=row["has_pause"], curr_node=row["curr_node"]))), axis=1)
final_poly_events = final_poly_events.drop(columns=["state_count", "curr_node", "t_end", "has_pause"])
final_poly_events["event_id"] = np.arange(len(final_poly_events.index))
print(tables.write(final_poly_events = final_poly_events))

# Handling of Fiber Events

In [None]:
raw_fiber_events = pd.read_csv(fiber_events_path).sort_values("TimeStamp")
raw_fiber_events.insert(0, "t", raw_fiber_events["TimeStamp"]/1000)
if "valid_recording_intervals" in all_metadata["recordings"]["fiber"]:
    l = all_metadata["recordings"]["fiber"]["valid_recording_intervals"]
    if len(l) > 1:
        raise Exception(f"For now, only one valid recording interval allowed")
    start = l[0]["start"]
    end = l[0]["end"]
    if not start:
        start = -np.inf
    if not end:
        end=np.inf
    raw_fiber_events = raw_fiber_events.loc[(raw_fiber_events["t"] >= start)&(raw_fiber_events["t"] <= end)]
tables.write(raw_fiber_events=raw_fiber_events)

In [None]:
raw_fiber_events["ev_num"] = raw_fiber_events.groupby(["Name", "State"]).cumcount()
fiber_events=raw_fiber_events.set_index(["ev_num", "Name", "State"])["t"].unstack("State").reset_index().rename(columns={0:"t_start", 1:"t_end"}).drop(columns="ev_num")
fiber_events.columns.name=None
tables.write(fiber_events=fiber_events.sort_values("t_start"))

In [None]:
def process_fiber_event(name, start, end):
    event = None
    metadata = None
    duration = end-start
    for d in event_metadata["fiber"]:
        detect = d["detection"]
        desc = d["description"]
        selected=True
        if "name" in detect:
            if name != detect["name"]:
                selected=False
        if "duration_min" in detect:
            if detect["duration_min"] > duration:
                selected=False
        if "duration_max" in detect:
            if detect["duration_max"] < duration:
                selected=False
        # print(detect, name, duration)
        if selected:
            if event is not None or metadata is not None:
                raise Exception(f"conflict {event} {metadata}, {detect}, {duration}")
            event = desc["event"]
            metadata = desc
    return (event, metadata)
processed_fiber_events = fiber_events.copy()
processed_fiber_events[["event_name", "metadata"]] = processed_fiber_events.apply(lambda row: process_fiber_event(row["Name"], row["t_start"], row["t_end"]), axis=1, result_type="expand")
tables.write(processed_fiber_events=processed_fiber_events.sort_values("t_start"))

In [None]:
final_fiber_events = processed_fiber_events.copy().dropna(subset="event_name").sort_values("t_start")
final_fiber_events["metadata"] = final_fiber_events.apply(lambda row: dict_merge(row["metadata"], dict(fiber=dict(FiberInputNum=row["Name"]))), axis=1)
final_fiber_events = final_fiber_events.drop(columns=["t_end", "Name"])
final_fiber_events["event_id"] = np.arange(len(final_fiber_events.index))
tables.write(final_fiber_events=final_fiber_events)

# Synchronizing and merging

In [None]:
df_dict = {"fiber": final_fiber_events, "poly": final_poly_events}
n_evs = 0
for k, v in df_dict.items():
    df_dict[k] = v.assign(event_id=np.arange(n_evs, n_evs+len(v.index)), source=k)
    n_evs+=len(v.index)
dfs = pd.Series(df_dict, name="df").to_frame()
dfs.index.name="source"
merge_data = dfs.to_xarray()
event_name_list = xr.apply_ufunc(lambda df: set(df["event_name"].dropna().to_list()), merge_data["df"], vectorize=True)
merge_data["event"] = xr.DataArray(list(set.union(*list(event_name_list.to_numpy()))), dims="event")
merge_data

In [None]:
merge_data["event_df"] = xr.apply_ufunc(lambda df, event: {"df": df.loc[df["event_name"] == event]}, merge_data["df"], merge_data["event"], vectorize=True)
merge_data["num_events"] = xr.apply_ufunc(lambda df: len(df["df"].index), merge_data["event_df"], vectorize=True)
merge_data

## Computing t_shifted

In [None]:
base_time_source = "fiber"
shift_using = "cue"

if not (merge_data["num_events"].sel(event=shift_using) == merge_data["num_events"].sel(source=base_time_source, event=shift_using)).all():
    raise Exception(f"Problem not same number of {shift_using} event")

def get_shift(df, target_df):
    df = df["df"]
    target_df = target_df["df"]
    res_df = pd.DataFrame()
    res_df["t"] = df["t_start"].to_numpy()
    res_df["shift"] = df["t_start"].to_numpy() - target_df["t_start"].to_numpy()
    res_df = pd.concat([pd.DataFrame([dict(t=-np.inf, shift=res_df["shift"].iat[0])]), res_df])
    return {"df":res_df}

tmp =merge_data["event_df"].sel(source=base_time_source, event=shift_using, drop=True)
tmp2 =  merge_data["event_df"].sel(event=shift_using, drop=True)
merge_data["shift_data"] = xr.apply_ufunc(get_shift, tmp2, tmp, vectorize=True)

shift_data_dfs = {merge_data["source"].isel(source=i).item(): merge_data["shift_data"].isel(source=i).item()["df"] for i in range(merge_data.sizes["source"])}
tables.write(**{f"shift_data_{k}":v for k,v in shift_data_dfs.items()})

def shift_data(df, shift_data):
    shift_data = shift_data["df"]
    res = df.copy()
    res.insert(1, "t_shifted", res["t_start"].to_numpy() - shift_data["shift"].iloc[np.searchsorted(shift_data["t"], res["t_start"], side="right")-1].to_numpy())
    return dict(df=res)

merge_data["shifted_df"] = xr.apply_ufunc(shift_data, merge_data["df"], merge_data["shift_data"], vectorize=True)
display(merge_data)
all_events_df = pd.concat([v["df"] for v in merge_data["shifted_df"].to_numpy()]).sort_values("t_shifted")
all_events_df["metadata"] = all_events_df.pop("metadata")
tables.write(all_events_df=all_events_df)

## Computing matching

In [None]:
ev_sets = DisjointSet.from_iterable(all_events_df["event_id"].to_numpy())
for ev, g in all_events_df.groupby("event_name"):
    sources = g.groupby("source")
    for s1, g1 in sources:
        for s2, g2 in sources:
            if len(g1.index) == 0 or len(g2.index) == 0:
                continue
            elif len(g1.index) == len(g2.index):
                for id1, id2 in zip(g1["event_id"], g2["event_id"]):
                    ev_sets.union(id1, id2)
            else:
                if len(g1.index)  < len(g2.index):
                    tmp = g1
                    g1 = g2
                    g2=g1
                g1_ar = xr.DataArray(g1["event_id"].to_numpy(), dims="t", coords=dict(t=g1["t_shifted"].to_numpy()))
                match = g1_ar.sel(t=g2["t_shifted"].to_numpy(), method="nearest").to_numpy()
                for id1, id2 in zip(match, g2["event_id"]):
                    ev_sets.union(id1, id2)
ev_sets = list(ev_sets.itersets())
print(" ; ".join([("\n" if i % 12==0 and i> 0 else "") + str(s).replace(" ", "") for i, s in enumerate(ev_sets)]))

In [None]:
t_priority = ["fiber", "poly"]
merged_ev_dict = []
dfs = []
for i,s in enumerate(ev_sets):
    df = all_events_df.loc[all_events_df["event_id"].isin(s)]
    metadata = dict_merge(*df["metadata"].to_list())
    ev_name = df["event_name"].iat[0]
    t = df.sort_values("source", key=lambda src: src.map({k:i for i,k in enumerate(t_priority)}))["t_shifted"].iat[0]
    sources = "+".join(df["source"].to_list())
    merged_ev_dict.append(dict(merged_id=i, event_ids=s, event_name=ev_name, t=t, sources=sources, metadata=metadata))
    dfs.append(df.assign(merged_id=i, event_name=ev_name))
merge_events_df = pd.DataFrame(merged_ev_dict)
all_merged_dataset = pd.concat(dfs).set_index(["source", "merged_id"]).to_xarray()
all_merged_dataset["event_name"] = all_merged_dataset["event_name"].isel(source=0)
all_merged_dataset = all_merged_dataset.set_coords("event_name")
display(all_merged_dataset)
tables.write(merge_events_df=merge_events_df)


## Checks

In [None]:
alterated = []
for i in range(all_merged_dataset.sizes["source"]):
    s = all_merged_dataset["source"].isel(source=i).item()
    d = all_merged_dataset.isel(source=i)
    d = d.where(d["event_id"].notnull(), drop=True)
    new_ev_order = d.sortby("t_shifted")["event_id"].to_numpy() 
    initial_ev_order= d.sortby("t_start")["event_id"].to_numpy()
    if not (new_ev_order == initial_ev_order).all():
        alterated.append(s)
if len(alterated) > 0:
    warnings.warn(warn.write(f'Merge warning: Shifting alters order for sources {alterated}'))
else:
    print("All ok")

In [None]:
for i in range(all_merged_dataset.sizes["source"]):
    for j in range(i+1, all_merged_dataset.sizes["source"]):
        s1 = all_merged_dataset["source"].isel(source=i).item()
        s2 = all_merged_dataset["source"].isel(source=j).item()
        d = (all_merged_dataset["t_shifted"].isel(source=i) - all_merged_dataset["t_shifted"].isel(source=j)).rename("delta_t").to_dataframe()
        min = d["delta_t"].dropna().min()
        max = d["delta_t"].dropna().max()
        if min < -0.020 or max > 0.020:
            warnings.warn(warn.write(f"Merge interval is between source {s1} ans {s2} is [{min}, {max}]"))
        bins = [-0.020, -0.010, -0.005, -0.002, 0.002, 0.005, 0.010, 0.020]
        bins = [b for b in bins if b>min and b<max]
        bins = [min] + bins + [max]
        sns.displot(d, x="delta_t", hue="event_name", bins=bins, multiple="dodge", stat="probability", common_norm=False, shrink=.8)
        
        plt.xticks(bins, labels = [f'{b:.2g}' for b in bins], rotation=25)
        plt.suptitle(f't_shifted variation between {s1} and {s2}')
        figures.write()


In [None]:
min=np.inf
max = -np.inf
for i in range(merge_data.sizes["source"]):
    s = merge_data["source"].isel(source=i).item()
    shift_df = merge_data["shift_data"].isel(source=i).item()["df"]
    c_min = shift_df["shift"].min()
    c_max = shift_df["shift"].max()
    if c_min < min:
        min = c_min
    if c_max > max:
        max = c_max
    plt.stairs(shift_df["shift"].to_numpy()[:-1],shift_df["t"].to_numpy(), color=f"C{i}", baseline=None, label = s)
plt.suptitle(f'Shift amount over time\nAlignment is done on {shift_using} events. Common time is that of {base_time_source}')
plt.xlabel("t\nin time of source")
plt.ylabel(f"shift amount\nt_common = t_source - shift")
plt.legend(title="source")
plt.ylim([min - (max-min)*0.05, max + (max-min)*0.05])
figures.write()



# Reducing to trial

In [None]:
all_evs = set(merge_events_df["event_name"].dropna().to_list())
trial_type_cols = set.union(*[set(d["trial_type"].keys()) for d in merge_events_df["metadata"].values if "trial_type" in d])
def reduce_trial(grp: pd.DataFrame): 
    trial = grp.name
    prev_trial_end = grp["t"].min()
    real_start = grp.loc[grp["event_name"]=="trial_start", "t"].max()
    first_start = grp.loc[grp["event_name"]=="trial_start", "t"].min()
    if pd.isna(real_start):
        real_grp = grp
    else:
        real_grp = grp.loc[grp["t"] >=real_start]
    events = real_grp[["event_name", "t"]].dropna(subset="event_name").set_index("event_name")["t"]
    event_dict = {k:events[k] if k in events else pd.NA for k in all_evs}
    if events.index.duplicated().any():
        raise Exception("Duplicated events")
    
    metadata = dict_merge(*real_grp["metadata"].to_list(), incompatible="remove")
    trial_type = {k:metadata["trial_type"][k] if "trial_type"  in metadata and k in metadata["trial_type"] else pd.NA for k in trial_type_cols }
    res = {
           "events": pd.DataFrame([event_dict | dict(first_start=first_start, prev_trial_end = prev_trial_end)]),
           "trial_type": pd.DataFrame([trial_type]),
    }
    r = pd.concat(res, axis=1)
    r["trial"] = trial
    return r

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning) # Seems to be a pandas bug, see https://github.com/pandas-dev/pandas/issues/55928
    trial_event_df = merge_events_df.groupby((merge_events_df["event_name"]=="trial_end").cumsum()-1).apply(reduce_trial).reset_index(drop=True)

trial_event_df[("events", "trial_end")] = trial_event_df.pop(("events", "prev_trial_end")).shift(-1)
trial_event_df=trial_event_df.set_index("trial")

tables.write(trial_event_df=trial_event_df)
# test

In [None]:
trial_dataset = xr.merge([trial_event_df["events"].rename_axis('event_name', axis=1).stack().rename("event_t").to_xarray(), trial_event_df["trial_type"].to_xarray()])
trial_dataset

# Exporting

In [None]:
trial_dataset.fillna(np.nan).to_netcdf(trial_dataset_path)

In [None]:
xr.load_dataset(trial_dataset_path)

# Finishing

In [None]:
%%javascript
IPython.notebook.save_notebook()

In [None]:
import os
os.system(f'jupyter nbconvert --to html {notebook_name} --output {notebook_save_path} --no-prompt --Application.log_level=40')


In [None]:
del tables
del dicts
del figures
del warn