In [None]:
#| default_exp lhs

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# %pip install lark
# %pip install networkx
from lark import Lark

## LHS parser ##

parsing of the pattern sent as lhs, into a networkX graph representing the template to search.

The module converts the declerative constraints regarding the properties of the nodes and edges in the LHS, to imperative functions that are checked together with the 'condition' parameter

## grammar

In [None]:
#| hide
#attributes: allow optional \n here, for imperative syntax)
#attribute: #["=" value] 

#    attr_name: /[a-zA-Z0-9]+/ #TODO: lark-imported
#    type:  "int" | "string" | "bool" #TODO: escaped string or word
#    value: /[0-9a-zA-Z]/

#    %import common.WS #CHANGE to allow \n in the imperative option.


In [None]:
#| export

# lhs_parser = Lark(r"""
#     %import common.NUMBER -> NATURAL_NUMBER 
#     %import common.ESCAPED_STRING
#     %import common.WS 
#     %ignore WS

#     NAMED_VERTEX: /[a-zA-Z0-9]+/
#     ANONYMUS: "_"
#     ATTR_NAME: /[a-zA-Z0-9]+/
#     TYPE:  "int" | "string" | "bool"
#     VALUE: /[0-9a-zA-Z]/

#     attribute: ATTR_NAME [":" TYPE] ["=" VALUE]
#     attributes: "[" attribute ("," attribute)* "]"

#     multi_connection: "-" NATURAL_NUMBER "+" [attributes] "->" 
#                     | "-" NATURAL_NUMBER [attributes] "->" 
#     connection: "-" [attributes "-"] ">"
#               | multi_connection
    
#     index_vertex: NAMED_VERTEX "<" NATURAL_NUMBER ("," NATURAL_NUMBER)* ">"

#     vertex: NAMED_VERTEX [attributes]
#         | index_vertex [attributes]
#         | ANONYMUS [attributes]

#     pattern: vertex (connection vertex)*
#     patterns: pattern (";" pattern)*
        
#     """, parser="lalr", start='patterns' , debug=True)


In [None]:
lhs_parser = Lark(r"""
    %import common.INT -> INT 
    %import common.FLOAT -> FLOAT
    %import common.ESCAPED_STRING -> STRING
    %ignore WS

    NAMED_VERTEX: /[a-zA-Z0-9]+/
    ANONYMUS: "_"
    ATTR_NAME: /[a-zA-Z0-9]+/
    TYPE:  "int" | "string"
    BOOLEAN: "True" | "False"
    NATURAL_NUMBER: /[1-9][0-9]?/
    INDEX: /[0-9]+/

    value: FLOAT | STRING | INT | BOOLEAN

    attribute: ATTR_NAME [":" TYPE] ["=" value]
    attributes: "[" attribute ("," attribute)* "]"

    multi_connection: "-" NATURAL_NUMBER [attributes] "->" 
    connection: "-" [attributes "-"] ">"
              | multi_connection
    
    index_vertex: NAMED_VERTEX "<" INDEX ("," INDEX)* ">"

    vertex: NAMED_VERTEX [attributes]
    | index_vertex [attributes]
    | ANONYMUS [attributes]

    pattern: vertex (connection vertex)*
        
    """, parser="lalr", start='pattern' , debug=True)

# multi_connection: "-" NATURAL_NUMBER "+" [attributes] "->"  - setting for the "-num+->" feature

## Transformer
The transformer is designed to return the networkX graph representing the patterns.

For each branch, the appropriate method will be called with the children of the branch as its argument, and its return value will replace the branch in the tree.

In [None]:
#| export
import itertools
import copy
import networkx as nx

