# Engine
> Execution spannerlog commands

In [None]:
#| default_exp engine

In [None]:
#| hide
from nbdev.showdoc import show_doc

%load_ext autoreload
%autoreload 2

In [None]:
#| export
from abc import ABC, abstractmethod
import pytest

import pandas as pd
from pathlib import Path
from typing import no_type_check, Set, Sequence, Any,Optional,List,Callable,Dict,Union
from pydantic import BaseModel
import networkx as nx
import itertools
from graph_rewrite import draw, draw_match, rewrite, rewrite_iter
from spannerlib.utils import serialize_graph

## Utils TODO move to utils

In [None]:
#| export
def _biggest_int_node_name(g:nx.Graph):
    return max([n for n in g.nodes if isinstance(n,int)],default=0)

def is_node_in_graphs(name,gs):
    return any(name in g.nodes for g in gs)

def get_new_node_name(g,prefix=None,avoid_names_from=None):
    if avoid_names_from is None:
        avoid_names_from = []
    graphs_to_avoid = [g]+avoid_names_from
    # ints
    if prefix is None:
        max_int = _biggest_int_node_name(g)+1
        while is_node_in_graphs(max_int,graphs_to_avoid):
            max_int+=1
        return max_int
    # strings
    else: 
        if not is_node_in_graphs(prefix,graphs_to_avoid):
            return prefix
        for i in itertools.count():
            name = f"{prefix}_{i}"
            if not is_node_in_graphs(name,graphs_to_avoid):
                return name

In [None]:
g = nx.Graph()
g.add_node(1)
g.add_node(2)
g.add_node('hello')

g2 = nx.Graph()
g2.add_node(1)
g2.add_node(2)
g2.add_node(3)
g2.add_node('hello_1')

assert _biggest_int_node_name(g) == 2

In [None]:
assert is_node_in_graphs(3,[g,g2])
assert not is_node_in_graphs(4,[g,g2])

In [None]:
assert get_new_node_name(g) == 3
assert get_new_node_name(g,'hello') == 'hello_0'
g.add_node('hello_0')
assert get_new_node_name(g,'hello') == 'hello_1'
assert get_new_node_name(g,'hello',avoid_names_from=[g2]) == 'hello_2'
assert get_new_node_name(g,'world') == 'world'

## Basic Datatypes

In [None]:
#| export
from enum import Enum
from typing import Any
from pydantic import ConfigDict

class Span(BaseModel):
    start: int
    end: int

    def __lt__(self, other) -> bool:
        if self.start == other.start:
            return self.end < other.end

        return self.start < other.start

    # # used for sorting `Span`s in dataframes
    # def __hash__(self) -> int:
    #     return hash((self.start, self.end))

class Var(BaseModel):
    name: str
    def __hash__(self):
        return hash(self.name)

class FreeVar(BaseModel):
    name: str
    def __hash__(self):
        return hash(self.name)

PrimitiveType=Union[str,int,Span]
Type = Union[PrimitiveType,Var,FreeVar]

