# Match Class

In [1]:
#| default_exp match_class

In [2]:
#| hide
from nbdev.showdoc import show_doc
%load_ext autoreload
%autoreload 2

### Overview
The following module defines the **Match class**, that will be used by the Matcher module later. This class provides a subview to the graph which we attempt to transform, and is created based on some match from the LHS pattern to a subgraph of the input graph (Further explanations on the meaning of matches are included in the next module).

This subview is set such that read-write operations can be done on nodes and edges imperatively, based on their symbolic name (the node/edge they match in the pattern) rather than their actual name in the input graph. That is, we can access nodes and edges in the input graph which were matched to the pattern (and only them), by using their corresponding names in the pattern. 

For each match found, the user of the library will recieve its correpsonding Match object, which can be used for imperative side effects.

### Requirements

In [3]:
#| export
import networkx as nx
from networkx import DiGraph
from typing import *
from graph_rewrite.core import _create_graph, draw, GraphRewriteException, NodeName, EdgeName
from itertools import product
from collections import defaultdict
import logging

### Logger ###
This logger is used for debugging purposes. 
As a default, it is printing warnings and errors on screen, but it can be configured to redirect the output to a file.
It is used used in this module to warn regarding the use of the same input node for different pattetn nodes in the graph - which might happen while using the collection feature.

In [4]:
#| export
logger = logging.getLogger(__name__)    

### Example
Assume that we have an input graph $G$:

In [5]:
draw(_create_graph(['A','B','C','D','E'], [('A','B'), ('A','C'),('D','E')]))

And a pattern which looks like:

In [6]:
draw(_create_graph(['1','2','3'], [('1','2'), ('1','3')]))

Intuitively, we can see that the pattern can be found in the input graph $G$, such that the pattern node $1$ corresponds to the graph node $A$, and the same goes for $2$ and $B$, $3$ and $C$. That mapping defines a match of the pattern in graph $G$ (We will dive into the definition of a match in the next module).

The corresponding Match object will allows us, for example, changing an attribute of node $A$ in the input graph, by accessing the symbolic name $1$ (which matches to $A$) and setting it to some dictionary of attributes. Note that the user will be able to access this class only after the transformation is done - therefore, if the transformation removed $A$ from the graph, the user won't be able to access the symbolic name $1$ anymore.

### Utils
Edge $(u,v)$ is represented in a Match with the name "u->v".

In [7]:
#| export
def _convert_to_edge_name(src: NodeName, dest: NodeName) -> str:
    """Given a pair of node names, source and destination, return the name of the edge
    connecting the two in the format {src}->{dest}, which is the same format the parser
    uses to create edges in the pattern graph.

    Args:
        src (NodeName): A node name
        dest (NodeName): A node name

    Returns:
        str: A representative name for the edge (src, dest).
    """
    return f"{src}->{dest}"

A node is **anonymous** if its name begins with '_'. This notion allows us to define patters with anonymous nodes, whose existence we want to enforce, but which we do not use in the RHS part of the pattern and thus, can ignore them by not assigning any symbolic name to them.

In [8]:
#| export
def is_anonymous_node(node_name: NodeName) -> bool:
    """Given a name of a node in the pattern graph, return true if it begins with '$',
    which is the notion the parser uses to denote anonymous nodes.

    Args:
        node_name (NodeName): A node name in the pattern

    Returns:
        bool: Returns True if the node is anonymous, False otherwise.
    """
    return len(node_name) >= 1 and node_name.startswith('_anonymous_node_')

***Collection View***
The NodeCollectionViewView and NodeCollectionView classes provide a way to retrieve the input nodes/edges mapped to a collection pattern node/edge, and their attributes.

