# The Session Object

In [None]:
#| default_exp session

In [None]:
#| hide
from nbdev.showdoc import show_doc

%load_ext autoreload
%autoreload 2

In [None]:
#| export
import csv

import os
import re
from pathlib import Path
from typing import Tuple, List, Union, Optional, Callable, Type, Iterable, no_type_check, Sequence
from fastcore.basics import patch
from IPython import display
from singleton_decorator import singleton
from numbers import Real
import pandas as pd
import os
from itables import init_notebook_mode,show
import logging
logger = logging.getLogger(__name__)

from graph_rewrite import draw

from spannerflow.span import Span
from spannerlib.utils import checkLogs,get_base_file_path,assert_df_equals,DefaultIEs,DefaultAGGs
from spannerlib.grammar import parse_spannerlog,reconstruct
from spannerlib.data_types import (
    _infer_relation_schema,
     Var,
    FreeVar,
    RelationDefinition,
    Relation,
    IEFunction,
    AGGFunction,
    IERelation,
    Rule,
    pretty,
)
from spannerlib.engine import Engine

from spannerlib.micro_passes import (
    convert_primitive_values_to_objects,
    CheckReservedRelationNames,
    dereference_vars,
    check_referenced_paths_exist,
    inline_aggregation,
    relations_to_dataclasses,
    verify_referenced_relations_and_functions,
    rules_to_dataclasses,
    check_rule_safety,
    consistent_free_var_types_in_rule,
    assignments_to_name_val_tuple,
)


In [None]:
#| export

def _load_stdlib():
    # make sure we import the modules that register the stdlib
    import spannerlib.ie_func.basic 
    import spannerlib.ie_func.json_path


In [None]:
#| export

def _class_repr(x):
    """returns the repr of x if x is a Span, else returns x
    used to display spans in a more readable way in pandas 
    """
    if not isinstance(x,(str,float,int,bool)):
        return f"{repr(x)}"
    return x

In [None]:
#| export
class Session():
    def __init__(self,
    register_stdlib=True, # if True, registers the standard library of IEs and AGGs
    ):
        """
        A Session object is the main interface to the spannerlog engine. 
        It is used to parse, check semantics, plan and execute queries.
        It allows importing data and callbacks to the Spannerlog engine and exporting data from the engine back to python.
        """

        self.pass_stack = [
            convert_primitive_values_to_objects,
            CheckReservedRelationNames('spanner_'),
            check_referenced_paths_exist,
            dereference_vars,
            inline_aggregation,
            relations_to_dataclasses,
            verify_referenced_relations_and_functions,
            rules_to_dataclasses,
            check_rule_safety,
            consistent_free_var_types_in_rule,
            assignments_to_name_val_tuple,
        ]

        self.clear(register_stdlib=register_stdlib)

In [None]:
#| export
@patch
def clear(self:Session,
    register_stdlib=True, # if True, registers the standard library of IEs and AGGs
    ):
    """Resets the engine and clears all relations, functions and rules."""
    self.engine = Engine()
    if not register_stdlib:
        return
    _load_stdlib()
    for ie_def in DefaultIEs().as_list():
        self.register(*ie_def)
    for agg_def in DefaultAGGs().as_list():
        self.register_agg(*agg_def)


## Importing information to spannerlog

In [None]:
#| export
@patch
def register(self:Session,
    name, # name of the IE function in spannerlog
    func, # the python function that implements the IE
    in_schema, # the schema of the input relation
    out_schema, # the schema of the output relation
    ):
    """Registers an IE function with the spannerlog engine."""
    ie_func_obj = IEFunction(name=name,func=func,in_schema=in_schema,out_schema=out_schema)
    self.engine.set_ie_function(ie_func_obj)


In [None]:
#| export
@patch
def register_agg(self:Session,
        name, # name of the AGG function in spannerlog
        func, # the python function that implements the AGG
        in_schema, # the schema of the input relation, can be of arity 1 only
        out_schema # the schema of the output relation, can be of arity 1 only
    ):
    """Registers an AGG function with the spannerlog engine."""
    agg_func_obj = AGGFunction(name=name,func=func,in_schema=in_schema,out_schema=out_schema)
    self.engine.set_agg_function(agg_func_obj)