class RelationDefinition(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    scheme: List[type]

class Relation(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    terms: List[Type]

class IEFunction(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    in_schema: List[type]
    out_schema: List[type]
    func: Callable


class IERelation(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    name: str
    in_terms: List[Type]
    out_terms: List[Type]
    def __hash__(self):
        hash_str = f'''{self.name}_in_{'_'.join([str(x) for x in self.in_terms])}_out_{'_'.join([str(x) for x in self.out_terms])}'''
        return hash(hash_str)
class Rule(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    head: Relation
    body: List[Union[Relation,IERelation]]

In [None]:
#| export
def pretty(obj):
    """pretty printing dataclasses for user messages,
    making them look like spannerlog code instead of python code"""
    
    if isinstance(obj,Span):
        return f"[{obj.start},{obj.end})"
    elif isinstance(obj,(Var,FreeVar)):
        return obj.name
    elif isinstance(obj,RelationDefinition):
        return f"{obj.name}({','.join(pretty(o) for o in obj.scheme)})"
    elif isinstance(obj,Relation):
        return f"{obj.name}({','.join(pretty(o) for o in obj.terms)})"
    elif isinstance(obj,IERelation):
        return f"{obj.name}({','.join(pretty(o) for o in obj.in_terms)}) -> ({','.join(pretty(o) for o in obj.out_terms)})"
    elif isinstance(obj,IEFunction):
        return f"{obj.name}({','.join(pretty(o) for o in obj.in_schema)}) -> ({','.join(pretty(o) for o in obj.out_schema)})"
    elif isinstance(obj,Rule):
        return f"{pretty(obj.head)} <- {','.join(pretty(o) for o in obj.body)}"
    elif isinstance(obj,type):
        return obj.__name__
    else:
        return str(obj)

In [None]:
rule = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'), Span(start=1,end=4)]),
        IERelation(name='T', in_terms=[FreeVar(name='X'), 1], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')])
    ])
assert pretty(rule) == 'R(X,Y,Z) <- S(X,[1,4)),T(X,1) -> (Y,Z)'

In [None]:
schema = RelationDefinition(name='R', scheme=[int, str, Span])
assert pretty(schema) == 'R(int,str,Span)'
ie_func_schema = IEFunction(name='f', in_schema=[int, str], out_schema=[str, Span],func=lambda x,y: (y,Span(1,2)))
assert pretty(ie_func_schema) == 'f(int,str) -> (str,Span)'

## Term graph

Here we have functions for manipulating rules into term graphs

In [None]:
def _get_bounding_order(rule:Rule)->List[Union[Relation,IERelation]]:
    """Get an order of evaluation for the body of a rule
    this is a very naive ordering that can be heavily optimized"""

    # we start with all relations since they can be bound at once
    order = list()
    bounded_vars = set()
    for rel in rule.body:
        if isinstance(rel,Relation):
            order.append(rel)
            for term in rel.terms:
                if isinstance(term,FreeVar):
                    bounded_vars.add(term)

    unordered_ierelations = {rel for rel in rule.body if isinstance(rel,IERelation)}
    while len(unordered_ierelations) > 0:
        for ie_rel in unordered_ierelations:
            in_free_vars = {term for term in ie_rel.in_terms if isinstance(term,FreeVar)}
            if in_free_vars.issubset(bounded_vars):
                order.append(ie_rel)
                out_free_vars = {term for term in ie_rel.out_terms if isinstance(term,FreeVar)}
                bounded_vars = bounded_vars.union(out_free_vars)
                unordered_ierelations.remove(ie_rel)
                break

    return order

In [None]:
r = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]),
    body=[
        IERelation(name='T2', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='W'), FreeVar(name='Z')]),
        IERelation(name='T', in_terms=[FreeVar(name='X'), 1], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]),
        Relation(name='S', terms=[FreeVar(name='X'), Span(start=1,end=4)]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),FreeVar(name='B')]),

    ])

order = _get_bounding_order(r)
assert [o.name for o in order ] == ['S','S2', 'T', 'T2']
order

[Relation(name='S', terms=[FreeVar(name='X'), Span(start=1, end=4)]),
 Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'), FreeVar(name='B')]),
 IERelation(name='T', in_terms=[FreeVar(name='X'), 1], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]),
 IERelation(name='T2', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='W'), FreeVar(name='Z')])]

In [None]:
def _name_node(counter,op,rule_id,rel_idx,rel_name):
    return next(counter)
    # return f'{op}_{rel_name}_rule_{rule_id}_rel_{rel_idx}'

