# Match Class

In [52]:
#| default_exp match_class

In [53]:
#| hide
from nbdev.showdoc import show_doc
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [54]:
from enum import Enum
from itertools import product
from collections import defaultdict

### Overview
The following module defines the **Match class**, that will be used by the Matcher module later. This class provides a subview to the graph which we attempt to transform, and is created based on some match from the LHS pattern to a subgraph of the input graph (Further explanations on the meaning of matches are included in the next module).

This subview is set such that read-write operations can be done on nodes and edges imperatively, based on their symbolic name (the node/edge they match in the pattern) rather than their actual name in the input graph. That is, we can access nodes and edges in the input graph which were matched to the pattern (and only them), by using their corresponding names in the pattern. 

For each match found, the user of the library will recieve its correpsonding Match object, which can be used for imperative side effects.

### Requirements

In [55]:
#| export
import networkx as nx
from networkx import DiGraph
from typing import *
from graph_rewrite.core import _create_graph, draw, GraphRewriteException, NodeName, EdgeName

### Example
Assume that we have an input graph $G$:

In [56]:
draw(_create_graph(['A','B','C','D','E'], [('A','B'), ('A','C'),('D','E')]))

And a pattern which looks like:

In [57]:
draw(_create_graph(['1','2','3'], [('1','2'), ('1','3')]))

Intuitively, we can see that the pattern can be found in the input graph $G$, such that the pattern node $1$ corresponds to the graph node $A$, and the same goes for $2$ and $B$, $3$ and $C$. That mapping defines a match of the pattern in graph $G$ (We will dive into the definition of a match in the next module).

The corresponding Match object will allows us, for example, changing an attribute of node $A$ in the input graph, by accessing the symbolic name $1$ (which matches to $A$) and setting it to some dictionary of attributes. Note that the user will be able to access this class only after the transformation is done - therefore, if the transformation removed $A$ from the graph, the user won't be able to access the symbolic name $1$ anymore.

### Utils
Edge $(u,v)$ is represented in a Match with the name "u->v".

In [58]:
#| export
def convert_to_edge_name(src: NodeName, dest: NodeName) -> str:
    """Given a pair of node names, source and destination, return the name of the edge
    connecting the two in the format {src}->{dest}, which is the same format the parser
    uses to create edges in the pattern graph.

    Args:
        src (NodeName): A node name
        dest (NodeName): A node name

    Returns:
        str: A representative name for the edge (src, dest).
    """
    return f"{src}->{dest}"

A node is **anonymous** if its name begins with '_'. This notion allows us to define patters with anonymous nodes, whose existence we want to enforce, but which we do not use in the RHS part of the pattern and thus, can ignore them by not assigning any symbolic name to them.

In [59]:
#| export
def is_anonymous_node(node_name: NodeName) -> bool:
    """Given a name of a node in the pattern graph, return true if it begins with '$',
    which is the notion the parser uses to denote anonymous nodes.

    Args:
        node_name (NodeName): A node name in the pattern

    Returns:
        bool: Returns True if the node is anonymous, False otherwise.
    """
    return len(node_name) >= 1 and node_name[0] == '_'

In [60]:
#| export
class NodeAttributeAccessor:
    """
    This class acts as a wrapper for a set of nodes, allowing access to specific attributes.
    """
    def __init__(self, nodes):
        self.nodes = nodes

    def __getitem__(self, attribute):
        """
        Return a list of the requested attribute from each node in the set

        Args:
            attribute: The attribute to retrieve from each node

        Returns:
            List: A list of the requested attribute from each node in the set
        """
        return [getattr(node, attribute) for node in self.nodes if hasattr(node, attribute)]

class EdgeAttributeAccessor:
    """
    This class acts as a wrapper for a set of edges, allowing access to specific attributes.
    """
    def __init__(self, edges):
        self.edges = edges

    def __getitem__(self, attribute):
        """
        Return a list of the requested attribute from each edge in the set

        Args:
            attribute: The attribute to retrieve from each edge

        Returns:
            List: A list of the requested attribute from each edge in the set
        """
        return [getattr(edge, attribute) for edge in self.edges if hasattr(edge, attribute)]

### The Match Class

A **Match** is a subview of the original graph, limited to the nodes, edges and attributes specified in the pattern. It includes the original graph, dictionaries which map nodes/edges to their corresponding attributes in the pattern, and the mapping from pattern nodes to real ones.

