# Micro Passes
> passes over the AST of a statement to do semantic checks and register state in the session object

In [None]:
#| default_exp micro_passes

In [None]:
#| hide
from nbdev.showdoc import show_doc

%load_ext autoreload
%autoreload 2

In [None]:
#| export
import os
import pytest
from pathlib import Path
from typing import no_type_check, Set, Sequence, Any,Optional,List,Callable,Dict,Union
from pydantic import BaseModel
import networkx as nx
from deepdiff import DeepDiff
from numbers import Real
import logging
logger = logging.getLogger(__name__)

from graph_rewrite import draw, draw_match, rewrite, rewrite_iter
from spannerlib.utils import (
    checkLogs,serialize_df_values,serialize_tree,
    schema_match,is_of_schema,type_merge,schema_merge
)
from spannerlib.grammar import parse_spannerlog
from spannerlib.span import Span
from spannerlib.engine import (
    Engine,
    Var,
    FreeVar,
    RelationDefinition,
    Relation,
    IEFunction,
    AGGFunction,
    IERelation,
    Rule,
    pretty,
    )

## Scaffolding

In [None]:
class DummySession():
    def __init__(self,passes=None,execution_function=None):
        if passes is None:
            passes = []
        self.passes = passes
        self.engine=Engine()
        self.should_execute = execution_function is not None
        self.execution_function = execution_function

    def run_query(self,query):
        statements = parse_spannerlog(query,split_statements=True)
        clean_asts = []
        results = []
        for statement_nx,statement_lark in statements:
            ast = statement_nx
            for pass_ in self.passes:
                pass_(ast,self.engine)
            clean_asts.append(ast)

            if self.should_execute:
                results.append(self.execution_function(ast,self.engine))
        if self.should_execute:
            return results
        else:
            return clean_asts

    

In [None]:
def assert_asts(asts,expected=None):
    if expected is None:
        return [serialize_tree(ast) for ast in asts]
    for i,(ast,expected_ast) in enumerate(zip(asts,expected)):
        serialized = serialize_tree(ast)
        assert serialized == expected_ast,(f"AST {i} does not match expected"
            f"\nAST:\n{serialized}\nExpected:\n{expected_ast}\n"
            f"Diff is:{DeepDiff(serialized,expected_ast)}")
            
        

In [None]:

sess = DummySession()
asts = sess.run_query("""
            new string(str)
            string("a")
            string_length(Str, Len) <- string(Str), Length(Str) -> (Len)
            """)
for ast in asts:
    draw(ast)

# Semantic Checks

## convert_primitive_values_to_objects

In [None]:
bool('False')

True

In [None]:
#| export
def convert_primitive_values_to_objects(ast,session):

    # primitive values
    def cast_new_value(match):
        val_type = match['var']['type']
        value = match['val_node']['val']
        match val_type:
            case 'var_name':
                value = Var(name=value)
            case 'free_var_name':
                value = FreeVar(name=value)
            case 'string':
                # remove the quotes
                value = str(value)[1:-1]
            case 'int':
                value = int(value)
            case 'int_neg':
                value = -int(value)
            case 'float':
                value = float(value)
            case 'float_neg':
                value = -float(value)
            case 'bool':
                value = True if value == 'True' else False
            case _:
                return value
        return value
    
    primitive_node_types = [
        'string','int','int_neg',
        'float','float_neg',"bool",
        'var_name','relation_name',
        'free_var_name','agg_name'
    ]
    #TODO FROM HERE
    rewrite(ast,
        lhs='var[type]->val_node[val]',
        p='var[type]',
        rhs='var[type,val={{new_val}}]',
        condition= lambda match: match['var']['type'] in primitive_node_types,
        render_rhs={'new_val': cast_new_value},
        # display_matches=True
        )


    # schema types into class types
    decl_type_to_class = {
        'decl_string':str,
        'decl_int':int,
        'decl_float':float,
        'decl_bool':bool,
    }

    for decl_type,decl_class in decl_type_to_class.items():
        for match in rewrite_iter(ast,lhs=f'x[val="{decl_type}"]'):
            match['x']['val']=decl_class


In [None]:
sess = DummySession(passes=[
  convert_primitive_values_to_objects
  ])
asts = sess.run_query("""
x="a"
x=True
x=False
x=1.3
x=-3.5
x=1
x=-2
x=y
x= "hello \
world"
            """)
for ast in asts:
    draw(ast)



In [None]:
sess = DummySession(passes=[convert_primitive_values_to_objects])
asts = sess.run_query("""
            x=1
            S("a",1,2)
            new R(int,str)
            ?R($x,X)
            R(X,sum(Y))<-S(X,Y)
            """)
for ast in asts:
    draw(ast)

assert_asts(asts,[{'type': 'assignment',
  'id': 0,
  'children': [{'type': 'var_name', 'val': Var(name='x'), 'id': 1},
   {'type': 'int', 'val': 1, 'id': 3}]},
 {'type': 'add_fact',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'S', 'id': 1},
   {'type': 'term_list',
    'id': 3,
    'children': [{'type': 'string', 'val': 'a', 'id': 4},
     {'type': 'int', 'val': 1, 'id': 6},
     {'type': 'int', 'val': 2, 'id': 8}]}]},
 {'type': 'relation_declaration',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'R', 'id': 1},
   {'type': 'decl_term_list',
    'id': 3,
    'children': [{'val': int, 'id': 4}, {'val': str, 'id': 5}]}]},
 {'type': 'query',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'R', 'id': 1},
   {'type': 'term_list',
    'id': 3,
    'children': [{'type': 'var_name', 'val': Var(name='x'), 'id': 4},
     {'type': 'free_var_name', 'val': FreeVar(name='X'), 'id': 6}]}]},
 {'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'id': 1,
    'children': [{'type': 'relation_name', 'val': 'R', 'id': 2},
     {'type': 'term_list',
      'id': 4,
      'children': [{'type': 'free_var_name',
        'val': FreeVar(name='X'),
        'id': 5},
       {'type': 'aggregated_free_var',
        'id': 7,
        'children': [{'type': 'agg_name', 'val': 'sum', 'id': 8},
         {'type': 'free_var_name', 'val': FreeVar(name='Y'), 'id': 10}]}]}]},
   {'type': 'rule_body_relation_list',
    'id': 12,
    'children': [{'type': 'relation',
      'id': 13,
      'children': [{'type': 'relation_name', 'val': 'S', 'id': 14},
       {'type': 'term_list',
        'id': 16,
        'children': [{'type': 'free_var_name',
          'val': FreeVar(name='X'),
          'id': 17},
         {'type': 'free_var_name', 'val': FreeVar(name='Y'), 'id': 19}]}]}]}]}])