In [None]:
#| export
@patch
def import_csv(self:Session,
        name:str, # name of the relation in spannerlog
        csv_filepath:Union[str,Path], # path to the csv file
        delim:str = None, # the delimiter of the csv file
        has_header: bool = False, # does the first line is a header line
    ):
    """Imports a csv file into the current session."""
    csv_file_name = Path(csv_filepath)
    if not csv_file_name.is_file():
        raise IOError("csv file does not exist")
    if os.stat(csv_file_name).st_size == 0:
        raise IOError("csv file is empty")
    
    with open(csv_file_name) as csv_file:
        reader = csv.reader(csv_file, delimiter=delim)
        first_row = next(reader)
        if has_header:
            first_row = next(reader)
    scheme = _infer_relation_schema(first_row)
    rel_def = RelationDefinition(name=name,scheme=scheme)
    
    if self.engine.get_relation(name):
        if self.engine.get_relation(name) != rel_def:
            raise ValueError(f"Relation {name} already exists with a different schema")
    else:
        self.engine.set_relation(rel_def)
    self.engine.load_csv(name, csv_file_name, delim=delim, has_header=has_header)




In [None]:
#| export
@patch
def import_rel(self:Session,
    name:str, # name of the relation in spannerlog
    data:Union[str,Path,pd.DataFrame], # either a pandas dataframe or a path to a csv file
    delim:str = None, # the delimiter of the csv file
    has_header: bool = None, # does the first line is a header line
    ):
    """Imports a relation into the current session, either from a dataframe or from a csv file."""

    if isinstance(data, (Path,str)):
        self.import_csv(name,data,delim=delim,has_header=has_header)
        return
        
    first_row = list(data.iloc[0,:])
    scheme = _infer_relation_schema(first_row)
    rel_def = RelationDefinition(name=name,scheme=scheme)
    
    if self.engine.get_relation(name):
        if self.engine.get_relation(name) != rel_def:
            raise ValueError(f"Relation {name} already exists with a different schema")
    else:
        self.engine.set_relation(rel_def)
    self.engine.add_facts(name,data)


In [None]:
#| export
@patch
def import_var(self:Session,
    name, # name of the variable in spannerlog
    value, # the value of the variable
    ):
    """Imports a variable into the current session."""
    self.engine.set_var(name,value)

## Exporting data from spannerlog to python

In [None]:
#| export
# parsing statements
#TODO from here seperate parse and check semantics so we can know the number of statements
# before we iterate on semantic checks that might depend on execution of previous statements
@patch
def _parse_code(self:Session,code):
    """Parses a spannerlog code snippet and returns a list of statements."""
    try:
        statements = parse_spannerlog(code,split_statements=True)
    except Exception as e:
        print(f"Syntax ERROR:\n{e}\n")
        raise e
    return statements

@patch
def _check_semantics(self:Session,statements):
    """An iterator for performing semantic checks on a list of statements.
    Yields the AST and the Lark parse tree of each statement.

    Each statement must be executed, between yields in order to check the semantics
    of the next statement based on the side effects of the previous statement.
    """
    for statement_nx,statement_lark in statements:
        ast = statement_nx
        for pass_ in self.pass_stack:
            try:
                pass_(ast,self.engine)
            except Exception as e:
                print(
                    f"SEMANTIC ERROR:\n"
                    f"During semantic checks for statement \n\"{reconstruct(statement_lark)}\"\n"
                    f"in pass {pass_} the following exception was raised:\n{e}\n"
                    )
                raise e
        yield ast,statement_lark

In [None]:
#| export

## executing statements
def _statement_type_and_value(ast):
    """gets the type and value of a statement from the ast
    assumes an ast with a single node
    """
    statement_node = list(ast.nodes)[0]
    node_data = ast.nodes[statement_node]
    statement = node_data['type']
    value = node_data['val']
    return statement,value


