# Engine
> Execution spannerlog commands

In [None]:
#| default_exp engine

In [None]:
#| hide
from nbdev.showdoc import show_doc
from IPython.display import display, HTML
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from abc import ABC, abstractmethod
import pytest
from collections import defaultdict
from spannerflow.engine import Engine as SpannerflowEngine
from spannerflow.span import Span
from numbers import Real
import pandas as pd
from pathlib import Path
from typing import no_type_check, Set, Sequence, Any,Optional,List,Callable,Dict,Union
from pydantic import BaseModel
import networkx as nx
import itertools
import logging
logger = logging.getLogger(__name__)

from graph_rewrite import draw, draw_match, rewrite, rewrite_iter
from spannerlib.utils import (
    serialize_graph,
    assert_df_equals,
    checkLogs,
    get_new_node_name
    )


from spannerlib.data_types import (
    Var, 
    FreeVar, 
    RelationDefinition, 
    Relation, 
    IEFunction,
    AGGFunction,
    IERelation, 
    Rule, 
    pretty
)
from spannerlib.ra import (
    _col_names,
    get_const,
    select,
    project,
    rename,
    union,
    intersection,
    difference,
    join,
    product,
    groupby,
    ie_map,
    merge_rows
)

from spannerlib.term_graph import graph_compose, merge_term_graphs_pair,rule_to_graph,add_relation,add_project_uniq_free_vars



## Helper functions

In [None]:
#| export
def _pd_drop_row(df,row_vals):
    new_df = df[(df!=row_vals).all(axis=1)]
    return new_df



In [None]:
df = pd.DataFrame([
    [1,'2fs'],[3,4]
])
assert list(_pd_drop_row(df,[3,4]).itertuples(index=False,name=None))==[(1,'2fs')]

In [None]:
#| export
class DB(dict):
    def __repr__(self):
        key_str=', '.join(self.keys())
        return f'DB({key_str})'

## Engine Class

In [None]:
#| export
from copy import deepcopy
from time import sleep
import atexit
import os

import numpy as np

