# Session

In [None]:
#| default_exp session

In [None]:
#| hide
from nbdev.showdoc import show_doc

%load_ext autoreload
%autoreload 2

In [None]:
#| export
import csv

import os
import re
from pathlib import Path
from typing import Tuple, List, Union, Optional, Callable, Type, Iterable, no_type_check, Sequence
from IPython import display

import pandas as pd
import os

import logging
logger = logging.getLogger(__name__)

from graph_rewrite import draw

from spannerlib.utils import checkLogs,get_base_file_path,assert_df_equals
from spannerlib.grammar import parse_spannerlog,reconstruct
from spannerlib.span import Span
from spannerlib.data_types import (
    _infer_relation_schema,
     Var,
    FreeVar,
    RelationDefinition,
    Relation,
    IEFunction,
    IERelation,
    Rule,
    pretty,
)
from spannerlib.engine import Engine

from spannerlib.micro_passes import (
    convert_primitive_values_to_objects,
    remove_new_lines_from_strings,
    CheckReservedRelationNames,
    check_referenced_paths_exist,
    dereference_vars,
    relations_to_dataclasses,
    verify_referenced_relations,
    rules_to_dataclasses,
    consistent_free_var_types_in_rule,
    check_rule_safety,
    assignments_to_name_val_tuple,
    execute_statement,
)


In [None]:
#| export
#| hide


# from spannerlib.ie_func.json_path import JsonPath, JsonPathFull
# from spannerlib.ie_func.nlp import (Tokenize, SSplit, POS, Lemma, NER, EntityMentions, CleanXML, Parse, DepParse, Coref, OpenIE, KBP, Quote, Sentiment, TrueCase)
from spannerlib.ie_func.python_regex import PYRGX, PYRGX_STRING
# from spannerlib.ie_func.rust_spanner_regex import RGX, RGX_STRING, RGX_FROM_FILE, RGX_STRING_FROM_FILE


# # ordered by rgx, json, nlp, etc.
# PREDEFINED_IE_FUNCS = [PYRGX, PYRGX_STRING, RGX, RGX_STRING, RGX_FROM_FILE, RGX_STRING_FROM_FILE,
#                        JsonPath, JsonPathFull,
#                        Tokenize, SSplit, POS, Lemma, NER, EntityMentions, CleanXML, Parse, DepParse, Coref, OpenIE, KBP, Quote, Sentiment,
#                        TrueCase]





## Session class