In [None]:
#| export
cnt:int = 0 # unique id for anonymous vertices
from lark import Tree, Transformer
class lhsTransformer(Transformer):
    def __init__(self, visit_tokens: bool = True) -> None:
        super().__init__(visit_tokens)
        self.constraints = {}

    def STRING(self, arg):
        return arg[1:-1] # remove " "
    
    def BOOLEAN(self, arg):
        return bool(arg)
    
    def INT(self, arg): # can be negative
        return int(arg)
    
    def FLOAT(self, arg):
        return float(arg)
    
    def NATURAL_NUMBER(self, number): # for duplications
        return int(number)
    
    def value(self,arg):
        return arg
    
    def attribute(self, args): #(attr_name, *rest):
        attr_name = args[0]
        if len(args) == 3:
            type, value = args[1:]
        elif len(args) == 2:
            # CHANGE!
            print("parsed: " + args[1])
            value = None
            type = None
        # pass a tuple of attr_name, required type, required value.
        return (attr_name, type, value) # constraints are handled in other transformer.
    
    def attributes(self, attributes): # a list of triples 
        # return a packed list of the attribute names.
        attr_names, constraints = {}, {}
        for attribute in attributes:
            # print(attribute)
            attr_names[str(attribute[0])] = None # will be added to the graph itself
            constraints[attribute[0]] = (attribute[1], attribute[2]) # will be added to the conditio function
        return (attr_names, constraints)

    def multi_connection(self, args): # +
        # return the list of attributes(strings), add a special attribute to denote number of duplications.
        #   for "-+->" implementation also return FALSE if "+" is parsed (indicating that the connection is not deterministic)
        number, attributes = args
        attributes[0]["$dup"] = number # removed in graph construction
        return attributes

    def connection(self, attributes): # (dict of attributes, constraints: attribute -> (val,type) )
        # return the packed list of attributes received, num_duplications = 1, is_deterministic = True
        attributes["$dup"] = 1
        return (attributes, True)

        # STOPPED HERE - in pattern, 
        # during the buiding process, conclude which constraints belong to each edge (for vertices its easier)
        # then edit "condition" accordingly.

    def ANONYMUS(self): #
        # return a dedicated name for anonymus (string), and an empty list.
        cnt += 1
        return ("$" + str(cnt), [])

    def index_vertex(self, args):
        # return the main name of the vertex, and a list of the indices specified.
        main_name_tup, *numbers = args #numbers is a list
        # print(main_name_tup)
        # print(numbers)
        return (main_name_tup[0], list(numbers))
    
    def NAMED_VERTEX(self, name):
        # return the main name of the vertex, and an empty list.
        return (name, [])

    def vertex(self, args): # (vertex_tuple: tuple, attributes: dict = {})
        # return arguments
        vertex_tuple, *attributes = args # attributes is a empty list/ a list containing a dict
        name, indices_list = vertex_tuple 
        if indices_list == None:
            indices_list = []
        indices = ",".join([str(num) for num in indices_list])
        new_name = name + "<"
        new_name = new_name + indices + ">" # numbers are strings, no convertion needed.
        if len(attributes) == 0:
            return (new_name, {})
        return (new_name, attributes[0])

    def pattern(self, args):
        # 1) unpack lists of vertices and connections.
        vertex, *rest = args
        conn, vertices = list(rest)[::2], list(rest)[1::2]
        vertices.insert(0,vertex)
        print(vertices)
        print(conn)
        # 2) create a networkX graph:
            # if there is a special attribute with TRUE, dumplicate the connection __number__ times.
        G = nx.DiGraph()

        # simplified vertion - ignore duplications
        G.add_nodes_from(vertices)
        edge_list = []
        for i,edge in enumerate(conn):
            edge_list.append((vertices[i], vertices[i+1], edge[0])) # ignore edge[1] - determinism flag. edge[0] is attributes.

        # more complex vertion - duplications
        # create a recursive function that adds the vertices and edges, 
        # that calls itself by the number of duplications on each level.
        G.add_edges_from(edge_list)
        return G

    # def patterns(self, g, *graphs):
    #     patterns = list(graphs)
    #     patterns.insert(0,g)
    #     # unite all the patterns into a single graph
    #     G = nx.DiGraph()

    #     combined_attributes = dict() # dict of dicts (node_name -> attribute -> value)
    #     new_nodes = []
    #     new_edges = []
    #     for graph in patterns:
    #         for node in graph.nodes:
    #             combined_attributes[node] = combined_attributes[node] | graph.nodes.data()[node]
    #             new_nodes.append(node) #unite the dicts for each
    #         for edge in graph.edges:
    #             combined_attributes[edge[0]+","+edge[1]] = combined_attributes[edge[0]+","+edge[1]] | graph.edges[edge[0],edge[1]]
    #             new_edges.append(edge)

    #     G.add_nodes_from([(node, combined_attributes[node]) for node in new_nodes])
    #     G.add_edges_from([(node1, node2, combined_attributes[node1+","+node2]) for (node1,node2) in new_edges])
    #     return G

## Type and constant value checking
The transformer is designed to collect the node type and constant node value constraints, such that they are added to the 'condition' parameter to be checked later.

This transformer works on a copy of the tree to keep it intact.

In [None]:
#| export