class Engine():
    def __init__(self, rewrites=None):
        if rewrites is None:
            self.rewrites = []
        self.symbol_table={
            # key : type,val
        }
        self.Relation_defs={
            # key : RelationDefinition for both real and derived relations
        }
        self.ie_functions={
            # name : IEFunction class
        }

        self.agg_functions={
        }

        self.term_graph = nx.DiGraph()
        
        self.node_counter = itertools.count()
        self.rule_counter = itertools.count()

        self.db = DB(
            # relation_name: dataframe
        )
        self.collections = set()
        self.rules = set()
        # lets skip this for now and keep it a an attribute in the node graph
        self.rules_to_ids = {
            # rule pretty string: ( node id in term_graph, head_name)
        }
        self.head_to_rules = defaultdict(set)
        # head relation name to rule pretty string

        # self.rels_to_nodes() = {
        #     # relation name to node that represents it
        # }
        self.spannerflow_engine = SpannerflowEngine()
        atexit.register(self.spannerflow_engine.close)


    def set_var(self,var_name,value,read_from_file=False):
        symbol_table = self.symbol_table
        if var_name in symbol_table:
            existing_type,existing_value = symbol_table[var_name]
            if type(value) != existing_type:
                raise ValueError(f"Variable {var_name} was previously defined with {existing_value}({pretty(existing_type)})"
                                f" but is trying to be redefined to {value}({pretty(type(value))}) of a different type which might interfere with previous rule definitions")    
        symbol_table[var_name] = type(value),value
        return
    def get_var(self,var_name):
        return self.symbol_table.get(var_name,None)
    
    def del_var(self,var_name):
        del self.symbol_table[var_name]

    def get_relation(self,rel_name:str)-> RelationDefinition:
        return self.Relation_defs.get(rel_name,None)

    def set_relation(self,rel_def:RelationDefinition, rule=False):
        if rel_def.name in self.spannerflow_engine.get_collections():
            existing_def = self.Relation_defs[rel_def.name]
            if existing_def != rel_def:
                raise ValueError(f"Relation {rel_def.name} was previously defined with {existing_def}"
                                f"but is trying to be redefined to {rel_def} which might interfere with previous rule definitions")
            elif not rule:
                raise ValueError(f"Relation {rel_def.name} was previously defined")     
        SPANNER_LIB_TO_SPANNER_FLOW_TYPES_DICT = {
            str: "DATA_TYPE_STRING",
            int: "DATA_TYPE_INT",
            float: "DATA_TYPE_FLOAT",
            bool: "DATA_TYPE_BOOL",
            Span: "DATA_TYPE_SPAN",
            Real: "DATA_TYPE_FLOAT",
            np.int64: "DATA_TYPE_INT64",
        }

        spannerflow_schema = []
        for col_type in rel_def.scheme:
            if col_type not in SPANNER_LIB_TO_SPANNER_FLOW_TYPES_DICT:
                raise ValueError(f"Type {col_type} not supported by spannerflow")
            spannerflow_schema.append(SPANNER_LIB_TO_SPANNER_FLOW_TYPES_DICT[col_type])
        
        self.term_graph.add_node(rel_def.name, rel=rel_def.name, rule_id={'fact'})
        self.Relation_defs[rel_def.name] = rel_def
        if not rule:
            self.collections.add(rel_def.name)
            self.spannerflow_engine.add_collection(rel_def.name, spannerflow_schema)
        else:
            self.rules.add(rel_def.name)
        
    def del_relation(self,rel_name:str):
        if rel_name not in self.spannerflow_engine.get_collections():
            raise ValueError(f"Relation {rel_name.name} is not defined")
        
        if rel_name in self.Relation_defs:
            self.Relation_defs.pop(rel_name)
            if rel_name in self.rules:
                self.rules.remove(rel_name)
            elif rel_name in self.collections:
                self.collections.remove(rel_name)

        self.spannerflow_engine.delete_collection(rel_name)

    def add_fact(self,fact:Relation):
        self.spannerflow_engine.add_row(fact.name, fact.terms)
        
    def add_facts(self,rel_name,facts:pd.DataFrame):
        self.spannerflow_engine.add_rows(rel_name, facts.values.tolist())
        
    def load_csv(self, rel_name:str , path: str|Path, delim: str = ',', has_header: bool = False):
        if not os.path.exists(path):
            raise ValueError(f"Path {path} does not exist")
        self.spannerflow_engine.load_from_csv(rel_name, path, delim, has_header)
        
    def del_fact(self,fact:Relation):
        self.spannerflow_engine.delete_row(fact.name, fact.terms)
        # self.db[fact.name] = _pd_drop_row(df = self.db[fact.name],row_vals=fact.terms)

    def get_ie_function(self,name:str):
        return self.ie_functions.get(name,None)

    def set_ie_function(self,ie_func:IEFunction):
        self.ie_functions[ie_func.name]=ie_func
        self.spannerflow_engine.set_ie_function(ie_func.name, ie_func.func, ie_func.in_schema, ie_func.out_schema)

    def del_ie_function(self,name:str):
        del self.ie_functions[name]

    def get_agg_function(self,name:str):
        return self.agg_functions.get(name,None)
    
    def set_agg_function(self,agg_func:AGGFunction):
        self.agg_functions[agg_func.name]=agg_func
        self.spannerflow_engine.set_agg_function(agg_func.name, agg_func.func, agg_func.in_schema, agg_func.out_schema)
    
    def del_agg_function(self,name:str):
        del self.agg_functions[name]

    def add_rule(self,rule:Rule,schema:RelationDefinition=None):
        if not self.get_relation(rule.head.name) and schema is None:
            raise ValueError(f"Relation {rule.head.name} not defined before adding the rule with it's head\n"
                             f"And an relation schema was not supplied."
                             f"existing relations are {self.Relation_defs.keys()}")
        # if already defined, do nothing.
        if pretty(rule) in self.rules_to_ids:
            return

        if not schema is None:
            self.set_relation(schema, rule=True)

        rule_id = next(self.rule_counter)

        self.rules_to_ids[pretty(rule)] = rule_id,rule.head.name
        self.head_to_rules[rule.head.name].add(pretty(rule))

        g2 = rule_to_graph(rule,rule_id)

        merge_term_graph = merge_term_graphs_pair(self.term_graph,g2)
        self.term_graph = merge_term_graph
        

    def del_rule(self,rule_str:str):
        #TODO here we need to save rules by their head and when removing the last rule of a head, remove its definition from db as well
        if not rule_str in self.rules_to_ids:
            raise ValueError(f"Rule {rule_str} does not exist\n"
                             f"existing rules are {self.rules_to_ids.keys()}")
        rule_id,rule_head = self.rules_to_ids[rule_str]
        self.rules_to_ids.pop(rule_str)
        self.head_to_rules[rule_head].remove(rule_str)

        g = self.term_graph

        # if the head has no more rules, remove it from the relation defs and the term graph
        if len(self.head_to_rules[rule_head])==0:
            self.Relation_defs.pop(rule_head)
            g.remove_node(rule_head)

        nodes_to_delete=[]
        for u in g.nodes:
            node_rule_ids = g.nodes[u].get('rule_id',set())
            if rule_id in node_rule_ids:
                node_rule_ids.remove(rule_id)
                if len(node_rule_ids) == 0:
                    nodes_to_delete.append(u)
        g.remove_nodes_from(nodes_to_delete)
            
        return

    def del_head(self,head_name:str):
        """Deletes all rules whose head is head_name
        """
        rules_to_delete = self.head_to_rules[head_name].copy()
        for rule_str in rules_to_delete:
            self.del_rule(rule_str)

    def _inline_db_and_ies_in_graph(self,g:nx.DiGraph):
        g=deepcopy(g)
        for u in g.nodes:
            if g.out_degree(u)==0 and 'rel' in g.nodes[u]:
                g.nodes[u]['op'] = 'get_rel'
                g.nodes[u]['db'] = self.db
                g.nodes[u]['schema'] = _col_names(len(self.Relation_defs[g.nodes[u]['rel']].scheme))
            elif g.nodes[u]['op'] == 'ie_map':
                ie_func_name = g.nodes[u]['func']
                ie_definition = self.ie_functions[ie_func_name]
                g.nodes[u]['func'] = ie_definition.func
                g.nodes[u]['name'] = ie_definition.name
                g.nodes[u]['in_schema'] = ie_definition.in_schema
                g.nodes[u]['out_schema'] = ie_definition.out_schema
            elif g.nodes[u]['op'] == 'groupby':
                aggregate_func_names = g.nodes[u]['agg']
                aggregate_funcs = [self.agg_functions[name].func if name is not None else None for name in aggregate_func_names]
                g.nodes[u]['agg'] = aggregate_funcs
        return g


    def plan_query(self,q_rel:Relation,rewrites=None):
        if rewrites is None:
            rewrites = self.rewrites
        query_graph = self._inline_db_and_ies_in_graph(self.term_graph)

        # get the sub term graph induced by the relation head
        root_node = q_rel.name
        connected_nodes = list(nx.shortest_path(query_graph,root_node).keys())
        query_graph = nx.DiGraph(nx.subgraph(query_graph,connected_nodes))
        
        # add selects renames etc based on the query relation
        root_node,_ = add_relation(query_graph,name='query',terms=q_rel.terms,source=root_node)

        # TODO for all rewrites, run them
        return query_graph,root_node

    def execute_plan(self,query_graph,root_node,return_intermediate=False, output_csv_path: Path| str | None = None):
        if isinstance(output_csv_path, Path):
            output_csv_path = str(output_csv_path.resolve())
        res =  self.spannerflow_engine.run_dataflow(nx.reverse(query_graph), output_csv_path=output_csv_path)
        return pd.DataFrame(columns=query_graph.nodes[root_node]['schema'], data=res)

    def run_query(self,q:Relation,rewrites=None,return_intermediate=False):
        query_graph, root_node = self.plan_query(q,rewrites)
        return self.execute_plan(query_graph,root_node,return_intermediate)