In [None]:
#| export
class Session():
    def __init__(self):
        
        self.pass_stack = [
            convert_primitive_values_to_objects,
            remove_new_lines_from_strings,
            CheckReservedRelationNames('spanner_'),
            check_referenced_paths_exist,
            dereference_vars,
            relations_to_dataclasses,
            verify_referenced_relations,
            rules_to_dataclasses,
            consistent_free_var_types_in_rule,
            check_rule_safety,
            assignments_to_name_val_tuple
        ]
        self.engine=Engine()

        # TODO define all the default ie funcs

    def get_pass_stack(self):
        """
        @return: the current pass stack.
        """
        return self._pass_stack.copy()

    def set_pass_stack(self,
        user_stack #  a user supplied pass stack
        ): 
        """
        Sets a new pass stack instead of the current one.
        """
        if type(user_stack) is not list:
            raise TypeError('user stack should be a list of passes')
        self._pass_stack = user_stack.copy()

    def _display_result(self,result,statement_lark):
        if result is None:
            pass
        elif isinstance(result,pd.DataFrame):
            display.display(reconstruct(statement_lark))
            display.display(result)
        else:
            pass
            # display.display(reconstruct(statement_lark))
            # print(result)
    
    def register(self,name,func,in_schema,out_schema):
        #TODO add option for outschema to be a callable that gets the input and the output and confirms if the output is valid
        ie_func_obj = IEFunction(name=name,func=func,in_schema=in_schema,out_schema=out_schema)
        self.engine.set_ie_function(ie_func_obj)

    def parse_and_check_semantics(self,code):
        statements = parse_spannerlog(code,split_statements=True)
        asts = []
        for statement_nx,statement_lark in statements:
            ast = statement_nx
            for pass_ in self.pass_stack:
                try:
                    pass_(ast,self.engine)
                except Exception as e:
                    raise Exception(
                        f"During semantic checks for statement \n\"{reconstruct(statement_lark)}\"\n"
                        f"in pass {pass_} the following exception was raised:\n{e}\n"
                        ).with_traceback(e.__traceback__)
            yield ast,statement_lark

    def handle_boolean_results(self,res):
        if not isinstance(res,pd.DataFrame):
            return res
        if res.shape == (1,0):
            return True
        elif res.shape == (0,0):
            return False
        else:
            return res

    def plan_query(self,code):
        statements = list(self.parse_and_check_semantics(code))
        if len(statements) > 1:
            raise ValueError(f"Only one statement is allowed in plan_query, got {len(statements)}")
        ast,_ = statements[0]
        statement_node = list(ast.nodes)[0]
        node_data = ast.nodes[statement_node]
        statement = node_data['type']
        value = node_data['val']
        if statement != 'query':
            raise ValueError(f"Expected a query statement, got {statement}")

        query_graph,root =  self.engine.plan_query(value)
        return query_graph,root

    def execute_plan(self,query_graph,root,return_intermediate=False):
        res,inter = self.engine.execute_plan(query_graph,root,return_intermediate=True)
        res = self.handle_boolean_results(res)
        if return_intermediate:
            return res,inter
        return res
        

    def export(self,code,display_results=False):
        #TODO reconstruct the code for each statement using lark,reconstruct so we can print the query string together with the result
        results = []
        for clean_ast,statement_lark in self.parse_and_check_semantics(code):
            try:
                result = execute_statement(clean_ast,self.engine)
                result = self.handle_boolean_results(result)
            except Exception as e:
                raise Exception(
                    f"During execution of statement \n\"{reconstruct(statement_lark)}\n\""
                    f"the following exception was raised:\n{e}\n"
                    ).with_traceback(e.__traceback__)
            results.append(result)
            if display_results:
                self._display_result(result,statement_lark)
        
        if len(results) == 0:
            return None
        else:
            return results[-1]

    def import_rel(self,name:str,data:Union[str,Path,pd.DataFrame],delim:str = None):
        """Imports a relation into the current session, either from a dataframe or from a csv file."""
        if isinstance(data, (Path,str)):
            csv_file_name = Path(data)
            if not csv_file_name.is_file():
                raise IOError("csv file does not exist")
            if os.stat(csv_file_name).st_size == 0:
                raise IOError("csv file is empty")
            data = pd.read_csv(csv_file_name, delimiter=delim)

        first_row = list(data.iloc[0,:])
        scheme = _infer_relation_schema(first_row)
        rel_def = RelationDefinition(name=name,scheme=scheme)
        self.engine.set_relation(rel_def)
        self.engine.add_facts(name,data)
        
    def print_rules(self):
        rules = list(self.engine.rules_to_ids.keys())
        for rule in rules:
            print(rule)
        return rules
    def remove_rule(self,rule:str):
        self.engine.del_rule(rule)
    def remove_relation(self,relation:str):
        self.engine.del_relation(relation)
    def clear(self):
        self.engine = Engine()

    def get_all_functions(self):
        return self.engine.ie_functions.copy()

### Test scaffold

In [None]:
#| export
def test_session(
    queries,
    expected_outputs,# list of expected dfs
    ie_funcs=None,# List of [name,func,in_scheme,out_scheme]
    csvs=None,# List of [name,df]
    debug=False,
    display_results=True
    ):

    sess=Session()

    # add data
    if csvs:
        for name,df in csvs:
            sess.import_rel(name,df)
    # add ies
    if ie_funcs:
        for name,func,in_scheme,out_scheme in ie_funcs:
            sess.register(name,func,in_scheme,out_scheme)
    
    if not isinstance(queries,list):
        queries = [queries]
    if not isinstance(expected_outputs,list):
        expected_outputs = [expected_outputs]
    for query,expected in zip(queries,expected_outputs):
        if debug and query == queries[-1]:
            print(query)
            q,root = sess.plan_query(query)
            draw(q)
            res = sess.execute_plan(q,root)
        else:
            res = sess.export(query,display_results=display_results)
        # used for debugging, we return the results of the first query without expected
        # if expected is None:
        #     return query
        if isinstance(expected,pd.DataFrame):
            assert_df_equals(res,expected)
        else:
            assert res == expected, f"expected {expected}, got {res}"
        

## Tests

In [None]:
# TODO add test for rule removal etc...