## Check reserved relation names

In [None]:
#| export
class CheckReservedRelationNames():
    def __init__(self,reserved_prefix):
        self.reserved_prefix = reserved_prefix
    def __call__(self,ast,engine):
        for match in rewrite_iter(ast,lhs='X[type="relation_name",val]'):
            relation_name = match['X']['val']
            if relation_name.startswith(self.reserved_prefix):
                raise ValueError(f"Relation name '{relation_name}' starts with reserved prefix '{self.reserved_prefix}'")

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    # remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    ])
asts = sess.run_query("""
            S("a",1)
            R(X,Y)<-S(X,Y),T(X,Y)
            """)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
            spanner_S("a",1)
            """)
print(exc_info.value)


Relation name 'spanner_S' starts with reserved prefix 'spanner_'


In [None]:
draw(asts[0])

## derefrence vars

In [None]:
#| export
def dereference_vars(ast,engine):

    # first rename all left hand sign variables 
    # as type "var_name_lhs"
    # so we can seperate them from reference variables
    for assignment_type in ["assignment","read_assignment"]:
        for match in rewrite_iter(ast,
                lhs=f"""X[type="{assignment_type}"]-[idx=0]->LHS[type="var_name",val]"""
                ):
            match['LHS']['type'] = "var_name_lhs"

    # now for each reference variable check if it is in the symbol table and replace it with the value
    for match in rewrite_iter(ast,lhs=f"""X[type="var_name",val]"""):
        var_name = match['X']['val'].name
        if not engine.get_var(var_name):
            raise ValueError(f'Variable {var_name} is not defined')
        var_type,var_value = engine.get_var(var_name)
        match['X']['type'] = var_type
        match['X']['val'] = var_value


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    # remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    dereference_vars,
    ])

sess.engine.set_var('y',1)
sess.engine.set_var('x',"hello")

asts = sess.run_query(f"""
            z=1
            x=$y
            R($x,$y)
            """)
for ast in asts:
    draw(ast)

assert_asts(asts,[{'type': 'assignment',
  'id': 0,
  'children': [{'type': 'var_name_lhs', 'val': Var(name='z'), 'id': 1},
   {'type': 'int', 'val': 1, 'id': 3}]},
 {'type': 'assignment',
  'id': 0,
  'children': [{'type': 'var_name_lhs', 'val': Var(name='x'), 'id': 1},
   {'type': int, 'val': 1, 'id': 3}]},
 {'type': 'add_fact',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'R', 'id': 1},
   {'type': 'term_list',
    'id': 3,
    'children': [{'type': str, 'val': 'hello', 'id': 4},
     {'type': int, 'val': 1, 'id': 6}]}]}])


with pytest.raises(ValueError) as exc_info:
    _ = sess.run_query(f"""
                R($x,$z)
                """)
assert 'Variable z is not defined' in str(exc_info.value)
print(exc_info.value)




Variable z is not defined


## Check read assignments got existing path

In [None]:
#| export
def check_referenced_paths_exist(ast,engine):
    for match in rewrite_iter(ast,
    lhs='X[type="read_assignment"]-[idx=1]->PathNode[val]',
    # display_matches=True
    ):
        path = Path(match['PathNode']['val'])
        print(repr(path))
        if not path.exists():
            raise ValueError(f'path {path} was not found in {os.getcwd()}')


In [None]:
# check that read assignments got a string which is an existing path
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    # remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    dereference_vars,
    check_referenced_paths_exist,
    ])

sess.engine.set_var('y','file.txt')


file = Path("file.txt")
file.touch()

# TODO figure out why this doesnt work
asts = sess.run_query(f"""
            x=read("file.txt")
            z=read(y)
            """)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
            x=read("not_existing_file.txt")
            """)
print(exc_info.value)


