In [None]:
import networkx as nx, graphviz
import pandas as pd, numpy as np,xarray as xr, plotly
from pathlib import Path
import re, yaml, copy, json
import helper, config_adapter
from helper import RenderJSON
from copy import deepcopy
plotly.offline.init_notebook_mode()


In [None]:
import itables
itables.init_notebook_mode(all_interactive=True )
itables.options.maxBytes = "1MB"
itables.options.lengthMenu = [25, 10, 50, 100, 200]
itables.options.buttons = ["copyHtml5", "csvHtml5", "excelHtml5"]
itables.options.layout={"topEnd": "pageLength", "top1": "searchBuilder"}

In [None]:
params = yaml.safe_load(Path("params.yaml").open("r"))
# params = dict(
#     variables=dict(
#         task_file='/home/julienb/Documents/Data/Raphael/Poly_Exercices/LASER_Task7_v3_100LeftLever_LeftHanded_PADComp_Check-PADS@RT_ITI400-1200_RT2000_MT3500_Errors_6secOFF_Partial300MT_Laser3070_ContiL1only_Nico_GOOD.xls',
#         dat_file='/home/julienb/Documents/Data/Raphael/Poly_Data/#517/01072024/Rat_#517_Ambidexter_LeftHemiStimCTRL_Beta300MT_Laser3070_L1L25050_01072024_01.dat',
#     ),
#     config_path='/home/julienb/Documents/database_scripts/templates/mk_graph.yaml'
# )

In [None]:
config_path = Path(params["config_path"])
if "variables" in params:
    variables = config_adapter.normalize_yaml_paramlist(params["variables"], format=config_adapter.variable_param_format)
else: 
    variables = []
RenderJSON(params)

In [None]:
config = config_adapter.load(config_path)
RenderJSON(config)

In [None]:
if "variables" in config:
    variables += config_adapter.normalize_yaml_paramlist(config["variables"], format=config_adapter.variable_param_format)
display(RenderJSON(variables))
ctx = config_adapter.Context()
for var in variables:
    config_adapter.add_variable_context(ctx, var)
RenderJSON(ctx.variables)

In [None]:
def from_poly_task(ctx, params):
    params = ctx.evaluate(params)
    task_path = params["task_file"]
    with Path(task_path).open("r") as f:
        i=0
        while(True):
            l = f.readline().split("\t")
            if len([x for x in l if "NEXT" in x]) >1:
                break
            i+=1
    task_df = pd.read_csv(task_path, sep="\t", skiprows=i)
    # header_line = task_df.str.contains("NEXT").astype(int).sum(axis=1).argmax()
    # task_df.columns = task_df.iloc[header_line, :]
    # task_df = task_df.iloc[header_line+1:, :]
    task_df = task_df.rename(columns={task_df.columns[0]: "task_node" })
    df = task_df
    df = df.loc[~pd.isna(df["task_node"])]
    df = df.dropna(subset=df.columns[1:], how="all")
    df["task_node"] = df["task_node"].astype(int)
    graph = nx.DiGraph()
    for _, row in df.iterrows():
        row = row.dropna().to_dict()
        names = []
        graph.add_node(row["task_node"])
        node = graph.nodes[row["task_node"]]
        for col in row:
            if col.startswith("NEXT"):
                pattern = r'\(.+\)$'
                ns = re.findall(pattern, row[col])
                if len(ns) == 0:
                    next_line = row["task_node"]+1
                    cond = row[col]
                elif len(ns) ==1:
                    cond = row[col][:-len(ns[0])]
                    nlname = ns[0][1: -1]
                    if re.match(r'\d+', nlname):
                        next_line = int(nlname)
                    else:
                        next_line = df.loc[(df[["T1", "T2", "T3"]].apply(lambda s: s.str.lstrip("_")) == nlname).any(axis=1)]["task_node"]
                        if len(next_line) != 1:
                            raise Exception(f"problem {len(next_line)} {nlname}")
                        next_line = next_line.iat[0]
                else:
                    raise Exception("Problem")
                graph.add_edge(row["task_node"], next_line, cond=cond)
            elif re.match("T\d+", col):
                m = re.match(r'(?P<time>\d*-?\d*)_(?P<name>\w+)$', str(row[col]))
                if not m is None:
                    names.append(m["name"])
                    if m["time"]:
                        node[col] = m["time"]
                else:
                    node[col] = row[col]
            else:
                node[col] = row[col]
        node["poly_names"] = names
    return graph



