# Engine
> Execution spannerlog commands

In [None]:
#| default_exp engine

In [None]:
#| hide
from nbdev.showdoc import show_doc
from IPython.display import display, HTML
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from abc import ABC, abstractmethod
import pytest
from collections import defaultdict

import pandas as pd
from pathlib import Path
from typing import no_type_check, Set, Sequence, Any,Optional,List,Callable,Dict,Union
from pydantic import BaseModel
import networkx as nx
import itertools
from graph_rewrite import draw, draw_match, rewrite, rewrite_iter
from spannerlib.utils import serialize_graph,serialize_df_values,checkLogs,get_new_node_name
from spannerlib.span import Span,SpanParser

import logging
logger = logging.getLogger(__name__)

## Basic Datatypes

In [None]:
#| export
from enum import Enum
from typing import Any
from pydantic import ConfigDict


class Var(BaseModel):
    name: str
    def __hash__(self):
        return hash(self.name)

class FreeVar(BaseModel):
    name: str
    def __hash__(self):
        return hash(self.name)

PrimitiveType=Union[str,int,Span]
Type = Union[str,int,Span,Var,FreeVar]

class RelationDefinition(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    scheme: List[type]

class Relation(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    terms: List[Type]

class IEFunction(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    func: Callable
    in_schema: List[type]
    out_schema: List[type]


class IERelation(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    in_terms: List[Type]
    out_terms: List[Type]
    def __hash__(self):
        hash_str = f'''{self.name}_in_{'_'.join([str(x) for x in self.in_terms])}_out_{'_'.join([str(x) for x in self.out_terms])}'''
        return hash(hash_str)
class Rule(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    head: Relation
    body: List[Union[Relation,IERelation]]

In [None]:
#| export
def pretty(obj):
    """pretty printing dataclasses for user messages,
    making them look like spannerlog code instead of python code"""
    
    if isinstance(obj,Span):
        return f"[{obj.start},{obj.end})"
    elif isinstance(obj,(Var,FreeVar)):
        return obj.name
    elif isinstance(obj,RelationDefinition):
        return f"{obj.name}({','.join(pretty(o) for o in obj.scheme)})"
    elif isinstance(obj,Relation):
        return f"{obj.name}({','.join(pretty(o) for o in obj.terms)})"
    elif isinstance(obj,IERelation):
        return f"{obj.name}({','.join(pretty(o) for o in obj.in_terms)}) -> ({','.join(pretty(o) for o in obj.out_terms)})"
    elif isinstance(obj,IEFunction):
        return f"{obj.name}({','.join(pretty(o) for o in obj.in_schema)}) -> ({','.join(pretty(o) for o in obj.out_schema)})"
    elif isinstance(obj,Rule):
        return f"{pretty(obj.head)} <- {','.join(pretty(o) for o in obj.body)}"
    elif isinstance(obj,type):
        return obj.__name__
    else:
        return str(obj)

In [None]:
rule = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'), Span(start=1,end=4)]),
        IERelation(name='T', in_terms=[FreeVar(name='X'), 1], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')])
    ])
assert pretty(rule) == 'R(X,Y,Z) <- S(X,[1,4)),T(X,1) -> (Y,Z)'

In [None]:
schema = RelationDefinition(name='R', scheme=[int, str, Span])
assert pretty(schema) == 'R(int,str,Span)'
ie_func_schema = IEFunction(name='f', in_schema=[int, str], out_schema=[str, Span],func=lambda x,y: (y,Span(1,2)))
assert pretty(ie_func_schema) == 'f(int,str) -> (str,Span)'

## Term graph

Here we have functions for manipulating rules into term graphs

In [None]:
#| export
def _get_bounding_order(rule:Rule)->List[Union[Relation,IERelation]]:
    """Get an order of evaluation for the body of a rule
    this is a very naive ordering that can be heavily optimized"""

    # we start with all relations since they can be bound at once
    order = list()
    bounded_vars = set()
    for rel in rule.body:
        if isinstance(rel,Relation):
            order.append(rel)
            for term in rel.terms:
                if isinstance(term,FreeVar):
                    bounded_vars.add(term)

    unordered_ierelations = {rel for rel in rule.body if isinstance(rel,IERelation)}
    while len(unordered_ierelations) > 0:
        for ie_rel in unordered_ierelations:
            in_free_vars = {term for term in ie_rel.in_terms if isinstance(term,FreeVar)}
            if in_free_vars.issubset(bounded_vars):
                order.append(ie_rel)
                out_free_vars = {term for term in ie_rel.out_terms if isinstance(term,FreeVar)}
                bounded_vars = bounded_vars.union(out_free_vars)
                unordered_ierelations.remove(ie_rel)
                break

    return order

In [None]:
r = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]),
    body=[
        IERelation(name='T2', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='W'), FreeVar(name='Z')]),
        IERelation(name='T', in_terms=[FreeVar(name='X'), 1], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]),
        Relation(name='S', terms=[FreeVar(name='X'), Span(start=1,end=4)]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),FreeVar(name='B')]),

    ])

order = _get_bounding_order(r)
assert [o.name for o in order ] == ['S','S2', 'T', 'T2']
order

[Relation(name='S', terms=[FreeVar(name='X'), [1,4)]),
 Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'), FreeVar(name='B')]),
 IERelation(name='T', in_terms=[FreeVar(name='X'), 1], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]),
 IERelation(name='T2', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='W'), FreeVar(name='Z')])]

In [None]:

x=itertools.count()
type(x)

itertools.count

In [None]:
#| export
def _name_node(counter):
    if isinstance(counter,itertools.count):
        return next(counter)
    else: # if its just the name to give
        return counter