### Test

In [None]:
s = pd.DataFrame([
    [1,1],
    [2,2],
    [3,3],
    [4,5]
])

s2 = pd.DataFrame([
    [1,2,3],
    [2,3,4],
    [3,4,5],
    [4,5,6]
])

In [None]:
r1 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),3]),
    ])

r2 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        IERelation(name='T', in_terms=[FreeVar(name='X'),FreeVar(name='Y')], out_terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r3 = Rule(
    head=Relation(name='R2', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S3', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),1]),
    ])


rec_r1 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='B', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

rec_r2 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='C', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

rec_r3 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='D', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

rec_r4 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='A', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])


In [None]:
e = Engine()
e.set_relation(RelationDefinition(name='S', scheme=[int,int]))
e.set_relation(RelationDefinition(name='S2', scheme=[int,int,int]))
e.set_relation(RelationDefinition(name='S3', scheme=[int,int]))

e.add_rule(r1,RelationDefinition(name='R', scheme=[int,int]))
e.add_rule(r2,RelationDefinition(name='R', scheme=[int,int]))
e.add_rule(r3,RelationDefinition(name='R2', scheme=[int,int]))

In [None]:
e.add_facts('S',s)
e.add_facts('S2',s2)

In [None]:
#e.db['S']

In [None]:
#e.db['S2']