def _execute_statement(
    ast, # networkx ast after semantic checks, should have a single node with a single statement
    engine, # the spannerlog engine to execute the statement on
    plan_only=False, # if True, plans queries returns the graph and root, but does not execute them
    draw_graph=False, # if True, draws the graph of the query plan
    ):
    """executes a single statement from the ast
    """
    statement,value = _statement_type_and_value(ast)
    match statement:
        case 'assignment':
            engine.set_var(*value)
        case 'read_assignment':
            engine.set_var(*value,read_from_file=True)
        case 'add_fact':
            engine.add_fact(value)
        case 'remove_fact':
            engine.del_fact(value)
        case 'relation_declaration':
            engine.set_relation(value)
        case 'rule':
            engine.add_rule(value)
        case 'query':
            graph,root = engine.plan_query(value)
            if draw_graph:
                draw(graph)
            if plan_only:
                return graph,root
            return engine.execute_plan(graph,root)
        case _:
            raise ValueError(f"Unknown statement type {statement}")
    return None


    

In [None]:
#| export
# formatting query results
def _sort_df(df):
    """sort df, if possible by value of rows else sort by string representation of rows.        
    """
    try:
        sorted_df = df.sort_values(by=list(df.columns))
    except TypeError as e:
        sorted_df = df.sort_values(by=list(df.columns),key=lambda x: tuple(str(i) for i in x) ) 
    return sorted_df

def _format_results(res):
    """format the results of a query. if a boolean dataframe is returned, return the boolean value,
    else sort the dataframe and reset the index.
    """
    if not isinstance(res,pd.DataFrame):
        return res
    if res.shape == (1,0):
        return True
    elif res.shape == (0,0):
        return False
    else:
        return _sort_df(res).reset_index(drop=True)

def _display_result(result,statement_lark):
    """format the results and display it and the query that generated it to stdout
    if its a dataframe, display it using itables"""
    if result is None:
        pass
    elif isinstance(result,pd.DataFrame):
        display.display(reconstruct(statement_lark))
        show(_format_results(result).map(_class_repr)
            .style.set_properties(**{
                'overflow-wrap': 'break-word',
                'max-width': '800px',
                'text-align': 'left'}),
            columnDefs=[{
                "targets": list(result.columns),
                "render": """function(data, type, row) {
                    return '<div style="white-space: normal; word-wrap: break-word;">' + data + '</div>';
                }""",
                "width": "300px"
            }],
            eval_functions=True,
            escape=True)
    elif isinstance(result,bool):
        display.display(reconstruct(statement_lark))
        display.display(result)
    else:
        pass



In [None]:
#| export
@patch
def export(self:Session,
    code:str , # the spannerlog code to execute
    display_results=False, # if True, displays the results of the query to screen
    draw_query=False, # if True, draws the query graph of queries to screen
    plan_query=False, # if True, if last statement is a query, plans the query and returns the query graph and root node.
    return_statements_meta=False, # if True, returns both the return value and the statements meta data, used internally.
    ):
    """Takes a string of spannerlog code, and executes it, returning the value of the last statement in the code string.
    All statements that are not queries, return None.
    """
    results = []
    statements = []
    parsed_statements = self._parse_code(code)
    num_statements = len(parsed_statements)
    
    for statement_index,(clean_ast,statement_lark) in enumerate(self._check_semantics(parsed_statements)):
        is_last_statement = statement_index == num_statements - 1
        plan_only = plan_query and is_last_statement
        try:
            result = _execute_statement(clean_ast,self.engine,draw_graph=draw_query,plan_only=plan_only)
            result = _format_results(result)
        except Exception as e:
            print(f"RUNTIME ERROR:\n"
                f"During execution of statement \n\"{reconstruct(statement_lark)}\"\n"
                f"the following exception was raised:\n"
                )
            raise e
        
        s_type,s_dataclass = _statement_type_and_value(clean_ast)
        statements.append((s_type,s_dataclass,reconstruct(statement_lark)))
        results.append(result)
        if display_results:
            _display_result(result,statement_lark)
    
    if len(results) == 0:
        ret_val =  None
    else:
        ret_val =  results[-1]

    if return_statements_meta:
        return ret_val,statements
    else:
        return ret_val

In [None]:
#| export
@patch  
def print_rules(self:Session):
    """Prints all the rules in the engine. and returns them as a list"""
    rules = list(self.engine.rules_to_ids.keys())
    for rule in rules:
        print(rule)
    return rules