In [9]:
#| export
class NodeCollectionView:
    """
    This class acts as a wrapper for a set of node names, allowing to get information about the nodes
    matched to the collection pattern node in the input graph.
    """
    def __init__(self, input_nodes):
        """
        Initialize a NodeCollectionView object.

        Args:
            nodes (set): A set of Digraph node objects.
        """
        self.input_nodes = input_nodes

    def __getitem__(self, attribute):
        """
        Return a list of the requested attribute from each node in the set.
        If the attribute is not found in a node, it is not included in the result, and no warning is raised.

        Args:
            attribute (str): The attribute to retrieve from each node.

        Returns:
            List: A list of the requested attribute from each node in the set.

        """
        result = []
        for node in self.input_nodes:
            if attribute in node:
                result.append(node[attribute])
        return result

    def __str__(self):
        """
        Return a user-friendly string representation of the NodeCollectionView object.
        This will be used when calling print().
        """
        node_names = [str(node) for node in self.input_nodes]
        return f"Collection nodes: {', '.join(node_names)}"
    
    def _get_nodes(self):
        """
        Return the nodes in the object.

        Returns:
            List: A list of the nodes in the object.
        """
        return list(self.input_nodes)

class EdgeCollectionView:
    """
    This class acts as a wrapper for a set of edges, allowing to get information about the edges
    matched to the collection pattern edge in the input graph.
    """
    def __init__(self, input_edges):
        """
        Initialize an EdgeCollectionView object.

        Args:
            edges (set): A set of edges.
        """
        self.input_edges = input_edges

    def __getitem__(self, attribute):
        """
        Return a list of the requested attribute from each edge in the set.
        If the attribute is not found in an edge, it is not included in the result, and no warning is raised.

        Args:
            attribute (str): The attribute to retrieve from each edge.

        Returns:
            List: A list of the requested attribute from each edge in the set.
        """
        result = []
        for edge in self.input_edges:
            if attribute in edge:
                result.append(edge[attribute])
        return result

    def __str__(self):
        """
        Return a user-friendly string representation of the EdgeCollectionView object.
        This will be used when calling print().
        """
        edge_names = [str(edge) for edge in self.input_edges]  
        return f"Collection edges: {', '.join(edge_names)}"
    
    def _get_edges(self):
        """
        Return the edges in the object.

        Returns:
            List: A list of the edges in the object.
        """
        return list(self.input_edges)

### The Match Class

A **Match** is a subview of the original graph, limited to the nodes, edges and attributes specified in the pattern. It includes the original graph, dictionaries which map nodes/edges to their corresponding attributes in the pattern, and the mapping from pattern nodes to real ones.

In [10]:
#| export