def _select_if_needed(g,node_counter,source_node,terms,rel,rel_idx,rule_id):
    """add a project node as a father of source_node if the terms are not all free variables
    returns the source_node if no project is needed, or the project node if it is needed
    the name of the project node should be supplied
    """

    need_select = any(not isinstance(term,FreeVar) for term in terms)
    if not need_select:
        return source_node
    
    select_pos_val = list()
    for i,term in enumerate(terms):
        if not isinstance(term,FreeVar):
            select_pos_val.append((i,term))
    
    select_name = _name_node(node_counter,'select',rule_id,rel_idx,rel.name)
    g.add_node(select_name, op='select',theta=select_pos_val)
    g.add_edge(select_name,source_node)
    return select_name

def _product_if_needed(g,node_counter,source_node,terms,rel,rel_idx,rule_id):
    """add a product node as a father of source_node if the terms are not all free variables
    returns the source_node if no product is needed, or the product node if it is needed
    the name of the product node should be supplied
    """

    need_product = any(not isinstance(term,FreeVar) for term in terms)
    if not need_product:
        return source_node
    
    product_pos_val = list()
    for i,term in enumerate(terms):
        if not isinstance(term,FreeVar):
            product_pos_val.append((i,term))
    
    product_name = _name_node(node_counter,'product',rule_id,rel_idx,rel.name)
    g.add_node(product_name, op='product',theta=product_pos_val)
    g.add_edge(product_name,source_node)
    return product_name

# TODO from here, iteratively build the joins each time using _project_if_needed on the outrel and inrel of the relations/ierelaitons
def _rule_to_term_graph(rule:Rule,rule_id) -> nx.DiGraph:
    """Convert a rule to a directed RA+IE term graph"""
    node_counter = itertools.count()
    G = nx.DiGraph()
    # add nodes for all relations
    body_term_connectors = list()
    body_rels = _get_bounding_order(rule)

    # create derivation for each rel in the body
    for rel_idx,rel in enumerate(body_rels):
        if isinstance(rel,Relation):
            G.add_node(rel.name, op='get_rel',rel=rel.name)
            rename_node = _name_node(node_counter,'rename',rule_id,rel_idx,rel.name)
            G.add_node(rename_node, op='rename',names=[(i,term.name) for i,term in enumerate(rel.terms) if isinstance(term,FreeVar)])
            G.add_edge(rename_node,rel.name)
            top_rel_node = _select_if_needed(G,node_counter,rename_node,rel.terms,rel,rel_idx,rule_id)
            
            body_term_connectors.append((None,top_rel_node))

        elif isinstance(rel,IERelation):
            get_input_node_name =_name_node(node_counter,'get_input',rule_id,rel_idx,rel.name)
            calc_node_name = _name_node(node_counter,'calc_ie',rule_id,rel_idx,rel.name)
            G.add_node(get_input_node_name, op='project', on=[term.name for term in rel.in_terms if isinstance(term,FreeVar)])
            G.add_node(calc_node_name, op='calc',func=rel.name)

            product_name = _name_node(node_counter,'product_input',rule_id,rel_idx,rel.name)
            calc_son = _product_if_needed(G,node_counter,get_input_node_name,rel.in_terms,rel,rel_idx,rule_id)
            G.add_edge(calc_node_name,calc_son)
            select_name = _name_node(node_counter,'select_output',rule_id,rel_idx,rel.name)
            top_rel_node = _select_if_needed(G,node_counter,calc_node_name,rel.out_terms,rel,rel_idx,rule_id)
            body_term_connectors.append((get_input_node_name,top_rel_node))

    # connect outputs of different rels via joins
    # and connect input of ie functons into the join
    for i,(connectors,rel) in enumerate(zip(body_term_connectors,body_rels)):
        if i == 0:
            prev_top = connectors[1]
            continue

        current_top = connectors[1]

        join_node_name = _name_node(node_counter,'join',rule_id,i,rel.name)
        G.add_node(join_node_name, op='join')
        G.add_edge(join_node_name,prev_top)
        G.add_edge(join_node_name,current_top)

        if isinstance(rel,IERelation):
            ie_bottom = connectors[0]
            G.add_edge(ie_bottom,prev_top)


        prev_top = join_node_name

    # project all assignments into the head
    head_project_name = _name_node(node_counter,'project',rule_id,'head',rule.head.name)
    G.add_node(head_project_name, op='project', on=[term.name for term in rule.head.terms],rel=f'_{rule.head.name}_{rule_id}')
    G.add_edge(head_project_name,prev_top)

    # add a union for each rule for the given head
    G.add_node(rule.head.name,op='union',rel=rule.head.name)
    G.add_edge(rule.head.name,head_project_name)

    return G

