# LHS Parsing

In [62]:
#| default_exp lhs

In [63]:
#| hide
from nbdev.showdoc import show_doc
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Overview

This module defines the grammar of the LHS that is given by the user to the *rewrite* function of the library.
The module is also responsible for parsing of the pattern sent as LHS, into a networkX graph representing the template to search.

The module converts the declerative constraints regarding the properties of the nodes and edges in the LHS, to imperative functions that are checked together with the 'condition' parameter of *rewrite*.

### Requirements

In [64]:
from lark import Lark

In [65]:
Lark

lark.lark.Lark

In [66]:
#| export
import copy
from collections.abc import Callable
import networkx as nx
from lark import Transformer, Lark
from lark import UnexpectedCharacters, UnexpectedToken
from graph_rewrite.match_class import Match
from graph_rewrite.core import GraphRewriteException
from graph_rewrite.core import _create_graph,  _graphs_equal, draw
from collections import defaultdict
from graph_rewrite.match_class import _convert_to_edge_name

### Grammar
The grammar induces the allowed syntax of a legal LHS string that can be provided by the user. 

In [67]:
#| export
lhs_parser = Lark(r"""
    %import common.INT -> INT 
    %import common.FLOAT -> FLOAT
    %import common.ESCAPED_STRING -> STRING
    %import common.WS -> WS
    %ignore WS

    NAMED_VERTEX: /[_a-zA-Z0-9]+/
    ANONYMUS: "_"
    ATTR_NAME: /[_a-zA-Z0-9]+/
    TYPE:  "int" | "str" | "bool" | "float"
    BOOLEAN: "True" | "False"
    NATURAL_NUMBER: /[1-9][0-9]*/
    INDEX: /[0-9]+/

    value: FLOAT | INT | BOOLEAN | STRING

    attribute: ATTR_NAME [":" TYPE] ["=" value]
    attributes: "[" attribute ("," attribute)* "]"

    multi_connection: "-" NATURAL_NUMBER [attributes] "->" 
    connection: ["-" attributes]"->"
              | multi_connection
    
    index_vertex: NAMED_VERTEX "<" INDEX ("," INDEX)* ">"

    vertex: NAMED_VERTEX [attributes]
    | index_vertex [attributes]
    | ANONYMUS [attributes]

    pattern: vertex (connection vertex)*
    patterns: pattern ("," pattern)* 

    """, parser="lalr", start='patterns' , debug=True)

# TODO: Add the ";" delimiter to the lark grammar - don't think about it on your own, ask Dean

# multi_connection: "-" NATURAL_NUMBER "+" [attributes] "->"  - setting for the "-num+->" feature

### Transformer
The transformer is designed to return the networkX graph representing the patterns received by the user.

For each branch, the appropriate method will be called with the children of the branch as its argument, and its return value will replace the matching node in the tree.

The secondary task of the transformer is to collect the node/edge type and constant node/edge value constraints, such that they are added to the 'condition' parameter to be checked later. Thus, the lhsTransformer contains a python dictionary *constraints* which accumulates the constraints from all components of the parsed graph.

In [68]:
#| export
RenderFunc = Callable[[Match], any] # type of a function to render a parameter