class Match:
    """
    Represents a single match of a pattern inside an input graph.
    Provides a subview to a graph, limited to the nodes, edges, and attributes specified in the pattern.
    """
    def __init__(self, input_graph: DiGraph, pattern_nodes: List[NodeName], single_nodes: set[NodeName], 
                 pattern_edges: List[EdgeName], mapping: Dict[NodeName, set[NodeName]], 
                 warn_on_collisions: bool = True):
        """
        Initialize the Match object.

        Args:
            input_graph (DiGraph): The input graph.
            pattern_nodes (List[NodeName]): The nodes in the pattern (including anonymous nodes).
            single_nodes (Set[NodeName]): The single nodes in the pattern (including anonymous nodes).
            pattern_edges (List[EdgeName]): The edges in the pattern (including anonymous edges).
            mapping (Dict[NodeName, Set[NodeName]]): The mapping from pattern nodes to input graph nodes.
            filter (bool): Whether to filter out anonymous nodes and edges from the match.
            warn_on_collisions (bool): Whether to warn if there are collisions in the mapping.
        """
        self.graph = input_graph
        self._nodes = pattern_nodes
        self._single_nodes = single_nodes
        self._edges = pattern_edges
        self.mapping = mapping
        if warn_on_collisions:
            self._check_for_collisions()

    def _check_for_collisions(self) -> None:
        """
        Check if there are any collisions in the mapping - i.e., if the same input node is mapped to multiple
        pattern nodes. If there are, print a warning specifying the input nodes that are being mapped to multiple
        pattern nodes.
        """
        input_to_pattern_nodes = defaultdict(set)
        for pattern_node in self._nodes:
            for input_node in self.mapping[pattern_node]:
                input_to_pattern_nodes[input_node].add(pattern_node)
                if len(input_to_pattern_nodes[input_node]) > 1:
                    logger.warning(f"Input node {input_node} is mapped to multiple pattern nodes: {input_to_pattern_nodes[input_node]}") 
        return 
                
    def _check_node_in_pattern(self, pattern_node: NodeName):
        """Ensure that the pattern node exists in the pattern being matched."""
        if pattern_node not in self._nodes:
            raise GraphRewriteException(f"Node {pattern_node} does not exist in the pattern")

    def _check_edge_in_pattern(self, pattern_src: NodeName, pattern_dst: NodeName):
        """Ensure that the pattern edge exists in the pattern being matched."""
        if (pattern_src, pattern_dst) not in self._edges:
            raise GraphRewriteException(f"Edge {(pattern_src, pattern_dst)} does not exist in the pattern")

    def _check_node_in_graph(self, input_node: NodeName):
        """Ensure that the input node exists in the input graph."""
        if input_node not in self.graph.nodes:
            raise GraphRewriteException(f"Node {input_node} does not exist in the input graph")

    def _check_edge_in_graph(self, input_src: NodeName, input_dst: NodeName):
        """Ensure that the input edge exists in the input graph."""
        if (input_src, input_dst) not in self.graph.edges:
            raise GraphRewriteException(f"Edge {(input_src, input_dst)} does not exist in the input graph")

    def is_single(self, pattern_node: NodeName) -> bool:
        """Check if the pattern node requires an single."""
        return pattern_node in self._single_nodes

    def __get_node(self, pattern_node):
        """
        Returns the node or the collection of nodes mapped to the pattern node.

        Args:
            pattern_node: The node in the pattern.

        Returns:
            The node or the collection of nodes mapped to the pattern node.
        """
        if pattern_node == '_':
            raise GraphRewriteException("Anonymous nodes cannot be accessed directly")

        self._check_node_in_pattern(pattern_node)
        input_nodes = list(self.mapping[pattern_node])

        # Ensure all input nodes exist in the graph
        for input_node in input_nodes:
            self._check_node_in_graph(input_node)

        # Return the single matched node if it's a single node
        if self.is_single(pattern_node):
            return self.graph.nodes[input_nodes[0]]

        # Otherwise, return the collection of nodes
        return [self.graph.nodes[input_node] for input_node in input_nodes]

    def __get_edge(self, pattern_src, pattern_dst):
        """
        Returns the edge or the collection of edges mapped to the pattern edge.

        Args:
            pattern_src: The source node in the pattern.
            pattern_dst: The destination node in the pattern.

        Returns:
            The edge (if both source and destination nodes are single nodes) or collection of edges (if collections).
        """
        self._check_edge_in_pattern(pattern_src, pattern_dst)
        input_src_nodes = self.mapping[pattern_src]
        input_dst_nodes = self.mapping[pattern_dst]

        # Return the single edge if both source and destination nodes are single nodes
        if self.is_single(pattern_src) and self.is_single(pattern_dst):
            input_src_node = list(input_src_nodes)[0]
            input_dst_node = list(input_dst_nodes)[0]
            self._check_edge_in_graph(input_src_node, input_dst_node)
            return self.graph.edges[(input_src_node, input_dst_node)]

        input_edges = []
        for (src, dst) in product(input_src_nodes, input_dst_nodes):
            if self.graph.has_edge(src, dst):
                input_edges.append(self.graph.edges[(src, dst)])  # Add matching edge to the collection

        if not input_edges:
            raise GraphRewriteException(f"No edges found between {pattern_src} and {pattern_dst}.")
        return input_edges

    def set_graph(self, graph: DiGraph):
        """Update the graph associated with this match."""
        self.graph = graph

    def remove_anonymous_nodes_and_edges(self):
        """Remove anonymous nodes and edges from the match."""
        for node in list(self._nodes): # Use list() to avoid modifying the set while iterating
            if is_anonymous_node(node):
                self._nodes.remove(node)

        for edge in list(self._edges): # Use list() to avoid modifying the set while iterating
            if is_anonymous_node(edge[0]) or is_anonymous_node(edge[1]):
                self._edges.remove(edge)

        for node in list(self.mapping.keys()):
            if is_anonymous_node(node):
                self.mapping.pop(node)

    def __eq__(self, other):
        """Check if two matches are equal based on their node mappings."""
        if isinstance(other, Match) and len(other.mapping) == len(self.mapping):
            return all(other.mapping.get(k) == v for k, v in self.mapping.items())
        return False

    def __getitem__(self, key: Union[NodeName, str]):
        """
        Access the node or edge (single or collection) from the input graph based on the pattern.

        Args:
            key (Union[NodeName, str]): The pattern node or edge name to access.

        Returns:
            The corresponding node(s) or edge(s).
        """
        try:
            # Handle edges (using the "node1->node2" format)
            if "->" in str(key) and len(str(key).split("->")) == 2:
                src, dst = key.split("->")
                edge_or_edges = self.__get_edge(src, dst)

                if self.is_single(src) and self.is_single(dst):
                    return edge_or_edges
                return EdgeCollectionView(edge_or_edges)  # set of edges

            # Handle nodes
            node_or_nodes = self.__get_node(key)

            if self.is_single(key):
                return node_or_nodes
            return NodeCollectionView(node_or_nodes)  # set of nodes

        except KeyError:
            raise GraphRewriteException(f"The symbol {key} does not exist in the pattern, or it was removed from the graph")

    def __str__(self):
        """Return a string representation of the match's node mapping."""
        return str(self.mapping)
    
    def get_pattern_nodes(self):
        """Return the pattern nodes."""
        return self._nodes
    
    def get_pattern_edges(self):
        """Return the pattern edges."""
        return self._edges