def _select_if_needed(g,node_counter,source_node,terms):
    """add a project node as a father of source_node if the terms are not all free variables
    returns the source_node if no project is needed, or the project node if it is needed
    the name of the project node should be supplied
    """

    need_select = any(not isinstance(term,FreeVar) for term in terms)
    if not need_select:
        return source_node
    
    select_pos_val = list()
    for i,term in enumerate(terms):
        if not isinstance(term,FreeVar):
            select_pos_val.append((i,term))
    
    select_name = _name_node(node_counter)
    g.add_node(select_name, op='select',theta=select_pos_val)
    g.add_edge(select_name,source_node)
    return select_name

def _product_if_needed(g,node_counter,source_node,terms):
    """add a product node as a father of source_node if the terms are not all free variables
    returns the source_node if no product is needed, or the product node if it is needed
    the name of the product node should be supplied
    """

    need_product = any(not isinstance(term,FreeVar) for term in terms)
    if not need_product:
        return source_node
    
    product_pos_val = list()
    for i,term in enumerate(terms):
        if not isinstance(term,FreeVar):
            product_pos_val.append((i,term))
    
    product_name = _name_node(node_counter)
    g.add_node(product_name, op='product',theta=product_pos_val)
    g.add_edge(product_name,source_node)
    return product_name

# TODO from here, iteratively build the joins each time using _project_if_needed on the outrel and inrel of the relations/ierelaitons
def _rule_to_term_graph(rule:Rule,rule_id) -> nx.DiGraph:
    """Convert a rule to a directed RA+IE term graph"""
    node_counter = itertools.count()
    G = nx.DiGraph()
    # add nodes for all relations
    body_term_connectors = list()
    body_rels = _get_bounding_order(rule)

    # create derivation for each rel in the body
    for rel_idx,rel in enumerate(body_rels):
        if isinstance(rel,Relation):
            G.add_node(rel.name,rel=rel.name)
            rename_node = _name_node(node_counter)
            G.add_node(rename_node, op='rename',names=[(i,term.name) for i,term in enumerate(rel.terms) if isinstance(term,FreeVar)])
            G.add_edge(rename_node,rel.name)
            top_rel_node = _select_if_needed(G,node_counter,rename_node,rel.terms)
            
            body_term_connectors.append((None,top_rel_node))

        elif isinstance(rel,IERelation):
            get_input_node_name =_name_node(node_counter)
            calc_node_name = _name_node(node_counter)
            G.add_node(get_input_node_name, op='project', on=[term.name for term in rel.in_terms if isinstance(term,FreeVar)])
            G.add_node(calc_node_name, op='calc',func=rel.name)

            product_name = _name_node(node_counter)
            calc_son = _product_if_needed(G,node_counter,get_input_node_name,rel.in_terms)
            G.add_edge(calc_node_name,calc_son)
            rename_node = _name_node(node_counter)
            G.add_node(rename_node, op='rename',names=[(i,term.name) for i,term in enumerate(rel.out_terms) if isinstance(term,FreeVar)])
            G.add_edge(rename_node,calc_node_name)
            select_name = _name_node(node_counter)
            top_rel_node = _select_if_needed(G,node_counter,rename_node,rel.out_terms)
            body_term_connectors.append((get_input_node_name,top_rel_node))

    # connect outputs of different rels via joins
    # and connect input of ie functons into the join
    for i,(connectors,rel) in enumerate(zip(body_term_connectors,body_rels)):
        if i == 0:
            prev_top = connectors[1]
            continue

        current_top = connectors[1]

        join_node_name = _name_node(node_counter)
        G.add_node(join_node_name, op='join')
        G.add_edge(join_node_name,prev_top)
        G.add_edge(join_node_name,current_top)

        if isinstance(rel,IERelation):
            ie_bottom = connectors[0]
            G.add_edge(ie_bottom,prev_top)


        prev_top = join_node_name

    # project all assignments into the head
    head_project_name = _name_node(node_counter)
    G.add_node(head_project_name, op='project', on=[term.name for term in rule.head.terms],rel=f'_{rule.head.name}_{rule_id}')
    G.add_edge(head_project_name,prev_top)

    # add a union for each rule for the given head
    G.add_node(rule.head.name,op='union',rel=rule.head.name)
    G.add_edge(rule.head.name,head_project_name)

    # add rule id for each node
    for u in G.nodes:
        G.nodes[u]['rule_id'] = {rule_id}
    return G

In [None]:
#TODO FROM HERE add labels to nodes we can labels
# add HEAD projection node 
# maybe add free vars
g = _rule_to_term_graph(r,0)
draw(g)
serialize_graph(g)
assert serialize_graph(g) == ([('S', { 'rel': 'S', 'rule_id': {0}}),
  (0, {'op': 'rename', 'names': [(0, 'X')], 'rule_id': {0}}),
  (1, {'op': 'select', 'theta': [(1, Span(1,4))], 'rule_id': {0}}),
  ('S2', { 'rel': 'S2', 'rule_id': {0}}),
  (2,
   {'op': 'rename', 'names': [(0, 'X'), (1, 'A'), (2, 'B')], 'rule_id': {0}}),
  (3, {'op': 'project', 'on': ['X'], 'rule_id': {0}}),
  (4, {'op': 'calc', 'func': 'T', 'rule_id': {0}}),
  (6, {'op': 'product', 'theta': [(1, 1)], 'rule_id': {0}}),
  (7, {'op': 'rename', 'names': [(0, 'Y'), (1, 'Z')], 'rule_id': {0}}),
  (9, {'op': 'project', 'on': ['X', 'Y'], 'rule_id': {0}}),
  (10, {'op': 'calc', 'func': 'T2', 'rule_id': {0}}),
  (12, {'op': 'rename', 'names': [(0, 'W'), (1, 'Z')], 'rule_id': {0}}),
  (14, {'op': 'join', 'rule_id': {0}}),
  (15, {'op': 'join', 'rule_id': {0}}),
  (16, {'op': 'join', 'rule_id': {0}}),
  (17,
   {'op': 'project', 'on': ['X', 'Y', 'Z'], 'rel': '_R_0', 'rule_id': {0}}),
  ('R', {'op': 'union', 'rel': 'R', 'rule_id': {0}})],
 [(0, 'S', {}),
  (1, 0, {}),
  (2, 'S2', {}),
  (3, 14, {}),
  (4, 6, {}),
  (6, 3, {}),
  (7, 4, {}),
  (9, 15, {}),
  (10, 9, {}),
  (12, 10, {}),
  (14, 1, {}),
  (14, 2, {}),
  (15, 14, {}),
  (15, 7, {}),
  (16, 15, {}),
  (16, 12, {}),
  (17, 16, {}),
  ('R', 17, {})])

