# Micro Passes
> passes over the AST of a statement to do semantic checks and register state in the session object

In [None]:
#| default_exp micro_passes

In [None]:
#| hide
from nbdev.showdoc import show_doc

%load_ext autoreload
%autoreload 2

In [None]:
#| export
from abc import ABC, abstractmethod
import pytest

# from lark import Transformer, Token
# from lark import Tree as LarkNode
# from lark.visitors import Interpreter, Visitor_Recursive, Visitor
from pathlib import Path
from typing import no_type_check, Set, Sequence, Any,Optional,List,Callable,Dict,Union
from pydantic import BaseModel
import networkx as nx

# from spannerlib.ast_node_types import (Assignment, ReadAssignment, AddFact, RemoveFact, Query, Rule, IERelation, RelationDeclaration, Relation)
from spannerlib.primitive_types import Span, DataTypes, DataTypeMapping
# from spannerlib.engine import RESERVED_RELATION_PREFIX
# # from spannerlib.graphs import NetxStateGraph
# from spannerlib.symbol_table import SymbolTableBase
# from spannerlib.general_utils import (get_free_var_names, get_output_free_var_names, get_input_free_var_names, fixed_point, check_properly_typed_relation, type_check_rule_free_vars)
# from spannerlib.passes_utils import assert_expected_node_structure, unravel_lark_node,ParseNodeType

import logging
logger = logging.getLogger(__name__)

from graph_rewrite import draw, draw_match, rewrite, rewrite_iter
from spannerlib.utils import checkLogs,serialize_df_values,serialize_tree
from spannerlib.grammar import parse_spannerlog
from spannerlib.span import Span
from spannerlib.engine import (
    Engine,
    Var,
    FreeVar,
    RelationDefinition,
    Relation,
    IEFunction,
    IERelation,
    Rule,
    pretty,
    )

## Scaffolding

In [None]:
class DummySession():
    def __init__(self,passes=None,execution_function=None):
        if passes is None:
            passes = []
        self.passes = passes
        self.engine=Engine()
        self.should_execute = execution_function is not None
        self.execution_function = execution_function

    def run_query(self,query):
        statements = parse_spannerlog(query,split_statements=True)
        clean_asts = []
        results = []
        for statement in statements:
            ast = statement
            for pass_ in self.passes:
                pass_(ast,self.engine)
            clean_asts.append(ast)

            if self.should_execute:
                results.append(self.execution_function(ast,self.engine))
        if self.should_execute:
            return results
        else:
            return clean_asts

    

In [None]:

sess = DummySession()
asts = sess.run_query("""
            new string(str)
            string("a")
            string_length(Str, Len) <- string(Str), Length(Str) -> (Len)
            """)
for ast in asts:
    draw(ast)

# Semantic Checks

## convert_primitive_values_to_objects

In [None]:
def convert_primitive_values_to_objects(ast,session):

    # primitive values
    def cast_new_value(match):
        val_type = match['var']['type']
        value = match['val_node']['val']
        if val_type == 'integer':
            value = int(value)
        elif val_type == 'var_name':
            value = Var(name=value)
        elif val_type == 'free_var_name':
            value = FreeVar(name=value)
        else:#str
            value = str(value)
        return value
    
    rewrite(ast,
        lhs='var[type]->val_node[val]',
        p='var[type]',
        rhs='var[type,val={{new_val}}]',
        condition= lambda match: match['var']['type'] in ['string','integer','var_name','relation_name','free_var_name'],
        render_rhs={'new_val': cast_new_value},
        # display_matches=True
        )

    # span object from 2 integers
    for match in rewrite_iter(ast,
        lhs='u[type="span"]-[idx=0]->v;u-[idx=1]->w',
        p='u[type]'):
        match['u']['val']=Span(match['v']['val'],match['w']['val'])

    # schema types into class types
    decl_type_to_class = {
        'decl_string':str,
        'decl_int':int,
        'decl_span':Span,
    }

    for decl_type,decl_class in decl_type_to_class.items():
        for match in rewrite_iter(ast,lhs=f'x[val="{decl_type}"]'):
            match['x']['val']=decl_class


In [None]:
sess = DummySession(passes=[
  # convert_primitive_values_to_objects
  ])
asts = sess.run_query("""
            x="a"
            """)
for ast in asts:
    draw(ast)