In [69]:
#| export
cnt:int = 0 # unique id for anonymous vertices
class graphRewriteTransformer(Transformer):
    def __init__(self, visit_tokens: bool = True, component: str = "LHS", match: Match = None, render_funcs: dict[str, RenderFunc] = {}) -> None:
        super().__init__(visit_tokens)
        # general
        self.component = component
        # RHS parameters
        self.match = match
        self.render_funcs = render_funcs
        # LHS parameters
        self.constraints = {}
        self.cnt = 0

    def STRING(self, arg):
        # remove " "
        return arg[1:-1] 
    
    def BOOLEAN(self, arg):
        return bool(arg)
    
    def INT(self, arg):
        # can be negative
        return int(arg)
    
    def FLOAT(self, arg):
        return float(arg)
    
    def NATURAL_NUMBER(self, number): 
        # represents number of duplications
        return int(number)
    
    def USER_VALUE(self, arg):
        # get the variable name
        variable = arg[2:-2]
        # extract the actual value supplied by the user - can be of any type.
        return self.render_funcs[variable](self.match) 
    
    def value(self, args): 
        # one argument encased in a list
        return args[0]
    
    def attribute(self, args): 
        # if an optional token was not parsed, None is placed in the parse tree.
        # if type and value are not allowed, then None is entered manualy.
        if self.component == "P": 
            attr_name = args[0]
            type, value = None, None
        else:
            attr_name, type, value = args
        # pass a tuple of attr_name, required type, required value.
        return (attr_name, type, value)
    
    def attributes(self, attributes): # a list of triples 
        # return a packed list of the attribute names.
        attr_names, constraints = {}, {}
        for attribute in attributes:
            # will be added to the graph itself
            attr_name, type, value = attribute
            if self.component == "LHS":
                attr_names[str(attr_name)] = None 
                # will be added to the condition function
                constraints[str(attr_name)] = (type, value) 
            else:
                attr_names[str(attr_name)] = value

        return (attr_names, constraints)

    def multi_connection(self, args): # +
        # return the list of attributes(strings), add a special attribute to denote number of duplications.
        number, attributes = args
        if attributes == None:
            attributes = ({},{})
        # add a special atrribute to handle duplications during construction
        attributes[0]["$dup"] = number 
        return attributes

    def connection(self, args): 
        # (tuple of dicts: attributes, constraints. attributes is of the form: attribute -> val)
        attributes = args[0]
        if attributes == None:
            attributes = ({},{})
        # add a special atrribute to handle duplications during construction
        attributes[0]["$dup"] = 1
        return (attributes, True)

    def ANONYMUS(self, _): #
        # return a dedicated name for anonymus (string), and an empty indices list.
        x = "_" + str(self.cnt)
        self.cnt += 1
        return (x, [])

    def index_vertex(self, args):
        # return the main name of the vertex, and a list of the indices specified.
        main_name_tup, *numbers = args #numbers is a list
        return (main_name_tup[0], list(numbers))
    
    def NAMED_VERTEX(self, name):
        # return the main name of the vertex, and an empty indices list.
        return (str(name), [])

    def vertex(self, args): # (vertex_tuple: tuple, attributes: list)
        # attributes is a empty list/ a list containing a tuple: (names dict, constraints dict)
        vertex_tuple, *attributes = args 
        name, indices_list = vertex_tuple

        # create new name
        indices = ",".join([str(num) for num in indices_list])
        if len(indices) == 0:
            new_name = str(name)
        else:
            new_name =  name + "<" + indices + ">" 

        # no attributes to handle
        if attributes[0] == None:
            return (new_name, {})
        
        # now that we have the vertex name we add the attribute constraints:
        # vertices may appear multiple times in LHS thus we unite the constraints. We assume there cannot be contradicting constraints.
        attribute_names, constraints = attributes[0] 
        # the second element of the tuple is the constraints dict: attr_name -> (value,type)
        if self.component == "LHS":
            if new_name not in self.constraints.keys():
                self.constraints[new_name] = {}
            self.constraints[new_name] = self.constraints[new_name] | constraints 
        return (new_name, attribute_names)

    def pattern(self, args):
        # 1) unpack lists of vertices and connections.
        vertex, *rest = args
        conn, vertices = list(rest)[::2], list(rest)[1::2]
        vertices.insert(0,vertex)
        # 2) create a networkX graph:
            # Future feature: if there is a special attribute with TRUE (deterministic), dumplicate the connection $dup times.
        G = nx.DiGraph()

        # simplified vertion - ignore duplications
        G.add_nodes_from(vertices)
        edge_list = []
        for i,edge in enumerate(conn):
            # for now the duplication feature is not included so we remove the $dup attribute
            # we handeled None in the connection rule.
            attribute_names, constraints = edge[0]
            attribute_names.pop("$dup", 0)
            # ignore edge[1] - determinism flag. edge[0] is the tuple of dicts of attributes.
            vertex_name_pos = 0 # each item in vertices is a tuple (vertex_name, attrs)
            edge_list.append((vertices[i][vertex_name_pos], vertices[i+1][vertex_name_pos], attribute_names)) 

            # add constraints - we assume an edge only appears once in LHS
            if self.component == "LHS":
                filtered_cons = dict(filter(lambda tup: not tup[1] == (None, None), constraints.items()))
                # check if filtered_cons is not empty - there are concrete constraints
                if filtered_cons: 
                    self.constraints[str(vertices[i][vertex_name_pos]) + "->" + str(vertices[i+1][vertex_name_pos])] = filtered_cons

        # more complex vertion - duplications
        # create a recursive function that adds the vertices and edges, 
        # that calls itself by the number of duplications on each level.

        G.add_edges_from(edge_list)
        return G

    def empty(self, _):
        return nx.DiGraph()
    
    def patterns(self, args):
        g, *graphs = args
        graphs.insert(0,g)
        # unite all the patterns into a single graph
        G = nx.DiGraph()

        # dict of dicts (node_name -> attribute -> None/someValue)
        combined_attributes = dict() 
        new_nodes = []
        new_edges = []
        for graph in graphs:
            for node in graph.nodes:
                if node not in combined_attributes.keys():
                    combined_attributes[node] = {}
                combined_attributes[node] = combined_attributes[node] | graph.nodes.data()[node]
                #unite the dicts for each
                new_nodes.append(node) 
            for edge in graph.edges:
                # we assumed edges cannot appear more than once in LHS
                combined_attributes[edge[0] + "->" + edge[1]] = graph.edges[edge[0],edge[1]]
                new_edges.append(edge)
        # filtered_attr = dict(filter(lambda _,value: not value == (None, None), combined_attributes.items()))
        G.add_nodes_from([(node, combined_attributes[node]) for node in new_nodes])
        G.add_edges_from([(node1, node2, combined_attributes[node1 + "->" + node2]) for (node1,node2) in new_edges])
        
        #sent as a module output and replaces condition.
        return (G, copy.deepcopy(self.constraints)) 