In [None]:
r1 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),1]),
    ])

r2 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r3 = Rule(
    head=Relation(name='R2', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S3', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),1]),
    ])
rules = [r1,r2,r3]
for r in rules:
    print(pretty(r))
t1,t2,t3 = [_rule_to_term_graph(r,i) for i,r in enumerate(rules)]

R(X,Y) <- S(X,Y),S2(X,A,1)
R(X,Y) <- S(X,Y)
R2(X,Y) <- S3(X,Y),S2(X,A,1)


In [None]:
#| export
def graph_compose(g1,g2,mapping_dict,debug=False):
    """compose two graphs with a mapping dict"""
    # if there is a node in g2 that is renamed but has a name collision with an existing node that is not renamed, we will rename the existing node to a uniq name
    # making new names into a digraph is a dirty hack, TODO resolve this
    save_new_names= nx.DiGraph()
    original_mapping_dict = mapping_dict.copy()
    for u2 in g2.nodes():
        if u2 not in mapping_dict and u2 in g1.nodes():
            mapping_dict[u2] = get_new_node_name(g2,avoid_names_from=[g1,save_new_names])
            save_new_names.add_node(mapping_dict[u2])
    if debug:
        return mapping_dict
    g2 = nx.relabel_nodes(g2,mapping_dict,copy=True)

    merged_graph = nx.compose(g1,g2)
    for old_name,new_name in original_mapping_dict.items():
        rule_ids1 = g1.nodes[old_name].get('rule_id',set())
        rule_ids2 = g2.nodes[new_name].get('rule_id',set())
        merged_rule_ids = rule_ids1.union(rule_ids2)
        merged_graph.nodes[new_name]['rule_id'] = merged_rule_ids



    return merged_graph


In [None]:
draw(t1)
draw(t2)
draw(t3)

In [None]:
assert graph_compose(t1,t3,{
    'S':'S',0:0,1:1,
},debug=True) == {'S': 'S', 0: 0, 1: 1, 'S2': 5, 2: 6, 3: 7, 4: 8}

In [None]:
assert graph_compose(t1,t2,
    mapping_dict = {'S':'S','R':'R',0:0}
    ,debug=True) == {'S': 'S', 'R': 'R', 0: 0, 1: 5}

In [None]:
m= graph_compose(t1,t2,
    mapping_dict = {'S':'S','R':'R',0:0}) 
draw(m)

In [None]:
#| export
def merge_term_graphs_pair(g1,g2,exclude_props = ['label'],debug=False):
    """merge two term graphs into one term graph
    when talking about term graphs, 2 nodes if their data is identical and all of their children are identical
    but we would also like to merge rules for the same head, so we will also nodes that have the same 'rel' attribute
    """

    def _are_nodes_equal(g1,u1,g2,u2):

        u1_data = g1.nodes[u1]
        u2_data = g2.nodes[u2]
        
        if 'rel' in u1_data and 'rel' in u2_data:
            return u1_data['rel'] == u2_data['rel']


        
        return False
        # TODO this old code tries to merge nodes, but then its hard to remember which belong to which rules so we only merge
        # so we will do this merging per query
        u1_clean_data = {k:v for k,v in u1_data.items() if k not in exclude_props}
        u2_clean_data = {k:v for k,v in u2_data.items() if k not in exclude_props}

        are_equal = u1_clean_data == u2_clean_data and all(v2 in node_mappings for v2 in g2.successors(u2))
        return are_equal
        

    # we will check for each node in g2 if it has a node in g1 which is it's equal.
    # and save that in a mapping
    node_mappings=dict()# g2 node name to g1 node name
    # we use the fact that g2 is going to be acyclic to travers it in postorder
    for u2 in nx.dfs_postorder_nodes(g2):
        for u1 in g1.nodes():
            if _are_nodes_equal(g1,u1,g2,u2):
                node_mappings[u2] = u1
                break



    if debug:
        return node_mappings
    else:
        return graph_compose(g1,g2,node_mappings)



def merge_term_graphs(gs,exclude_props = ['label'],debug=False):
    """merge a list of term graphs into one term graph
    """
    merge = gs[0]
    for g in gs[1:-1]:
        merge = merge_term_graphs_pair(merge,g,exclude_props,debug=False)
    # if debug, we run debug only on the last merge so we can iteratively debug a list of merges
    return merge_term_graphs_pair(merge,gs[-1],exclude_props,debug=debug)


#### Tests

In [None]:
r1 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='B', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r2 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='C', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r3 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='D', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r4 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='A', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])



In [None]:
rules = [r1,r2,r3,r4]
print([pretty(r) for r in rules])
for r in rules:
    draw(_rule_to_term_graph(r,0))


['A(X,Y) <- B(X,Y)', 'A(X,Y) <- C(X,Y)', 'B(X,Y) <- D(X,Y)', 'B(X,Y) <- A(X,Y)']


