In [None]:
import pandas as pd, numpy as np,xarray as xr
from pathlib import Path
import re, yaml, copy, json
import helper, events_methods
import subprocess
import plotly
from helper import RenderJSON
plotly.offline.init_notebook_mode()

In [None]:
import itables
itables.init_notebook_mode(all_interactive=True )
itables.options.maxBytes = "1MB"
itables.options.lengthMenu = [25, 10, 50, 100, 200]
itables.options.buttons = ["copyHtml5", "csvHtml5", "excelHtml5"]
# itables.options.scrollY="200px"
# itables.options.scrollCollapse=True
# itables.options.paging=False
# itables.options.column_filters = "footer"
itables.options.layout={"topEnd": "pageLength", "top1": "searchBuilder"}

In [None]:
params = yaml.safe_load(Path("params.yaml").open("r"))
spike2_path = Path(params["smrx_path"])
info_path = Path(params["config_path"])
res_events_path = Path(params["dest_path"])
RenderJSON(params)

In [None]:
if not spike2_path.with_suffix("").with_stem(spike2_path.stem+ "_data").exists():
    subprocess.run(["/home/julienb/miniconda3/envs/spike2/bin/smrx2python", "-i", str(spike2_path)])
channels = pd.read_csv(spike2_path.with_suffix(".tsv"), sep="\t")
channels

In [None]:
info = yaml.safe_load(info_path.open("r"))
info

In [None]:
discretize = events_methods.EventProcessing.process_info(channels["name"].to_list(), info["discretize"], dest_name="dest_channel")
chans = set(channels["name"].to_list())
regular_chans =  set(channels["name"].loc[channels["data_kind"] == "RegularSampling"].to_list())
dkeys = {v["channel"] for v in discretize.values()}
if not dkeys.issubset(chans):
    raise Exception(f"Some source channels where not found {dkeys - chans}")
if not dkeys.issubset(regular_chans):
    raise Exception(f"Some source channels are not regular (continuous) {dkeys - regular_chans}")
pd.DataFrame(list(discretize.values()))

In [None]:
discretized = []
chans=[]
for dest_chan, item in discretize.items():
    npy = np.load(spike2_path.with_suffix("").with_stem(spike2_path.stem+ "_data") /( item["channel"] + ".npy"))
    if len(channels[channels["name"] == item["channel"]].index) != 1:
        raise Exception("Problem")
    meta = channels[channels["name"] == item["channel"]].iloc[0, :]
    data = xr.DataArray(npy, dims=["t"])
    data["t"] = np.arange(npy.size)/meta["fs"]
    events = events_methods.Discretize.call(item["method"],data, item)
    n_interp_start=10**7
    n_interp_val = 10**6
    if data.size > n_interp_start:
        display_data = data.interp(t=np.linspace(data["t"].min().item(),data["t"].max().item(), n_interp_val))
        display_name = item["channel"] + "_interp"
    else:
        display_data = data
        display_name = item["channel"]
    # display(data)
    import plotly.graph_objects as go
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=display_data["t"].to_numpy(), y=display_data.to_numpy(), name=display_name))
    ys = [data.min().item(), data.max().item()]
    for _, row in events.iterrows():
        fig.add_trace(go.Scatter(x=[row["t"]]*2, y=ys, line_color="green" if row["State"] else "red", showlegend=False))
        # fig.add_vline(x=row["t"], line_color="green" if row["State"] else "red")
    fig.show()
    # display(to_display)
    discretized.append(events)
    chans.append(dest_chan)
    # display(events)
    
discretized = pd.concat(discretized).sort_values("t").reset_index(drop=True)
discretized

In [None]:
event_df = [discretized]
for _, row in channels.loc[channels["data_kind"]=="Event"].iterrows():
    data = np.load(spike2_path.with_suffix("").with_stem(spike2_path.stem+ "_data") /(row["name"] + ".npy"))
    chans.append(row["name"])
    if row["smrx_type"] == "DataType.EventBoth":
        dr = pd.DataFrame().assign(t=data[:, 0], channel_name=row["name"], State=1)
        dd = pd.DataFrame().assign(t=data[: , 1], channel_name=row["name"], State=0)
        d = pd.concat([dd, dr])
    else: raise Exception("unhandled event type")
    event_df.append(d)
event_df = pd.concat(event_df).sort_values("t").reset_index(drop=True)
event_df

        

In [None]:
event_spec = events_methods.EventProcessing.process_info(chans, info["processing"])
pd.DataFrame(list(event_spec.values()))

In [None]:
all=[]
for ev_name, item in event_spec.items():
    ev_dataframe = events_methods.FiberEventProcessing.compute_evdataframe(event_df, item)
    if len(ev_dataframe.index) == 0: continue
    events = events_methods.FiberEventProcessing.call(item["method"],ev_dataframe, item)
    if len(events.index)!=0:
        all.append(events)
all = pd.concat(all).sort_values("t")
all

In [None]:
if "display" in info and "rename" in info["display"]:
    all["event_name"] = all["event_name"].map(lambda e: info["display"]["rename"][e] if e in info["display"]["rename"] else e)
json_cols = ["metadata", "waveform_changes", "waveform_values"]
for col in json_cols:
    all[f"{col}_json"] = all[col].apply(lambda d: json.dumps(d))
all.drop(columns=json_cols).to_csv(res_events_path, sep="\t", index=False)
reloaded = pd.read_csv(res_events_path, sep="\t", index_col=False)
for col in reloaded.columns:
    if col.endswith("_json"):
        reloaded[col[:-5]] = reloaded.pop(col).apply(lambda s: json.loads(s) if not pd.isna(s) else None)
reloaded

In [None]:
summary = events_methods.EventProcessing.summarize(reloaded)
summary