In [None]:
sess = DummySession(passes=[convert_primitive_values_to_objects])
asts = sess.run_query("""
            x=1
            S("a",1,[4,5))
            new R(int,str)
            ?R(x,X)
            R(X,Y)<-S(X,Y)
            """)
for ast in asts:
    draw(ast)

assert ([serialize_tree(ast) for ast in asts] ==  [{'type': 'assignment',
  'id': 0,
  'children': [{'type': 'var_name', 'val': Var(name='x'), 'id': 1},
   {'type': 'integer', 'val': 1, 'id': 3}]},
 {'type': 'add_fact',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'S', 'id': 1},
   {'type': 'const_term_list',
    'id': 3,
    'children': [{'type': 'string', 'val': '"a"', 'id': 4},
     {'type': 'integer', 'val': 1, 'id': 6},
     {'type': 'span', 'val': Span(start=4, end=5), 'id': 8}]}]},
 {'type': 'relation_declaration',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'R', 'id': 1},
   {'type': 'decl_term_list',
    'id': 3,
    'children': [{'val': int, 'id': 4}, {'val': str, 'id': 5}]}]},
 {'type': 'query',
  'id': 0,
  'children': [{'type': 'relation_name', 'val': 'R', 'id': 1},
   {'type': 'term_list',
    'id': 3,
    'children': [{'type': 'var_name', 'val': Var(name='x'), 'id': 4},
     {'type': 'free_var_name', 'val': FreeVar(name='X'), 'id': 6}]}]},
 {'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'id': 1,
    'children': [{'type': 'relation_name', 'val': 'R', 'id': 2},
     {'type': 'free_var_name_list',
      'id': 4,
      'children': [{'type': 'free_var_name',
        'val': FreeVar(name='X'),
        'id': 5},
       {'type': 'free_var_name', 'val': FreeVar(name='Y'), 'id': 7}]}]},
   {'type': 'rule_body_relation_list',
    'id': 9,
    'children': [{'type': 'relation',
      'id': 10,
      'children': [{'type': 'relation_name', 'val': 'S', 'id': 11},
       {'type': 'term_list',
        'id': 13,
        'children': [{'type': 'free_var_name',
          'val': FreeVar(name='X'),
          'id': 14},
         {'type': 'free_var_name', 'val': FreeVar(name='Y'), 'id': 16}]}]}]}]}] )

## Remove newlines from strings

In [None]:
#| export
def remove_new_lines_from_strings(ast,engine):
    for match in rewrite_iter(ast,
        lhs='v[type="string",val]'):
        # TODO we also remove the starting and ending quotes, TODO make them disapear in the parsing stage
        match['v']['val'] = match['v']['val'].replace('\\\n','')[1:-1]


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    ])
asts = sess.run_query("""
x="hello \
world"
""")
for ast in asts:
    draw(ast)

ast = asts[0]
assert serialize_tree(ast) == {'type': 'assignment',
 'id': 0,
 'children': [{'type': 'var_name', 'val': Var(name='x'), 'id': 1},
  {'type': 'string', 'val': 'hello world', 'id': 3}]}

## Check reserved relation names

In [None]:
class CheckReservedRelationNames():
    def __init__(self,reserved_prefix):
        self.reserved_prefix = reserved_prefix
    def __call__(self,ast,engine):
        for match in rewrite_iter(ast,lhs='X[type="relation_name",val]'):
            relation_name = match['X']['val']
            if relation_name.startswith(self.reserved_prefix):
                raise ValueError(f"Relation name '{relation_name}' starts with reserved prefix '{self.reserved_prefix}'")

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    ])
asts = sess.run_query("""
            S("a",1)
            R(X,Y)<-S(X,Y),T(X,Y)
            """)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
            spanner_S("a",1)
            """)
print(exc_info.value)


Relation name 'spanner_S' starts with reserved prefix 'spanner_'


In [None]:
draw(asts[0])

## Check read assignments got existing path

In [None]:
import os
def check_referenced_paths_exist(ast,engine):
    for match in rewrite_iter(ast,
    lhs='X[type="read_assignment"]-[idx=1]->PathNode[val]',
    # display_matches=True
    ):
        path = Path(match['PathNode']['val'])
        if not path.exists():
            raise ValueError(f'path {path} was not found in {os.getcwd()}')


In [None]:
# check that read assignments got a string which is an existing path
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    ])

file = Path("file.txt")
file.touch()

# TODO figure out why this doesnt work
asts = sess.run_query(f"""
            x=read("file.txt")
            """)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
            x=read("not_existing_file.txt")
            """)
