In [None]:
import pandas as pd, numpy as np,xarray as xr
from pathlib import Path
import re, yaml, copy, json
from helper import singleglob, Waveform, json_merge

In [None]:
base = Path("/home/julienb/Documents/database_scripts/database_scripts_test/poly_dat_files/Rats/Test_Julien_ForcedInput/")
# base = Path("/home/julienb/Documents/database_scripts/database_scripts_test/poly_dat_files/Rats/Luisa/Rat101_0729_opto_01")

dat_path = singleglob(base, "*.dat")
task_path = singleglob(base, "*.xls")
info_path = singleglob(base, "*.yaml", search_upward_limit=Path("/home/julienb/Documents/database_scripts/database_scripts_test/poly_dat_files"))
res_events_path = base/"events.tsv"
exists = {"dat_path":dat_path.exists(), "task_path":task_path.exists(), "info_path": info_path.exists()}
if not np.all(list(exists.values())):
    display(exists)
    raise Exception("Missing some input files...")
info_path

In [None]:
event_df = pd.read_csv(dat_path, sep="\t", names=['time (ms)', 'family', 'nbre', '_P', '_V', '_L', '_R', '_T', '_W', '_X', '_Y', '_Z'], skiprows=13, dtype=int)
event_df.insert(0, "t", event_df.pop("time (ms)")/1000)
event_df = event_df.reset_index(names="poly_evnum").sort_values(["t", "poly_evnum"]).reset_index(drop=True)
event_df["task_node"] = event_df["_T"].where(event_df["family"]==10).ffill()
print(event_df.to_string())

In [None]:
task_df = pd.read_csv(task_path, sep="\t", header=11)
task_df = task_df.rename(columns={task_df.columns[0]: "task_node" })
display(task_df.columns)
task_df

In [None]:


channels = pd.Series(task_df.columns).str.extract(r'\s*(?P<channel_name>\w+)\s*\((?P<family>\d+)\s*,\s*(?P<nbre>\d+)\)\s*').assign(taskcol_name=task_df.columns).dropna(how="any")
channels["family"] = channels["family"].astype(int)
channels["nbre"] = channels["nbre"].astype(int)
channels

    

In [None]:
pattern=r'on\(\d+(,\d+)*\)'
task_info=pd.DataFrame()
stacked = task_df.set_index("task_node")[channels["taskcol_name"].to_list()].stack().str.lower().str.strip().dropna()
stacked.index.names=["task_node", "taskcol_name"]
task_info["data"] = stacked
task_info["match"] = task_info["data"].str.fullmatch(pattern)
task_info = task_info.loc[task_info["match"]]
task_info["important"] = task_info["data"].str.slice(3, -1)
task_info["task_params"] = task_info["important"].str.split(",").apply(lambda l: [float(x) for x in l])
task_info = task_info.drop(columns=["important", "match", "data"]).join(stacked.rename("task_data"), how="outer")
task_info = task_info.reset_index() 
task_info["task_node"] = task_info["task_node"].astype(float)
task_info


In [None]:
event_channels_df = channels.merge(event_df, on=["family", "nbre"], how="right").merge(task_info, on=["taskcol_name", "task_node"], how="left").sort_values("t")
print(event_channels_df.to_string())

In [None]:
info = yaml.safe_load(info_path.open("r"))
info

In [None]:
event_spec = []

for item in info["processing"]:
    if "duplicate_over" not in item:
        item["duplicate_over"] = {}
    from helper import generate_duplicate_table, replace_vals
    duplication = generate_duplicate_table(item["duplicate_over"], dict(channel_name=channels["channel_name"]))
    for _, row in duplication.iterrows():
        final_d = replace_vals(item, row.to_dict())
        ev_name = final_d["event_name"]
        if "display" in info and "rename"in info["display"] and ev_name in info["display"]["rename"]:
            final_d["display_name"] = info["display"]["rename"][ev_name]
        else:
            final_d["display_name"] = ev_name
        del final_d["duplicate_over"]
        event_spec.append(final_d)

unique_df = pd.DataFrame(event_spec)["event_name"].value_counts().reset_index()
if not (unique_df["count"] == 1).all():
    display(unique_df.loc[unique_df["count"] > 1])
    raise Exception(f"Event name duplication")


display(pd.DataFrame(event_spec))
event_spec = {v["event_name"]: v for v in event_spec}


