In [None]:
import numpy as np, pandas as pd, xarray as xr
from pathlib import Path
import datetime, networkx as nx, yaml
from helper import singleglob, nxrender
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
base_folder = Path(".").resolve().parent.parent
task_path = singleglob(base_folder, "task.xls", "task --*.xls", "task--*.xls")
metadata_path = singleglob(base_folder, "metadata.yaml", "metadata --*.yaml", "metadata--*.yaml")
annotated_task_graph_path = (Path(".").resolve()/"annotated_task_graph.svg")
node_metadata_path = (Path(".").resolve()/"node_metadata.yaml")
base_folder, task_path, metadata_path, annotated_task_graph_path, node_metadata_path

# Task

In [None]:

task_df = pd.read_csv(task_path, sep="\t", header=11)
task_df = task_df.rename(columns={"Unnamed: 0":"line_num"})
task_df = task_df.loc[~pd.isna(task_df["line_num"])]
task_df["line_num"] = task_df["line_num"].astype(int)
task_df = task_df.dropna(subset=task_df.columns[1:],  how="all")
task_df 

In [None]:
graph = nx.DiGraph(size = "9, 16" )
for _, row in task_df.iterrows():
    graph.add_node(row["line_num"], **row.dropna().to_dict())
    for col in [col for col in task_df.columns if "NEXT" in col]:
        if pd.isna(row[col]):
            continue
        import re
        pattern = r'\(.+\)$'
        ns = re.findall(pattern, row[col])
        if len(ns) == 0:
            next_line = row["line_num"]+1
            cond = row[col]
        elif len(ns) ==1:
            cond = row[col][:-len(ns[0])]
            nlname = ns[0][1: -1]
            if re.match(r'\d+', nlname):
                next_line = int(nlname)
            else:
                next_line = task_df.loc[(task_df[["T1", "T2", "T3"]].apply(lambda s: s.str.lstrip("_")) == nlname).any(axis=1)]["line_num"]
                if len(next_line) != 1:
                    raise Exception(f"problem {len(next_line)} {nlname}")
                next_line = next_line.iat[0]
        else:
            raise Exception("Problem")
        graph.add_edge(row["line_num"], next_line, cond=cond)

descendants = nx.descendants(graph, 1).union({1})
useless_nodes = set(graph.nodes) - set(descendants)
print(f"Found the following useless nodes in your task... {useless_nodes}")
graph = graph.subgraph(descendants)
from IPython.core.display import SVG
SVG(nxrender(graph, nodeautolabel=[..., "+dot", re.compile('NEXT\d*')], format="svg", args='-Gsize=20 -Gratio=1.4'))



# Poly Event Metadata

In [None]:
metadata = yaml.safe_load(metadata_path.open("r"))
poly_event_metadata = metadata["task"]["events"]["poly"]
poly_event_metadata

In [None]:
line_nums = list(descendants)
node_metadata = {}
for ev in poly_event_metadata:
    nodes = line_nums
    for k, v in ev["detection"].items():
        if k == "poly_line_num":
            if not isinstance(v, list):
                v = [v]
            nodes = [n for n in nodes if n in v]
        elif k == "poly_name":
            if not isinstance(v, list):
                v = [v]
            vals = [re.compile("_" +val+"$") for val in v]
            nodes = [n for n in nodes if np.any([v.search(str(val)) is not None for val in graph.nodes(data=True)[n].values() for v in vals])]
        elif k == "poly_ignore_name":
            if not isinstance(v, list):
                v = [v]
            vals = [re.compile("_" +val+"$") for val in v]
            nodes = [n for n in nodes if np.all([v.search(str(val)) is None for val in graph.nodes(data=True)[n].values() for v in vals])]
        else:
            raise Exception(f"unrecognized detection key {k}")
    for n in nodes:
        if n in node_metadata:
            raise Exception(f"double event metadata information for node {n}")
        node_metadata[n] = {k:v for k,v in ev["description"].items()}
yaml.dump(node_metadata, node_metadata_path.open("w"))
node_metadata

In [None]:
annotated_graph = graph.copy()
cmap = mpl.colormaps["tab10"].colors
color_list= [mpl.colors.to_hex(c) for c in cmap[3:]]
colors={"error": "red", "reward": "lightgreen"}
for n, meta in node_metadata.items():
    attr_dict = annotated_graph.nodes(data=True)[n]
    attr_dict["event_metadata"] = meta
    attr_dict["shape"] = "rectangle"
    if "fillcolor" not in attr_dict:
        if "event" in attr_dict["event_metadata"]:
            ev_name = attr_dict["event_metadata"]["event"]
            if not ev_name in colors:
                colors[ev_name] = color_list[len(colors)]
            attr_dict["fillcolor"] = colors[ev_name]
            attr_dict["style"]="filled"
            attr_dict["legend"]=ev_name
svg = nxrender(annotated_graph, nodeautolabel=[..., "+dot", re.compile('NEXT\d*')], format="svg", args='-Gsize=22 -Gratio=1.4')
SVG(svg)

In [None]:
with annotated_task_graph_path.open("wb") as f:
    f.write(svg )