### Transformer Application
The following function applies the transformer on an LHS-formatted string provided by the user, to extract the constraints and the resulting networkx greaph. Then it unites the constraints with the constraints given in the *condition* function supplied by the user, so that they will be inforced together later on.

In [70]:
#| export
#TODO: Once Dean approves the solution, remove the condition parameter from this func and from all caller functions
# (it is only used for the POC, to avoid changing the entire code before even approving the code)

def lhs_to_graph(lhs: str, condition=lambda x: True, debug=False):
    """
    Converts a LHS string to a networkx graph and extracts constraints.

    Args:
    - lhs: str - a string representing the LHS of a rule.
    - debug: bool - if True, returns the parse tree and the collections tree, instead of the graphs.

    Returns:
    - nx.DiGraph - the graph representing the single nodes pattern.
    - nx.DiGraph - the graph representing the collections pattern.
    """
    try:
        lhs_strs = lhs.split(';')
        assert len(lhs_strs) <= 2 and len(lhs_strs) >= 1 # at most 2 parts: single nodes and collections, at least 1 part (single nodes)

        # Single nodes
        single_nodes_lhs = lhs_strs[0]
        single_nodes_tree = lhs_parser.parse(single_nodes_lhs)
        if debug:
            return single_nodes_tree, '', None # collections_tree == ''
        single_nodes_graph, single_nodes_constraints = graphRewriteTransformer(component="LHS").transform(single_nodes_tree)
        _process_constraints(single_nodes_graph, single_nodes_constraints)

        # Collections
        if len(lhs_strs) == 2: # collections part exists
            collections_lhs = lhs_strs[1]
            collections_tree = lhs_parser.parse(collections_lhs)
            if debug:
                return single_nodes_tree, collections_tree, None
            collections_graph, collections_constraints = graphRewriteTransformer(component="LHS").transform(collections_tree)
            _process_constraints(collections_graph, collections_constraints)
        else: # collections is empty
            collections_graph = nx.DiGraph()

        return single_nodes_graph, collections_graph, condition # TODO: remove condition once Dean approves the solution - it is only used for working on the POC

    except (BaseException, UnexpectedCharacters, UnexpectedToken) as e:
        raise GraphRewriteException('Unable to convert LHS: {}'.format(e))


def _process_constraints(graph: nx.DiGraph, constraints: dict):
    """
    Processes the constraints for the given graph and inserts the attribute values into the nodes/edges.

    Args:
    - graph: nx.DiGraph - the graph whose constraints are to be processed.
    - constraints: dict - a dictionary of constraints from the graphRewriteTransformer.

    This method directly modifies the graph to include the attribute values in the nodes/edges, using the same format as
    attributes in nx.DiGraphs objects (i.e. a dictionary of dictionaries, for example: {graph_obj: {attr_name: attr_value}}).
    """
    for graph_obj, attr_constraints in constraints.items():
        for attr_name, (attr_type_str, attr_value) in attr_constraints.items():
            converted_value = _convert_value(attr_value, attr_type_str)

            if graph_obj in graph.nodes:
                graph.nodes[graph_obj][attr_name] = converted_value
            else: # Edge
                node1, node2 = graph_obj.split("->")
                graph.edges[node1, node2][attr_name] = converted_value


def _convert_value(attr_value: str, attr_type_str: str):
    """
    Converts the attribute value to the specified type. If the type is unrecognized,
    it wraps the value in `TypedAttribute`. If the type is unmentioned (None), it 
    returns the value as is.

    Args:
    - attr_value: str - the attribute value as a string.
    - attr_type_str: str - the string representing the expected type.

    Returns:
    - The converted value or an instance of the `TypedAttribute` class.
    """
    if attr_value is None:
        return None

    # We need to take the type into account when examining the value
    type_mapping = {
        "str": str,
        "int": int,
        "float": float,
        "bool": bool
    }

    if attr_type_str is None:
        return attr_value
    
    # TODO: show Dean - do we want to keep ignoring unrecognized types, or should we raise an exception?
    if attr_type_str not in type_mapping:
        return TypedAttribute(attr_type_str, str(attr_value), False)

    return TypedAttribute(attr_type_str, type_mapping[attr_type_str](attr_value))