In [None]:
event_channels_df = event_channels_df.sort_values(["t", "poly_evnum"])
event_channels_df["used"]=False
from ipywidgets import widgets
outs = {k: display(display_id=f'{k}') for k in event_spec}

def myeval(df, expr):
    task_expr_pattern = re.compile(r'task\[(?P<col_num>\d+)\]')
    task_expr_df = pd.DataFrame()
    def handle_col(match):
        num = int(match["col_num"])
        name = f'__task_{num}'
        task_expr_df[name] = df["task_params"].apply(lambda l: l[num] if isinstance(l, list) and len(l) > num else np.nan)
        return name
    new_expr=re.sub(task_expr_pattern, handle_col, expr)
    return pd.concat([df, task_expr_df], axis=1).eval(new_expr)

def compute_relevant(config):
    df = pd.DataFrame()
    ev_name = config["event_name"]
    df["t"] = event_channels_df["t"]
    df["task_data"] = event_channels_df["task_data"]
    df["task_node"] = event_channels_df["task_node"]
    # task_expr_pattern = re.compile(r'task\[(?P<col_num>\d+)\]')
    for param, expr in config["method_params"].items():
        if param.endswith("_expr"):
            df[param.replace("_expr", "_value")] = myeval(event_channels_df, expr)
        else:
            df[param] = expr
    if "metadata" in config:
        metadata = pd.DataFrame()
        for k, v in config["metadata"].items():
            if k.endswith("_expr"):
                metadata[k.replace("_expr", "")] = myeval(event_channels_df, v)
            else:
                metadata[k] = v
        df["metadata"] = metadata.apply(lambda row: {k:v for k, v in row.items() if not pd.isna(v)}, axis=1)
    else:
        df["metadata"] = [{}] * len(df.index)
    filtered = df.loc[df["filter_value"]] if "filter_expr" in config["method_params"] else df
    if "state_expr" in config["method_params"]:
        relevant = filtered.loc[filtered["state_value"] != filtered["state_value"].shift(1)].copy()
        relevant["duration"] = relevant["t"].shift(-1) - relevant["t"]
    else:
        relevant = filtered.copy()
    outs[config["event_name"]].update(relevant.rename_axis([ev_name]))
    return relevant



In [None]:

def compute_metadata(join_metadata, **warnings):
    kept_warnings = {}
    for k, v in warnings.items():
        if isinstance(v, dict):
            vals = {v for v in v.values() if not pd.isna(v)}
            if len(vals) > 1:
                kept_warnings[k] = v
        else:
            if v[0]:
                kept_warnings[k] = v[1]
    if len(kept_warnings) > 0:
        metadata = json_merge(*join_metadata, dict(warnings=kept_warnings))
    else:
        metadata = json_merge(*join_metadata)
    return metadata


def output_accumulator_binary_wave(relevant: pd.DataFrame, config):
    res = []
    for _,grp in  relevant.groupby((relevant["state_value"] ==1).cumsum()):
        starts = grp[grp["state_value"]==1]
        if len(starts.index) ==0: continue
        elif len(starts.index) > 1: display(grp); raise Exception(f"Problem {len(starts.index)}")
        else:
            start = starts["t"].iat[0]
            rises = [0] + (grp["t"].loc[grp["state_value"] > grp["state_value"].shift(1)] -start).to_list()
            duration_on = grp["duration_on_value"].iat[0]
            expected_count = grp["expected_count_value"].iat[0]
            duration = rises[-1] + duration_on
            metadata_join = grp["metadata"].to_list()
            if pd.isnull(grp["task_data"].iat[0]) or grp["task_data"].iat[0].startswith("on"):
                metadata_join.append(dict(warnings=dict(free_event=f'task_data_at_event_is {grp["task_data"].iat[0]}')))
            metadata = compute_metadata(grp["metadata"], 
                count_mismatch=dict(read=len(rises), expected=expected_count),
                free_event=(pd.isnull(grp["task_data"].iat[0]) or not grp["task_data"].iat[0].startswith("on"), f'task_data_at_event_is {grp["task_data"].iat[0]}')
            )
            res.append(dict(t=start, metadata=metadata, duration=duration, waveform_info=
                dict(type="binary", rises=rises, durations=[duration_on]*len(rises))))
    return pd.DataFrame(res)

