# Result Set

In [12]:
#| default_exp result_set
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Overview
The Matcher finds a list of mappings from LHS-pattern nodes' names to actual nodes' names. Each mapping represents a single match. 

Next, we convert this list of mappings into a **ResultSet**, a subview of the original graph, which allows users to get and set the attributes of the nodes and edges included in each match. It can be used also in order to filter matches and specify RHSs per match. This ResultSet is eventually returned to the user when rewrite is done, in order to allow imperative side effect as well as imperative changes to the graph.

### Requirements

In [13]:
#| export
import networkx as nx
from networkx import DiGraph
from typing import *

### Utils
Edge $(u,v)$ is represented in the ResultSet with the name "u->v".

In [14]:
#| export
def convert_to_edge_name(src: Hashable, dest: Hashable) -> str:
    return f"{src}->{dest}"

A node is anonymous if its name begins with '$'.

In [15]:
#| export
def is_anonymous_node(node_name: Hashable) -> bool:
    return len(node_name) >= 1 and node_name[0] == '$'

### Type Definitions

#### Match
A **Match** is a subview of the original graph, limited to the nodes, edges and attributes specified in the pattern. It includes the original graph, dictionaries which map nodes/edges to their related attributes in the pattern, and the mapping from pattern nodes to real ones.

In [16]:
#| export
class Match:
    def __init__(self, graph: DiGraph, nodes: List[Hashable], edges: List[Tuple[Hashable, Hashable]], mapping: Dict[str, Hashable]):
        self.graph: DiGraph = graph
        self.__nodes: List[Hashable] = nodes
        self.__edges: List[Tuple[Hashable, Hashable]] = edges
        self.mapping: Dict[str, Hashable] = mapping

    class DoesNotExist(Exception):
        pass

    def __get_node(self, pattern_node):
        return self.graph.nodes[self.mapping[pattern_node]]

    def __get_edge(self, pattern_src, pattern_dst):
        if (pattern_src, pattern_dst) not in self.__edges:
            raise self.DoesNotExist
        return self.graph.edges[self.mapping[pattern_src], self.mapping[pattern_dst]]

    def get_nodes(self):
        return {pattern_node: self.__get_node(pattern_node) for pattern_node in self.__nodes}

    def get_edges(self):
        return {convert_to_edge_name(pattern_src, pattern_dest): self.__get_edge(pattern_src, pattern_dest) for (pattern_src, pattern_dest) in self.__edges}

    def __getitem__(self, key: Hashable):
        try:
            if str(key).__contains__("->") and len(str(key).split("->")) == 2:
                end_nodes = str(key).split("->")
                return self.__get_edge(end_nodes[0], end_nodes[1])
            else:
                return self.__get_node(key)
        except:
            raise self.DoesNotExist

#### ResultSet
A **ResultSet** contains all matches.

In [17]:
#| export
class ResultSet:
    def __init__(self, matches: List[Match]):
        self.__matches: List[Match] = matches

    def __getitem__(self, index: int):
        return self.__matches[index]
    
    def __len__(self):
        return len(self.__matches)

### Convert list of mappings to a ResultSet

In [18]:
#| export
def mapping_to_match(input: DiGraph, pattern: DiGraph, mapping: Dict[str, Hashable]) -> Match:
    nodes_list, edges_list = [], []
    cleared_mapping = mapping.copy()

    for pattern_node in mapping.keys():
        if is_anonymous_node(pattern_node):
            cleared_mapping.pop(pattern_node)
            continue # as we don't want to include this node in the Match
        nodes_list.append(pattern_node)

    for (n1, n2) in pattern.edges:
        if is_anonymous_node(n1) or is_anonymous_node(n2):
            continue # as before
        edges_list.append((n1, n2))

    return Match(input, nodes_list, edges_list, cleared_mapping)

In [19]:
#| export
def mappings_to_results_set(input: DiGraph, pattern: DiGraph, mappings: List[Dict[str, Hashable]]) -> ResultSet:
    matches = []

    for mapping in mappings:
      match = mapping_to_match(input, pattern, mapping)
      matches.append(match)

    return ResultSet(matches)

### Tests

#### Test Utils

In [20]:
from graph_rewrite.matcher import find_matches

def create_graph(nodes, edges):
    g = DiGraph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)
    return g

def get_result_set(input, pattern):
    return mappings_to_results_set(input, pattern, find_matches(input, pattern))

#### Test Cases

In [21]:
input = create_graph(
    ['A','B','C','D'], 
    [
        ('A','B'),
        ('A','C'),
        ('A', 'A'),
        ('C', 'C'),
        ('A', 'C')
    ]
)

pattern = create_graph(['X'], [('X', 'X')])
rs = get_result_set(input, pattern)
rs[0]['X']['hello'] = 5
assert input.nodes['A']['hello'] == 5
rs[1]['X->X']['attr'] = "hello my attr"
assert 'attr' in input.edges['C','C']
assert input.edges['C','C']['attr'] == "hello my attr"
assert 'attr' not in input.edges['A','A']