In [61]:
#| export
class Match:
    """Represents a single match of a pattern inside an input graph.
     Provides a subview to a graph, limited to the nodes, edges and attributes specified in the pattern.
    """
    def __init__(self, graph: DiGraph, nodes: List[NodeName], edges: List[EdgeName], 
                 mapping: Dict[NodeName, Set[NodeName]], collection_nodes: List[NodeName]):
        self.graph: DiGraph = graph
        self._nodes: List[NodeName] = nodes
        self._edges: List[EdgeName] = edges
        self.mapping: Dict[NodeName, Set[NodeName]] = mapping # Node names and edges can represent either single nodes or collections of nodes, so for each node name is mapped to a set of input nodes:
        #TODO: Show Dean - I think it is needed to distinguish between single pattern nodes with a single match and a collection pattern node with a single match
        # If we only check if there is only one input node, we can't distinguish between a single pattern node with a single match and a collection pattern node with a single match
        self._collection_nodes: Set[NodeName] = set(collection_nodes)     
    def _check_node_in_pattern(self, pattern_node: NodeName):
        if not pattern_node in self._nodes:
            raise GraphRewriteException(f"Node {pattern_node} does not exist in the pattern")
        
    def _check_edge_in_pattern(self, pattern_src: NodeName, pattern_dst: NodeName):
        if not (pattern_src, pattern_dst) in self._edges:
            raise GraphRewriteException(f"Edge {(pattern_src, pattern_dst)} does not exist in the pattern")
    
    def _is_collection(self, pattern_node: NodeName) -> bool:
        return pattern_node in self._collection_nodes
    
    def __get_node(self, pattern_node):
        """
        Returns the node or the collection of nodes mapped to the pattern node

        Args:
            pattern_node: The node in the pattern

        Returns:
            The node or the collection of nodes mapped to the pattern node
        """
        self._check_node_in_pattern(pattern_node)
        input_nodes = self.mapping[pattern_node]

        if self._is_single(pattern_node):
            return self.graph.nodes[list(input_nodes)[0]]
        return [self.graph.nodes[input_node] for input_node in input_nodes]
    
    def __get_edge(self, pattern_src, pattern_dst):
        """
        Returns the edge or the collection of edges mapped to the pattern edge

        Args:
            pattern_src: The source node in the pattern
            pattern_dst: The destination node in the pattern

        Returns:
            The edge or the collection of edges mapped to the pattern edge
        """
        self._check_edge_in_pattern(pattern_src, pattern_dst)
        input_src_nodes = self.mapping[pattern_src]
        input_dst_nodes = self.mapping[pattern_dst]
        if self._is_collection(pattern_src) or self._is_collection(pattern_dst):
            return [self.graph.edges[input_src_node, input_dst_node] for input_src_node, input_dst_node in product(input_src_nodes, input_dst_nodes)]
        return self.graph.edges[list(input_src_nodes)[0], list(input_dst_nodes)[0]]
    
    def nodes(self):
        return {pattern_node: self.__get_node(pattern_node) for pattern_node in self._nodes}        
    
    def edges(self):
        return {convert_to_edge_name(pattern_src, pattern_dest): self.__get_edge(pattern_src, pattern_dest) for (pattern_src, pattern_dest) in self._edges}

    def set_graph(self, graph: DiGraph):
        self.graph = graph

    def __eq__(self, other):
        if type(other) is Match and len(other.mapping.items()) == len(self.mapping.items()):
            return all([other.mapping.get(k) == v for k,v in self.mapping.items()])
        return False

    def __getitem__(self, key: Union[NodeName, str]):
        """
        Returns the node/edge (single or collections) of the input graph, which was mapped by the key in the pattern during matching.
        Supports nested access when the result is a set of nodes or edges (i.e. a collection).

        Args:
            key (Union[NodeName, str]): A symbolic name used by the pattern (for a node / edge)

        Raises:
            GraphRewriteException: If the key doesn't exist in the pattern, or is mapped to a node/edge
            which does not exist anymore (due to removal by the transformation, for example).

        Returns:
            The corresponding node/edge of the input graph, or a list of 'name' attributes if requested.
            If the result is a set of nodes or edges (if it is a collection), returns a NodeAttributeAccessor or EdgeAttributeAccessor respectively, 
            to allow access to specific attributes of each node/edge in the collection.
        """
        try:
            # Check if the key is for an edge ("node1->node2")
            if str(key).__contains__("->") and len(str(key).split("->")) == 2:
                end_nodes = str(key).split("->")
                edge_or_edges = self.__get_edge(end_nodes[0], end_nodes[1])
                
                if self._is_collection(end_nodes[0]) or self._is_collection(end_nodes[1]): # If the edge is connected to a collection of nodes
                    return EdgeAttributeAccessor(edge_or_edges)
                else: # Single edge
                    return edge_or_edges
            
            # Otherwise, assume it's for a node
            node_or_nodes = self.__get_node(key)
            
            # If the result is a set of nodes, return a NodeAttributeAccessor, that creates a list of attributes for each node in the set
            if self._is_collection(key):
                return NodeAttributeAccessor(node_or_nodes)
            else: # Single node
                return node_or_nodes
        except:
            raise GraphRewriteException(f"The symbol {key} does not exist in the pattern, or it was removed from the graph")
        
    def __str__(self):
        return self.mapping.__str__()