def input_binary_wave(relevant: pd.DataFrame, config):
    if relevant["state_value"].iat[0] == 0:
        relevant = relevant.iloc[1:, :].copy()
    relevant["next_metadata"] = relevant["metadata"].shift(-1)
    rises = relevant.iloc[::2].copy()
    rises["metadata"] = rises.apply(lambda row: compute_metadata([row["metadata"], row["next_metadata"]]) , axis=1)
    return rises[["t", "duration", "metadata"]].assign(waveform_info=rises["duration"].apply(lambda d: 
        dict(type="binary", rises=[0], durations=[d])))
    

def output_binary_wave(relevant: pd.DataFrame, config):
    if relevant["state_value"].iat[0] == 0:
        relevant = relevant.iloc[1:, :].copy()
    else:
        relevant= relevant.copy()
    relevant["next_metadata"] = relevant["metadata"].shift(-1)
    relevant["count_value"] = relevant["count_value"].shift(-1).replace(0, 1)
    rises = relevant.iloc[::2].copy()

    rises["read_duration"] = rises["duration"]
    rises["duration"] = np.where(
        (rises["duration"] <= 0.001) & rises["duration_on_value"].notna(), 
        rises["duration_on_value"], rises["duration"])
    rises["duration_on_value"] = rises["duration_on_value"].fillna(rises["duration"])
    rises["cycle_duration"] = rises["duration_on_value"] + rises["duration_off_value"]
    rises["rises"] = rises.apply(lambda row: [i * row["cycle_duration"] for i in range(int(row["count_value"]))] if pd.notna(row["count_value"]) else np.nan, axis=1)

    rises["metadata"] = rises.apply(lambda row: compute_metadata(
        [row["metadata"], row["next_metadata"]], 
        count_mismatch=dict(read=row["count_value"], expected=row["expected_count_value"]),
        duration_correction=dict(read=row["read_duration"], corrected=row["duration"]),
        free_event=(pd.isnull(row["task_data"]) or not row["task_data"].startswith("on"), f'task_data_at_event_is {row["task_data"]}')
    ), axis=1)
    
    rises["waveform_info"] = rises.apply(lambda row: 
        dict(type="binary", rises=row["rises"], durations=[row["duration_on_value"]]* len(row["rises"])), axis=1)
    return rises[["t","metadata", "duration", "waveform_info"]].copy()



In [None]:



all = []
for ev_name, item in event_spec.items():
    relevant = compute_relevant(item)
    if len(relevant.index) ==0:
        continue
    if item["method"] == "output_accumulator_binary_wave":
        events = output_accumulator_binary_wave(relevant, item)
    elif item["method"] == "input_binary_wave":
        events = input_binary_wave(relevant, item)
    elif item["method"] == "output_binary_wave":
        events = output_binary_wave(relevant, item)
    elif item["method"]=="step_wave":
        events = relevant.assign(curr_value= relevant["state_value"], prev_value=relevant["state_value"].shift(1), next_value=relevant["state_value"].shift(-1))
        events["waveform_info"] = events[["curr_value", "prev_value", "next_value"]].apply(lambda row: row.to_dict(), axis=1)
        events = events[["t","metadata", "duration", "waveform_info"]]
    elif item["method"]=="event":
        events = relevant[["t", "metadata"]].assign(duration=np.nan, waveform_info=None)
    if len(events.index)!=0:
        all.append(events[[c for c in events.columns if events[c].count() > 0]].assign(event_name=ev_name))
all = pd.concat(all).sort_values("t")
display(all)
  

In [None]:
if "display" in info and "rename" in info["display"]:
    all["event_name"] = all["event_name"].map(lambda e: info["display"]["rename"][e] if e in info["display"]["rename"] else e)

In [None]:

all["metadata_json"] = all["metadata"].apply(lambda d: json.dumps(d) if not pd.isna(d) else "{}")
all["waveform_info_json"] = all["waveform_info"].apply(lambda d: json.dumps(d) if not pd.isna(d) else "{}")

all.drop(columns=["waveform_info", "metadata"]).to_csv(res_events_path, sep="\t", index=False)

In [None]:
reloaded = pd.read_csv(res_events_path, sep="\t", index_col=False)
reloaded["metadata"] = reloaded["metadata_json"].apply(lambda s: json.loads(s))
reloaded["waveform_info"] = reloaded["waveform_info_json"].apply(lambda s: json.loads(s))
reloaded.drop(columns=["waveform_info_json", "metadata_json"])