In [None]:
def from_poly_dat(ctx, params):
    params = ctx.evaluate(params)
    dat_path = params["dat_file"]
    event_df = pd.read_csv(dat_path, sep="\t", names=['time (ms)', 'family', 'nbre', '_P', '_V', '_L', '_R', '_T', '_W', '_X', '_Y', '_Z'], skiprows=13, dtype=int)
    event_df.insert(0, "t", event_df.pop("time (ms)")/1000)
    event_df.insert(1, "next_t", event_df["t"].shift(-1))
    event_df["task_node"] = event_df["_T"].where(event_df["family"]==10).ffill()
    event_df["next_node"] = event_df["_T"].where(event_df["family"]==10).shift(-1).bfill()
    event_df["node_change"] = (event_df["family"]==10).cumsum()
    grp = event_df.groupby("node_change")
    graph = nx.DiGraph()
    if "node_info" in params:
        final = pd.DataFrame()
        for n in params["node_info"]:
            grp_value = grp.apply(
                lambda d: pd.Series(dict(task_node=d["task_node"].iat[0], group=d.eval(str(n["group_expr"])))), include_groups=False).reset_index(drop=True)
            res = grp_value.groupby("task_node").apply(lambda d: d.eval(str(n["agg_expr"])), include_groups=False)
            final[n["name"]] = res
        for n, row in final.iterrows():
            graph.add_node(n, **row.to_dict())
    if "edge_info" in params:
        final = pd.DataFrame()
        for n in params["edge_info"]:
            grp_value = grp.apply(
                lambda d: pd.Series(dict(task_node=d["task_node"].iat[0], next_node=d["next_node"].iat[0], group=d.eval(str(n["group_expr"])))), include_groups=False).reset_index(drop=True)
            res = grp_value.groupby(["task_node", "next_node"]).apply(lambda d: d.eval(str(n["agg_expr"])), include_groups=False)
            final[n["name"]] = res
        for (n1, n2), row in final.iterrows():
            graph.add_edge(n1, n2, **row.to_dict())
    return graph

In [None]:
ctx.methods["from_poly_task"] = from_poly_task
ctx.methods["from_poly_dat"] = from_poly_dat
combined_graph = nx.DiGraph()
for g in config["processing"]["graphs"]:
    graph = ctx.evaluate(g)
    node_infos_df = pd.DataFrame([dict(node=n) | v for n, v in graph.nodes(data=True)])
    edge_infos_df = pd.DataFrame([dict(src=n1, dest=n2) | v  for n1, n2,v in graph.edges(data=True)])
    display(node_infos_df)
    display(edge_infos_df)
    combined_graph = nx.compose(combined_graph, graph)
del ctx.methods["from_poly_task"]
del ctx.methods["from_poly_dat"]
json_graph = json.dumps(nx.cytoscape_data(graph), indent=4)
with Path("graph.json").open("w") as f:
    f.write(json_graph)
RenderJSON(json_graph)

In [None]:
display_graph = deepcopy(combined_graph)
for node, attrs in display_graph.nodes(data=True):
    label = '\n'.join([f'{k}: {v}' for k,v in attrs.items()])
    for attr in list(attrs):
        del attrs[attr]
    display_graph.nodes[node]["label"] = label

for n1, n2, attrs in display_graph.edges(data=True):
    label = '\n'.join([f'{k}: {v}' for k,v in attrs.items()])
    for attr in list(attrs):
        del attrs[attr]
    display_graph.edges[n1, n2]["label"] = label

In [None]:
nx.nx_pydot.write_dot(display_graph, "graph.dot")
graphviz.render("dot", filepath="graph.dot", outfile="graph.svg", format="svg")
from IPython.display import SVG
SVG(filename="graph.svg")