### Create a Match
This function will be used in the Matcher module, in order to convert a mapping from the pattern graph to the input graph into a corresponding instance of the Match class.

In [11]:
#| export
def mapping_to_match(input_graph: DiGraph, single_pattern: DiGraph, collections_pattern: DiGraph, 
                     mapping: Dict[NodeName, Set[NodeName]], warn_on_collisions: bool=True) -> Match:
    """
    Convert a given mapping (which represents a match between the pattern and the input graph) 
    into an instance of the Match class, which provides a subgraph view based on nodes and edges.

    Args:
        input_graph (DiGraph): The input graph where matches are found.
        single_pattern (DiGraph): A pattern graph representing nodes that require an single in the input graph.
        collections_pattern (DiGraph): A pattern graph representing nodes that match multiple input nodes (collections).
        mapping (Dict[NodeName, Set[NodeName]]): The mapping of pattern nodes to corresponding input graph nodes.
        filter (bool): If True, anonymous nodes (those starting with '_') are excluded from the match.

    Returns:
        Match: An instance of the Match class representing the match for the given mapping.
    """

    pattern_nodes = list(mapping.keys())

    single_nodes = set(single_pattern.nodes)

    pattern_edges = list(set(single_pattern.edges) | set(collections_pattern.edges))

    return Match(input_graph, pattern_nodes, single_nodes, pattern_edges, mapping, warn_on_collisions)

### Tests
We take the example graph and pattern described above, add a few attributes to the nodes and edges, and create a Match instance based on the single mapping:

In [12]:
# Define the input graph G with nodes a, b, c, d, e (will be reffereed to as input nodes)
G = _create_graph(
    [('a', {'name': 'a'}), ('b', {'name': 'b'}), ('c', {'name': 'c'}), ('d', {'name': 'd'}), ('e', {'name': 'e'})],
    [('a', 'b', {'edge_attr': 10}), ('a', 'c'), ('d', 'e')]
)

# Define the pattern graph with nodes x, y, z (will be reffereed to as pattern nodes)
pattern = _create_graph(['x', 'y', 'z'], [('x', 'y'), ('x', 'z')])

# The match mapping for the pattern in G
mapping = {'x': {'a'}, 'y': {'b'}, 'z': {'c'}}
single_nodes = ['x', 'y', 'z']  # Set of pattern nodes that need singlees (we did't use collections in this example)

# Create a match object
mapping_match = mapping_to_match(G, pattern, DiGraph(), mapping, single_nodes)

assert mapping_match['x']['name'] == 'a'  # Verify that the pattern node x is mapped to the input node a 
assert mapping_match['y']['name'] == 'b'  # Verify that the pattern node y is mapped to the input node b
assert mapping_match['z']['name'] == 'c'  # Verify that the pattern node z is mapped to the input node c


In graph $G$, each node has a single attribute "name" whose value is the node's name. Therefore, node $a$ has the attribute "name" with the value "a". We can access its attributes using the Match, by accessing the name of the corresponding matched node $x$:

In [13]:
assert mapping_match['x'] == {'name': 'a'}
assert mapping_match['x']['name'] == 'a'
print(mapping_match['x'])

{'name': 'a'}