In [None]:
# test rule removal
sess = Session()
_ = sess.export('''
    new parent(str, str)
    new grandparent(str, str)
    parent("Liam", "Noah")
    parent("Noah", "Oliver")
    parent("James", "Lucas")
    parent("Noah", "Benjamin")
    parent("Benjamin", "Mason")
    grandparent("Tom", "Avi")
    ancestor(X,Y) <- parent(X,Y)
    ancestor(X,Y) <- grandparent(X,Y)
    ancestor(X,Y) <- parent(X,Z), ancestor(Z,Y)
''')


rules = sess.print_rules()
assert rules == ['ancestor(X,Y) <- parent(X,Y)',
'ancestor(X,Y) <- grandparent(X,Y)',
'ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y)',]

sess.remove_rule("ancestor(X,Y) <- parent(X,Y)")
print("="*50)
rules = sess.print_rules()
assert rules == ['ancestor(X,Y) <- grandparent(X,Y)',
'ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y)',]

ancestor(X,Y) <- parent(X,Y)
ancestor(X,Y) <- grandparent(X,Y)
ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y)
ancestor(X,Y) <- grandparent(X,Y)
ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y)


In [None]:
# test clearing the engine
commands = """
    new parent(str, str)
    new grandparent(str, str)
    parent("Liam", "Noah")
    grandparent("Tom", "Avi")
    ancestor(X,Y) <- parent(X,Y)
    ancestor(X,Y) <- grandparent(X,Y)
    ancestor(X,Y) <- parent(X,Z), ancestor(Z,Y)
    """
session = Session()
output = session.export(commands)
session.print_rules()
session.clear()
assert session.print_rules() == []

ancestor(X,Y) <- parent(X,Y)
ancestor(X,Y) <- grandparent(X,Y)
ancestor(X,Y) <- parent(X,Z),ancestor(Z,Y)


In [None]:
# importing relations from csv
session = Session()
session.import_rel(name="enrolled",data="./sample_data/enrolled.csv", delim=",")
commands = """
enrolled("abigail", "chemistry")
?enrolled(X,Y)
"""
res = session.export(commands)
assert_df_equals(res,pd.DataFrame([
    ["abigail", "chemistry"],
    ["gale", "operating_systems"],
    ["howard", "chemistry"],
    ["howard", "physics"],
    ["jordan", "chemistry"]
],columns=["X","Y"]))

Unnamed: 0,X,Y
0,jordan,chemistry
1,gale,operating_systems
2,howard,chemistry
3,howard,physics
0,abigail,chemistry


In [None]:
# importing relations from dataframe
session = Session()
lecturer_df = pd.DataFrame(([["walter","chemistry"], ["linus", "operating_systems"]]))
session.import_rel("lecturer",lecturer_df)
commands = """ 
?lecturer(X,Y)
"""
res = session.export(commands)
assert_df_equals(res,pd.DataFrame([
    ["walter","chemistry"],
    ["linus", "operating_systems"]
],columns=["X","Y"]))


Unnamed: 0,X,Y
0,walter,chemistry
1,linus,operating_systems


In [None]:
test_session(
    queries = [
        """
        new Parent(str, str)
        Parent("Sam", "Noah")
        Parent("Noah", "Austin")
        Parent("Austin", "Stephen")

        GrandParent(G, C) <- Parent(G, M), Parent(M, C)
        ?GrandParent(X, "Austin")
        """,
        """?GrandParent(X, "Austin")"""
    ],
    expected_outputs = [
        pd.DataFrame({'X':['Sam']}),
        pd.DataFrame({'X':['Sam']})
    ],
    debug=False

)



'?GrandParent(X,"Austin")'

Unnamed: 0,X
0,Sam


'?GrandParent(X,"Austin")'

Unnamed: 0,X
0,Sam


In [None]:
def length(string: str) -> Iterable[int]:
        yield (len(string),)

_ =test_session(
    ["""new string(str)
    string("a")
    string("d")
    string("a")
    string("ab")
    string("abc")
    string("abcd")

    string_length(Str, Len) <- string(Str), Length(Str) -> (Len)

    """,
    """
    ?string_length(Str, Len)
    """],
    [
        None,
        pd.DataFrame({'Str':['a','d','ab','abc','abcd'],'Len':[1,1,2,3,4]}),
    ],
    ie_funcs=[
        ['Length',length,[str],[int]]
    ],
    debug=True
)


    ?string_length(Str, Len)
    


In [None]:
def ID(string: str):
        yield f'{string}_id',