In [None]:
draw(e.term_graph)

In [None]:
e.rules_to_ids

In [None]:
e.del_rule(pretty(r3))

In [None]:
draw(e.term_graph)

In [None]:
# assert serialize_graph(e.term_graph) ==([('S', {'rel': 'S', 'rule_id': {0, 1}}),
#   ('S2', {'rel': 'S2', 'rule_id': {0}}),
#   ('R', {'rel': 'R', 'op': 'union', 'rule_id': {0, 1}}),
#   (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {0}}),
#   (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {0}}),
#   (2, {'op': 'select', 'theta': [(2, 3)], 'rule_id': {0}}),
#   (3, {'op': 'join', 'rule_id': {0}}),
#   (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0', 'rule_id': {0}}),
#   (8, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
#   (9, {'op': 'project', 'on': ['X', 'Y'], 'rule_id': {1}}),
#   (10, {'op': 'calc', 'func': 'T', 'rule_id': {1}}),
#   (11, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
#   (6, {'op': 'join', 'rule_id': {1}}),
#   (7, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_1', 'rule_id': {1}})],
#  [('R', 4, {}),
#   ('R', 7, {}),
#   (0, 'S', {}),
#   (1, 'S2', {}),
#   (2, 1, {}),
#   (3, 0, {}),
#   (3, 2, {}),
#   (4, 3, {}),
#   (8, 'S', {}),
#   (9, 8, {}),
#   (10, 9, {}),
#   (11, 10, {}),
#   (6, 8, {}),
#   (6, 11, {}),
#   (7, 6, {})])

In [None]:
e.add_rule(rec_r1,RelationDefinition(name='A', scheme=[int,int]))
e.add_rule(rec_r2,RelationDefinition(name='A', scheme=[int,int]))
e.add_rule(rec_r3,RelationDefinition(name='B', scheme=[int,int]))
e.add_rule(rec_r4,RelationDefinition(name='B', scheme=[int,int]))


In [None]:
draw(e.term_graph)

In [None]:
e.del_head('R')

In [None]:
assert e.rules_to_ids == {'A(X,Y) <- B(X,Y).': (3, 'A'),
 'A(X,Y) <- C(X,Y).': (4, 'A'),
 'B(X,Y) <- D(X,Y).': (5, 'B'),
 'B(X,Y) <- A(X,Y).': (6, 'B')}

In [None]:
draw(e.term_graph)

## Naive execution

A recursive least fixed point logic algorithm mimicing the bottom up evalutation.

In [None]:
#| export

def get_rel(rel,db,**kwargs):
    # helper function to get the relation from the db for external relations
    return db[rel]

op_to_func = {
    'union':union,
    'intersection':intersection,
    'difference':difference,
    'select':select,
    'project':project,
    'rename':rename,
    'join':join,
    'ie_map':ie_map,
    'get_rel':get_rel,
    'get_const':get_const,
    'product':product,
    'groupby':groupby
}

In [None]:
#| export
def _in_cycle(g):
    return list(set(
        itertools.chain.from_iterable(nx.cycles.simple_cycles(g))
    ))