class TypedAttribute:
    """
    Wrapper class for attributes with a specified type. 
    This class holds the attribute value and its type.
    """
    def __init__(self, attr_type_str, attr_value, is_recognized=True):
        self._attr_type_str = attr_type_str
        self._attr_value = attr_value
        self._is_recognized = is_recognized

    # Transparent access to the value
    def __getitem__(self):
        return self._attr_value

    def __eq__(self, other):
        # Compare both the value and the type to another value
        if self._is_recognized:
            return self._attr_value == other and self._attr_type_str == other.__class__.__name__
        return self._attr_value == other
    
    def __ne__(self, other):
        return not self.__eq__(other)
    
    def __repr__(self):
        return str(self._attr_value)

### Tests
Note that throughout these tests, we use the naive condition which returns True for all matches. We chose to do that since this module is all about parsing, which is not affected by the condition.
The condition will be checked appropriately in the module that actually uses it, the Matcher module.

#### Basic Connections

In [71]:
naive_condition = lambda x: True
res, _, _ = lhs_to_graph("a", naive_condition)
expected = _create_graph(['a'], [])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

res, _, _ = lhs_to_graph("a->b", naive_condition)
expected = _create_graph(['a','b'], [('a','b')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

res, _, _ = lhs_to_graph("a -> b", naive_condition)
expected = _create_graph(['a','b'], [('a','b')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

res, _, _ = lhs_to_graph("a->b -> c", naive_condition)
expected = _create_graph(['a','b','c'], [('a','b'),('b','c')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

res, _, _ = lhs_to_graph("a->b -> a", naive_condition)
expected = _create_graph(['a','b'], [('a','b'),('b','a')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

# anonymus vertices
res, _, _ = lhs_to_graph("a->_->b->_", naive_condition)
expected = _create_graph(['a','b','_0','_1'], [('a','_0'),('_0','b'),('b','_1')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

#### Attributes

In [72]:
res, _, _ = lhs_to_graph("a[x=5]", naive_condition)
expected = _create_graph([('a', {'x': 5})], [])
assert(_graphs_equal(expected, res))

res, _, _ = lhs_to_graph("a-[x=5]->b", naive_condition)
expected = _create_graph(['a', 'b'], [('a','b',{'x': 5})])
assert(_graphs_equal(expected, res))

res, _, _ = lhs_to_graph("a<1,2>[x=5, y: int = 6]", naive_condition)
expected = _create_graph([('a<1,2>',{'x':5, 'y':6})],[])
assert(_graphs_equal(res,expected))

res, _, _ = lhs_to_graph("a[a]-[x]->b[ b ] -> c[ c ]", naive_condition)
expected = _create_graph([('a',{'a':None}), ('b',{'b':None}), ('c',{'c':None})],[('a','b',{'x':None}),('b','c')])
assert(_graphs_equal(res,expected))


In [73]:
t2, _, _ = lhs_to_graph('''rel[val:str="relation"]->z[val:str="relation_name"]''',naive_condition,debug=True)
g2, _, c = lhs_to_graph('''rel[val:str="relation"]->z[val:str="relation_name"]''' ,naive_condition)
draw(g2)
print(t2.pretty())

patterns
  pattern
    vertex
      rel
      attributes
        attribute
          val
          str
          value	"relation"
    connection	None
    vertex
      z
      attributes
        attribute
          val
          str
          value	"relation_name"



#### multiple patterns

In [74]:
res, _, _ = lhs_to_graph("a->b -> c, c-> d", naive_condition) 
expected = _create_graph(['a','b','c','d'], [('a','b'),('b','c'),('c','d')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

res, _, _ = lhs_to_graph("a->b -> c, d", naive_condition) 
expected = _create_graph(['a','b','c', 'd'], [('a','b'),('b','c')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

res, _, _ = lhs_to_graph("a->b[z] -> c[y], c[x=5]->b[r]", naive_condition) 
expected = _create_graph(['a',('b',{"z":None, "r":None}),('c',{'x':5,'y':None})], [('a','b'),('b','c'),('c','b')])
assert(_graphs_equal(expected, res))
#_plot_graph(res)

In [75]:
res_single, res_collection, _ = lhs_to_graph('c[type="course"];s[type="student"]->c', naive_condition) 
print(res_single.nodes.data())
print(res_collection.nodes.data())

[('c', {'type': 'course'})]
[('s', {'type': 'student'}), ('c', {})]


# Export

In [76]:
#|hide
import nbdev; nbdev.nbdev_export()
     