print(exc_info.value)

file.unlink()

path not_existing_file.txt was not found in /Users/dean/tdk/spannerlib/nbs


## check reference vars are defined

In [None]:
#| export
def check_referenced_vars_exist(ast,engine):

    # first rename all left hand sign variables 
    # as type "var_name_lhs"
    # so we can seperate them from reference variables
    for assignment_type in ["assignment","read_assignment"]:
        for match in rewrite_iter(ast,
                lhs=f"""X[type="{assignment_type}"]-[idx=0]->LHS[type="var_name",val]"""
                ):
            match['LHS']['type'] = "var_name_lhs"

    # now for each reference variable check if it is in the symbol table
    for match in rewrite_iter(ast,lhs=f"""X[type="var_name",val]"""):
        var_name = match['X']['val'].name
        if not engine.get_var(var_name):
            raise ValueError(f'Variable {var_name} is not defined')


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    ])

sess.engine.set_var('y',1)
sess.engine.set_var('x',"hello")

asts = sess.run_query(f"""
            z=1
            x=y
            R(x,y)
            """)
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R(x,z)
                """)
assert 'Variable z is not defined' in str(exc_info.value)
print(exc_info.value)

# for ast in asts:
#     draw(ast)


Variable z is not defined


## Cast relations to python objects

In [None]:
def relations_to_dataclasses(ast,engine):

   # regular relations
   #TODO another example where i need to edit the graph imperatively because i dont have horizontal recursion in LHS
   for match in rewrite_iter(ast,
      lhs='''
         statement[type]->name[type="relation_name",val];
         statement->terms[type]
         ''',
         #TODO i expect to be able to put an rhs here only, and if a p is not given, assume it is the identity over nodes in LHS
         p='statement[type]',
         condition=lambda match: (match['statement']['type'] in ['add_fact','remove_fact','relation','rule_head','query']
                                   and match['terms']['type'] in ['const_term_list','term_list','free_var_name_list']),
         # display_matches=True,
         ):
      term_nodes = list(ast.successors(match.mapping['terms']))
      #TODO check we iterate in order on the children
      logger.debug(f"casting relation to dataclasses - term_nodes: {term_nodes}")
      match['statement']['val'] = Relation(name=match['name']['val'],terms=[ast.nodes[term_node]['val'] for term_node in term_nodes])
      ast.remove_nodes_from(term_nodes)
   # relation declerations
   for match in rewrite_iter(ast,
      lhs='''
         statement[type="relation_declaration"]->name[type="relation_name",val];
         statement->terms[type="decl_term_list"]
         ''',
         p='statement[type]'):
      term_nodes = list(ast.successors(match.mapping['terms']))
      match['statement']['val'] = RelationDefinition(name=match['name']['val'],scheme=[ast.nodes[term_node]['val'] for term_node in term_nodes])
      ast.remove_nodes_from(term_nodes)

   # ie relations
   for match in rewrite_iter(ast,
      lhs='''
         statement[type="ie_relation"]->name[type="relation_name",val];
         statement-[idx=1]->in_terms[type="term_list"];
         statement-[idx=2]->out_terms[type="term_list"]
      ''',p='statement[type]'):
      in_term_nodes = list(ast.successors(match.mapping['in_terms']))
      out_term_nodes = list(ast.successors(match.mapping['out_terms']))

      match['statement']['val'] = IERelation(name=match['name']['val'],
                                             in_terms=[ast.nodes[term_node]['val'] for term_node in in_term_nodes],
                                             out_terms=[ast.nodes[term_node]['val'] for term_node in out_term_nodes]
                                             )
      ast.remove_nodes_from(in_term_nodes+out_term_nodes)

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    relations_to_dataclasses
    ])

with checkLogs():
    asts = sess.run_query("""
    R("hello",6)
    R("hello",[4,5))<-True
    R("hello",[4,5))<-False
    new R(str,span,int,int)
    ?R("hello",Y)
    R(X,Y)<-S(X,Y),T(X,Y)->(Y,Z)
    """)
for ast in asts:
    draw(ast)

__main__ - DEBUG - casting relation to dataclasses - term_nodes: [4, 6]
__main__ - DEBUG - casting relation to dataclasses - term_nodes: [4, 6]
__main__ - DEBUG - casting relation to dataclasses - term_nodes: [4, 6]
__main__ - DEBUG - casting relation to dataclasses - term_nodes: [4, 6]
__main__ - DEBUG - casting relation to dataclasses - term_nodes: [5, 7]
__main__ - DEBUG - casting relation to dataclasses - term_nodes: [14, 16]


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    relations_to_dataclasses
    ])
asts = sess.run_query("""
R("hello",6)
R("hello",[4,5))<-True
R("hello",[4,5))<-False
new R(str,span,int,int)
?R("hello",Y)
R(X,Y)<-S(X,Y),T(X,Y)->(Y,Z)
""")
for ast in asts:
    draw(ast)

assert  [serialize_tree(ast) for ast in asts] == [{'type': 'add_fact',
  'val': Relation(name='R', terms=['hello', 6]),
  'id': 0,
  'children': []},
 {'type': 'add_fact',
  'val': Relation(name='R', terms=['hello', Span(start=4, end=5)]),
  'id': 0,
  'children': []},
 {'type': 'remove_fact',
  'val': Relation(name='R', terms=['hello', Span(start=4, end=5)]),
  'id': 0,
  'children': []},
 {'type': 'relation_declaration',
  'val': RelationDefinition(name='R', scheme=[str,Span,int,int]),
  'id': 0,
  'children': []},
 {'type': 'query',
  'val': Relation(name='R', terms=['hello', FreeVar(name='Y')]),
  'id': 0,
  'children': []},
 {'type': 'rule',
  'id': 0,
  'children': [{'type': 'rule_head',
    'val': Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
    'id': 1},
   {'type': 'rule_body_relation_list',
    'id': 9,
    'children': [{'type': 'relation',
      'val': Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')]),
      'id': 10},
     {'type': 'ie_relation',
      'val': IERelation(name='T', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]),
      'id': 18}]}]}]

## Relation referencing

* check that referenced relations and ie relations:
  * exist in the symbol table 
  * are called with the correct arity
  * are called with correct constants or vars types

In [None]:
def verify_referenced_relations(ast,engine):

    def schema_match(types,vals):
        for type_,val in zip(types,vals):
            if isinstance(val,FreeVar):
                continue # free vars can be anything
            elif isinstance(val,Var):
                var_type = engine.get_var(val.name)[0]
                if not type_ == var_type:
                    return False
            elif not isinstance(val,type_):
                return False
        return True

    for match in rewrite_iter(ast,
            lhs='''rel[type]''',
            condition=lambda match: match['rel']['type'] in ['add_fact','remove_fact','relation','query'],
            ):
        rel:Relation = match['rel']['val']
        if not engine.get_relation(rel.name):
            raise ValueError(f"Relation '{rel.name}' is not defined")
        expected_len = len(engine.get_relation(rel.name).scheme)
        if len(rel.terms) != expected_len:
            raise ValueError(f"Relation '{pretty(rel)}' was called with {len(rel.terms)} terms but it was defined with {expected_len} terms")
        if not schema_match(engine.get_relation(rel.name).scheme,rel.terms):
            raise ValueError(f"Relation '{rel.name}' expected schema {pretty(engine.get_relation(rel.name))} but got called with {pretty(rel)}")
      


In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    relations_to_dataclasses,
    verify_referenced_relations
    ])

sess.engine.set_var('a',1)
sess.engine.set_var('b','hello')
sess.engine.set_relation(RelationDefinition(name='R',scheme=[str,int]))
sess.engine.set_relation(RelationDefinition(name='S',scheme=[str,int,int]))
sess.engine.set_ie_function(IEFunction(name='T',in_schema=[str,int],out_schema=[int,str],func=lambda x,y:(y,x)))

asts = sess.run_query("""
R("hello",6)
R("hello",a)
?R("hello",Y)
NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Z)
""")

# for ast in asts:
#     draw(ast)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                Z("hello",4)
                """)