In [None]:
m = merge_term_graphs([_rule_to_term_graph(r,i) for i,r in enumerate(rules)])
draw(m)
# assert serialize_graph(m) == ([('B', {'op': 'union', 'rel': 'B', 'rule_id': {0, 2, 3}}),
#   (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {0}}),
#   (1, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_A_0', 'rule_id': {0}}),
#   ('A', { 'rel': 'A', 'rule_id': {0, 1, 3}}),
#   ('C', { 'rel': 'C', 'rule_id': {1}}),
#   (2, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
#   (3, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_A_1', 'rule_id': {1}}),
#   ('D', { 'rel': 'D', 'rule_id': {2}}),
#   (4, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {2}}),
#   (5, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_B_2', 'rule_id': {2}}),
#   (6, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {3}}),
#   (7, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_B_3', 'rule_id': {3}})],
#  [('B', 5, {}),
#   ('B', 7, {}),
#   (0, 'B', {}),
#   (1, 0, {}),
#   ('A', 1, {}),
#   ('A', 3, {}),
#   (2, 'C', {}),
#   (3, 2, {}),
#   (4, 'D', {}),
#   (5, 4, {}),
#   (6, 'A', {}),
#   (7, 6, {})])

In [None]:
m = merge_term_graphs([t1,t2])
draw(m)
assert serialize_graph(m) == ([('S', { 'rel': 'S', 'rule_id': {0, 1}}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {0}}),
  ('S2', { 'rel': 'S2', 'rule_id': {0}}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {0}}),
  (2, {'op': 'select', 'theta': [(2, 1)], 'rule_id': {0}}),
  (3, {'op': 'join', 'rule_id': {0}}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0', 'rule_id': {0}}),
  ('R', {'op': 'union', 'rel': 'R', 'rule_id': {0, 1}}),
  (5, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
  (6, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_1', 'rule_id': {1}})],
 [(0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  ('R', 4, {}),
  ('R', 6, {}),
  (5, 'S', {}),
  (6, 5, {})])

In [None]:
m = merge_term_graphs([t1,t3])
draw(m)
assert serialize_graph(m) == ([('S', { 'rel': 'S', 'rule_id': {0}}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {0}}),
  ('S2', { 'rel': 'S2', 'rule_id': {0, 2}}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {0}}),
  (2, {'op': 'select', 'theta': [(2, 1)], 'rule_id': {0}}),
  (3, {'op': 'join', 'rule_id': {0}}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0', 'rule_id': {0}}),
  ('R', {'op': 'union', 'rel': 'R', 'rule_id': {0}}),
  ('S3', { 'rel': 'S3', 'rule_id': {2}}),
  (5, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {2}}),
  (6, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {2}}),
  (7, {'op': 'select', 'theta': [(2, 1)], 'rule_id': {2}}),
  (8, {'op': 'join', 'rule_id': {2}}),
  (9, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R2_2', 'rule_id': {2}}),
  ('R2', {'op': 'union', 'rel': 'R2', 'rule_id': {2}})],
 [(0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  ('R', 4, {}),
  (5, 'S3', {}),
  (6, 'S2', {}),
  (7, 6, {}),
  (8, 5, {}),
  (8, 7, {}),
  (9, 8, {}),
  ('R2', 9, {})])

In [None]:
m = merge_term_graphs([t1,t2,t3])
draw(m)
assert serialize_graph(m) == ([('S', { 'rel': 'S', 'rule_id': {0, 1}}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {0}}),
  ('S2', { 'rel': 'S2', 'rule_id': {0, 2}}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {0}}),
  (2, {'op': 'select', 'theta': [(2, 1)], 'rule_id': {0}}),
  (3, {'op': 'join', 'rule_id': {0}}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0', 'rule_id': {0}}),
  ('R', {'op': 'union', 'rel': 'R', 'rule_id': {0, 1}}),
  (5, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
  (6, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_1', 'rule_id': {1}}),
  ('S3', { 'rel': 'S3', 'rule_id': {2}}),
  (7, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {2}}),
  (8, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {2}}),
  (9, {'op': 'select', 'theta': [(2, 1)], 'rule_id': {2}}),
  (10, {'op': 'join', 'rule_id': {2}}),
  (11, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R2_2', 'rule_id': {2}}),
  ('R2', {'op': 'union', 'rel': 'R2', 'rule_id': {2}})],
 [(0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  ('R', 4, {}),
  ('R', 6, {}),
  (5, 'S', {}),
  (6, 5, {}),
  (7, 'S3', {}),
  (8, 'S2', {}),
  (9, 8, {}),
  (10, 7, {}),
  (10, 9, {}),
  (11, 10, {}),
  ('R2', 11, {})])

## Engine

Helper functions to get schema from dfs and verify them

In [None]:
#| export
import re
STRING_PATTERN = re.compile(r"^[^\r\n]+$")


def _infer_relation_schema(row) -> Sequence[type]: # Inferred type list of the given relation
    """
    Guess the relation type based on the data.
    We support both the actual types (e.g. 'Span'), and their string representation ( e.g. `"[0,8)"`).

    **@raise** ValueError: if there is a cell inside `row` of an illegal type.
    """
    relation_types = []
    for cell in row:
        try:
            int(cell)  # check if the cell can be converted to integer
            relation_types.append(int)
        except (ValueError, TypeError):
            if isinstance(cell, Span) or SpanParser.parse(cell):
                relation_types.append(Span)
            elif re.match(STRING_PATTERN, cell):
                relation_types.append(str)
            else:
                raise ValueError(f"value doesn't match any datatype: {cell}")

    return relation_types

In [None]:
assert _infer_relation_schema([1, 2, 3]) == [ int,int,int]
assert _infer_relation_schema([1, 'a']) == [ int,str]
assert _infer_relation_schema(['[0,1)','[0, 1)',Span(1,3)]) == [Span,Span,Span]

In [None]:
#| export
def _pd_drop_row(df,row_vals):
    new_df = df[(df!=row_vals).all(axis=1)]
    return new_df

def _pd_append_row(df,row_vals):
    return pd.concat([df,pd.DataFrame([row_vals],columns=df.columns)])

In [None]:
df = pd.DataFrame([
    [1,'2fs'],[3,4]
])
assert list(_pd_drop_row(df,[3,4]).itertuples(index=False,name=None))==[(1,'2fs')]
assert list(_pd_append_row(df,[5,6]).itertuples(index=False,name=None)) == [(1, '2fs'), (3, 4), (5, 6)]

In [None]:
#| export
class DB(dict):
    def __repr__(self):
        key_str=', '.join(self.keys())
        return f'DB({key_str})'

In [None]:
#| export
from copy import deepcopy
class Engine():
    def __init__(self,rewrites=None):
        if rewrites is None:
            self.rewrites = []
        self.symbol_table={
            # key : type,val
        }
        self.Relation_defs={
            # key : RelationDefinition for both real and derived relations
        }
        self.ie_functions={
            # name : IEFunction class
        }

        self.term_graph = nx.DiGraph()
        
        self.node_counter = itertools.count()
        self.rule_counter = itertools.count()

        self.db = DB(
            # relation_name: dataframe
        )

        # lets skip this for now and keep it a an attribute in the node graph
        self.rules_to_ids = {
            # rule pretty string, to node id in term_graph
        }

        # self.rels_to_nodes() = {
        #     # relation name to node that represents it
        # }

    @classmethod
    def _col_names(self,length):
        # these names wont conflixt with logical variables since they must always start with Uppercase letters
        return [f'col{i}' for i in range(length)]
    

    def set_var(self,var_name,value,read_from_file=False):
        symbol_table = self.symbol_table
        if var_name in symbol_table:
            existing_type,existing_value = symbol_table[var_name]
            if type(value) != existing_type:
                raise ValueError(f"Variable {var_name} was previously defined with {existing_value}({pretty(existing_type)})"
                                f" but is trying to be redefined to {value}({pretty(type(value))}) of a different type which might interfere with previous rule definitions")    
        symbol_table[var_name] = type(value),value
        return
    def get_var(self,var_name):
        return self.symbol_table.get(var_name,None)
    
    def del_var(self,var_name):
        del self.symbol_table[var_name]

    def get_relation(self,rel_name:str):
        return self.Relation_defs.get(rel_name,None)

    def set_relation(self,rel_def:RelationDefinition):
        if rel_def.name in self.Relation_defs:
            existing_def = self.Relation_defs[rel_def.name]
            if existing_def != rel_def:
                raise ValueError(f"Relation {rel_def.name} was previously defined with {existing_def}"
                                f"but is trying to be redefined to {rel_def} which might interfere with previous rule definitions")
        else:
            self.Relation_defs[rel_def.name] = rel_def
            #TODO fix make sure that the empty df has the correct types based on the rel_def
            empty_df = pd.DataFrame(columns=self._col_names(len(rel_def.scheme)))
            self.db[rel_def.name] = empty_df
            self.term_graph.add_node(rel_def.name,rel=rel_def.name)

    def del_relation(self,rel_name:str):
        # TODO we need to think about what to do with all relations that used this rule
        raise NotImplementedError("deleting relations is not supported yet")
        return

    def add_fact(self,fact:Relation):
        self.db[fact.name] = _pd_append_row(self.db[fact.name],fact.terms)

    def add_facts(self,rel_name,facts:pd.DataFrame):
        facts= facts.copy()
        facts.columns = self._col_names(len(facts.columns))
        self.db[rel_name] = pd.concat([self.db[rel_name],facts])

    def del_fact(self,fact:Relation):
        self.db[fact.name] = _pd_drop_row(df = self.db[fact.name],row_vals=fact.terms)

    def get_ie_function(self,name:str):
        return self.ie_functions.get(name,None)

    def set_ie_function(self,ie_func:IEFunction):
        self.ie_functions[ie_func.name]=ie_func

    def del_ie_function(self,name:str):
        del self.ie_functions[name]

    def add_rule(self,rule:Rule,schema:RelationDefinition=None):
        if not self.get_relation(rule.head.name) and schema is None:
            raise ValueError(f"Relation {rule.head.name} not defined before adding the rule with it's head\n"
                             f"And an relation schema was not supplied."
                             f"existing relations are {self.Relation_defs.keys()}")

        if not schema is None:
            self.set_relation(schema)

        rule_id = next(self.rule_counter)

        self.rules_to_ids[pretty(rule)] = rule_id

        g2 = _rule_to_term_graph(rule,rule_id)

        merge_term_graph = merge_term_graphs_pair(self.term_graph,g2)
        self.term_graph = merge_term_graph
        

    def del_rule(self,rule_str:str):
        if not rule_str in self.rules_to_ids:
            raise ValueError(f"Rule {rule_str} does not exist\n"
                             f"existing rules are {self.rules_to_ids.keys()}")
        rule_id = self.rules_to_ids[rule_str]
        g = self.term_graph
        nodes_to_delete=[]
        for u in g.nodes:
            node_rule_ids = g.nodes[u].get('rule_id',set())
            if rule_id in node_rule_ids:
                node_rule_ids.remove(rule_id)
                if len(node_rule_ids) == 0:
                    nodes_to_delete.append(u)
        g.remove_nodes_from(nodes_to_delete)
            
        return

    def _inline_db_and_ies_in_graph(self,g:nx.DiGraph):
        g=deepcopy(g)
        for u in g.nodes:
            if g.out_degree(u)==0:
                g.nodes[u]['op'] = 'get_rel'
                g.nodes[u]['db'] = self.db
            elif g.nodes[u]['op'] == 'calc':
                ie_func_name = g.nodes[u]['func']
                g.nodes[u]['func'] = self.ie_functions[ie_func_name].func
                g.nodes[u]['out_schema'] = self.ie_functions[ie_func_name].out_schema
        return g


    def plan_query(self,q_rel:Relation,rewrites=None):
        if rewrites is None:
            rewrites = self.rewrites
        query_graph = self._inline_db_and_ies_in_graph(self.term_graph)

        # get the sub term graph induced by the relation head
        root_node = q_rel.name
        connected_nodes = list(nx.shortest_path(query_graph,root_node).keys())
        query_graph = nx.DiGraph(nx.subgraph(query_graph,connected_nodes))
        
        # based on the asked relation, add:
        # select node if there are constants
        # project node to project to the remaining free variables
        # rename node to rename the cols to the Free vars the query is asking for
        select_node = _select_if_needed(query_graph,get_new_node_name(query_graph),root_node,q_rel.terms)
        rename_node = get_new_node_name(query_graph)
        query_graph.add_node(rename_node, op='rename',names=[(i,term.name) for i,term in enumerate(q_rel.terms) if isinstance(term,FreeVar)])
        query_graph.add_edge(rename_node,select_node)
        project_node = get_new_node_name(query_graph)
        query_graph.add_node(project_node, op='project', on=[term.name for term in q_rel.terms if isinstance(term,FreeVar)])
        query_graph.add_edge(project_node,rename_node)

        # TODO for all rewrites, run them
        return query_graph,project_node

    def execute_plan(self,query_graph,root_node):
        results_dict = defaultdict(list)
        return compute_node(query_graph,root_node,results_dict)

    def run_query(self,q:Relation,rewrites=None):
        query_graph,root_node = self.plan_query(q,rewrites)
        return self.execute_plan(query_graph,root_node)


### Test

In [None]:
s = pd.DataFrame([
    [1,1],
    [2,2],
    [3,3],
    [4,4]
])

s2 = pd.DataFrame([
    [1,2,3],
    [2,3,4],
    [3,4,5],
    [4,5,6]
])

In [None]:
r1 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),3]),
    ])