In [None]:
#TODO FROM HERE add labels to nodes we can labels
# add HEAD projection node 
# maybe add free vars
g = _rule_to_term_graph(r,0)
draw(g)
serialize_graph(g)
assert serialize_graph(g) == ([('S', {'op': 'get_rel', 'rel': 'S'}),
  (0, {'op': 'rename', 'names': [(0, 'X')]}),
  (1, {'op': 'select', 'theta': [(1, Span(start=1, end=4))]}),
  ('S2', {'op': 'get_rel', 'rel': 'S2'}),
  (2, {'op': 'rename', 'names': [(0, 'X'), (1, 'A'), (2, 'B')]}),
  (3, {'op': 'project', 'on': ['X']}),
  (4, {'op': 'calc', 'func': 'T'}),
  (6, {'op': 'product', 'theta': [(1, 1)]}),
  (8, {'op': 'project', 'on': ['X', 'Y']}),
  (9, {'op': 'calc', 'func': 'T2'}),
  (12, {'op': 'join'}),
  (13, {'op': 'join'}),
  (14, {'op': 'join'}),
  (15, {'op': 'project', 'on': ['X', 'Y', 'Z'], 'rel': '_R_0'}),
  ('R', {'op': 'union', 'rel': 'R'})],
 [(0, 'S', {}),
  (1, 0, {}),
  (2, 'S2', {}),
  (3, 12, {}),
  (4, 6, {}),
  (6, 3, {}),
  (8, 13, {}),
  (9, 8, {}),
  (12, 1, {}),
  (12, 2, {}),
  (13, 12, {}),
  (13, 4, {}),
  (14, 13, {}),
  (14, 9, {}),
  (15, 14, {}),
  ('R', 15, {})])

In [None]:
r1 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),1]),
    ])