sess.engine.set_var('w','not_existing_file.txt')
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
            x=read(w)
            """)
print(exc_info.value)

file.unlink()

Path('file.txt')
Path('file.txt')
Path('not_existing_file.txt')
path not_existing_file.txt was not found in /Users/dean/tdk/spannerlib/nbs
Path('not_existing_file.txt')
path not_existing_file.txt was not found in /Users/dean/tdk/spannerlib/nbs


## Inline aggregation func names on free var nodes

In [None]:
#| export
def inline_aggregation(ast,engine):
    for match in rewrite_iter(ast,
        lhs='''
        agg_marker[type="aggregated_free_var"];
        agg_marker->agg_func[type="agg_name",val];
        agg_marker->agg_var[type="free_var_name",val]
        ''',
        #rhs='agg_marker[type="free_var_name",val=agg_var.val,agg=agg_func.val]',
        p='agg_marker[type]',
        # display_matches=True
        ):
        match['agg_marker']['type'] = 'free_var_name'
        match['agg_marker']['val'] = match['agg_var']['val']
        match['agg_marker']['agg'] = match['agg_func']['val']


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    ])

with checkLogs():
    asts = sess.run_query("""
    R(X,sum(Y),count(Z))<-S(X,Y,Z)
    """)
for ast in asts:
    draw(ast)

assert_asts(asts,[{'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'id': 1,
    'children': [{'type': 'relation_name', 'val': 'R', 'id': 2},
     {'type': 'term_list',
      'id': 4,
      'children': [{'type': 'free_var_name',
        'val': FreeVar(name='X'),
        'id': 5},
       {'type': 'free_var_name',
        'val': FreeVar(name='Y'),
        'agg': 'sum',
        'id': 7},
       {'type': 'free_var_name',
        'val': FreeVar(name='Z'),
        'agg': 'count',
        'id': 12}]}]},
   {'type': 'rule_body_relation_list',
    'id': 17,
    'children': [{'type': 'relation',
      'id': 18,
      'children': [{'type': 'relation_name', 'val': 'S', 'id': 19},
       {'type': 'term_list',
        'id': 21,
        'children': [{'type': 'free_var_name',
          'val': FreeVar(name='X'),
          'id': 22},
         {'type': 'free_var_name', 'val': FreeVar(name='Y'), 'id': 24},
         {'type': 'free_var_name', 'val': FreeVar(name='Z'), 'id': 26}]}]}]}]}])

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    ])

with checkLogs():
    asts = sess.run_query("""
    R(X,sum(Y),count(Z))<-S(X,Y),T(X,Y)->(Z)
    R(X,sum(Y),count(Z),"yes")<-S(X,Y,Z)
    """)
for ast in asts:
    draw(ast)



## Cast relations to dataclasses

In [None]:
#| export
def relations_to_dataclasses(ast,engine):

   # regular relations
   #TODO another example where i need to edit the graph imperatively because i dont have horizontal recursion in LHS
   for match in rewrite_iter(ast,
      lhs='''
         statement[type]->name[type="relation_name",val];
         statement->terms[type]
         ''',
         #TODO i expect to be able to put an rhs here only, and if a p is not given, assume it is the identity over nodes in LHS
         p='statement[type]',
         condition=lambda match: (match['statement']['type'] in ['add_fact','remove_fact','relation','rule_head','query']
                                   and match['terms']['type'] in ['term_list','free_var_name_list',]),
         # display_matches=True,
         ):
      term_nodes = list(ast.successors(match.mapping['terms']))
      #TODO check we iterate in order on the children
      logger.debug(f"casting relation to dataclasses - term_nodes: {term_nodes}")
      terms = [ast.nodes[term_node]['val'] for term_node in term_nodes]
      
      has_agg = any('agg' in ast.nodes[term_node] for term_node in term_nodes)
      if has_agg:
         agg_by_term = [ast.nodes[term_node].get('agg',None) for term_node in term_nodes]
      else:
         agg_by_term = None

      rel_object = Relation(name=match['name']['val'],terms=terms,agg=agg_by_term)
      if has_agg and match['statement']['type'] != 'rule_head':
            raise ValueError(f'''Aggregations are only allowed in rule heads, not in {match['statement']['type']}, found in {pretty(rel_object)}''')
      match['statement']['val'] = rel_object
      ast.remove_nodes_from(term_nodes)
   # relation declerations
   for match in rewrite_iter(ast,
      lhs='''
         statement[type="relation_declaration"]->name[type="relation_name",val];
         statement->terms[type="decl_term_list"]
         ''',
         p='statement[type]'):
      term_nodes = list(ast.successors(match.mapping['terms']))
      match['statement']['val'] = RelationDefinition(name=match['name']['val'],scheme=[ast.nodes[term_node]['val'] for term_node in term_nodes])
      ast.remove_nodes_from(term_nodes)

   # ie relations
   for match in rewrite_iter(ast,
      lhs='''
         statement[type="ie_relation"]->name[type="relation_name",val];
         statement-[idx=1]->in_terms[type="term_list"];
         statement-[idx=2]->out_terms[type="term_list"]
      ''',p='statement[type]'):
      in_term_nodes = list(ast.successors(match.mapping['in_terms']))
      out_term_nodes = list(ast.successors(match.mapping['out_terms']))



      ie_obj = IERelation(name=match['name']['val'],
                                             in_terms=[ast.nodes[term_node]['val'] for term_node in in_term_nodes],
                                             out_terms=[ast.nodes[term_node]['val'] for term_node in out_term_nodes]
                                             )
      for term_node in in_term_nodes+out_term_nodes:
         if 'agg' in ast.nodes[term_node]:
            raise ValueError(f'Aggregations are not allowed in IE relations, found in {pretty(ie_obj)}')
      match['statement']['val'] = ie_obj
      ast.remove_nodes_from(in_term_nodes+out_term_nodes)

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    relations_to_dataclasses
    ])

# with checkLogs():
asts = sess.run_query("""
# R("hello",6)
# R("hello",[4,5))<-True
# R("hello",[4,5))<-False
# new R(str,span,int,int)
# ?R("hello",Y)
R(X,Y)<-S(X,Y),T(X,Y)->(Y,Z)
R(X,sum(Y))<-S(X,Y)
R(X,sum(Y),"yes")<-S(X,Y)
""")
for ast in asts:
    draw(ast)


In [None]:

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
        R(X,Y)<-S(X,sum(Y)),T(X,Y)->(Y,Z)
            """)