### Create a Match
This function will be used in the Matcher module, in order to convert a mapping from the pattern graph to the input graph into a corresponding instance of the Match class.

In [62]:
#| export
def mapping_to_match(input: DiGraph, pattern: DiGraph, collections_pattern: DiGraph, mapping: Dict[NodeName, Set[NodeName]],
                      node_type_mapping: Dict[NodeName,NodeType], filter: bool=True) -> Match:
    """Given a mapping, which denotes a match of the pattern in the input graph,
    create a corresponding instance of the Match class.

    Args:
        input (DiGraph): An input graph
        pattern (DiGraph): A pattern graph
        mapping (Dict[NodeName, Set[NodeName]]): A mapping of nodes in the pattern to nodes in the input graph
        node_type_mapping (Dict[NodeName, NodeType]): A mapping of nodes in the pattern to their type (single or collection)
        filter (bool): If True, filter out anonymous nodes from the match

    Returns:
        Match: A corresponding instance of the Match class
    """
    nodes_list = []
    edges_list = []

    cleared_mapping = mapping.copy()

    for pattern_node in mapping.keys():
        if filter and is_anonymous_node(pattern_node):
            cleared_mapping.pop(pattern_node)
            continue # as we don't want to include this node in the Match
        nodes_list.append(pattern_node)

    for (n1, n2) in pattern.edges:
        if filter and (is_anonymous_node(n1) or is_anonymous_node(n2)):
            continue # as before
        edges_list.append((n1, n2))

    return Match(input, nodes_list, edges_list, cleared_mapping, node_type_mapping)

### Tests
We take the example graph and pattern described above, add a few attributes to the nodes and edges, and create a Match instance based on the single mapping:

In [63]:
G = _create_graph(
    [('A', {'name': 'A'}),('B', {'name': 'B'}),('C', {'name': 'C'}),('D', {'name': 'D'}),('E', {'name': 'E'})],
    [('A','B', {'edge_attr': 10}), ('A','C'),('D','E')])
pattern = _create_graph(['1','2','3'], [('1','2'), ('1','3')])
# It has a single match, which is defined by the following mapping:
mapping = {'1': {'A'}, '2': {'B'}, '3': {'C'}}
node_type_mapping = {'1': NodeType.SINGLE, '2': NodeType.SINGLE, '3': NodeType.SINGLE}
mapping_match = mapping_to_match(G, pattern, DiGraph(), mapping, node_type_mapping)

In graph $G$, each node has a single attribute - "name" - whose value is the node's name. Therefore, node $A$ has the attribute "name" with the value "A". We can access its attributes using the Match, by accessing the name of the corresponding matched node $1$:

In [64]:
assert mapping_match['1'] == {'name': 'A'}
assert mapping_match['1']['name'] == 'A'
mapping_match['1']

{'name': 'A'}

This indicates that as expected, pattern node $1$ is matched in this Match instance to the input graph node $A$, as that node is the only one in the graph whose "name" attribute is equal to "A".

We can access the edges similarly, by using the format {src}->{dst} for edge $(src, dst)$:

