In [1]:
import os

import graphviz
import pandas as pd
import jax.numpy as jnp

import appletree as apt
from appletree.share import DATAPATH
from appletree.components import ERBand

In [2]:
apt.set_gpu_memory_usage(0.2)

In [3]:
data = pd.read_csv(os.path.join(DATAPATH, 'data_XENONnT_Rn220_v8_strax_v1.2.2_straxen_v1.7.1_cutax_v1.9.0.csv'))
bins_cs1, bins_cs2 = apt.utils.get_equiprob_bins_2d(
    data[['cs1', 'cs2']].to_numpy(),
    [15, 15],
    order=[0, 1],
    x_clip=[0, 100],
    y_clip=[1e2, 1e4],
    which_np=jnp
)

In [4]:
er = ERBand(bins=[bins_cs1, bins_cs2], bins_type='irreg')

In [5]:
er.deduce(
    data_names=['cs1', 'cs2'], 
    func_name='simulate'
)

# er.compile()

In [6]:
def add_spaces(x):
    """Add four spaces to every line in x
    This is needed to make html raw blocks in rst format correctly
    """
    y = ''
    if isinstance(x, str):
        x = x.split('\n')
    for q in x:
        y += '    ' + q
    return y

def tree_to_svg(graph_tree, save_as='data_types'):
    # Where to save this node
    graph_tree.render(save_as)
    with open(f'{save_as}.svg', mode='r') as f:
        svg = add_spaces(f.readlines()[5:])
    # os.remove(f'{save_as}.svg')
    os.remove(save_as)
    return svg

In [7]:
def add_deps_to_graph_tree(graph_tree, 
                           data_names: list = ['cs1', 'cs2', 'eff'], 
                           _seen = None):
    """Recursively add nodes to graph base on plugin.deps"""
    if _seen is None:
        _seen = []
    for data_name in data_names:
        if data_name in _seen:
            continue

        # Add new one
        graph_tree.node(data_name,
                        style='filled',
                        href='#' + data_name.replace('_', '-'),
                        fillcolor='white')
        if data_name == 'batch_size':
            continue
        dep_plugin = er._plugin_class_registry[data_name]
        for dep in dep_plugin.depends_on:
            graph_tree.edge(data_name, dep)
            graph_tree, _seen = add_deps_to_graph_tree(graph_tree,
                                                       dep_plugin.depends_on,
                                                       _seen)
        _seen.append(data_name)
    return graph_tree, _seen

def add_plugins_to_graph_tree(graph_tree, 
                              data_names: list = ['cs1', 'cs2', 'eff'], 
                              _seen = None,
                              with_data_names=False):
    """Recursively add nodes to graph base on plugin.deps"""
    if _seen is None:
        _seen = []
    for data_name in data_names:
        if data_name == 'batch_size':
            continue

        plugin = er._plugin_class_registry[data_name]
        plugin_name = plugin.__name__
        if plugin_name in _seen:
            continue

        # Add new one
        label = f'{plugin_name}'
        if with_data_names:
            label += f"\n{', '.join(plugin.depends_on)}\n{', '.join(plugin.provides)}"
        graph_tree.node(plugin_name,
                        label=label,
                        style='filled',
                        href='#' + plugin_name.replace('_', '-'),
                        fillcolor='white')

        for dep in plugin.depends_on:
            if dep == 'batch_size':
                continue
            dep_plugin = er._plugin_class_registry[dep]
            graph_tree.edge(plugin_name, dep_plugin.__name__)
            graph_tree, _seen = add_plugins_to_graph_tree(
                graph_tree,
                plugin.depends_on,
                _seen
            )
        _seen.append(data_name)
    return graph_tree, _seen

In [8]:
graph_tree = graphviz.Digraph(format='svg', strict=True)
_ = add_deps_to_graph_tree(graph_tree)
_ = tree_to_svg(graph_tree, 'dtypes')

In [9]:
graph_tree = graphviz.Digraph(format='svg', strict=True)
_ = add_plugins_to_graph_tree(graph_tree)
_ = tree_to_svg(graph_tree, 'plugins')