print(exc_info.value)
assert "Aggregations are only allowed in rule heads, not in relation" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
        R(X,Y)<-S(X,Y),T(X,sum(Y))->(Y,Z)
            """)
print(exc_info.value)
assert "Aggregations are not allowed in IE relations, found in T(X,Y) -> (Y,Z)" in str(exc_info.value)


Aggregations are only allowed in rule heads, not in relation, found in S(X,sum(Y))
Aggregations are not allowed in IE relations, found in T(X,Y) -> (Y,Z)


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    relations_to_dataclasses
    ])
asts = sess.run_query("""
R("hello",6)
+R("hello",6)
-R("hello",6)
new R(str,str,int,int)
?R("hello",Y)
R(X,Y)<-S(X,Y),T(X,Y)->(Y,Z)
R(X,sum(Y))<-S(X,Y)
R(X,sum(Y),"yes")<-S(X,Y)

""")
for ast in asts:
    draw(ast)

assert_asts(asts,[{'type': 'add_fact',
  'val': Relation(name='R', terms=['hello', 6], agg=None),
  'id': 0,
  'children': []},
 {'type': 'add_fact',
  'val': Relation(name='R', terms=['hello', 6], agg=None),
  'id': 0,
  'children': []},
 {'type': 'remove_fact',
  'val': Relation(name='R', terms=['hello', 6], agg=None),
  'id': 0,
  'children': []},
 {'type': 'relation_declaration',
  'val': RelationDefinition(name='R', scheme=[str,str,int,int]),
  'id': 0,
  'children': []},
 {'type': 'query',
  'val': Relation(name='R', terms=['hello', FreeVar(name='Y')], agg=None),
  'id': 0,
  'children': []},
 {'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'val': Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')], agg=None),
    'id': 1},
   {'type': 'rule_body_relation_list',
    'id': 9,
    'children': [{'type': 'relation',
      'val': Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')], agg=None),
      'id': 10},
     {'type': 'ie_relation',
      'val': IERelation(name='T', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]),
      'id': 18}]}]},
 {'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'val': Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')], agg=[None, 'sum']),
    'id': 1},
   {'type': 'rule_body_relation_list',
    'id': 12,
    'children': [{'type': 'relation',
      'val': Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')], agg=None),
      'id': 13}]}]},
 {'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'val': Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), 'yes'], agg=[None, 'sum', None]),
    'id': 1},
   {'type': 'rule_body_relation_list',
    'id': 14,
    'children': [{'type': 'relation',
      'val': Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')], agg=None),
      'id': 15}]}]}]
    )



## verify referenced relations and functions

* check that referenced relations and ie relations:
  * exist in the symbol table 
  * are called with the correct arity
  * are called with correct constants or vars types

In [None]:
#| export
def verify_referenced_relations_and_functions(ast,engine):

    def resolve_var_types(terms):
        return [engine.get_var(term.name)[0] if isinstance(term,Var) else term for term in terms]

    # check no free vars in adding or removing facts
    for match in rewrite_iter(ast,
            lhs='''rel[type]''',
            condition=lambda match: match['rel']['type'] in ['add_fact','remove_fact'],
            ):
        if any(isinstance(term,FreeVar) for term in match['rel']['val'].terms):
            raise ValueError(f"Adding or removing facts cannot have free variables, found in {pretty(match['rel']['val'])}")

    # regular relations
    for match in rewrite_iter(ast,
            lhs='''rel[type]''',
            condition=lambda match: match['rel']['type'] in ['add_fact','remove_fact','relation','query'],
            ):
        rel:Relation = match['rel']['val']
        if not engine.get_relation(rel.name):
            raise ValueError(f"Relation '{rel.name}' is not defined")
        expected_schema = engine.get_relation(rel.name).scheme
        if not is_of_schema(rel.terms,expected_schema,ignore_types=[FreeVar]):
            raise ValueError(f"Relation '{rel.name}' expected schema {pretty(engine.get_relation(rel.name))} but got called with {pretty(rel)}")

    # ie relations
    for match in rewrite_iter(ast,
            lhs='''rel[type="ie_relation"]''',
            ):
        rel:IERelation = match['rel']['val']
        if not engine.get_ie_function(rel.name):
            raise ValueError(f"ie function '{rel.name}' was not registered, registered functions are {list(engine.ie_functions.keys())}")
        in_schema = engine.get_ie_function(rel.name).in_schema
        if callable(in_schema):
            in_schema = in_schema(len(rel.in_terms))
        out_schema = engine.get_ie_function(rel.name).out_schema
        if callable(out_schema):
            out_schema = out_schema(len(rel.out_terms))
        if not is_of_schema(rel.in_terms,in_schema,ignore_types=[FreeVar]):
            raise ValueError(f"IERelation '{rel.name}' input expected schema {pretty(in_schema)} but got called with {pretty(rel.in_terms)}")
        if not is_of_schema(rel.out_terms,out_schema,ignore_types=[FreeVar]):
            raise ValueError(f"IERelation '{rel.name}' output expected schema {pretty(out_schema)} but got called with {pretty(rel.out_terms)}")
      
    # aggregation functions
    for match in rewrite_iter(ast,
        lhs='''rel[type="rule_head"]''',
        # display_matches=True
        ):
        rel = match['rel']['val']
        if rel.agg is None:
            continue
        agg_funcs = [func for func in rel.agg if func is not None]
        for agg_func in agg_funcs:
            if not engine.get_agg_function(agg_func):
                raise ValueError(f"agg function '{agg_func}' was not registered, registered functions are {list(engine.agg_functions.keys())}")


In [None]:
sess.engine.get_agg_function('sums')

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    relations_to_dataclasses,
    verify_referenced_relations_and_functions
    ])


def string_schema(n):
    return [str]*n

sess.engine.set_var('a',1)
sess.engine.set_var('b','hello')
sess.engine.set_relation(RelationDefinition(name='R',scheme=[str,int]))
sess.engine.set_relation(RelationDefinition(name='F',scheme=[float,Real]))
sess.engine.set_relation(RelationDefinition(name='S',scheme=[str,int,int]))
sess.engine.add_fact(Relation(name='R',terms=[Span("hello"),5]))
sess.engine.set_ie_function(IEFunction(name='T',in_schema=[str,Real],out_schema=[Real,str],func=lambda x,y:(y,x)))
sess.engine.set_ie_function(IEFunction(name='T2',in_schema=string_schema,out_schema=string_schema,func=lambda x,y:(y,x)))
sess.engine.set_agg_function(AGGFunction(name='sum',func='sum',in_schema=[int],out_schema=[int]))
sess.engine.set_agg_function(AGGFunction(name='count',func='count',in_schema=[object],out_schema=[int]))


asts = sess.run_query("""
R("hello",6)
R("hello",$a)
?R("hello",Y)
F(1.3,5)
F(1.3,5.5)
NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Z)
NewRel2(X,Y)<-S(X,Y,3),T2(X,Y)->(Y,Z)
NewRel3(X,sum(Y))<-R(X,Y)
NewRel4(X,count(Y))<-R(X,Y)
NewRel5(count(X),Y)<-R(X,Y)
""")

# for ast in asts:
#     draw(ast)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                x=3
                Z("hello",x)
                """)