print(exc_info.value)
assert "Relation 'Z' is not defined" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R("hello",4,[4,5))
                """)
print(exc_info.value)
assert "Relation 'R(hello,4,[4,5))' was called with 3 terms but it was defined with 2 terms" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R(4,4)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(4,4)" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                R("hello",b)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(hello,b)" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query(f"""
                ?R(4,Y)
                """)
print(exc_info.value)
assert "Relation 'R' expected schema R(str,int) but got called with R(4,Y)" in str(exc_info.value)

# assert  [serialize_tree(ast) for ast in asts] 
# [serialize_tree(ast) for ast in asts]

Relation 'Z' is not defined
Relation 'R(hello,4,[4,5))' was called with 3 terms but it was defined with 2 terms
Relation 'R' expected schema R(str,int) but got called with R(4,4)
Relation 'R' expected schema R(str,int) but got called with R(hello,b)
Relation 'R' expected schema R(str,int) but got called with R(4,Y)


## cast rules to data classes

In [None]:
#| export
def rules_to_dataclasses(ast,engine):
   for match in rewrite_iter(ast,
      lhs='''
         statement[type="rule"]->head[type="rule_head",val];
         statement->body[type="rule_body_relation_list"]
      ''',p='statement[type]'):
      body_nodes = list(ast.successors(match.mapping['body']))
      head = match['head']['val']
      match['statement']['val'] = Rule(head=match['head']['val'],body=[ast.nodes[body_node]['val'] for body_node in body_nodes])
      ast.remove_nodes_from(body_nodes)
   return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    relations_to_dataclasses,
    # verify_referenced_relations,
    rules_to_dataclasses
    ])

asts = sess.run_query("""
R(X,Y,Z)<-S(X,Y),T(X,Y)
R(X,Y,Z)<-S(X,Y),T(X,Y)->(Y,Z)
""")
for ast in asts:
    draw(ast)

assert  [serialize_tree(ast) for ast in asts] == [{'type': 'rule',
  'val': Rule(head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]),
               body=[Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')]), 
                     Relation(name='T', terms=[FreeVar(name='X'), FreeVar(name='Y')])]),
  'id': 0,
  'children': []},
 {'type': 'rule',
  'val': Rule(head=Relation(name='R', terms=[FreeVar(name='X'), FreeVar(name='Y'), FreeVar(name='Z')]), 
              body=[Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y')]), 
                    IERelation(name='T', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')])]),
  'id': 0,
  'children': []}]

## Consistent Free Var types

In [None]:
#| export
def _check_rule_consistency(rule,engine):
    # for each free var we encounter, what is the type is is according to the relation schema
    free_var_to_type = {}
    # what is the first relation we found each var in, useful for error messages
    first_rel_to_define_free_var = {}

    # go over each body relation
    for rel_idx,relation in enumerate(rule.body):

        # if ie relation split into two relations
        rel_type_terms_and_schema_list = []
        if isinstance(relation,Relation):
            rel_type_terms_and_schema_list.append(('relation',relation.terms,engine.get_relation(relation.name).scheme))
        elif isinstance(relation,IERelation):
            rel_type_terms_and_schema_list.append(('ie input relation',relation.in_terms,engine.get_ie_function(relation.name).in_schema))
            rel_type_terms_and_schema_list.append(('ie output relation',relation.out_terms,engine.get_ie_function(relation.name).out_schema))

        # for each relation type in the body relation
        for rel_type_terms_and_schema in rel_type_terms_and_schema_list:
            rel_type,terms,expected_schema = rel_type_terms_and_schema
            # for each term in the relation that is a free var
            for term_idx,(term,expected_type) in enumerate(zip(terms,expected_schema)):
                if isinstance(term,FreeVar):
                    # check if was defined before in other body relations
                    # if so must have same type as before
                    if term.name in free_var_to_type:
                        if free_var_to_type[term.name] != expected_type:
                            predefined_relation,predefined_term_idx = first_rel_to_define_free_var[term.name]
                            raise ValueError(f"In rule {pretty(rule)}, in body {rel_type} {pretty(relation)}, FreeVar {term.name} position {term_idx} expects type {pretty(expected_type)} "
                                            f"but was previously defined in relation {pretty(predefined_relation)} position {predefined_term_idx} with type {pretty(free_var_to_type[term.name])}")
                    # if not register it to the mapping
                    else:
                        free_var_to_type[term.name] = expected_type
                        first_rel_to_define_free_var[term.name] = (relation.name,term_idx)

    # for rule head, make sure all free vars are defined in the body
    # and if the rule head was used in another rule, make sure it has the same types
    head_name, head_terms = rule.head.name, rule.head.terms

    current_head_schema = []
    for term in head_terms:
        if not isinstance(term,FreeVar):
            raise ValueError(f"In rule {pretty(rule)}, in head clause {head_name}, only FreeVars are allowed")
        if not term.name in free_var_to_type:
            raise ValueError(f"In rule {pretty(rule)}, FreeVar {term.name} is used in the head but was not defined in the body")

    current_head_schema = RelationDefinition(name=head_name,scheme=[free_var_to_type[term.name] for term in head_terms])

    if engine.get_relation(head_name):
        expected_head_schema = engine.get_relation(head_name)
        if expected_head_schema != current_head_schema:
            raise ValueError(f"In rule {pretty(rule)}, expected schema {pretty(expected_head_schema)} from a previously defined rule to {head_name} but got {pretty(current_head_schema)}")
    else:
        engine.set_relation(current_head_schema)

def consistent_free_var_types_in_rule(ast,engine):
    for match in rewrite_iter(ast,lhs='X[type="rule",val]'):
        rule = match['X']['val']
        _check_rule_consistency(rule,engine)
    return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    relations_to_dataclasses,
    verify_referenced_relations,
    rules_to_dataclasses,
    consistent_free_var_types_in_rule,
    ])


sess.engine.set_var('a',1)
sess.engine.set_var('b','hello')
sess.engine.set_relation(RelationDefinition(name='R',scheme=[str,int]))
sess.engine.set_relation(RelationDefinition(name='S',scheme=[str,int,int]))
sess.engine.set_relation(RelationDefinition(name='NewRel',scheme=[str,int]))
sess.engine.set_ie_function(IEFunction(name='T',in_schema=[str,int],out_schema=[int,str],func=lambda x,y:(y,x)))

# legal query
asts = sess.run_query("""
NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
""")
t = serialize_tree(asts[0])

# change types of R to make it illegal
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Y,Z)
    """)