def _depends_on_cycle(g):
    in_cycle_nodes = _in_cycle(g)
    depends_on_cycle = {
        node for node in g.nodes if node in in_cycle_nodes or 
        len(set(nx.descendants(g,node)).intersection(in_cycle_nodes))>0
    }
    return depends_on_cycle

In [None]:
#| export
def _collect_children_and_run(G,u,results,stack,log=False):
    children = list(G.successors(u))
    u_data = G.nodes[u]

    children_results = [results[v][-1] for v in children]
    op_func = op_to_func[u_data['op']]

    if log:
        logger.debug(f"computing node {u} with children {children} and data {u_data} , stack = {stack}")
        logger.debug(f"children results are {children_results}")
        logger.debug(f"children_data is {[G.nodes[v] for v in children]}")
    try:
        res = op_func(*children_results,**u_data)
    except Exception as e:
        raise Exception(f'During excution of node {u} with args {children_results} and kwargs {u_data}'
                        f' got error {e}'
        )
    if log:
        logger.debug(f"result of node {u} is {res}")
    results[u].append(res)
    return res


In [None]:
#| export
def compute_acyclic_node(G,u,results,stack=None):
    res = _collect_children_and_run(G,u,results,[])
    logger.debug(f"setting {u} to final since it is acyclic\n")
    G.nodes[u]['final'] = True
    return res

def compute_recursive_node(G,u,results,stack=None):

    if stack is None:
        stack = []

    children = list(G.successors(u))
    u_data = G.nodes[u]
    op_func = op_to_func[u_data['op']]

    if u_data.get('final',False):
        return results[u][-1]    


    logger.debug(f"computing node {u} with stack {stack}")


    went_in_a_cycle = u in stack
    if went_in_a_cycle:
        logger.debug(f"went in a cycle at {u}, computing op with empty children if necessary\n")
        # for each child that doesnt have data, put an empty df instead of it
        res = _collect_children_and_run(G,u,results,stack,log=True)
        return res

    # if we are here we are in a cycle but didnt return to an old position yet
    # then we compute all our children first
    for v in children:
        stack.append(u)
        compute_recursive_node(G,v,results,stack)
        stack.pop()

    # compute and mark as final if reached fixed point
    res = _collect_children_and_run(G,u,results,stack,log=True)

    all_children_final = all(G.nodes[v].get('final',False) for v in children)
    fixed_point_reached = len(results[u])>1 and results[u][-1].equals(results[u][-2])

    if all_children_final:
        logger.debug(f"setting {u} to final since all children are final\n")
        G.nodes[u]['final'] = True
    elif fixed_point_reached:
        logger.debug(f"setting {u} to final since fixed point has been achieved\n")
        # if u==9:
        #     logger.debug(f"graph nodes are{g.nodes(data=True)}")
        G.nodes[u]['final'] = True
    else:
        logger.debug(f"{u} not final yet so we will need to run another iteration\n")

    return res



def compute_node(G,root,ret_inter=False):

    # makes sure there is always a last value in the list for each key
    # which is None
    list_with_none_factory = lambda : [None]
    results_dict = defaultdict(list_with_none_factory)

    depends_on_cycle = _depends_on_cycle(G)
    not_depends_on_cycle = [u for u in G.nodes if u not in depends_on_cycle]

    # compute non cyclic nodes in postorder
    non_cycle_topological_sort = list(nx.topological_sort(nx.DiGraph(nx.subgraph(G,not_depends_on_cycle))))
    for u in non_cycle_topological_sort[::-1]:
        compute_acyclic_node(G,u,results_dict)

    logger.debug(f"the following nodes were computed non cyclically {non_cycle_topological_sort}")
    # now that all initial conditions for recursions are set
    # run the compute_recursive_node on u
    logger.debug(f"running compute_recursive_node on {root}")

    while True:
        res = compute_recursive_node(G,root,results_dict)
        if G.nodes[root].get('final',False):
            break

    if ret_inter:
        return res,results_dict
    else:
        return res


#### Test - path query

In [None]:
graph  = nx.DiGraph()
graph.add_nodes_from([
    0,1,2,3,
])
graph.add_edges_from(
    [(0,1),(0,2),(1,3),(2,3),(3,4)]
)
draw(graph)
edges_df = pd.DataFrame(list(graph.edges),columns=['S','T'])
edges_df
db = DB({
    'edges':edges_df
})

