# Core

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import show_doc
%load_ext autoreload
%autoreload 2

### Overview
This basic module contains useful utility functions for other modules, as well as the defintion of the library's exceptions class.

### Requirements

In [None]:
#| export
from pathlib import Path

import networkx as nx
from networkx import DiGraph, planar_layout, spring_layout, draw_networkx_nodes, draw_networkx_labels, draw_networkx_edges

import html
import pandas as pd

from IPython.display import display
from typing import *

from jinja2 import Template, Environment, PackageLoader, meta


### GraphRewriteException
Exceptions in our library modules will be defined by this exceptions class.

In [None]:
#| export
class GraphRewriteException(Exception):
    """Exception class for the graph_rewrite library."""
    def __init__(self, msg: str):
        self.message = msg
        super().__init__(msg)
    pass

### Test Utilities
In our modules, we construct different graphs for testing and explanations. The following functions provide ways to construct the graphs as NetworkX DiGraphs - similarily to the graphs used in this library as input and output  - plot them and compare them.

#### Graph Construction
We want to allow constructing new DiGraph's based on lists of nodes and edges. Begin by defining the allowed types for node and edge names, based on NetworkX's restrictions:

In [None]:
#| export
NodeName = str
# When defining an edge, the first node is the source and the second is the target (as we use directed graphs).
EdgeName = Tuple[NodeName, NodeName]

Now, we can construct new graphs out of nodes/edges lists:

In [None]:
#| export
def _create_graph(nodes: list[Union[NodeName, Tuple[NodeName, dict]]], edges: list[Union[EdgeName, Tuple[NodeName, NodeName, dict]]]) -> DiGraph:
    """Construct a directed graph (NetworkX DiGraph) out of lists of nodes and edges.

    Args:
        nodes (list[Union[NodeName, Tuple[NodeName, dict]]]): 
            a list of node names (with or without attributes). e.g., ['A', 'B', (1, {'attr': 5}), 2].
        edges (list[Union[EdgeName, Tuple[NodeName, NodeName, dict]]]):
            a list of edges, each defined by a tuple of two node names (source, target), perhaps with attributes added.
            e.g., [('A','B'), (1,'A', {'attr': 5})].

    Returns:
        DiGraph: the newly constructed DiGraph.
    """
    g = DiGraph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)
    return g

#### Graph Plotting
We will use the following constants when plotting graphs in the modules:

In [None]:
#| export
plot_consts = {
    "node_size": 300,
    "node_color": 'g',
    # Highlighted nodes can have different colors
    "hl_node_color": 'r',

    "font_size": 7,
    "font_color": 'w',

    "arrow_size": 10,
    "edge_color": 'k',
    "edge_width": 1,
    # Highlighted edges can have different colors
    "hl_edge_color": 'r',
    "hl_edge_width": 2,

    # The plotter has some optional layouting modes, we choose one here
    "layouting_method": planar_layout
}

The following function allows plotting graphs, as well as highlighting a subset of nods and edges if desired:

In [None]:
#| export
def _plot_graph(g: DiGraph, hl_nodes: set[NodeName] = set(), hl_edges: set[EdgeName] = set(), node_attrs: bool = False, edge_attrs: bool = False):
    """Plot a graph, and potentially highlight certain nodes and edges.

    Args:
        g (DiGraph): a graph to plot
        hl_nodes (set[NodeName], optional): set of node names to highlight. Defaults to set().
        hl_edges (set[EdgeName], optional): set of edge names to highlight. Defaults to set().
        node_attrs (bool, optional): If true, print node attributes. Defaults to False.
        edge_attrs (bool, optional): If true, print edge attributes. Defaults to False.
    """
    global plot_consts

    # Seperate highlighted nodes and edges, remove if doesn't exist in the graph g
    hl_nodes = [node for node in g.nodes() if node in hl_nodes]
    non_hl_nodes = [node for node in g.nodes() if node not in hl_nodes]
    hl_edges = [edge for edge in g.edges() if edge in hl_edges]
    non_hl_edges = [edge for edge in g.edges() if edge not in hl_edges]

    # plotting
    for layout in [plot_consts["layouting_method"], spring_layout]:
        try:
            pos = layout(g)
            draw_networkx_nodes(g, pos, nodelist=non_hl_nodes, node_size=plot_consts["node_size"], 
                                node_color=plot_consts["node_color"])
            draw_networkx_nodes(g, pos, nodelist=hl_nodes, node_size=plot_consts["node_size"], 
                                node_color=plot_consts["hl_node_color"])
            draw_networkx_labels(g, pos, font_size=plot_consts["font_size"], font_color=plot_consts["font_color"])
            draw_networkx_edges(g, pos, edgelist=non_hl_edges, arrowsize=plot_consts["arrow_size"], 
                                node_size=plot_consts["node_size"], edge_color=plot_consts["edge_color"], width=plot_consts["edge_width"])
            draw_networkx_edges(g, pos, edgelist=hl_edges, arrowsize=plot_consts["arrow_size"], node_size=plot_consts["node_size"],
                                 edge_color=plot_consts["hl_edge_color"], width=plot_consts["hl_edge_width"])
            
            if node_attrs:
                display(pd.DataFrame([[attrs] for _, attrs in g.nodes(data=True)], 
                                    columns = ['Attributes'], 
                                    index=[node for node, _ in g.nodes(data=True)])
                            .style.set_properties(**{'text-align': 'left', 'max_colwidth': None})
                            .set_table_styles([dict(selector = 'th', props=[('text-align', 'left')])]))

            if edge_attrs:
                display(pd.DataFrame([[attrs] for _, _, attrs in g.edges(data=True)], 
                                    columns = ['Attributes'], 
                                    index=[f'({src}, {dst})' for src, dst, _ in g.edges(data=True)])
                            .style.set_properties(**{'text-align': 'left', 'max_colwidth': None})
                            .set_table_styles([dict(selector = 'th', props=[('text-align', 'left')])]))
            
            return
        except:
            print("Graph isn't planar, priniting in spring layout mode.")