@patch
def get_all_functions(self:Session):
    """Returns all the IEs and AGGs in the engine, as a nested dictionary of the form:
    {
        'ie':{name:IEFunction},
        'agg':{name:AGGFunction}
    }
    """
    return {
        'ie':self.engine.ie_functions.copy(),
        'agg':self.engine.agg_functions.copy()
    }


## Removing information

These functions are mostly used when debugging spannerlog code, to remove rules and relations we want to redefine.

In [None]:
#| export
@patch
def remove_rule(self:Session,
    rule:str # the rule string to remove
    ):
    """removes a rule from the engine, rule string must be identical to the rule defined previously
    """
    self.engine.del_rule(rule)


In [None]:
#| export

@patch
def remove_head(self:Session,head:str):
    """removes all rules of a given head relation
    """
    self.engine.del_head(head)


In [None]:
#| export

@patch
def remove_all_rules(self:Session):
    """removes all rules from the engine
    """
    rules = list(self.engine.rules_to_ids.keys())
    for rule in rules:
        self.remove_rule(rule)


In [None]:
#| export

@patch
def remove_relation(self:Session,relation:str):
    """removes a relation from the engine, either a extrinsic or intrinsic relation
    """
    self.engine.del_relation(relation)



#| hide
## Test scaffold

In [None]:
#| exporti
def test_session(
    code_strings,
    expected_outputs=None,# list of expected dfs
    ie_funcs=None,# List of [name,func,in_scheme,out_scheme]
    agg_funcs=None,
    csvs=None,# List of [name,df]
    debug=False,
    display_results=True,
    ):

    sess=Session()

    # add data
    if csvs:
        for name,df in csvs:
            sess.import_rel(name,df)
    # add ies
    if ie_funcs:
        for name,func,in_scheme,out_scheme in ie_funcs:
            sess.register(name,func,in_scheme,out_scheme)
    
    if agg_funcs:
        for name,func,in_scheme,out_scheme in agg_funcs:
            sess.register_agg(name,func,in_scheme,out_scheme)

    # normalize code strings and expected outputs to lists
    if not isinstance(code_strings,list):
        code_strings = [code_strings]
    if expected_outputs is None:
        expected_outputs = [None]*len(queries)
        dont_assert = True
    else:
        dont_assert = False
    if not isinstance(expected_outputs,list):
        expected_outputs = [expected_outputs]

    
    for code,expected in zip(code_strings,expected_outputs):
        try:
            res = sess.export(code,display_results=True,draw_query=debug)
        except Exception as e:
            print(f"Error in code {code}")
            raise e
        
        if dont_assert:
            continue
        if isinstance(expected,pd.DataFrame) and isinstance(res,pd.DataFrame):
            assert_df_equals(res,expected)
        else:
            assert res == expected, f"expected {expected}, got {res}"
    return sess
        

## Examples

In [None]:
sess = Session()
df1 = pd.DataFrame([['John Doe', 35],['Jane Smith', 28]],columns=['X','Y'])
df2 = pd.DataFrame([['John Doe', 30]],columns=['X','Y'])

sess.import_rel("AgeOfKids",df1)
display.display(sess.export("?AgeOfKids(X,Y)"))
sess.import_rel("AgeOfKids",df2)
display.display(sess.export("?AgeOfKids(X,Y)"))

#| hide
## Tests

In [None]:
# basic export with metadata
sess = Session()

res = sess.export("""
new A(int)
""")
display.display(res)

In [None]:
sess.engine.Relation_defs

In [None]:
#| hide
# basic export with metadata
sess = Session()

res,meta = sess.export("""
new A(int)
A(1)
?A(X)
""",return_statements_meta=True)
display.display(res)

assert meta == [
    ('relation_declaration',
        RelationDefinition(name='A', scheme=[int]),
        'new A(int)'),
 ('add_fact', Relation(name='A', terms=[1], agg=None), 'A(1)'),
 ('query', Relation(name='A', terms=[FreeVar(name='X')], agg=None), '?A(X)')]

In [None]:
#| hide