r2 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        IERelation(name='T', in_terms=[FreeVar(name='X'),FreeVar(name='Y')], out_terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r3 = Rule(
    head=Relation(name='R2', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S3', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),1]),
    ])


rec_r1 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='B', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

rec_r2 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='C', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

rec_r3 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='D', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

rec_r4 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='A', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])


In [None]:
e = Engine()
e.set_relation(RelationDefinition(name='S', scheme=[int,int]))
e.set_relation(RelationDefinition(name='S2', scheme=[int,int,int]))
e.set_relation(RelationDefinition(name='S3', scheme=[int,int]))

e.add_rule(r1,RelationDefinition(name='R', scheme=[int,int]))
e.add_rule(r2,RelationDefinition(name='R', scheme=[int,int]))
e.add_rule(r3,RelationDefinition(name='R2', scheme=[int,int]))

In [None]:
e.add_facts('S',s)
e.add_facts('S2',s2)

In [None]:
e.db['S']

Unnamed: 0,col0,col1
0,1,1
1,2,2
2,3,3
3,4,4


In [None]:
e.db['S2']

Unnamed: 0,col0,col1,col2
0,1,2,3
1,2,3,4
2,3,4,5
3,4,5,6


In [None]:
draw(e.term_graph)

In [None]:
e.rules_to_ids