print(exc_info.value)
assert "FreeVar Y position 0 expects type str but was previously defined in relation S position 1 with type int" in str(exc_info.value)

# change types of R to make it illegal
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(X,Y)<-S(X,Y,3),T(X,Y)->(Y,Y)
    """)
print(exc_info.value)
assert "FreeVar Y position 1 expects type str but was previously defined in relation S position 1 with type int" in str(exc_info.value)

# free var in head, bound by body
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(X,Y,W)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
    """)
print(exc_info.value)
assert "FreeVar W is used in the head but was not defined in the body" in str(exc_info.value)

# head free var type mismatch
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    NewRel(Y,X)<-S(X,Y,3),T(X,Y)->(Y,Z),R(Z,Y)
    """)
print(exc_info.value)
assert "expected schema NewRel(str,int) from a previously defined rule to NewRel but got NewRel(int,str)" in str(exc_info.value)
t

In rule NewRel(X,Y) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Y,Z), in body relation R(Y,Z), FreeVar Y position 0 expects type str but was previously defined in relation S position 1 with type int
In rule NewRel(X,Y) <- S(X,Y,3),T(X,Y) -> (Y,Y), in body ie output relation T(X,Y) -> (Y,Y), FreeVar Y position 1 expects type str but was previously defined in relation S position 1 with type int
In rule NewRel(X,Y,W) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Z,Y), FreeVar W is used in the head but was not defined in the body
In rule NewRel(Y,X) <- S(X,Y,3),T(X,Y) -> (Y,Z),R(Z,Y), expected schema NewRel(str,int) from a previously defined rule to NewRel but got NewRel(int,str)


{'type': 'rule',
 'val': Rule(head=Relation(name='NewRel', terms=[FreeVar(name='X'), FreeVar(name='Y')]), body=[Relation(name='S', terms=[FreeVar(name='X'), FreeVar(name='Y'), 3]), IERelation(name='T', in_terms=[FreeVar(name='X'), FreeVar(name='Y')], out_terms=[FreeVar(name='Y'), FreeVar(name='Z')]), Relation(name='R', terms=[FreeVar(name='Z'), FreeVar(name='Y')])]),
 'id': 0,
 'children': []}

## Rule safety

In [None]:
#| export
def is_rule_safe(rule:Rule):
    """
    Checks that the Spannerlog Rule is safe
    ---
    In spannerlog, rule safety is a semantic property that ensures that IE relation's inputs are limited 
    in the values they can be assigned to by other relations in the rule body.
    This could include outputs of other IE relations.

    We call a free variable in a rule body "bound" if it exists in the output of any safe relation in the rule body.
    For normal relations, they only have output terms, so all their free variables are considered bound.

    We call a relation in a rule's body safe if all its input free variables are bound.
    For normal relations, they don't have input relations, so they are always considered safe.

    We call a rule safe if all of its body relations are safe.

    This basically means that we need to make sure there is at least one order of IE relation evaluation, in which
    each IE relation input variables is bound by the normal relations and the output relation of the previous IE relations.

    Examples:
    * `rel2(X,Y) <- rel1(X,Z), ie1(X)->(Y)` is a safe rule as the only input free variable, `X`, exists in the output of the safe relation `rel1(X, Z)`.  
    * `rel2(Y) <- ie1(Z)->(Y)` is not safe as the input free variable `Z` does not exist in the output of any safe relation.
    * `rel2(Z,W) <- rel1(X,Y),ie1(Z,Y)->(W),ie2(W,Y)->Z` is not safe as both ie functions require each other's output as input, creating a circular dependency.
    ---
    """

    # get all free vars in regular relations
    normal_relations_free_vars = set()
    for body in rule.body:
        if isinstance(body,Relation):
            for term in body.terms:
                if isinstance(term,FreeVar):
                    normal_relations_free_vars.add(term.name)
    
    # get list of form [(ie_rel,{input_vars},{output_vars})]
    free_vars_per_ie_relation = {}
    for body in rule.body:
        if isinstance(body,IERelation):
            input_vars = set(term.name for term in body.in_terms if isinstance(term,FreeVar))
            output_vars = set(term.name for term in body.out_terms if isinstance(term,FreeVar))
            free_vars_per_ie_relation[body]=(input_vars,output_vars)
        
    # iteratively go over all previously unsafe ie relations and check if they are now safe

    safe_vars = normal_relations_free_vars.copy()
    safe_ie_relations = set()

    while True:
        if len(free_vars_per_ie_relation)==0:
            break
        
        safe_ie_relations_in_this_iteration = set()
        for ie_relation,(input_vars,output_vars) in free_vars_per_ie_relation.items():
            if input_vars.issubset(safe_vars):
                safe_ie_relations_in_this_iteration.add(ie_relation)
        
        if len(safe_ie_relations_in_this_iteration)== 0 :
            raise ValueError(f"Rule \'{pretty(rule)}\' is not safe:\n"
                            f"the following free vars where bound by normal relations: {normal_relations_free_vars}\n"
                            f"the following ie relations where safe: {safe_ie_relations}\n"
                            f"leading to the following free vars being bound: {safe_vars}\n"
                            f"However the following ie relations could not be bound: {[pretty(ie) for ie in free_vars_per_ie_relation.keys()]}\n"
                             )
    
        for ie_relation in safe_ie_relations_in_this_iteration:
            input_vars,output_vars = free_vars_per_ie_relation[ie_relation]
            safe_vars.update(output_vars)
            safe_ie_relations.add(ie_relation)
            del free_vars_per_ie_relation[ie_relation]



    return True

In [None]:
def check_rule_safety(ast,engine):
    for match in rewrite_iter(ast,lhs='X[type="rule",val]'):
        rule = match['X']['val']
        is_rule_safe(rule)
    return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    relations_to_dataclasses,
    rules_to_dataclasses,
    check_rule_safety,
    ])


# safe rules
asts = sess.run_query("""
S(X,Y)<-R(X,Y)    
rel2(X,Y) <- rel1(X,Z), ie1(X)->(Y)
rel2(X,Y) <- rel1(X,Z), ie1(X)->(Y), ie2(Y)->(Z)
rel2(X,Y) <- rel1(X), ie1(X)->(Y), ie2(Y)->(Z)
""")
t = serialize_tree(asts[0])

# change types of R to make it illegal
with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    rel2(Y) <- ie1(Z)->(Y)
    """)