# test rule removal
sess = Session()
_ = sess.export('''
    new parent(str, str)
    new grandparent(str, str)
    parent("Liam", "Noah")
    parent("Noah", "Oliver")
    parent("James", "Lucas")
    parent("Noah", "Benjamin")
    parent("Benjamin", "Mason")
    grandparent("Tom", "Avi")
    ancestor(X,Y) <- parent(X,Y).
    ancestor(X,Y) <- grandparent(X,Y).
    ancestor(X,Y) <- parent(X,Z), ancestor(Z,Y).
''')


rules = sess.print_rules()
assert rules == ['ancestor(X,Y) <- parent(X,Y).',
'ancestor(X,Y) <- grandparent(X,Y).',
'ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y).',]

sess.remove_rule("ancestor(X,Y) <- parent(X,Y).")
print("="*50)
rules = sess.print_rules()
assert rules == ['ancestor(X,Y) <- grandparent(X,Y).',
'ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y).',]

In [None]:
#| hide

# test clearing the engine
commands = """
    new parent(str, str)
    new grandparent(str, str)
    parent("Liam", "Noah")
    grandparent("Tom", "Avi")
    ancestor(X,Y) <- parent(X,Y).
    ancestor(X,Y) <- grandparent(X,Y).
    ancestor(X,Y) <- parent(X,Z), ancestor(Z,Y).
    """
session = Session()
output = session.export(commands)
session.print_rules()
session.clear()
assert session.print_rules() == []

In [None]:
#| hide

# importing relations from csv
session = Session()
session.import_rel(name="enrolled",data="./sample_data/enrolled.csv", delim=",")
commands = """
enrolled("abigail", "chemistry")
?enrolled(X,Y)
"""
res = session.export(commands)
assert_df_equals(res,pd.DataFrame([
    ["abigail", "chemistry"],
    ["gale", "operating_systems"],
    ["howard", "chemistry"],
    ["howard", "physics"],
    ["jordan", "chemistry"],
    ["abigail", "operating_systems"],
],columns=["X","Y"]))

In [None]:
#| hide

# importing relations from dataframe
session = Session()
lecturer_df = pd.DataFrame(([["walter","chemistry"], ["linus", "operating_systems"]]))
session.import_rel("lecturer",lecturer_df)
commands = """ 
?lecturer(X,Y)
"""
res = session.export(commands)
assert_df_equals(res,pd.DataFrame([
    ["walter","chemistry"],
    ["linus", "operating_systems"]
],columns=["X","Y"]))


In [None]:
#| hide

test_session(
    [
        """
        new Parent(str, str)
        Parent("Sam", "Noah")
        Parent("Noah", "Austin")
        Parent("Austin", "Stephen")

        GrandParent(G, C) <- Parent(G, M), Parent(M, C).
        """,
        """?GrandParent(X, "Austin")"""
    ],
    expected_outputs = [
        None,
        pd.DataFrame({'X':['Sam']})
    ],
    # debug=True

)



In [None]:
#| hide

# constants in rule heads
test_session(
    [
        """
        new Parent(str, str)
        Parent("Sam", "Noah")
        Parent("Noah", "Austin")
        Parent("Austin", "Stephen")

        GrandParent(G, C) <- Parent(G, M), Parent(M, C).
        # all grand parents are fun
        FunGrandParent(G,C,"yes")<- GrandParent(G,C).
        """,
        """?FunGrandParent(X, "Austin",AreFun)"""
    ],
    expected_outputs = [
        None,
        pd.DataFrame([
            ["Sam","yes"]
        ],columns=["X","AreFun"])
    ],
    # debug=True

)



In [None]:
#| hide

def length(string: str) -> Iterable[int]:
        yield (len(string),)

_ =test_session(
    ["""new string(str)
    string("a")
    string("d")
    string("a")
    string("ab")
    string("abc")
    string("abcd")

    string_length(Str, Len) <- string(Str), Length(Str) -> (Len).

    """,
    """
    ?string_length(Str, Len)
    """],
    [
        None,
        pd.DataFrame({'Str':['a','d','ab','abc','abcd'],'Len':[1,1,2,3,4]}),
    ],
    ie_funcs=[
        ['Length',length,[str],[int]]
    ],
    # debug=True
)

In [None]:
#| hide

def ID(string: str):
        yield f'{string}_id',

def ID2(string: str):
        yield f'{string}_id2_z',f'{string}_id2_w'

In [None]:
#| hide

