In [None]:
import pandas as pd, numpy as np,xarray as xr
from pathlib import Path
import re, yaml, copy, json, datetime
import helper, events_methods
from sonpy import lib as sp
import subprocess
import plotly
from helper import RenderJSON
plotly.offline.init_notebook_mode()

In [None]:
import itables
itables.init_notebook_mode(all_interactive=True )
itables.options.maxBytes = "1MB"
itables.options.lengthMenu = [25, 10, 50, 100, 200]
itables.options.buttons = ["copyHtml5", "csvHtml5", "excelHtml5"]
# itables.options.scrollY="200px"
# itables.options.scrollCollapse=True
# itables.options.paging=False
# itables.options.column_filters = "footer"
itables.options.layout={"topEnd": "pageLength", "top1": "searchBuilder"}

In [None]:
params = yaml.safe_load(Path("params.yaml").open("r"))
spike2_path = Path(params["smrx_path"])
info_path = Path(params["config_path"])
res_path = Path(params["dest_path"])
RenderJSON(params)

In [None]:

MyFile = sp.SonFile(str(spike2_path), True)
time_base = MyFile.GetTimeBase()
date = MyFile.GetTimeDate()
date = datetime.datetime(year=date[-1], month=date[-2], day=date[-3], hour=date[-4], minute=date[-5], second=date[-6])
source_file = str(spike2_path)
file_comments = [MyFile.GetFileComment(i) for i in range(8)]
all_channels = pd.DataFrame([dict(
    name=MyFile.GetChannelTitle(i), 
    id=i, 
    chan_type=str(MyFile.ChannelType(i))[len("DataType."):],
    unit=MyFile.GetChannelUnits(i),
    size=MyFile.ChannelBytes(i),
    item_size=MyFile.ItemSize(i),
    scale=MyFile.GetChannelScale(i)/6553.6,
    offset=MyFile.GetChannelOffset(i),
    divide = MyFile.ChannelDivide(i)*time_base,
    comment = MyFile.GetChannelComment(i),

)  for i in range(MyFile.MaxChannels()) if MyFile.ChannelType(i) != sp.DataType.Off])
all_channels



In [None]:
info = yaml.safe_load(info_path.open("r"))
info

In [None]:
channels_dict = {k+"_channel": grp["name"].to_list() for k, grp in all_channels.groupby("chan_type")}
processing = events_methods.EventProcessing.process_info(channels_dict, info["processing"] , dest_name="dest_channel")
processing_df = pd.DataFrame(list(processing.values()))
processing_df ["name"] = processing_df["method_params"].apply(lambda m: m["channel_name"])
processing_df  = processing_df.merge(all_channels, how="left", on="name")
processing_df

In [None]:
datas = {}
for _, row in processing_df.iterrows():
    if row["method"]=="adc_extract":
        if row["chan_type"] != "Adc":
            raise Exception("problem")
        fs = 1/(row["divide"])
        data = []
        block_size = 10**6
        current = 0
        while True:
            fetch_data = sp.SonFile.ReadInts(MyFile, row["id"], block_size, current)
            if len(fetch_data) > 0:
                data.append(np.array(fetch_data))
            if len(fetch_data) < block_size:
                break
            current+=len(fetch_data)
        data = np.concatenate(data)
        data = xr.DataArray(data*row["scale"]+row["offset"], dims="t")
        data["t"] = np.arange(data.size)*row["divide"]
        data["t"].attrs = dict(fs=1/row["divide"])
        data.attrs = dict(unit = row["unit"], spike2_chan_type=row["chan_type"], spike2_id=row["id"], spike2_name=row["name"], comment=row["comment"])
        # n_interp_start=10**6
        # n_interp_val = 10**5
        # if data.size > n_interp_start:
        #     display_data = data.interp(t=np.linspace(data["t"].min().item(),data["t"].max().item(), n_interp_val))
        #     display_name = row["dest_channel"] + "_interp_" + f'{1/row["divide"]}Hz -> {np.round((data["t"].max().item() - data["t"].min().item())/n_interp_val)}Hz'
        # else:
        #     display_data = data
        #     display_name = row["dest_channel"]
        data=data.to_dataset(name=row["dest_channel"])
        data.attrs = dict(date=date.isoformat(), source_file=source_file) | {f"comment_{i}": v for i,v in enumerate(file_comments) if not v==""}
        display(data)
        data.to_netcdf(row["method_params"]["dest_file"])
        datas[row["dest_channel"]] = row["method_params"]["dest_file"]
        # fig.add_trace(go.Scatter(x=display_data["t"].to_numpy(), y=display_data.to_numpy(), name=display_name))

In [None]:
channels_dict = {k+"_channel": grp["name"].to_list() for k, grp in all_channels.groupby("chan_type")}
display = events_methods.EventProcessing.process_info(channels_dict, info["display"] , dest_name="dest_trace")
display = pd.DataFrame(list(display.values()))
display[["fig", "label"]] = display["dest_trace"].str.split(":", expand=True)
display

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
for title, grp in display.groupby("fig"):
    fig = make_subplots(specs=[[{"secondary_y": grp["method_params"].apply(lambda d: d["secondary_y"]).any()}]])
    for _, row in grp.iterrows():
        if row["method"] == 'continuous_xrarray':
            data = xr.load_dataset(row["method_params"]["file"])[row["method_params"]["variable"]]
            if data.size > row["method_params"]["max_numpoints"]:
                display_data = data.interp(t=np.linspace(data["t"].min().item(),data["t"].max().item(), row["method_params"]["max_numpoints"]))
                display_name = row["label"] + "_interp_" + f'{data["t"].attrs["fs"]}Hz -> {np.round(row["method_params"]["max_numpoints"] /(data["t"].max().item() - data["t"].min().item()))}Hz'
            else:
                display_data = data
                display_name = row["label"]
            fig.add_trace(go.Scatter(x=display_data["t"].to_numpy(), y=display_data.to_numpy(), name=display_name), secondary_y=row["method_params"]["secondary_y"])
                
    fig.show()