print(exc_info.value)
assert "Adding or removing facts cannot have free variables" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                Z("hello",4)
                """)
print(exc_info.value)
assert "Relation 'Z' is not defined" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R("hello",4,3)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(hello,4,3)" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R(4,4)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(4,4)" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R("hello",$b)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(hello,hello)" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                ?R(4,Y)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(4,Y)" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                NewRel(X,Y)<-S(X,Y,3),T(X,Y,X)->(Y,Z)
                """)
print(exc_info.value)
assert "IERelation 'T' input expected schema" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,2)
                """)
print(exc_info.value)
assert "IERelation 'T' output expected schema" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                NewRel(X,Y)<-S(X,Y,3),T3(X,Y)->(Y,X)
                """)
print(exc_info.value)
assert "ie function 'T3' was not registered" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
            NewRel3(X,sums(Y))<-R(X,Y)
            """)
print(exc_info.value)
assert "agg function 'sums' was not registered" in str(exc_info.value)


# assert  [serialize_tree(ast) for ast in asts] 
# [serialize_tree(ast) for ast in asts]

Adding or removing facts cannot have free variables, found in Z(hello,x)
Relation 'Z' is not defined
Relation 'R' expected schema R(str,int) but got called with R(hello,4,3)
Relation 'R' expected schema R(str,int) but got called with R(4,4)
Relation 'R' expected schema R(str,int) but got called with R(hello,hello)
Relation 'R' expected schema R(str,int) but got called with R(4,Y)
IERelation 'T' input expected schema [str,Real] but got called with [X,Y,X]
IERelation 'T' output expected schema [Real,str] but got called with [Y,2]
ie function 'T3' was not registered, registered functions are ['T', 'T2']
agg function 'sums' was not registered, registered functions are ['sum', 'count']


## cast rules to data classes

In [None]:
#| export
def rules_to_dataclasses(ast,engine):
   for match in rewrite_iter(ast,
      lhs='''
         statement[type="rule"]->head[type="rule_head",val];
         statement->body[type="rule_body_relation_list"]
      ''',p='statement[type]'):
      body_nodes = list(ast.successors(match.mapping['body']))
      head = match['head']['val']
      match['statement']['val'] = Rule(head=match['head']['val'],body=[ast.nodes[body_node]['val'] for body_node in body_nodes])
      ast.remove_nodes_from(body_nodes)
   return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    relations_to_dataclasses,
    # verify_referenced_relations_and_functions,
    rules_to_dataclasses
    ])

asts = sess.run_query("""
R(X,Y,Z)<-S(X,Y),T(X,Y)
R(X,Y,Z)<-S(X,Y),T(X,Y)->(Y,Z)
""")
for ast in asts:
    draw(ast)

assert  [serialize_tree(ast) for ast in asts] == [{'type': 'rule',
  'val': Rule(head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]),
               body=[Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')]), 
                     Relation(name='T', terms=[FreeVar(name='X'), FreeVar(name='Y')])]),
  'id': 0,
  'children': []},
 {'type': 'rule',
  'val': Rule(head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]), 
              body=[Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')]), 
                    IERelation(name='T', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')])]),
  'id': 0,
  'children': []}]

## Rule safety

In [None]:
#| export
def is_rule_safe(rule:Rule):
    """
    Checks that the Spannerlog Rule is safe
    ---
    In spannerlog, rule safety is a semantic property that ensures that IE relation's inputs are limited 
    in the values they can be assigned to by other relations in the rule body.
    This could include outputs of other IE relations.

    We call a free variable in a rule body "bound" if it exists in the output of any safe relation in the rule body.
    For normal relations, they only have output terms, so all their free variables are considered bound.

    We call a relation in a rule's body safe if all its input free variables are bound.
    For normal relations, they don't have input relations, so they are always considered safe.

    We call a rule safe if all of its body relations are safe.

    This basically means that we need to make sure there is at least one order of IE relation evaluation, in which
    each IE relation input variables is bound by the normal relations and the output relation of the previous IE relations.

    Examples:
    * `rel2(X,Y) <- rel1(X,Z), ie1(X)->(Y)` is a safe rule as the only input free variable, `X`, exists in the output of the safe relation `rel1(X, Z)`.  
    * `rel2(Y) <- ie1(Z)->(Y)` is not safe as the input free variable `Z` does not exist in the output of any safe relation.
    * `rel2(Z,W) <- rel1(X,Y),ie1(Z,Y)->(W),ie2(W,Y)->Z` is not safe as both ie functions require each other's output as input, creating a circular dependency.
    ---
    """

    # get all free vars in regular relations
    normal_relations_free_vars = set()
    for body in rule.body:
        if isinstance(body,Relation):
            for term in body.terms:
                if isinstance(term,FreeVar):
                    normal_relations_free_vars.add(term.name)
    
    # get list of form [(ie_rel,{input_vars},{output_vars})]
    free_vars_per_ie_relation = {}
    for body in rule.body:
        if isinstance(body,IERelation):
            input_vars = set(term.name for term in body.in_terms if isinstance(term,FreeVar))
            output_vars = set(term.name for term in body.out_terms if isinstance(term,FreeVar))
            free_vars_per_ie_relation[body]=(input_vars,output_vars)
        
    # iteratively go over all previously unsafe ie relations and check if they are now safe

    safe_vars = normal_relations_free_vars.copy()
    safe_ie_relations = set()

    while True:
        if len(free_vars_per_ie_relation)==0:
            break
        
        safe_ie_relations_in_this_iteration = set()
        for ie_relation,(input_vars,output_vars) in free_vars_per_ie_relation.items():
            if input_vars.issubset(safe_vars):
                safe_ie_relations_in_this_iteration.add(ie_relation)
        
        if len(safe_ie_relations_in_this_iteration)== 0 :
            raise ValueError(f"Rule \'{pretty(rule)}\' is not safe:\n"
                            f"the following free vars where bound by normal relations: {normal_relations_free_vars}\n"
                            f"the following ie relations where safe: {safe_ie_relations}\n"
                            f"leading to the following free vars being bound: {safe_vars}\n"
                            f"However the following ie relations could not be bound: {[pretty(ie) for ie in free_vars_per_ie_relation.keys()]}\n"
                             )
    
        for ie_relation in safe_ie_relations_in_this_iteration:
            input_vars,output_vars = free_vars_per_ie_relation[ie_relation]
            safe_vars.update(output_vars)
            safe_ie_relations.add(ie_relation)
            del free_vars_per_ie_relation[ie_relation]



    return True

In [None]:
#| export
def check_rule_safety(ast,engine):
    for match in rewrite_iter(ast,lhs='X[type="rule",val]'):
        rule = match['X']['val']
        is_rule_safe(rule)
    return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    relations_to_dataclasses,
    rules_to_dataclasses,
    check_rule_safety,
    ])


# safe rules
asts = sess.run_query("""
S(X,Y)<-R(X,Y)    
rel2(X,Y) <- rel1(X,Z), ie1(X)->(Y)
rel2(X,Y) <- rel1(X,Z), ie1(X)->(Y), ie2(Y)->(Z)
rel2(X,Y) <- rel1(X), ie1(X)->(Y), ie2(Y)->(Z)
""")
t = serialize_tree(asts[0])

# change types of R to make it illegal
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    rel2(Y) <- ie1(Z)->(Y)
    """)