This indicates that as expected, pattern node $x$ is matched in this Match instance to the input graph node $a$, as that node is the only one in the graph whose "name" attribute is equal to "a".

We can access the edges similarly, by using the format {src}->{dst} for edge $(src, dst)$:

In [14]:
assert mapping_match['x->y'] == {'edge_attr': 10}
print(mapping_match['x->y'])  # Accesses the attributes of the corresponding edge a->b in G

{'edge_attr': 10}


Say we want to add an attribute to $a$, we can change it in the same way using the Match:

In [15]:
mapping_match['x']['attr'] = 5
assert mapping_match['x'] == {'name': 'a', 'attr': 5}
mapping_match['x']

{'name': 'a', 'attr': 5}

And then see that the attributes of $a$ in the original graph $G$ have changed, as the Match is a subview of that graph, and so changes in the subview are reflected in $G$:

In [16]:
assert G.nodes(data=True)['a'] == {'name': 'a', 'attr': 5}
G.nodes(data=True)['a']

{'name': 'a', 'attr': 5}

We can also modify existing attributes in the same way using the Match, and see those changes reflected in $G$

In [17]:
mapping_match['y']['name'] = 'B*'
assert mapping_match['y']['name'] == 'B*'
assert G.nodes(data=True)['b']['name'] == 'B*'
G.nodes(data=True)['b']

{'name': 'B*'}

This also works in the reverse direction: changes in $G$ are reflected in the subview. Say that we set an attribute for edge $(a,b)$ in $G$, then that change would be reflected by accessing the edge 'x->y' in the Match object that refers graph $G$:

In [18]:
G.edges()[('a','b')]['attr2'] = 20
assert mapping_match['x->y']['attr2'] == 20
mapping_match['x->y']

{'edge_attr': 10, 'attr2': 20}

## Visualizing matches

In [19]:
list(G.nodes(data=True)),list(G.edges(data=True))

([('a', {'name': 'a', 'attr': 5}),
  ('b', {'name': 'B*'}),
  ('c', {'name': 'c'}),
  ('d', {'name': 'd'}),
  ('e', {'name': 'e'})],
 [('a', 'b', {'edge_attr': 10, 'attr2': 20}), ('a', 'c', {}), ('d', 'e', {})])

In [20]:
mapping_match.mapping

{'x': {'a'}, 'y': {'b'}, 'z': {'c'}}

In [21]:
#| export
def draw_match(g, m, **kwargs):
    """
    Draw the input graph with the nodes and edges that are part of the match highlighted.
    
    Args:
        g (DiGraph): The input graph to draw.
        m (Match): The match object representing the nodes and edges to highlight.
        **kwargs: Additional keyword arguments to pass to the draw function

    Returns:
        None
    """
    g_copy = g.copy()
    node_styles = {}
    edge_styles = {}

    # Reverse mapping with defaultdict for handling multiple pattern nodes per input node
    reverse_mapping = defaultdict(list)
    for pattern_node, input_nodes in m.mapping.items():
        for input_node in input_nodes:
            reverse_mapping[input_node].append(pattern_node)

    # Handle node styling
    for input_node, pattern_nodes in reverse_mapping.items():
        if len(pattern_nodes) > 1: # Collection node
            for pattern_node in pattern_nodes:
                node_styles[input_node] = 'stroke:red,stroke-width:4px;'
                g_copy.nodes[input_node]['label'] = f"{pattern_node}_{input_node}"
        else: # Single node
            node_styles[input_node] = 'stroke:blue,stroke-width:4px;'
            g_copy.nodes[input_node]['label'] = pattern_nodes[0]

    for u, v in m._edges:
        edge_styles.update({(input_u, input_v): 'stroke:blue,stroke-width:4px;'
                            for input_u, input_v in product(m.mapping.get(u, []), m.mapping.get(v, []))})

    draw(g_copy, node_styles=node_styles, edge_styles=edge_styles, **kwargs)


In [22]:
draw_match(G,mapping_match)

Collections Feature Tests 

Test #1:

In [23]:
# Define the input graph G: a -> b, b -> c, b -> d
G = _create_graph(
    [('a', {'name': 'a'}), ('b', {'name': 'b'}), ('c', {'name': 'c'}), ('d', {'name': 'd'})],
    [('a', 'b', {'edge_attr': 5}), ('b', 'c'), ('b', 'd')]
)