In [65]:
assert mapping_match['1->2'] == {'edge_attr': 10}
mapping_match['1->2'] # accesses the attributes of the corresponding edge A->B in G

{'edge_attr': 10}

Say we want to add an attribute to $A$, we can change it in the same way using the Match:

In [66]:
mapping_match['1']['attr'] = 5
assert mapping_match['1'] == {'name': 'A', 'attr': 5}
mapping_match['1']

{'name': 'A', 'attr': 5}

And then see that the attributes of $A$ in the original graph $G$ have changed, as the Match is a subview of that graph, and so changes in the subview are reflected in $G$:

In [67]:
assert G.nodes(data=True)['A'] == {'name': 'A', 'attr': 5}
G.nodes(data=True)['A']

{'name': 'A', 'attr': 5}

We can also modify existing attributes in the same way using the Match, and see those changes reflected in $G$

In [68]:
mapping_match['2']['name'] = 'B*'
assert mapping_match['2']['name'] == 'B*'
assert G.nodes(data=True)['B']['name'] == 'B*'
G.nodes(data=True)['B']

{'name': 'B*'}

This also works in the reverse direction: changes in $G$ are reflected in the subview. Say that we set an attribute for edge $(A,B)$ in $G$, then that change would be reflected by accessing the edge '1->2' in the Match object that refers graph $G$:

In [69]:
G.edges()[('A','B')]['attr2'] = 20
assert mapping_match['1->2']['attr2'] == 20
mapping_match['1->2']

{'edge_attr': 10, 'attr2': 20}

## Visualizing matches

In [70]:
list(G.nodes(data=True)),list(G.edges(data=True))

([('A', {'name': 'A', 'attr': 5}),
  ('B', {'name': 'B*'}),
  ('C', {'name': 'C'}),
  ('D', {'name': 'D'}),
  ('E', {'name': 'E'})],
 [('A', 'B', {'edge_attr': 10, 'attr2': 20}), ('A', 'C', {}), ('D', 'E', {})])

In [71]:
mapping_match.nodes(),mapping_match._edges

({'1': {'name': 'A', 'attr': 5}, '2': {'name': 'B*'}, '3': {'name': 'C'}},
 [('1', '2'), ('1', '3')])

In [72]:
mapping_match.mapping

{'1': {'A'}, '2': {'B'}, '3': {'C'}}

In [73]:
# We want to color the nodes and edges that are part of the match in blue, and specify the name of the pattern node
# Also, if more than one pattern node is mapped to the same input node, we want to color them in red, and specify the names of the pattern nodes
# and also notify the user (as a warning/notification) 
# To do that, we need to use a reverse mapping, where the key is the input node and the value is the pattern node
# TODO: Show Dean the idea in the meeting notes

def draw_match(g, m, **kwargs):
    g_copy = g.copy()
    node_styles = {}
    edge_styles = {}

    # Reverse mapping with defaultdict for handling multiple pattern nodes per input node
    # TODO: We still need to decide how to handle multiple pattern nodes per input node
    reverse_mapping = defaultdict(list)
    for pattern_node, input_nodes in m.mapping.items():
        for input_node in input_nodes:
            reverse_mapping[input_node].append(pattern_node)

    # Handle node styling
    for input_node, pattern_nodes in reverse_mapping.items():
        if len(pattern_nodes) > 1: # Collection node
            for pattern_node in pattern_nodes:
                node_styles[input_node] = 'stroke:red,stroke-width:4px;'
                g_copy.nodes[input_node]['label'] = f"{pattern_node}_{input_node}"
            #TODO: remove the warning, we just want to see a list of all the pattern nodes that are mapped to the input node
            print(f"Warning: Node {input_node} is mapped to multiple pattern nodes: {pattern_nodes}")
        else: # Single node
            node_styles[input_node] = 'stroke:blue,stroke-width:4px;'
            g_copy.nodes[input_node]['label'] = pattern_nodes[0]

    for u, v in m._edges:
        edge_styles.update({(input_u, input_v): 'stroke:blue,stroke-width:4px;'
                            for input_u, input_v in product(m.mapping.get(u, []), m.mapping.get(v, []))})

    draw(g_copy, node_styles=node_styles, edge_styles=edge_styles, **kwargs)


In [74]:
draw_match(G,mapping_match)

# Export

In [75]:
#|hide
import nbdev; nbdev.nbdev_export()