print(exc_info.value)
assert "is not safe" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    rel2(Z,W) <- rel1(X,Y),ie1(Z,Y)->(W),ie2(W,Y)->(Z)    
    """)
print(exc_info.value)
assert "is not safe" in str(exc_info.value)





Rule 'rel2(Y) <- ie1(Z) -> (Y)' is not safe:
the following free vars where bound by normal relations: set()
the following ie relations where safe: set()
leading to the following free vars being bound: set()
However the following ie relations could not be bound: ['ie1(Z) -> (Y)']

Rule 'rel2(Z,W) <- rel1(X,Y),ie1(Z,Y) -> (W),ie2(W,Y) -> (Z)' is not safe:
the following free vars where bound by normal relations: {'X', 'Y'}
the following ie relations where safe: set()
leading to the following free vars being bound: {'X', 'Y'}
However the following ie relations could not be bound: ['ie1(Z,Y) -> (W)', 'ie2(W,Y) -> (Z)']



## Consistent Free Var types

In [None]:
#TODO from here, add tests with complex data types and sub types using the utils
from spannerlib.term_graph import get_bounding_order
get_bounding_order?

[0;31mSignature:[0m [0mget_bounding_order[0m[0;34m([0m[0mrule[0m[0;34m:[0m [0mspannerlib[0m[0;34m.[0m[0mdata_types[0m[0;34m.[0m[0mRule[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Get an order of evaluation for the body of a rule
this is a very naive ordering that can be heavily optimized
[0;31mFile:[0m      ~/tdk/spannerlib/spannerlib/term_graph.py
[0;31mType:[0m      function

In [None]:
#| export

from spannerlib.term_graph import get_bounding_order

def _check_rule_consistency(rule,engine):
    # for each free var we encounter, what is the type is is according to the relation schema
    free_var_to_type = {}
    # what is the first relation we found each var in, useful for error messages
    first_rel_to_define_free_var = {}

    
    def verify_freevar_type(free_var,col_type):
        if free_var.name in free_var_to_type:
            try:
                new_type = type_merge(free_var_to_type[free_var.name],col_type)
                free_var_to_type[free_var.name] = new_type
            except:
                raise ValueError(f"FreeVar {free_var.name} is used with type {pretty(col_type)}\n"
                        f"but was previously defined with type {pretty(free_var_to_type[free_var.name])}\n"
                        f"in relation {first_rel_to_define_free_var[free_var.name]}")
        else:
            free_var_to_type[free_var.name] = col_type
            first_rel_to_define_free_var[free_var.name] = relation.name



    def verify_relation_types(rel_type,terms,expected_schema):
        logger.debug(f"verifying relation types for {rel_type} {pretty(relation)} expected_schema={pretty(expected_schema)}")
        if callable(expected_schema):
            expected_schema = expected_schema(len(terms))
        # for each term in the relation that is a free var
        for term,expected_type in zip(terms,expected_schema):
            if isinstance(term,FreeVar):
                try:
                    logger.debug(f"verifying free var type for {term} with type {expected_type}")
                    verify_freevar_type(term,expected_type)
                except ValueError as e:
                    raise ValueError(f"In rule {pretty(rule)}\nin {rel_type} {relation.name}\n{e}")
                

    for relation in get_bounding_order(rule):
        if isinstance(relation,Relation):
            verify_relation_types('relation',relation.terms,engine.get_relation(relation.name).scheme)
        elif isinstance(relation,IERelation):
            verify_relation_types('ie input relation',relation.in_terms,engine.get_ie_function(relation.name).in_schema)
            verify_relation_types('ie output relation',relation.out_terms,engine.get_ie_function(relation.name).out_schema)

    logger.debug(f"after deriving types from body clauses:free_var_to_type={free_var_to_type}")
           
    # for rule head, make sure all free vars are defined in the body
    # and if the rule head was used in another rule, make sure it has the same types
    head_name, head_terms = rule.head.name, rule.head.terms
    head_agg = rule.head.agg
    for term in head_terms:
        if isinstance(term,FreeVar) and not term.name in free_var_to_type:
            raise ValueError(f"In rule {pretty(rule)}\nFreeVar {term.name}\n"
                f"is used in the head but was not defined in the body")
    

    # if no aggregations, the head schema is the same as the free var types
    if head_agg is None:
        head_scheme = [free_var_to_type[term.name] if isinstance(term,FreeVar) else type(term) for term in head_terms]
    
    # if we do have aggregation, than we need to change the type of the free var to the type of the aggregation's output
    else:
        head_scheme = []
        for term,agg_name in zip(head_terms,head_agg):
            if not isinstance(term,FreeVar):
                head_scheme.append(type(term))
                continue

            if agg_name is not None:
                agg_func = engine.get_agg_function(agg_name)
                in_schema = agg_func.in_schema
                out_schema = agg_func.out_schema
                if not schema_match([free_var_to_type[term.name]],in_schema):
                    raise ValueError(f"In rule {pretty(rule)}\n"
                        f"in head clause {head_name}\n"
                        f"FreeVar {term.name} is aggregated with {agg_name}\n"
                        f"which expects input type {pretty(in_schema[0])}\n"
                        f"but got {pretty(free_var_to_type[term.name])}")
                head_scheme.append(out_schema[0])
            else:
                head_scheme.append(free_var_to_type[term.name])

    current_head_schema = RelationDefinition(name=head_name,scheme=head_scheme)

    if engine.get_relation(head_name):
        expected_head_schema = engine.get_relation(head_name)
        if expected_head_schema != current_head_schema:
            raise ValueError(f"In rule {pretty(rule)}\n"
                f"expected schema {pretty(expected_head_schema)}\n"
                f"from a previously defined rule to {head_name}\n"
                f"but got {pretty(current_head_schema)}")
    else:
        engine.set_relation(current_head_schema)


In [None]:
#| export
def consistent_free_var_types_in_rule(ast,engine):
    for match in rewrite_iter(ast,lhs='X[type="rule",val]'):
        rule = match['X']['val']
        _check_rule_consistency(rule,engine)
    return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    relations_to_dataclasses,
    verify_referenced_relations_and_functions,
    rules_to_dataclasses,
    check_rule_safety,
    consistent_free_var_types_in_rule,
    ])


sess.engine.set_var('a',1)
sess.engine.set_var('b','hello')
sess.engine.set_relation(RelationDefinition(name='R',scheme=[str,int]))
sess.engine.set_relation(RelationDefinition(name='S',scheme=[str,int,int]))
sess.engine.set_relation(RelationDefinition(name='S2',scheme=[str,float,float]))
sess.engine.set_relation(RelationDefinition(name='NewRel',scheme=[str,int]))
sess.engine.set_ie_function(IEFunction(name='T',in_schema=[str,Real],out_schema=[Real,str],func=lambda x,y:(y,x)))
sess.engine.set_agg_function(AGGFunction(name='sum',func='sum',in_schema=[int],out_schema=[int]))
sess.engine.set_agg_function(AGGFunction(name='count',func='count',in_schema=[object],out_schema=[int]))

# legal query
# with checkLogs():
asts = sess.run_query("""
NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
NewRel(X,sum(Y))<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
NewRel2(X,sum(Y),"yes")<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
S3(X,Y,'yes')<-R(X,Y)    
""")
t = serialize_tree(asts[0])

assert sess.engine.get_relation('NewRel').scheme == [str,int]
assert sess.engine.get_relation('NewRel2').scheme == [str,int,str]
assert sess.engine.get_relation('S3').scheme == [str,int,str]

# change types of R to make it illegal
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Y,Z)
    """)
