# Condition Filter

In [None]:
#| default_exp condition_filter
%load_ext autoreload
%autoreload 2

### Overview
Thus far, the user searched for matches in the input graph that would match one of the LHS patterns he passed as parameter, and the matches found were converted to a ResultSet. 

Before these matches are passed to the transformation phase of rewrite, the user can filter matches based on a **Condition** function, which receives a Match instance, checks its attributes and returns a boolean value. 
If the condition function returned False for some pair, this match is then filtered by rewrite and will not be transformed in the next phase.

### Requirements

In [None]:
#| export
from graph_rewrite.result_set import Match, ResultSet
from typing import *

### Filter Matches based on a Condition function
Receives a ResultSet object and a Condition function, and returns a filtered ResultSet.

In [None]:
#| export
def filter_matches(unfiltered_results: ResultSet, condition: Callable[[Match], bool]) -> ResultSet:
    return ResultSet([match for match in unfiltered_results if condition(match)])

### Tests

#### Test Utils

In [None]:
from networkx import DiGraph
from graph_rewrite.matcher import find_matches
from graph_rewrite.result_set import mappings_to_results_set

def create_graph(nodes, edges):
    g = DiGraph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)
    return g

def get_result_set(input, pattern):
    return mappings_to_results_set(input, pattern, find_matches(input, pattern))

#### Test Cases

In [None]:
input = create_graph(
    [('A', {'attr': 5}),'B',('C', {'attr': 10}),'D'],
    [
        ('A','B'),
        ('A','C'),
        ('A', 'A'),
        ('C', 'C'),
        ('A', 'C')
    ]
)

pattern = create_graph(['X'], [('X', 'X')])
rs = get_result_set(input, pattern)
rs_filtered = filter_matches(rs, lambda match: match['X']['attr'] > 5)

assert rs_filtered[0]['X']['attr'] == 10
rs_index = 0 if rs[0]['X']['attr'] == 10 else 1
rs_filtered[0]['X']['attr'] = 0

assert rs_filtered[0]['X']['attr'] == 0
assert rs[rs_index]['X']['attr'] == 0
assert rs[1 - rs_index]['X']['attr'] == 5
assert len(filter_matches(rs, lambda match: match['X']['attr'] > 5)) == 0