{'R(X,Y) <- S(X,Y),S2(X,A,3)': 0,
 'R(X,Y) <- S(X,Y),T(X,Y) -> (X,Y)': 1,
 'R2(X,Y) <- S3(X,Y),S2(X,A,1)': 2}

In [None]:
e.del_rule(pretty(r3))

In [None]:
draw(e.term_graph)

In [None]:
assert serialize_graph(e.term_graph) ==([('S', {'rel': 'S', 'rule_id': {0, 1}}),
  ('S2', {'rel': 'S2', 'rule_id': {0}}),
  ('R', {'rel': 'R', 'op': 'union', 'rule_id': {0, 1}}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {0}}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')], 'rule_id': {0}}),
  (2, {'op': 'select', 'theta': [(2, 3)], 'rule_id': {0}}),
  (3, {'op': 'join', 'rule_id': {0}}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0', 'rule_id': {0}}),
  (8, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
  (9, {'op': 'project', 'on': ['X', 'Y'], 'rule_id': {1}}),
  (10, {'op': 'calc', 'func': 'T', 'rule_id': {1}}),
  (11, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')], 'rule_id': {1}}),
  (6, {'op': 'join', 'rule_id': {1}}),
  (7, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_1', 'rule_id': {1}})],
 [('R', 4, {}),
  ('R', 7, {}),
  (0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  (8, 'S', {}),
  (9, 8, {}),
  (10, 9, {}),
  (11, 10, {}),
  (6, 8, {}),
  (6, 11, {}),
  (7, 6, {})])

In [None]:
e.add_rule(rec_r1,RelationDefinition(name='A', scheme=[int,int]))
e.add_rule(rec_r2,RelationDefinition(name='A', scheme=[int,int]))
e.add_rule(rec_r3,RelationDefinition(name='B', scheme=[int,int]))
e.add_rule(rec_r4,RelationDefinition(name='B', scheme=[int,int]))


In [None]:
draw(e.term_graph)

## RA operations

For each RA operation, and the CalcIE operation, we have an operation

* Select
* Project
* Rename
* Union
* Intersection
* Difference
* Join
* SemiJoin
* Calc (for calcing ie relations)
* GetRel (for accessing relations from the DB)
  


In [None]:
s = pd.DataFrame([
    [1,1],
    [2,2],
    [3,3],
    [4,5]
])

s2 = pd.DataFrame([
    [1,2,3],
    [2,3,4],
    [2,3,5],
    [4,5,6]
])

In [None]:
df = pd.DataFrame([
    [1,2],
    [1,3],
    [1,2],
    ])
df

Unnamed: 0,0,1
0,1,2
1,1,3
2,1,2


In [None]:
#| export
def select(df,theta,**kwargs):
    if df.empty:
        return df

    if callable(theta):
        return df[theta(df)]
    else:
        pos_val_pairs = theta
        for pos,val in pos_val_pairs:
            df = df[df.iloc[:,pos] == val]
        return df

def project(df,on,**kwargs):
    if df.empty:
        return df
    return df[on]

def rename(df,names,**kwargs):
    if df.empty:
        return df
    names_mapper = {
    }
    current_col_names = list(df.columns)
    for i,name in names:
        names_mapper[current_col_names[i]] = name
    return df.rename(names_mapper,axis=1)

def union(*dfs,**kwargs):
    return pd.concat(dfs).drop_duplicates()

def intersection(df1,df2,**kwargs):
    return pd.merge(df1,df2,how='inner',on=list(df1.columns))

def difference(df1,df2,**kwargs):
    return pd.concat([df1,df2]).drop_duplicates(keep=False)

def join(df1,df2,**kwargs):
    cols1 = set(df1.columns)
    cols2 = set(df2.columns)
    on = cols1 & cols2
    # get only logical variables
    on = [ col for col in on if isinstance(col,str) and col[0].isupper()]
    if len(on)==0:
        return pd.merge(df1,df2,how='cross')
    else:
        return pd.merge(df1,df2,how='inner',on=on)

def calc(df,func,out_schema,semantics='per_row',**kwargs):
    if df.empty:
        return df
    if semantics == 'per_row':
    
        def check_output_generator(vals):
            for val in vals:
                if _infer_relation_schema(val) != out_schema:
                    pretty_actual_schema = [pretty(term) for term in _infer_relation_schema(val)]
                    pretty_expected_schema = [pretty(term) for term in out_schema]
                    raise ValueError(f"Function {func} returned a value of an unexpected schema "
                            f"{val}({pretty_actual_schema}) expected {pretty_expected_schema}")
                yield val

        vals = (list(vals) for _,vals in df.iterrows())
        vals = (func(*vals) for vals in vals)
        vals = check_output_generator(vals)

        return pd.DataFrame(vals,columns=Engine._col_names(len(out_schema)))
    
    else: # bulk
        raise NotImplementedError('non per row semantics are not supported for calculating ie functions')

def get_rel(rel,db,**kwargs):
    return db[rel]


In [None]:
#| export

op_to_func = {
    'union':union,
    'intersection':intersection,
    'difference':difference,
    'select':select,
    'project':project,
    'rename':rename,
    'join':join,
    'calc':calc,
    'get_rel':get_rel
}

### Tests

In [None]:
s3 = pd.DataFrame([
    [4,5,6],
    [5,6,7],
    [1,2,3],
    [7,8,9]
])
s3

Unnamed: 0,0,1,2
0,4,5,6
1,5,6,7
2,1,2,3
3,7,8,9


In [None]:
s

Unnamed: 0,0,1
0,1,1
1,2,2
2,3,3
3,4,5


In [None]:
s2_copy = s2.copy()
s2

Unnamed: 0,0,1,2
0,1,2,3
1,2,3,4
2,2,3,5
3,4,5,6


In [None]:
empty = pd.DataFrame()

truth = pd.DataFrame([()])

In [None]:
empty.empty

True

In [None]:
assert serialize_df_values(select(empty,[(0,1)]))==set()
assert serialize_df_values(rename(empty,names=[(0,'X')]))==set()
assert serialize_df_values(project(empty,on=['X','Y']))==set()

In [None]:
res = select(s,[(0,1)])
assert serialize_df_values(res)=={(1,1)}

res = select(s2,[(0,2),(1,3)])
assert serialize_df_values(res) == {(2,3,4),(2,3,5)}

res = select(s,lambda row: row[0] == row[1])
assert serialize_df_values(res) == {(1, 1), (2, 2), (3, 3)}

In [None]:
res = project(s2,on=[2,1])
assert serialize_df_values(res) == {(3, 2), (4, 3), (5, 3), (6, 5)}
res

Unnamed: 0,2,1
0,3,2
1,4,3
2,5,3
3,6,5


In [None]:
assert list(rename(s2,[(0,'X')]).columns) == ['X',1,2]
assert list(rename(s2,[(0,'X'),(2,'Z')]).columns) == ['X',1,'Z']

In [None]:
res = union(s2,s3)
assert serialize_df_values(res) == {(1, 2, 3), (2, 3, 4), (2, 3, 5), (4, 5, 6), (5, 6, 7), (7, 8, 9)}

res = intersection(s2,s3) 
assert serialize_df_values(res) == {(1, 2, 3), (4, 5, 6),}

res = difference(s2,s3)
assert serialize_df_values(res) == { (2, 3, 4), (2, 3, 5),(5, 6, 7), (7, 8, 9)}


In [None]:
left = rename(s,[(1,'Y')])
left

Unnamed: 0,0,Y
0,1,1
1,2,2
2,3,3
3,4,5


In [None]:
right = rename(s2,[(0,'Y'),(1,'X')])
right

Unnamed: 0,Y,X,2
0,1,2,3
1,2,3,4
2,2,3,5
3,4,5,6


In [None]:
res = join(left,right)
serialize_df_values(res) == {(1, 1, 2, 3), (2, 2, 3, 4), (2, 2, 3, 5)}
res

Unnamed: 0,0,Y,X,2
0,1,1,2,3
1,2,2,3,4
2,2,2,3,5


In [None]:
res = join(rename(s,[(0,'a'),(1,'b')]),s)
assert len(res)==16
assert list(res.columns) == ['a', 'b', 0, 1]
res.head()

Unnamed: 0,a,b,0,1
0,1,1,1,1
1,1,1,2,2
2,1,1,3,3
3,1,1,4,5
4,2,2,1,1


In [None]:
def func(x,y): return (x+y,x-y)
def func2(x,y,z): return (x,y)

res = calc(s,func,[int,int])
assert serialize_df_values(res) == {(2, 0), (4, 0), (6, 0), (9, -1)}
res

Unnamed: 0,col0,col1
0,2,0
1,4,0
2,6,0
3,9,-1


In [None]:
# test checking of schema
with pytest.raises(ValueError) as exc_info:
    calc(s2,func2,[int,int,int])
assert 'returned a value of an unexpected schema' in str(exc_info.value)
print(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    calc(s2,func2,[int,str])
assert 'returned a value of an unexpected schema' in str(exc_info.value)
print(exc_info.value)

Function <function func2> returned a value of an unexpected schema (1, 2)(['int', 'int']) expected ['int', 'int', 'int']
Function <function func2> returned a value of an unexpected schema (1, 2)(['int', 'int']) expected ['int', 'str']


In [None]:
res =  get_rel('S',e.db)
res

Unnamed: 0,col0,col1
0,1,1
1,2,2
2,3,3
3,4,4


## SemiNaive Operations

for every RA operation and CalcIE operation, we have a funcion that taken the children as nodes that can get the underlying data from a specific iteration and compute the differential version of the operation

In [None]:
#TODO this will be implemented in the future

## Naive execution

A recursive least fixed point logic algorithm mimicing the bottom up evalutation.

In [None]:
#| export
def compute_node(G,u,results_dict):
    """
    """
    global op_to_func
    u_data = G.nodes[u]
    if u_data.get('final',False):
        return results_dict[u][-1]
    

    children_ids = list(G.successors(u))
    children_data = [G.nodes[child_id] for child_id in children_ids]

    if len(children_ids)==0:
        all_children_final=True
        children_results=[]
        
    else:
        all_children_final = all(child_data.get('final',False) for child_data in children_data)
        # this block helps avoid infinite recursion when we need to initialize the 0th iteration of a node in a cycle as the empty relation
        if u_data.get('visited',False):
            u_data['visited'] = True
            return pd.DataFrame()
        u_data['visited'] =True
        
        children_results = [compute_node(G,child_id,results_dict) for child_id in children_ids]



    op_code = u_data['op']
    op_func = op_to_func[op_code]
    logger.debug(f'computing node {u} with op {op_code} and children results {children_results} and data {u_data}')
    current_results = op_func(*children_results,**u_data)
    results_dict[u].append(current_results)

    # if all children are final then we can mark this node as final
    if all_children_final:
        u_data['final'] = True
    # else check for fixed point
    elif len(results_dict[u])>1:
        this_res = results_dict[u][-1]
        last_res = results_dict[u][-2]
        if pd.DataFrame.equals(this_res,last_res):
            u_data['final'] = True
        else:
            u_data['final'] = False
    # if this is the first iteration we need to wait for the second one to check for fixed point so final is still false
    else:
        u_data['final']=False

    return current_results





#### Case1

In [None]:
e = Engine()
e.set_relation(RelationDefinition(name='S', scheme=[int,int]))
e.set_relation(RelationDefinition(name='S2', scheme=[int,int,int]))

e.add_rule(r1,RelationDefinition(name='R', scheme=[int,int]))
e.add_rule(r2,RelationDefinition(name='R', scheme=[int,int]))

e.add_facts('S',s)
e.add_facts('S2',s2)

def func(x,y):
    return y,x

ie_def = IEFunction(name='T',func=func,in_schema=[int,int],out_schema=[int,int])

e.set_ie_function(ie_def)
g = e._inline_db_and_ies_in_graph(e.term_graph)
print(e.rules_to_ids)
display(s)
display(s2)

{'R(X,Y) <- S(X,Y),S2(X,A,3)': 0, 'R(X,Y) <- S(X,Y),T(X,Y) -> (X,Y)': 1}


Unnamed: 0,0,1
0,1,1
1,2,2
2,3,3
3,4,5


Unnamed: 0,0,1,2
0,1,2,3
1,2,3,4
2,2,3,5
3,4,5,6


In [None]:
# with checkLogs():
results_dict = defaultdict(list)
query_res = compute_node(g.copy(),'R',results_dict)
assert serialize_df_values(query_res) == {(1, 1), (2, 2), (3, 3)}

In [None]:
draw(g)

In [None]:
res = e.run_query(Relation(name='R',terms=[FreeVar(name='X'),FreeVar(name='Y')]))
assert serialize_df_values(res) == {(1, 1), (2, 2), (3, 3)}

In [None]:
q,root = e.plan_query(Relation(name='R',terms=[FreeVar(name='X'),3]))
draw(q)

In [None]:
# TODO from here, debug the select node and add a rename node to rename the columns to the correct names given by the query
# with checkLogs():
"?R(S,3)"
res = e.run_query(Relation(name='R',terms=[FreeVar(name='S'),3]))
assert serialize_df_values(res)=={(3,)}
assert list(res.columns) == ['S']

#### case 2

In [None]:
e2 = Engine()
e2.set_relation(RelationDefinition(name='C', scheme=[int,int]))
e2.set_relation(RelationDefinition(name='D', scheme=[int,int]))

for rule in [rec_r1,rec_r2,rec_r3,rec_r4]:
    e2.add_rule(rule,RelationDefinition(name=rule.head.name, scheme=[int,int]))

e2.add_fact(Relation(name='C',terms=[1,2]))
e2.add_fact(Relation(name='D',terms=[3,4]))

g2 = e2._inline_db_and_ies_in_graph(e2.term_graph)
e2.rules_to_ids

{'A(X,Y) <- B(X,Y)': 0,
 'A(X,Y) <- C(X,Y)': 1,
 'B(X,Y) <- D(X,Y)': 2,
 'B(X,Y) <- A(X,Y)': 3}

In [None]:
# with checkLogs():
query_res = e2.run_query(Relation(name='A',terms=[FreeVar(name='X'),FreeVar(name='Y')]))
assert serialize_df_values(query_res) == {(3, 4), (1, 2)}
query_res = e2.run_query(Relation(name='B',terms=[FreeVar(name='X'),FreeVar(name='Y')]))
assert serialize_df_values(query_res) == {(3, 4), (1, 2)}

In [None]:
draw(g2)

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     