def ID2(string: str):
        yield f'{string}_id2_z',f'{string}_id2_w'

In [None]:
# empty input to ie functions
test_session(
        """
        new A(str, str)
        new B(str, str)
        A("1", "2")
        B("1", "3_id")
        C(X, Y) <- A(X, Y)
        D(X, Y, X) <- C(X, Y)
        # nothing will feed into ID but we still need output in the same schema as the first D rule
        D(X, Y, Z) <- A(X, "1"), B(X, Y), ID(X) -> (Y), ID2(Y)->(Z,W)
        ?D(X, Y, Z)
    """,
    pd.DataFrame([['1','2','1']],columns=['X','Y','Z']),
    ie_funcs=[
        ['ID',ID,[str],[str]],
        ['ID2',ID2,[str],[str,str]]
    ] 
)

'?D(X,Y,Z)'

Unnamed: 0,X,Y,Z
0,1,2,1


In [None]:
#multiple ie functions
test_session(
        """
        new A(str, str)
        new B(str, str)
        A("1", "2")
        B("1", "1_id")
        C(X, Y) <- A(X, Y)
        D(X, Y, X) <- C(X, Y)
        # nothing will feed into ID but we still need output in the same schema as the first D rule
        D(X, Y, Z) <- A(X, "2"), B(X, Y), ID(X) -> (Y), ID2(Y)->(Z,W)
        ?D(X, Y, Z)
    """,
    pd.DataFrame([('1', '1_id', '1_id_id2_z'), ('1', '2', '1')],columns=['X','Y','Z']),
    ie_funcs=[
        ['ID',ID,[str],[str]],
        ['ID2',ID2,[str],[str,str]]
    ] 
)


'?D(X,Y,Z)'

Unnamed: 0,X,Y,Z
0,1,2,1
1,1,1_id,1_id_id2_z


In [None]:
# ie functions bounded by constants
def split(string: str):
    for part in string.split():
        yield (part,)

test_session(
    ["""
    new String(str)
    String("he")
    String("hehe")
    Text(T) <- Split("he ho hehe hoho")->(T)
    StringFromText(S) <- String(S), Split("he ho hehe hoho")->(S)
    ?StringFromText(S)
    """,
    """?Text(T)"""
    ],
    [
        pd.DataFrame({'S':['he','hehe']}),
        pd.DataFrame({'T':['he','ho','hehe','hoho']})
    ],   
    ie_funcs=[
        ['Split',split,[str],[str]]
    
    ]
)

'?StringFromText(S)'

Unnamed: 0,S
0,he
1,hehe


'?Text(T)'

Unnamed: 0,T
0,he
1,ho
2,hehe
3,hoho


In [None]:
# Boolean queries

test_session(
    ["""
    new Parent(str, str)
    Parent("Sam", "Noah")
    Parent("Noah", "Austin")
    Parent("Austin", "Stephen")
    GrandParent(G, C) <- Parent(G, M), Parent(M, C)
    ?GrandParent("Sam", "Austin")
    """,
    """?GrandParent("Bob", "Austin")"""
    ],
    [
        True,
        False
    ]
)


In [None]:
# use of ie functions with variable output arity
test_session(
    """
    input_string = "John Doe: 35 years old, Jane Smith: 28 years old"
    AgeOf(Name,Age) <- py_rgx_string(input_string, "(\w+\s\w+):\s(\d+)") -> (Name,Age)
    ?AgeOf(X,Y)
    """,
    pd.DataFrame({'X':['John Doe','Jane Smith'],'Y':['35','28']}),
    ie_funcs=[PYRGX_STRING]
)

'?AgeOf(X,Y)'

Unnamed: 0,X,Y
0,John Doe,35
1,Jane Smith,28


In [None]:
def format_ie(f_string,*params):
    yield f_string.format(*params),

string_schema = lambda x: ([str]*x)

test_session(
"""
new AgeOf(str, str)
AgeOf("John Doe", "35")
AgeOf("Jane Smith", "28")
age_description(Desc) <- AgeOf(Name, Age), format("{} is {} years old",Name,Age) -> (Desc)
?age_description(D)
""",
pd.DataFrame({'D':['Jane Smith is 28 years old','John Doe is 35 years old']}),
ie_funcs=[['format',format_ie,string_schema,[str]]]
)

'?age_description(D)'

Unnamed: 0,D
0,John Doe is 35 years old
1,Jane Smith is 28 years old


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     