print(exc_info.value)
assert "is not safe" in str(exc_info.value)


with pytest.raises(ValueError) as exc_info:
    asts = sess.run_query("""
    rel2(Z,W) <- rel1(X,Y),ie1(Z,Y)->(W),ie2(W,Y)->(Z)    
    """)
print(exc_info.value)
assert "is not safe" in str(exc_info.value)





Rule 'rel2(Y) <- ie1(Z) -> (Y)' is not safe:
the following free vars where bound by normal relations: set()
the following ie relations where safe: set()
leading to the following free vars being bound: set()
However the following ie relations could not be bound: ['ie1(Z) -> (Y)']

Rule 'rel2(Z,W) <- rel1(X,Y),ie1(Z,Y) -> (W),ie2(W,Y) -> (Z)' is not safe:
the following free vars where bound by normal relations: {'Y', 'X'}
the following ie relations where safe: set()
leading to the following free vars being bound: {'Y', 'X'}
However the following ie relations could not be bound: ['ie1(Z,Y) -> (W)', 'ie2(W,Y) -> (Z)']



In [None]:
def assignemnts_to_name_val_tuple(ast,engine):
    for match in rewrite_iter(ast,lhs='''
                                statement[type]-[idx=0]->var_name_node[val];
                                statement-[idx=1]->val_node[val]''',p='statement[type]',
                                condition=lambda match: match['statement']['type'] in ['assignment','read_assignment'],
                                # display_matches=True
                                ):
        match['statement']['val'] = (
            match['var_name_node']['val'].name,
            match['val_node']['val']
        )
    return ast

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    # check_referenced_paths_exist,
    # check_referenced_vars_exist,
    relations_to_dataclasses,
    # verify_referenced_relations,
    rules_to_dataclasses,
    # consistent_free_var_types_in_rule,
    check_rule_safety,
    assignemnts_to_name_val_tuple
    ])