# Define the single pattern: x -> y
single_pattern = _create_graph(['x', 'y'], [('x', 'y')])  # 'x' -> 'y' single

# Define the collection pattern: y -> z (collection where y points to multiple nodes)
collections_pattern = _create_graph(['y', 'z'], [('y', 'z')])  # 'y' matches multiple nodes

# In this case:
# - 'x' matches node 'a' (single),
# - 'y' matches node 'b' (single and part of a collection),
# - 'z' matches nodes 'c' and 'd' (collection).
mapping = {'x': {'a'}, 'y': {'b'}, 'z': {'c', 'd'}}
single_nodes = {'x'}  # Only 'x' is an single, 'y' and 'z' form a collection

# Create the match object
collection_match = mapping_to_match(G, single_pattern, collections_pattern, mapping, single_nodes)

In [24]:
# Print the full mapping to see if 'z' is correctly mapped to 'c' and 'd'
print("Full mapping: ", collection_match.mapping)

Full mapping:  {'x': {'a'}, 'y': {'b'}, 'z': {'c', 'd'}}


In [25]:
# Test that 'x' is singlely matched to node 'a'
print(collection_match['x']) 
assert collection_match['x']['name'] == 'a'
print(collection_match['x']['name'])  # Should print 'a'

{'name': 'a'}
a


In [26]:
# Test that 'y' is part of the collection and matched to node 'b'
print(collection_match['y'])
assert collection_match['y']['name'] == 'b'
print(collection_match['y']['name'])  # Should print 'b'

{'name': 'b'}
b


In [27]:
# Test that 'z' is part of the collection and matched to nodes 'c' and 'd'
print(collection_match['z'])
assert set(collection_match['z']['name']) == {'c', 'd'}  # Check using set comparison to account for order
print(collection_match['z']['name'])  # Should print ['c', 'd'] in any order

Collection nodes: {'name': 'c'}, {'name': 'd'}
['c', 'd']


Test #2:

In [28]:
# Define the input graph: a -> b -> c, a -> d, b -> d
G = _create_graph(
    [('a', {'name': 'a'}), ('b', {'name': 'b'}), ('c', {'name': 'c'}), ('d', {'name': 'd'})],
    [('a','b'), ('a','d'), ('b','d', {'name': '(b,d)'}), ('b','c', {'name': '(b,c)'})]
)

In [29]:
# Define the single pattern: x -> y
single_pattern = _create_graph(['x', 'y'], [('x', 'y')])

# Define the collection pattern: y -> z (where z forms multiple edges)
collections_pattern = _create_graph(['y', 'z'], [('y', 'z')])

# In this case:
# - 'x' matches node 'a' (single),
# - 'y' matches node 'b' (single),
# - 'z' matches nodes 'c' and 'd' (collection of edges).
mapping = {'x': {'a'}, 'y': {'b'}, 'z': {'c', 'd'}}
single_nodes = {'x', 'y'}  # 'x' and 'y' are singlees, 'z' forms a collection of edges

# Create the match object
edge_collection_match = mapping_to_match(G, single_pattern, collections_pattern, mapping, single_nodes)

In [30]:
# Test that 'x' matches node 'a'
assert edge_collection_match['x']['name'] == 'a'
print(edge_collection_match['x']['name'])  # Should print 'a'

a


In [31]:
# Test that 'y' matches node 'b'
assert edge_collection_match['y']['name'] == 'b'
print(edge_collection_match['y']['name'])  # Should print 'b'

b


In [32]:
# Test that 'z' matches nodes 'c' and 'd' as a collection of edges
print(edge_collection_match['y->z']) 

Collection edges: {'name': '(b,c)'}, {'name': '(b,d)'}


In [33]:
#This is true if we don't have any attribute for the edge, and we decide not to throw an exception
assert set(edge_collection_match['y->z']['name']) == {'(b,c)', '(b,d)'}  # Check using set comparison to account for order
print(edge_collection_match['y->z']['name'])  # Should print ['(b,c)', '(b,d)'] in any order

['(b,c)', '(b,d)']


# Export

In [34]:
#|hide
import nbdev; nbdev.nbdev_export()