#### Graph Comparison

In [None]:
#| export
def _graphs_equal(graph1: DiGraph, graph2: DiGraph) -> bool:  
    """Compare two graphs - nodes, edges and attributes.

    Args:
        graph1 (DiGraph): A NetworkX graph
        graph2 (DiGraph): A NetworkX graph

    Returns:
        bool: True if the graphs are equal, False otherwise.
    """

    # Compare node attributes
    for node in graph1.nodes():
        if node not in graph2.nodes():
            return False

        attributes1 = graph1.nodes[node]
        attributes2 = graph2.nodes[node]

        if attributes1 != attributes2:
            return False

    # Compare edge attributes
    for edge in graph1.edges():
        if edge not in graph2.edges():
            return False

        attributes1 = graph1.edges[edge]
        attributes2 = graph2.edges[edge]

        if attributes1 != attributes2:
            return False
        
    # Compare graph structures
    #graph_structure_equal = nx.is_isomorphic(graph1, graph2)
    return True

### jinja2 rendering 

In [None]:
#| export

def template_undeclared_vars(template):
    """Computes all undeclared vars in a jinja template

    Args:
        template (Path or str): Path to file of template or string with the template content

    Returns:
        set: set of all undeclared vars
    """
    if isinstance(template, Path):
        template = template.read_text()
    env = Environment()
    parsed_content = env.parse(template)
    return meta.find_undeclared_variables(parsed_content)

def render_jinja(template, params: dict, silent=True, to_file: Path = None):
    """renders a jinja template

    Args:
        template (Path or str): Path to file of template or string with the template content
        params (Dict): parameter dictionary with the variables to render into the template
        silent (Bool, Optional): Whether to print the rendered template to screen, defaults to False
        to_file (Path, Optional): If a path is supplied, prints the template to the file of said path

    Returns:
        set: set of all undeclared vars
    """
    if isinstance(template, Path):
        template = template.read_text()
    instance_str = Template(template).render(**params)

    if not silent:
        print(instance_str)

    if to_file:
        to_file.write_text(instance_str)
        return None
    else:
        return instance_str
    

## Viz

In [None]:
#| export 

# visualizing the graph
import base64
from IPython.display import Image, display

## mermaid

In [None]:
#| export
def mm_ink(graphbytes):
    """Given a bytes object holding a Mermaid-format graph, return a URL that will generate the image."""
    base64_bytes = base64.b64encode(graphbytes)
    base64_string = base64_bytes.decode("ascii")
    return "https://mermaid.ink/img/" + base64_string


def mm_display(graphbytes):
    """Given a bytes object holding a Mermaid-format graph, display it."""
    display(Image(url=mm_ink(graphbytes)))


def mm(graph):
    """Given a string containing a Mermaid-format graph, display it."""
    graphbytes = graph.encode("ascii")
    mm_display(graphbytes)


def mm_link(graph):
    """Given a string containing a Mermaid-format graph, return URL for display."""
    graphbytes = graph.encode("ascii")
    return mm_ink(graphbytes)


def mm_path(path):
    """Given a path to a file containing a Mermaid-format graph, display it"""
    with open(path, "rb") as f:
        graphbytes = f.read()
    mm_display(graphbytes)