In [None]:
expected_paths = pd.DataFrame(
    [
        [0,1],
        [0,2],
        [1,3],
        [2,3],
        [3,4],
        [0,3],
        [0,4],
        [1,4],
        [2,4]
    ],
    columns=['S','T']
)

In [None]:
g = nx.DiGraph()
g.add_nodes_from([
    ('edges',{'rel':'edges','op':'get_rel','db':db}),
    (1,{'op':'rename','schema':['S','T']}),
    (2,{'op':'rename','schema':['S','X']}),
    (3,{'op':'rename','schema':['X','T']}),
    (4,{'op':'join','schema':['S','X','T']}),
    (5,{'op':'project','schema':['S','T']}),
    ('reachable',{'op':'union','schema':[0,1]}),
    (6,{'op':'rename','schema':['S','T']})]
)
g.add_edges_from([
    (1,'edges'),
    (2,'edges'),
    (4,2),
    (4,3),
    (5,4),
    ('reachable',5),
    ('reachable',1),
    (3,'reachable'),
    (6,'reachable')
])
draw(g)

In [None]:
root = 6
# with checkLogs():
res,inter = compute_node(g,root,True)
assert_df_equals(res,expected_paths)

## e2e tests

#### Case 0 - path queries

In [None]:
e=Engine()
e.set_relation(RelationDefinition(name='edges',scheme=[int,int]))
e.add_facts('edges',edges_df)

base_rule = Rule(
    head=Relation(name='reachable',terms=[FreeVar(name='S'),FreeVar(name='T')]),
    body=[
        Relation(name='edges',terms=[FreeVar(name='S'),FreeVar(name='T')]),
    ])

rec_rule = Rule(
    head=Relation(name='reachable',terms=[FreeVar(name='S'),FreeVar(name='T')]),
    body=[
        Relation(name='edges',terms=[FreeVar(name='S'),FreeVar(name='X')]),
        Relation(name='reachable',terms=[FreeVar(name='X'),FreeVar(name='T')]),
    ])

e.add_rule(base_rule,RelationDefinition(name='reachable',scheme=[int,int]))
e.add_rule(rec_rule,RelationDefinition(name='reachable',scheme=[int,int]))

print(e.rules_to_ids)


In [None]:
list(e.run_query(Relation(name='reachable',terms=[FreeVar(name='S'),FreeVar(name='T')])))

In [None]:
q,r = e.plan_query(Relation(name='reachable',terms=[FreeVar(name='S'),FreeVar(name='T')]))
draw(q)
# with checkLogs():
res,inter = e.run_query(Relation(name='reachable',terms=[FreeVar(name='S'),FreeVar(name='T')]),return_intermediate=True)

assert_df_equals(pd.Dataframe(columns=("S", "T"), data=res),expected_paths)

In [None]:
# make sure we actually got the intermediate results
assert len(inter)!=0

#### Case1

In [None]:
e = Engine()
e.set_relation(RelationDefinition(name='S', scheme=[int,int]))
e.set_relation(RelationDefinition(name='S2', scheme=[int,int,int]))

e.add_rule(r1,RelationDefinition(name='R', scheme=[int,int]))
e.add_rule(r2,RelationDefinition(name='R', scheme=[int,int]))

e.add_facts('S',s)
e.add_facts('S2',s2)

def func(x,y):
    return [(y,x)]

ie_def = IEFunction(name='T',func=func,in_schema=[int,int],out_schema=[int,int])

e.set_ie_function(ie_def)
g = e._inline_db_and_ies_in_graph(e.term_graph)
print(e.rules_to_ids)
display(s)
display(s2)
# draw(g)

In [None]:
res = pd.DataFrame(columns=("X", "Y"), data=e.run_query(Relation(name='R',terms=[FreeVar(name='X'),FreeVar(name='Y')])))


In [None]:
assert_df_equals(res,pd.DataFrame([
    [1,1],
    [2,2],
    [3,3]
],columns=['X','Y']))
res