r2 = Rule(
    head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r3 = Rule(
    head=Relation(name='R2', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='S3', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
        Relation(name='S2', terms=[FreeVar(name='X'), FreeVar(name='A'),1]),
    ])
rules = [r1,r2,r3]
for r in rules:
    print(pretty(r))
t1,t2,t3 = [_rule_to_term_graph(r,i) for i,r in enumerate(rules)]

R(X,Y) <- S(X,Y),S2(X,A,1)
R(X,Y) <- S(X,Y)
R2(X,Y) <- S3(X,Y),S2(X,A,1)


In [None]:
#| export
def graph_compose(g1,g2,mapping_dict,debug=False):
    """compose two graphs with a mapping dict"""
    # if there is a node in g2 that is renamed but has a name collision with an existing node that is not renamed, we will rename the existing node to a uniq name
    # making new names into a digraph is a dirty hack, TODO resolve this
    save_new_names= nx.DiGraph()
    for u2 in g2.nodes():
        if u2 not in mapping_dict and u2 in g1.nodes():
            mapping_dict[u2] = get_new_node_name(g2,avoid_names_from=[g1,save_new_names])
            save_new_names.add_node(mapping_dict[u2])
    if debug:
        return mapping_dict
    g2 = nx.relabel_nodes(g2,mapping_dict,copy=True)
    return nx.compose(g1,g2)


In [None]:
draw(t1)
draw(t2)
draw(t3)

In [None]:
assert graph_compose(t1,t3,{
    'S':'S',0:0,1:1,
},debug=True) == {'S': 'S', 0: 0, 1: 1, 'S2': 5, 2: 6, 3: 7, 4: 8}

In [None]:
assert graph_compose(t1,t2,
    mapping_dict = {'S':'S','R':'R',0:0}
    ,debug=True) == {'S': 'S', 'R': 'R', 0: 0, 1: 5}

In [None]:
m= graph_compose(t1,t2,
    mapping_dict = {'S':'S','R':'R',0:0}) 
draw(m)

In [None]:
#| export
def merge_term_graphs_pair(g1,g2,exclude_props = ['label'],debug=False):
    """merge two term graphs into one term graph
    when talking about term graphs, 2 nodes if their data is identical and all of their children are identical
    but we would also like to merge rules for the same head, so we will also nodes that have the same 'rel' attribute
    """

    def _are_nodes_equal(g1,u1,g2,u2):
        u1_data = g1.nodes[u1]
        u2_data = g2.nodes[u2]
        
        if 'rel' in u1_data and 'rel' in u2_data:
            return u1_data['rel'] == u2_data['rel']

        u1_clean_data = {k:v for k,v in u1_data.items() if k not in exclude_props}
        u2_clean_data = {k:v for k,v in u2_data.items() if k not in exclude_props}

        are_equal = u1_clean_data == u2_clean_data and all(v2 in node_mappings for v2 in g2.successors(u2))
        return are_equal
        

    # we will check for each node in g2 if it has a node in g1 which is it's equal.
    # and save that in a mapping
    node_mappings=dict()# g2 node name to g1 node name
    # we use the fact that g2 is going to be acyclic to travers it in postorder
    for u2 in nx.dfs_postorder_nodes(g2):
        for u1 in g1.nodes():
            if _are_nodes_equal(g1,u1,g2,u2):
                node_mappings[u2] = u1
                break



    if debug:
        return node_mappings
    else:
        return graph_compose(g1,g2,node_mappings)



def merge_term_graphs(gs,exclude_props = ['label'],debug=False):
    """merge a list of term graphs into one term graph
    """
    merge = gs[0]
    for g in gs[1:-1]:
        merge = merge_term_graphs_pair(merge,g,exclude_props,debug=False)
    # if debug, we run debug only on the last merge so we can iteratively debug a list of merges
    return merge_term_graphs_pair(merge,gs[-1],exclude_props,debug=debug)


#### Tests

In [None]:
for orig in [t1,t2,t3]:
    merge_self = merge_term_graphs([orig,orig])
assert serialize_graph(merge_self) == serialize_graph(orig)

In [None]:
m = merge_term_graphs([t1,t2])
draw(m)
assert serialize_graph(m) == ([('S', {'op': 'get_rel', 'rel': 'S'}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  ('S2', {'op': 'get_rel', 'rel': 'S2'}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')]}),
  (2, {'op': 'select', 'theta': [(2, 1)]}),
  (3, {'op': 'join'}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0'}),
  ('R', {'op': 'union', 'rel': 'R'}),
  (5, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_1'})],
 [(0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  ('R', 4, {}),
  ('R', 5, {}),
  (5, 0, {})])

In [None]:
m = merge_term_graphs([t1,t3])
draw(m)
assert serialize_graph(m) == ([('S', {'op': 'get_rel', 'rel': 'S'}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  ('S2', {'op': 'get_rel', 'rel': 'S2'}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')]}),
  (2, {'op': 'select', 'theta': [(2, 1)]}),
  (3, {'op': 'join'}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0'}),
  ('R', {'op': 'union', 'rel': 'R'}),
  ('S3', {'op': 'get_rel', 'rel': 'S3'}),
  (5, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  (6, {'op': 'join'}),
  (7, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R2_2'}),
  ('R2', {'op': 'union', 'rel': 'R2'})],
 [(0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  ('R', 4, {}),
  (5, 'S3', {}),
  (6, 5, {}),
  (6, 2, {}),
  (7, 6, {}),
  ('R2', 7, {})])

In [None]:
m = merge_term_graphs([t1,t2,t3])
draw(m)
assert serialize_graph(m) == ([('S', {'op': 'get_rel', 'rel': 'S'}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  ('S2', {'op': 'get_rel', 'rel': 'S2'}),
  (1, {'op': 'rename', 'names': [(0, 'X'), (1, 'A')]}),
  (2, {'op': 'select', 'theta': [(2, 1)]}),
  (3, {'op': 'join'}),
  (4, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_0'}),
  ('R', {'op': 'union', 'rel': 'R'}),
  (5, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R_1'}),
  ('S3', {'op': 'get_rel', 'rel': 'S3'}),
  (6, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  (7, {'op': 'join'}),
  (8, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_R2_2'}),
  ('R2', {'op': 'union', 'rel': 'R2'})],
 [(0, 'S', {}),
  (1, 'S2', {}),
  (2, 1, {}),
  (3, 0, {}),
  (3, 2, {}),
  (4, 3, {}),
  ('R', 4, {}),
  ('R', 5, {}),
  (5, 0, {}),
  (6, 'S3', {}),
  (7, 6, {}),
  (7, 2, {}),
  (8, 7, {}),
  ('R2', 8, {})])

In [None]:
r1 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='B', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r2 = Rule(
    head=Relation(name='A', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='C', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r3 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='D', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])

r4 = Rule(
    head=Relation(name='B', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    body=[
        Relation(name='A', terms=[FreeVar(name='X'),FreeVar(name='Y')]),
    ])



In [None]:
rules = [r1,r2,r3,r4]
print([pretty(r) for r in rules])
m = merge_term_graphs([_rule_to_term_graph(r,i) for i,r in enumerate(rules)])
draw(m)
assert serialize_graph(m) ==([('B', {'op': 'union', 'rel': 'B'}),
  (0, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  (1, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_A_0'}),
  ('A', {'op': 'get_rel', 'rel': 'A'}),
  ('C', {'op': 'get_rel', 'rel': 'C'}),
  (2, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  (3, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_A_1'}),
  ('D', {'op': 'get_rel', 'rel': 'D'}),
  (4, {'op': 'rename', 'names': [(0, 'X'), (1, 'Y')]}),
  (5, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_B_2'}),
  (6, {'op': 'project', 'on': ['X', 'Y'], 'rel': '_B_3'})],
 [('B', 5, {}),
  ('B', 6, {}),
  (0, 'B', {}),
  (0, 'A', {}),
  (1, 0, {}),
  ('A', 1, {}),
  ('A', 3, {}),
  (2, 'C', {}),
  (3, 2, {}),
  (4, 'D', {}),
  (5, 4, {}),
  (6, 0, {})])

['A(X,Y) <- B(X,Y)', 'A(X,Y) <- C(X,Y)', 'B(X,Y) <- D(X,Y)', 'B(X,Y) <- A(X,Y)']


## Engine

In [None]:
#TODO from here, make a notebook that will make a Span extension type for pandas
df= pd.DataFrame([
])

In [None]:
class Engine():
    def __init__(self):
        self.symbol_table={
            # key : type,val
        }
        self.Relation_defs={
            # key : RelationDefinition for both real and derived relations
        }
        self.ie_functions={
            # name : IEFunction class
        }

        self.term_graph = nx.Digraph()
        
        node_counter = itertools.Count()

        self.db = {
            # relation_name: dataframe
        }

        # lets skip this for now and keep it a an attribute in the node graph
        # self.rules_to_nodes = {
        #     # rule pretty string, to node id in term_graph
        # }

        # self.rels_to_nodes() = {
        #     # relation name to node that represents it
        # }


    def set_var(var_name,value,read_from_file=False):
        symbol_table = self.symbol_table
        if var_name in symbol_table:
            existing_type,existing_value = symbol_table[var_name]
            if type(value) != existing_type:
                raise ValueError(f"Variable {var_name} was previously defined with {existing_value}({pretty(existing_type)})"
                                f" but is trying to be redefined to {value}({pretty(type(value))}) of a different type which might interfere with previous rule definitions")    
        symbol_table[var_name] = type(value),value
        return
    def get_var(var_name):
        return self.symbol_table[var_name]
    
    def del_var(var_name):
        del self.symbol_table[var_name]

    def get_relation(self,rel_name:str):
        return self.Relation_defs[rel_name]

    def set_relation(self,rel_def:RelationDefinition):
        if rel_def.name in self.Relation_defs:
            existing_def = self.Relation_defs[rel_def.name]
            if existing_def != rel_def:
                raise ValueError(f"Relation {rel_def.name} was previously defined with {existing_def}"
                                f"but is trying to be redefined to {rel_def} which might interfere with previous rule definitions")
        else:
            self.Relation_defs[rel_def.name] = rel_def
            empty_df = pd.DataFrame(columns=[pretty(s) for s in rel_def.scheme])

    def del_relation(self,rel_name:str):
        # TODO we need to think about what to do with all relations that used this rule
        raise NotImplementedError("deleting relations is not supported yet")
        pass

    def add_fact(self,fact:Relation):
        pass

    def del_fact(self,fact:Relation):
        pass

    def get_ie_function(self,name:str):
        pass

    def set_ie_function(self,ie_func:IEFunction):
        pass

    def del_ie_function(self,name:str):
        pass

    def add_rule(self,rule:Rule):
        # make term graph of the rule, make sure the nodes of relations have the same name
        # merge the term graph with the existing one via graph union

        # TODO extension, make the graph a semiring graph, so we can share common expressions
        # give the rule head the id of the rule string so we can find it for removal

        pass

    def del_rule(self,name:str):
        # find the node in the term graph that has the name of the rule
        # remove the nodes and all of it's ancestors if they are not connected to any other node
        # if the node of the head relation has no sons now, remove it as well
        pass

    def run_query(self,q:Relation,rewrites=None):
        # get the subgraph of the term graph that has the query relation as a node
        # call all rewrites on it
        # TODO add verbose and display options for optimizations

        # call the semi-naive evaluation on the subgraph
        pass
        


In [None]:
# * Resolve Vars  # lets put this in the term graph
# * Register new Relations
# * Add/remove facts
# * Add rules
# * Run Queries

## DB operations

For each RA operation, and the CalcIE operation, we have an operation

* Select
* Project
* Rename
* Union
* Intersection
* Difference
* Join
* SemiJoin
* Calc (for calcing ie relations)
* GetRel (for accessing relations from the DB)
  


## SemiNaive Operations

for every RA operation and CalcIE operation, we have a funcion that taken the children as nodes that can get the underlying data from a specific iteration and compute the differential version of the operation

* Select
* Project
* Rename
* Union
* Intersection
* Difference
* Join
* SemiJoin
* Calc (for calcing ie relations)
* GetRel (for accessing relations from the DB)
  


## Semi naive execution

A recursive least fixed point logic algorithm mimicing the seminaive bottom up evalutation.

In [None]:
# TODO agg - build tests for engine and query based on current implementation to be used as a regression test
# these tests should check the resulting computation graph and the output of the query


In [None]:
# compute bottom up
# start at root
# if

In [None]:
#compute_node

    # if node is not part of circle,
    # compute it by computing all of its children and performing the nodes operation on them.

    # if its a part of a cycle, 
        #compute current iterations


# compute_current_iteration (i)
    # take the (i-1) value of children that are in the cycle
    # take the final value of children that are not in the cycle

    # compute the node based on the children values, assign it to the values of the ith iteration of the node
    # if the value of the node didnt change from last time, mark the node as finished and return


In [None]:
def compute_node(G,u,iter=0):
    if u.final == True:
        return u.answers[-1]
    for child in G.children(u):
        vals = compute_node(G,child,iter)

    u.answers[iter] = u.op(vals)

    if u.answers[iter] == u.answers[iter-1]:
        u.final = True

    
    