# empty input to ie functions
test_session(
        """
        new A(str, str)
        new B(str, str)
        A("1", "2")
        B("1", "3_id")
        C(X, Y) <- A(X, Y).
        D(X, Y, X) <- C(X, Y).
        # nothing will feed into ID but we still need output in the same schema as the first D rule
        D(X, Y, Z) <- A(X, "1"), B(X, Y), ID(X) -> (Y), ID2(Y)->(Z,W).
        ?D(X, Y, Z)
    """,
    pd.DataFrame([['1','2','1']],columns=['X','Y','Z']),
    ie_funcs=[
        ['ID',ID,[str],[str]],
        ['ID2',ID2,[str],[str,str]]
    ] 
)

In [None]:
#| hide

#multiple ie functions
test_session(
        """
        new A(str, str)
        new B(str, str)
        A("1", "2")
        B("1", "1_id")
        C(X, Y) <- A(X, Y).
        D(X, Y, X) <- C(X, Y).
        # nothing will feed into ID but we still need output in the same schema as the first D rule
        D(X, Y, Z) <- A(X, "2"), B(X, Y), ID(X) -> (Y), ID2(Y)->(Z,W).
        ?D(X, Y, Z)
    """,
    pd.DataFrame([('1', '1_id', '1_id_id2_z'), ('1', '2', '1')],columns=['X','Y','Z']),
    ie_funcs=[
        ['ID',ID,[str],[str]],
        ['ID2',ID2,[str],[str,str]]
    ] 
)


In [None]:
#| hide

# ie functions bounded by constants
def split(string: str):
    for part in string.split():
        yield (part,)

test_session(
    ["""
    new String(str)
    String("he")
    String("hehe")
    Text(T) <- Split("he ho hehe hoho")->(T).
    StringFromText(S) <- String(S), Split("he ho hehe hoho")->(S).
    ?StringFromText(S)
    """,
    """?Text(T)"""
    ],
    [
        pd.DataFrame({'S':['he','hehe']}),
        pd.DataFrame({'T':['he','ho','hehe','hoho']})
    ],   
    ie_funcs=[
        ['Split',split,[str],[str]]
    
    ]
)

In [None]:
#| hide

# Boolean queries

test_session(
    ["""
    new Parent(str, str)
    Parent("Sam", "Noah")
    Parent("Noah", "Austin")
    Parent("Austin", "Stephen")
    GrandParent(G, C) <- Parent(G, M), Parent(M, C).
    ?GrandParent("Sam", "Austin")
    """,
    """?GrandParent("Bob", "Austin")"""
    ],
    [
        True,
        False
    ]
)


In [None]:
#| hide

# use of ie functions with variable output arity
sess = test_session(
    ["""
    input_string = "John Doe: 35 years old, Jane Smith: 28 years old"
    AgeOf(Name,Age) <- 
        rgx("(\w+\s\w+):\s(\d+)",$input_string) -> (NameSpan,AgeSpan),
        as_str(NameSpan)->(Name),
        as_str(AgeSpan)->(Age).
    ""","""
    ?AgeOf(X,Y)
    """],
    [
        None,
        pd.DataFrame({'X':["John Doe","Jane Smith"],'Y':["35","28"]}),
    ],
)

In [None]:
#| hide
def format_ie(f_string,*params):
    yield f_string.format(*params),

string_schema = lambda x: ([str]*x)

test_session(
"""
new AgeOf(str, str)
AgeOf("John Doe", "35")
AgeOf("Jane Smith", "28")
age_description(Desc) <- AgeOf(Name, Age), format("{} is {} years old",Name,Age) -> (Desc).
?age_description(D)
""",
pd.DataFrame({'D':['Jane Smith is 28 years old','John Doe is 35 years old']}),
ie_funcs=[['format',format_ie,string_schema,[str]]]
)

In [None]:
#| hide

test_session(
"""
new AgeOfKids(str, int)
AgeOfKids("John Doe", 35)
AgeOfKids("John Doe", 30)
AgeOfKids("Jane Smith", 28)
total_age(X,sum(Y)) <- AgeOfKids(X,Y).
?total_age(X,T)
""",
pd.DataFrame({'X':['John Doe','Jane Smith'],'T':[65,28]}),
)

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     