In [None]:
mm("""
graph LR;
A[hello]
A--> B & C & D;
B--> A & E;
C--> A & E;
D--> A & E;
E--> B & C & D;
""")

In [None]:
#| export
graph_template = """
flowchart {{direction}}
{% for i,name,desc,style in nodes -%}
{{name}}["{{desc}}"]
{% if style -%}
style {{name}} {{style}}
{% endif -%}
{% endfor -%}

{% for i,s,t,desc,style in edges -%}
{% if desc -%}
{{s}} -->|"{{desc}}"| {{t}}
{% else -%}
{{s}} --> {{t}}
{% endif -%}
{% if style -%}
linkStyle {{i}} {{style}}
{% endif -%}
{% endfor -%}

{% if default_node_style -%}
classDef default {{default_node_style}}
{% endif -%}

{% if default_edge_style -%}
linkStyle default {{default_edge_style}}
{% endif -%}


"""

def _escaped_html_format(s):
    s = repr(s)
    s = s.replace('\'','#quot;').replace('\"','#quot;')
    s = html.escape(s)
    s = s.replace('&#','#').replace('&','#')
    return s

def _get_node_description(node,data,props=None):
    label = data.pop('label',None)
    if props is None:
        keys = data.keys()
    else:
        keys = props
    
    attrs = ', '.join([f'{k}={_escaped_html_format(v)}' for k,v in data.items() if k in keys])
    if label is None:
        return f'{node}\n{attrs}'
    else:
        return f'{node}({label})\n{attrs}'

def _get_edge_description(data,props=None):
    if props is None:
        keys = data.keys()
    else:
        keys = props
    attrs = ', '.join([f'{k}={_escaped_html_format(v)}' for k,v in data.items() if k in keys])
    return f'{attrs}'

def draw(g:nx.DiGraph,props=None,ret_mermaid=False,
         default_node_style=None,
         default_edge_style=None,
         node_styles=None,
         edge_styles=None,
         direction='TB'):
    global graph_template
    # so we dont change the original graph
    g = g.copy()
    if node_styles is None:
        node_styles = {}
    if edge_styles is None:
        edge_styles = {}



    nodes = [(i,n, _get_node_description(n,data,props),node_styles.get(n,None)) 
             for i,(n,data) in enumerate(g.nodes(data=True))]
    edges = [(i,u,v,_get_edge_description(data,props),edge_styles.get((u,v),None)) 
             for i,(u,v,data) in enumerate(g.edges(data=True))]

    mermaid_text = render_jinja(graph_template,{'nodes':nodes,'edges':edges,
                                                'default_node_style':default_node_style,
                                                'default_edge_style':default_edge_style,
                                                'direction':direction,
                                                })
                                                
    if ret_mermaid:
        print(mermaid_text)
        #return mermaid_text
    
    mm(mermaid_text)

In [None]:
g = _create_graph([
    ('stringnode1', {'type': 'STRING', 'args': '"hello \\\nworld"', 'idx': 0}),
],[])
draw(g,ret_mermaid=True)


flowchart TB
stringnode1["stringnode1
type=#quot;STRING#quot;, args=#quot;#quot;hello \\\nworld#quot;#quot;, idx=0"]



In [None]:
from pydantic import BaseModel

class Class(BaseModel):
    name: str
    type: str

g = _create_graph([
    ('stringnode1', {'type': Class, 'val':Class(name='bob',type='person'), 'idx': 0}),
],[])
draw(g,ret_mermaid=True)


flowchart TB
stringnode1["stringnode1
type=#lt;class #quot;__main__.Class#quot;#gt;, val=Class(name=#quot;bob#quot;, type=#quot;person#quot;), idx=0"]



In [None]:
g = _create_graph(
    [
        (1,{'color':'blue','size':10,'label':'one'}),
        (2,{'color':'red','size':20}),
        (3,{'color':'green','size':30}),
    ],
    [
        (1,2,),
        (2,3,),
        (1,3,{'edge_attr':'foo'}),
    ]
)

In [None]:
#draw(g,ret_mermaid=True)
draw(g)

In [None]:
#draw(g,ret_mermaid=True)
draw(g,direction='LR')

In [None]:
draw(g,default_node_style='fill:#f9f,stroke:#333,stroke-width:4px;',default_edge_style='stroke:red,stroke-width:4px;')

In [None]:
draw(g,node_styles={1:'stroke:red,stroke-width:4px;'},edge_styles={(1,3):'stroke:red,stroke-width:4px;'})

# Export

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     