In [None]:
res = pd.DataFrame(columns=("S",), data=e.run_query(Relation(name='R',terms=[FreeVar(name='S'),3])))
assert_df_equals(res,pd.DataFrame([
    [3]
],columns=['S']))
res


#### case 2

In [None]:
e2 = Engine()
e2.set_relation(RelationDefinition(name='C', scheme=[int,int]))
e2.set_relation(RelationDefinition(name='D', scheme=[int,int]))

for rule in [rec_r1,rec_r2,rec_r3,rec_r4]:
    e2.add_rule(rule,RelationDefinition(name=rule.head.name, scheme=[int,int]))

e2.add_fact(Relation(name='C',terms=[1,2]))
e2.add_fact(Relation(name='D',terms=[3,4]))

g2 = e2._inline_db_and_ies_in_graph(e2.term_graph)
e2.rules_to_ids

In [None]:
# with checkLogs():
res = pd.DataFrame(columns=("X", "Y"), data=e2.run_query(Relation(name='A',terms=[FreeVar(name='X'),FreeVar(name='Y')])))
assert_df_equals(res,pd.DataFrame([
    [3,4],
    [1,2]
],columns=['X','Y']))

In [None]:
res = pd.DataFrame(columns=("X", "Y"), data=e2.run_query(Relation(name='B',terms=[FreeVar(name='X'),FreeVar(name='Y')])))
assert_df_equals(res,pd.DataFrame([
    [3,4],
    [1,2]
],columns=['X','Y']))

### Case 3

In [None]:
e = Engine()
e.set_relation(RelationDefinition(name='string', scheme=[str]))
e.add_fact(Relation(name='string',terms=['a']))
e.add_fact(Relation(name='string',terms=['aa']))
def func(str):
    yield (len(str),)

e.set_ie_function(IEFunction(name='Length',func=func,in_schema=[str],out_schema=[int]))

r = Rule(
    head=Relation(name='string_length', terms=[FreeVar(name='Str'), FreeVar(name='Len')]),
    body=[
        Relation(name='string', terms=[FreeVar(name='Str')]),
        IERelation(name='Length', in_terms=[FreeVar(name='Str')], out_terms=[FreeVar(name='Len')]),
    ])

e.add_rule(r,RelationDefinition(name='string_length', scheme=[str,int]))
# check that adding the same rule twice does nothing.
e.add_rule(r,RelationDefinition(name='string_length', scheme=[str,int]))



g = e._inline_db_and_ies_in_graph(e.term_graph)
print(e.rules_to_ids)
draw(g)


In [None]:
# with checkLogs():
res = pd.DataFrame(columns=("Str", "Len"), data=e.run_query(Relation(name='string_length', terms=[FreeVar(name='Str'), FreeVar(name='Len')])))
assert_df_equals(res,pd.DataFrame([
    ['a',1],
    ['aa',2]
],columns=['Str','Len']))

#### Aggregation

In [None]:
s3 = pd.DataFrame([
    [1,2,"str"],
    [1,3,"str"],
    [3,4,"str"],
    [3,5,"str"]
])

In [None]:
e = Engine()
e.set_relation(RelationDefinition(name='S3', scheme=[int,int,str]))
e.add_facts('S3',s3)

e.set_agg_function(AGGFunction(name='max',func='max',in_schema=[int],out_schema=[float]))
e.set_agg_function(AGGFunction(name='count',func='count',in_schema=[str],out_schema=[int]))

agg_rule = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')],
    agg=[None,'max','count']
    ),
    body=[
        Relation(name='S3', terms=[FreeVar(name='X'),FreeVar(name='Y'),FreeVar(name='Z')]),
    ])

e.add_rule(agg_rule,RelationDefinition(name='R', scheme=[int,int,int]))

g = e._inline_db_and_ies_in_graph(e.term_graph)
print(e.rules_to_ids)
draw(g)

In [None]:
res = pd.DataFrame(columns=("X", "Y", "Z"), data=e.run_query(Relation(name='R',terms=[FreeVar(name='X'),FreeVar(name='Y'),FreeVar(name='Z')])))
assert_df_equals(res,pd.DataFrame([
    [1,3.0,2],
    [3,5.0,2]
],columns=['X','Y','Z']))
res

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     