file = Path("file.txt")
file.write_text("hello file")

# safe rules
asts = sess.run_query("""
x=3
y="hello"
z="hello \
world"
w=[3,4)
f=read("file.txt")
""")

for ast in asts:
    draw(ast)
assert [serialize_tree(ast) for ast in asts] == [{'type': 'assignment', 'val': ('x', 3), 'id': 0, 'children': []},
 {'type': 'assignment', 'val': ('y', 'hello'), 'id': 0, 'children': []},
 {'type': 'assignment', 'val': ('z', 'hello world'), 'id': 0, 'children': []},
 {'type': 'assignment',
  'val': ('w', Span(start=3, end=4)),
  'id': 0,
  'children': []},
 {'type': 'read_assignment',
  'val': ('f', 'file.txt'),
  'id': 0,
  'children': []}]

# Execution passes

In [None]:
#| export
def execute_statement(ast,engine):
    statement_node = list(ast.nodes)[0]
    node_data = ast.nodes[statement_node]
    statement = node_data['type']
    value = node_data['val']
    match statement:
        case 'assignment':
            engine.set_var(*value)
        case 'read_assignment':
            engine.set_var(*value,read_from_file=True)
        case 'add_fact':
            engine.add_fact(value)
        case 'remove_fact':
            engine.del_fact(value)
        case 'relation_declaration':
            engine.set_relation(value)
        case 'rule':
            return engine.add_rule(value)
        case 'query':
            return engine.run_query(value)
        case _:
            raise ValueError(f"Unknown statement type {statement}")
    return None

    

In [None]:
# sess.engine.set_relation(RelationDefinition(name='R',scheme=[str,int]))
sess.engine.Relation_defs

{}

In [None]:
sess = DummySession(passes=[
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames('spanner_'),
    check_referenced_paths_exist,
    check_referenced_vars_exist,
    relations_to_dataclasses,
    verify_referenced_relations,
    rules_to_dataclasses,
    consistent_free_var_types_in_rule,
    check_rule_safety,
    assignemnts_to_name_val_tuple,
    ],
    execution_function=execute_statement
    )

sess.symbol_table ={}

file = Path("file.txt")
file.write_text("hello file")

# safe rules
results = sess.run_query("""
x=3
y=read("file.txt")
new R(str,int)
R("hello",4)
?R("hello",x)
S(X,Y)<-R(X,Y)    
?R(X,Y)
R("hello",4)<-False
?S(X,Y)
""")
file.unlink()

assert serialize_df_values(results[-3])=={('hello',4)}
assert serialize_df_values(results[-1])==set()

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     