exc_msg = str(exc_info.value).replace('\n',' ')
print(exc_msg)
assert "FreeVar Y is used with type str but was previously defined with type int in relation S" in exc_msg

# change types of R to make it illegal
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Y)
    """)
exc_msg = str(exc_info.value).replace('\n',' ')
print(exc_msg)
assert "Y is used with type str but was previously defined with type int in relation S" in exc_msg

# free var in head, bound by body
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(X,Y,W)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
    """)
exc_msg = str(exc_info.value).replace('\n',' ')
print(exc_msg)
assert "FreeVar W is used in the head but was not defined in the body" in exc_msg

# head free var type mismatch
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(Y,X)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
    """)
exc_msg = str(exc_info.value).replace('\n',' ')
print(exc_msg)
assert "expected schema NewRel(str,int) from a previously defined rule to NewRel but got NewRel(int,str)" in exc_msg
t

# aggregation got wrong input type
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(sum(X),Y)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
    """)
exc_msg = str(exc_info.value).replace('\n',' ')
print(exc_msg)
assert "FreeVar X is aggregated with sum which expects input type int but got str" in exc_msg

# aggregation caused conflict with previous type
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(count(X),Y)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
    """)
exc_msg = str(exc_info.value).replace('\n',' ')
print(exc_msg)
assert "expected schema NewRel(str,int) from a previously defined rule to NewRel but got NewRel(int,int)" in exc_msg

In rule NewRel(X,Y) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Y,Z) in relation R FreeVar Y is used with type str but was previously defined with type int in relation S
In rule NewRel(X,Y) <- S(X,Y,3),T(X,Y) -> (Y,Y) in ie output relation T FreeVar Y is used with type str but was previously defined with type int in relation S
In rule NewRel(X,Y,W) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Z,Y) FreeVar W is used in the head but was not defined in the body
In rule NewRel(Y,X) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Z,Y) expected schema NewRel(str,int) from a previously defined rule to NewRel but got NewRel(int,str)
In rule NewRel(sum(X),Y) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Z,Y) in head clause NewRel FreeVar X is aggregated with sum which expects input type int but got str
In rule NewRel(count(X),Y) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Z,Y) expected schema NewRel(str,int) from a previously defined rule to NewRel but got NewRel(int,int)


## Preprocess assignments

In [None]:
#| export
def assignments_to_name_val_tuple(ast,engine):
    for match in rewrite_iter(ast,lhs='''
                                statement[type]-[idx=0]->var_name_node[val];
                                statement-[idx=1]->val_node[val]''',p='statement[type]',
                                condition=lambda match: match['statement']['type'] in ['assignment','read_assignment'],
                                # display_matches=True
                                ):
        match['statement']['val'] = (
            match['var_name_node']['val'].name,
            match['val_node']['val']
        )
    return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    # check_referenced_paths_exist,
    # dereference_vars,
    inline_aggregation,
    relations_to_dataclasses,
    # verify_referenced_relations_and_functions,
    rules_to_dataclasses,
    check_rule_safety,
    # consistent_free_var_types_in_rule,
    assignments_to_name_val_tuple
    ])


file = Path("file.txt")
file.write_text("hello file")

# safe rules
asts = sess.run_query("""
x=3
y="hello"
z="hello \
world"
f=read("file.txt")
""")

for ast in asts:
    draw(ast)
assert [serialize_tree(ast) for ast in asts] == [{'type': 'assignment', 'val': ('x', 3), 'id': 0, 'children': []},
 {'type': 'assignment', 'val': ('y', 'hello'), 'id': 0, 'children': []},
 {'type': 'assignment', 'val': ('z', 'hello world'), 'id': 0, 'children': []},
 {'type': 'read_assignment',
  'val': ('f', 'file.txt'),
  'id': 0,
  'children': []}]

# Execution passes

In [None]:
#| export
def statement_type_and_value(ast):
    statement_node = list(ast.nodes)[0]
    node_data = ast.nodes[statement_node]
    statement = node_data['type']
    value = node_data['val']
    return statement,value


def execute_statement(ast,engine):
    statement,value = statement_type_and_value(ast)
    match statement:
        case 'assignment':
            engine.set_var(*value)
        case 'read_assignment':
            engine.set_var(*value,read_from_file=True)
        case 'add_fact':
            engine.add_fact(value)
        case 'remove_fact':
            engine.del_fact(value)
        case 'relation_declaration':
            engine.set_relation(value)
        case 'rule':
            engine.add_rule(value)
        case 'query':
            return engine.run_query(value)
        case _:
            raise ValueError(f"Unknown statement type {statement}")
    return None

    

In [None]:
# sess.engine.set_relation(RelationDefinition(name='R',scheme=[str,int]))
sess.engine.Relation_defs

{}

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    dereference_vars,
    inline_aggregation,
    relations_to_dataclasses,
    verify_referenced_relations_and_functions,
    rules_to_dataclasses,
    consistent_free_var_types_in_rule,
    check_rule_safety,
    assignments_to_name_val_tuple,
    ],
    execution_function=execute_statement
    )

sess.symbol_table ={}

file = Path("file.txt")
file.write_text("hello file")

# safe rules
results = sess.run_query("""
x=3
y=read("file.txt")
new R(str,int)
R("hello",4)
?R("hello",x)
S(X,Y)<-R(X,Y)    
?R(X,Y)
-R("hello",4)
?S(X,Y)
""")
file.unlink()

assert serialize_df_values(results[-3])=={('hello',4)}
assert serialize_df_values(results[-1])==set()

results = sess.run_query("""
S2(X,Y,'yes')<-R(X,Y)    
?S2(X,Y,A)
""")
results

Path('file.txt')


[None,
 Empty DataFrame
 Columns: [X, Y, A]
 Index: []]

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    # dereference_vars,
    inline_aggregation,
    relations_to_dataclasses,
    # verify_referenced_relations_and_functions,
    rules_to_dataclasses,
    # consistent_free_var_types_in_rule,
    check_rule_safety,
    assignments_to_name_val_tuple,
    ],
    # execution_function=execute_statement
    )

sess.symbol_table ={}

file = Path("file.txt")
file.write_text("hello file")

# safe rules
asts = sess.run_query("""
x=3
y=read("file.txt")
new R(str,int)
R("hello",4)
?R("hello",x)
S(X,Y)<-R(X,Y)    
S2(X,Y,"yes")<-R(X,Y)
?R(X,Y)
-R("hello",4)
?S(X,Y)
?S2(X,Y,A)
""")
file.unlink()

for ast in asts:
    draw(ast)


Path('file.txt')


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     