# lark merge transformers
from graph_rewrite.result_set import Match
class collectTypeConstraints(Transformer):
     
    def attribute(self, args): #(attr_name, *rest):
        
    
    def attributes(self, attributes): # a list of tuples 
        # return a packed list of the attribute names.
        # attr_dict = {}
        # for attribute in attributes:
        #     # print(attribute)
        #     attr_dict[str(attribute[0])] = attribute[1]
        # return attr_dict

    def multi_connection(self, args): # +


    def connection(self, attributes): # a dict of attributes
        # return the packed list of attributes received, num_duplications = 1, is_deterministic = True
        attributes["$dup"] = 1
        return (attributes, True)

    def ANONYMUS(self): #
        # return a dedicated name for anonymus (string), and an empty list.
        cnt += 1
        return ("$" + str(cnt), [])

    def index_vertex(self, args):
        # return the main name of the vertex, and a list of the indices specified.
        main_name_tup, *numbers = args #numbers is a list
        # print(main_name_tup)
        # print(numbers)
        return (main_name_tup[0], list(numbers))
    
    def NAMED_VERTEX(self, name):
        # return the main name of the vertex, and an empty list.
        return (name, [])

    def vertex(self, args): # (vertex_tuple: tuple, attributes: dict = {})
        # return arguments
        vertex_tuple, *attributes = args # attributes is a empty list/ a list containing a dict
        name, indices_list = vertex_tuple 


    def pattern(self, args):
    # def patterns(self, g, *graphs):



    def attribute(self, args):
        # return a mapping from attr_name - > required type and value

    def attributes(self, args):
        return args # return the packed list of the attribute mappings, empty if no attributes.

    def vertex(self, args):
        name, indices_list, attributes_list = args
        # same as lhsTransformer
        pass

    def pattern(self, vertex, *connections_to_vertex):
        # return arguments
        pass

    def patterns(self, *patterns):
        # unpack lists of vertices and connections.
        def typeCondition(Match):
            # for every vertex in vertex list:
                # create full_vertex_name by the attached indices list
                # for every attr, type, name required for the vertex:
                    # constructor = getName(type) - get the constructor for the type
                    # 1) check that the required type and value match together.
                    # try:
                    #     instance = constructor(value)
                    # Except:
                        # flag = False: value does not match the type.

                    # 2) check that the value constraint holds
                    # if getattr(instance, __eq__) == None:
                        # flag = False. the type must implement __eq__
                    # if not (instance == match[full_vertex_name][attr])
                        # flag = false

                    # no need to check the type constraint(?), if the value fits. (python)

            # TODO: perform the same iterations in the connections list.

            #return flag and condition(Match)
            pass

        return typeCondition #sent as a module output and replaces condition.
        pass

Apply the Transformers

In [None]:
#| export
def lhs_to_graph(lhs):
    tree = lhs_parser.parse(lhs)
    final_g = lhsTransformer().transform(tree) #networkx graph
    return final_g

### Tests

#### Grammar Testing

In [None]:
# tree = lhs_parser.parse("aaaaa->b") #.pretty(indent_str = " $ ")
# assert(tree != None)
# assert(tree)
# assert(tree == 
#     Tree(Token('RULE', 'patterns'),[
#       Tree(Token('RULE', 'pattern'),[
#         Tree(Token('RULE', 'vertex'),[
#           Tree(Token('RULE', 'vertex'),[]),
#           None
#         ])
#       ])
#     ]))

#### Graph Testing

In [None]:
def create_graph(nodes, edges=[]):
    g = nx.DiGraph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)
    return g

def graphs_equal(graph1, graph2):  
    return (
        graph1.adj == graph2.adj
        and graph1.nodes == graph2.nodes
        and graph1.edges == graph2.edges #added
        and graph1.graph == graph2.graph 
    )

#### Basic Connections

In [None]:
res = lhs_to_graph("a")
expected = create_graph(['a'], [])
assert(graphs_equal(expected, res))

res = lhs_to_graph("a->b")
expected = create_graph(['a','b'], [('a','b')])
assert(graphs_equal(expected, res))

res = lhs_to_graph("a -> b")
expected = create_graph(['a','b'], [('a','b')])
assert(graphs_equal(expected, res))

res = lhs_to_graph("a->b -> c")
expected = create_graph(['a','b','c'], [('a','b'),('b','c')])
assert(graphs_equal(expected, res))

#### Attributes

In [None]:
res = lhs_to_graph("a<1,2>[x=5, y: int = 6]")
expected = create_graph(['a<1,2>'])
expected.nodes["a<1,2>"]["x"] = "default"
expected.nodes["a<1,2>"]["y"] = "default"
assert(graphs_equal(res,expected))
# print(res.nodes)
# print(res.nodes.data())
# print(res.edges)
# print("------")
# print(expected.nodes)
# print(expected.nodes.data())
# print(expected.edges)

res = lhs_to_graph("a[a]->b[ b ] -> c[ c ]")
expected = create_graph(['a','b','c'], [('a','b'),('b','c')])
expected.nodes["a"]["a"] = "default"
expected.nodes["b"]["b"] = "default"
expected.nodes["c"]["c"] = "default"
assert(graphs_equal(expected, res))

#### multiple patterns

In [None]:
res = lhs_to_graph("a->b -> c; c-> d")
expected = create_graph(['a','b','c','d'], [('a','b'),('b','c'),('c','d')])
assert(graphs_equal(expected, res))

res = lhs_to_graph("a->b -> c; d")
expected = create_graph(['a','b','c', 'd'], [('a','b'),('b','c')])
assert(graphs_equal(expected, res))

res = lhs_to_graph("a->b -> c; c[x=5]")
expected = create_graph(['a','b','c'], [('a','b'),('b','c')])
expected.nodes["c"]["x"] = "default"
assert(graphs_equal(expected, res))

In [None]:
# required_syntax =  """
# a -> b

# a -[x:int = ...]-> b

# a -> b[x:int = ...]

# a -> b -6+[weight:int]-> c -> d[value:int]
# d<0> -> e
# d<5> -> e

# b -+-> d[value:int]
# d<0> -7-> e
# e<0,5> -> _

# b[ \
# value: str = \"hello\", \
# id: int \
# ]

# b -[
# ...
# ]-> c 

# """