# Utils

In [None]:
#| default_exp utils

In [None]:
#| export
import shlex
import logging
import psutil
import requests
import os
from pathlib import Path
import git
from configparser import ConfigParser
from subprocess import Popen, PIPE
from sys import platform
from threading import Timer
from typing import no_type_check, get_type_hints, Iterable, Any, Optional, Callable
from fastcore.basics import patch
import itertools
from singleton_decorator import singleton

In [None]:
#| export 
@singleton
class UniqueId:
    def __init__(self):
        self.reset()
    def __call__(self):
        return next(self.counter)
    def reset(self):
        self.counter = itertools.count()
    

def uniq_id():
    return UniqueId()()

In [None]:
# #TODO see how to use this when debugging
# logging.config.dictConfig(
#     {
#         "version": 1,
#         "disable_existing_loggers": True,
#         "formatters": {
#             "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
#             "terse": {"format": "%(message)s"},
#         },
#         "handlers": {
#             "default": {
#                 "level": "INFO",
#                 "formatter": "standard",
#                 "class": "logging.StreamHandler",
#                 "stream": "ext://sys.stderr",  # Default is stderr
#             },
#             "terse": {
#                 "level": "WARNING",
#                 "formatter": "terse",
#                 "class": "logging.StreamHandler",
#                 "stream": "ext://sys.stderr",  # Default is stderr
#             },
#         },
#         "loggers": {
#             "root": {
#                 "handlers": ["default"],
#             },
#             "spannerlib": {
#                 "level": "INFO",
#             },
#             "spannerlib.session": {
#                 "handlers": ["terse"],
#                 "incremental": False,
#                 "level": "DEBUG",
#             },
#         },
#     }
# )

In [None]:
#| export
logger = logging.getLogger(__name__)

WINDOWS_OS = "win32"
IS_POSIX = (platform != WINDOWS_OS)

# google drive
GOOGLE_DRIVE_URL = "https://docs.google.com/uc?export=download"
GOOGLE_DRIVE_CHUNK_SIZE = 32768

In [None]:
#| export
def get_git_root(path='.'):

        git_repo = git.Repo(path, search_parent_directories=True)
        git_root = git_repo.git.rev_parse("--show-toplevel")
        return Path(git_root)

In [None]:
#| export

def get_base_file_path() -> Path: # The absolute path of parent folder of nbs
    return get_git_root()


def get_lib_name() -> str:
    setting_ini = ConfigParser()
    setting_ini.read(get_base_file_path()/'settings.ini')
    setting_ini = setting_ini['DEFAULT']
    return setting_ini['lib_name']

In [None]:
#| export
from contextlib import contextmanager
import logging

@contextmanager
def checkLogs(level: int=logging.DEBUG, name :str=None, toFile=None):
    """context manager for temporarily changing logging levels. used for debugging purposes

    Args:
        level (logging.Level: optional): logging level to change the logger to. Defaults to logging.DEBUG.
        name (str: optional): module name to raise logging level for. Defaults to root logger
        toFile (Path: optional): File to output logs to. Defaults to None
        

    Yields:
        [logging.Logger]: the logger object that we raised the level of
    """
    logger = logging.getLogger(name)
    current_level = logger.getEffectiveLevel()
    format = "%(name)s - %(levelname)s - %(message)s"
    logger.setLevel(level)
    if len(logger.handlers) == 0:
        sh = logging.StreamHandler()
        sh.setFormatter(logging.Formatter(format))
        logger.addHandler(sh)
    if toFile != None:
        fh = logging.FileHandler(toFile)
        fh.setFormatter(logging.Formatter(format))
        logger.addHandler(fh)
    try:
        yield logger
    finally:
        logger.setLevel(current_level)
        if toFile != None:
            logger.removeHandler(fh)
        if len(logger.handlers) == 1:
            logger.handlers= []

In [None]:
#| export
def patch_method(func : Callable, *args, **kwargs) -> None:
    """
    Applies fastcore's `patch` decorator and removes `func` from `cls.__abstractsmethods__` in case <br>
    `func` is an `abstractmethods`
    """
    cls = next(iter(get_type_hints(func).values()))
    try:
        abstracts_needed = set(cls.__abstractmethods__)
        abstracts_needed.discard(func.__name__)
        cls.__abstractmethods__ = abstracts_needed
    except AttributeError: # If the class does not inherit from an abstract class
        pass
    finally:
        # Apply the original `patch` decorator
        patch(*args, **kwargs)(func)

In [None]:
#| export
def kill_process_and_children(process: Popen) -> None:
    logger.info("~~~~ process timed out ~~~~")
    if process.poll() is not None:
        ps_process = psutil.Process(process.pid)
        for child in ps_process.children(recursive=True):  # first, kill the children :)
            child.kill()  # not recommended in real life
        process.kill()  # lastly, kill the process

In [None]:
#| export
def run_cli_command(command: str, # a single command string
                    stderr: bool = False, # if true, suppress stderr output. default: `False`
                    # if true, spawn shell process (e.g. /bin/sh), which allows using system variables (e.g. $HOME),
                    # but is considered a security risk (see: https://docs.python.org/3/library/subprocess.html#security-considerations)
                    shell: bool = False, 
                    timeout: float = -1 # if positive, kill the process after `timeout` seconds. default: `-1`
                    ) -> Iterable[str]: # string iterator
    """
    This utility can be used to run any cli command, and iterate over the output.
    """
    # `shlex.split` just splits the command into a list properly
    command_list = shlex.split(command, posix=IS_POSIX)
    stdout = PIPE  # we always use stdout
    stderr_channel = PIPE if stderr else None

    process = Popen(command_list, stdout=stdout, stderr=stderr_channel, shell=shell)

    # set timer
    if timeout > 0:
        # set timer to kill the process
        process_timer = Timer(timeout, kill_process_and_children, [process])
        process_timer.start()

    # get output
    if process.stdout:
        process.stdout.flush()
    process_stdout, process_stderr = [s.decode("utf-8") for s in process.communicate()]
    for output in process_stdout.splitlines():
        output = output.strip()
        if output:
            yield output

    if stderr:
        logger.info(f"stderr from process {command_list[0]}: {process_stderr}")

In [None]:
#| export
import os
def download_file_from_google_drive(file_id: str, # the id of the file to download
                                     destination: Path # the path to which the file will be downloaded
                                     ) -> None:
    """
    [Downloads a file from Google Drive](https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039)
    """
    destination = Path(os.path.join(get_base_file_path(Path.cwd()),'spannerlog','stanford-corenlp-4.1.0.zip'))
    requests_session = requests.Session()
    response = requests_session.get(GOOGLE_DRIVE_URL, params={'id': file_id}, stream=True)

    def get_confirm_token() -> Optional[Any]:
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value

        return None

    def save_response_content() -> None:
        with open(destination, "wb") as f:
            for chunk in response.iter_content(GOOGLE_DRIVE_CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)

    token = get_confirm_token()
    logger.debug(f"got token from google: {token}")

    if token:
        params = {'id': file_id, 'confirm': token}
        response = requests_session.get(GOOGLE_DRIVE_URL, params=params, stream=True)

    save_response_content()

In [None]:
#| export
def df_to_list(df):
    return df.to_dict(orient='records')

In [None]:
#| export
def serialize_tree(g):
    root = next(nx.topological_sort(g))
    return nx.tree_data(g,root) 


## Old general utils

In [None]:
#| hide
from nbdev.showdoc import show_doc
from __future__ import annotations
%load_ext autoreload
%autoreload 2

In [None]:
#| export
#| output: false
import functools
import re
from typing import (Union, Tuple, Set, Dict, List, Optional, Callable, Any, no_type_check, Sequence)

from spannerlib.ast_node_types import (Relation, IERelation, Rule)
from spannerlib.primitive_types import DataTypes, Span
from spannerlib.symbol_table import SymbolTableBase, SymbolTable

In [None]:
#| export
SPAN_GROUP1 = "start"
SPAN_GROUP2 = "end"

# as of now, we don't support negative/float numbers (for both spans and integers)
SPAN_PATTERN = re.compile(r"^\[(?P<start>\d+), ?(?P<end>\d+)\)$")
QUERY_RESULT_PREFIX = "printing results for query "

In [None]:
#| export
def strip_lines(text: str) -> str:
    """
    removes leading and trailing whitespace from each line in the input text and excludes empty lines.
    """
    return "\n".join([line.strip() for line in text.splitlines() if line.strip()])

In [None]:
#| export
def fixed_point(start: Any, # a starting value
                 step: Callable, # a step function
                   distance: Callable, # a function that measures distance between the input and the output of the step function
                     thresh: int = 0 # a distance threshold
                     ) -> Any:
    """
    Implementation of a generic fixed point algorithm - an algorithm that takes a step function and runs it until
    some distance is zero or below a threshold.
    """
    x = start
    y = step(x)
    while distance(x, y) > min(thresh,0):
        x = y
        y = step(x)
    return x

In [None]:
#| export
def get_free_var_names(term_list: Sequence, # a list of terms
                       type_list: Sequence # a list of the term types
                       ) -> Set[str]: # a set of all the free variable names in term_list
    """ 
    @raise Exception: if length of term list doesn't match the length of type list.
    """
    if len(term_list) != len(type_list):
        raise Exception(f"received different lengths of term_list ({len(term_list)}) "
                        f"and type_list ({len(type_list)})")
    free_var_names = set(term for term, term_type in zip(term_list, type_list)
                         if term_type is DataTypes.free_var_name)
    return free_var_names

#| hide
##### TEST

In [None]:
#| hide
import pytest

In [None]:
#| hide
term_list = ["X", 10, "Z"]
type_list = [DataTypes.free_var_name, DataTypes.integer, DataTypes.free_var_name]

result = get_free_var_names(term_list, type_list)
assert result == {"X", "Z"}

In [None]:
#| hide
term_list = []
type_list = []

result = get_free_var_names(term_list, type_list)
assert result == set()

In [None]:
#| hide
term_list = ["X", "y", "Z"]
type_list = [DataTypes.free_var_name, DataTypes.integer]

with pytest.raises(Exception, match="received different lengths of term_list"):
    get_free_var_names(term_list, type_list)

::: {.callout-note collapse="true"}

##### Example

In [None]:
term_list = ["X", 10, "Z"]
type_list = [DataTypes.free_var_name, DataTypes.integer, DataTypes.free_var_name]    
print(get_free_var_names(term_list, type_list))

{'Z', 'X'}


:::

In [None]:
#| export
@no_type_check
def position_freevar_pairs(relation: Union[Relation, IERelation] # a relation (either a normal relation or an ie relation)
                           ) -> List[Tuple[int, str]]: # a list of all (index, free_var) pairs based on term_list
    term_list, type_list = relation.get_term_list(), relation.get_type_list()
    pos_var_pairs = [(i, term) for i, (term, term_type) in enumerate(zip(term_list, type_list))
                     if term_type is DataTypes.free_var_name]
    return pos_var_pairs

#| hide
##### TEST

In [None]:
#| hide
term_list = ["X", "abc", "Y", "def", "Z"]
type_list = [DataTypes.free_var_name, DataTypes.string, DataTypes.free_var_name, DataTypes.integer, DataTypes.free_var_name]

relation = Relation("relation1",term_list, type_list)

result = position_freevar_pairs(relation)
expected_result = [(0, "X"), (2, "Y"), (4, "Z")]

assert result == expected_result

In [None]:
#| hide
term_list = []
type_list = []

relation = Relation("relation1",term_list, type_list)

result = position_freevar_pairs(relation)
assert result == []

In [None]:
#| hide
term_list = ["abc", "def", "ghi"]
type_list = [DataTypes.string, DataTypes.string, DataTypes.string]

relation = Relation("relation1",term_list, type_list)

result = position_freevar_pairs(relation)
assert result == []

::: {.callout-note collapse="true"}

##### Example

In [None]:
term_list = ["X", "abc", "Y", "def", "Z"]
type_list = [DataTypes.free_var_name, DataTypes.string, DataTypes.free_var_name, DataTypes.integer, DataTypes.free_var_name]

relation = Relation("relation1",term_list, type_list)

print(position_freevar_pairs(relation))

[(0, 'X'), (2, 'Y'), (4, 'Z')]


:::

In [None]:
#| export
def get_input_free_var_names(relation: Union[Relation, IERelation] # a relation (either a normal relation or an ie relation)
                             ) -> Set[Any]: # a set of the free variables used as input terms in the relation.
    if isinstance(relation, IERelation):
        return get_free_var_names(relation.input_term_list, relation.input_type_list)
    else:
        return set()

::: {.callout-note}
if the input is relation it returns empty set as the regular relation don't have input free variables,
if the input is ie-relation it returns it's input free variables
:::

#| hide
##### TEST

In [None]:
#| hide
input_term_list = ["X", "Y"]
input_type_list = [DataTypes.free_var_name, DataTypes.integer]

ie_relation = IERelation("relation1",input_term_list, input_type_list, [], [])

result = get_input_free_var_names(ie_relation)
assert result == {"X"}

In [None]:
#| hide
term_list = ["abc", "def", "ghi"]
type_list = [DataTypes.string, DataTypes.string, DataTypes.string]
relation = Relation("relation1",term_list, type_list)

result = get_input_free_var_names(relation)
assert result == set()  

::: {.callout-note collapse="true"}

##### Example


In [None]:
input_term_list = ["X", "Y"]
input_type_list = [DataTypes.free_var_name, DataTypes.integer]

ie_relation = IERelation("relation1",input_term_list, input_type_list, [], [])

print(get_input_free_var_names(ie_relation))

{'X'}


:::

In [None]:
#| export
def get_output_free_var_names(relation: Union[Relation, IERelation] # a relation (either a normal relation or an ie relation)
                              ) -> Set[str]: # a set of the free variables used as output terms in the relation
    return get_free_var_names(relation.get_term_list(), relation.get_type_list())

::: {.callout-note}
if the input is relation it returns empty set as the regular relation don't have input free variables,
if the input is ie-relation it returns it's input free variables
:::

In [None]:
#| export
def get_free_var_to_relations_dict(relations: Set[Union[Relation, IERelation]] # a set of relations
                                   ) -> (Dict[str, List[Tuple[Union[Relation, IERelation], int]]]): # a mapping between each free var to the relations and corresponding columns in which it appears
    """
    Finds for each free var in any of the relations, all the relations that contain it.
    also return the free vars' index in each relation (as pairs). <br>
    for example: <br>
        relations = [a(X,Y), b(Y)] ->
        dict = {X:[(a(X,Y),0)], Y:[(a(X,Y),1),(b(Y),0)]}
    """
    # note: don't remove variables with less than 2 uses here, we need them as well
    free_var_positions = {relation: position_freevar_pairs(relation) for relation in relations}
    free_var_set = {var for pair_list in free_var_positions.values() for (_, var) in pair_list}

    # create a triple of every relation, free var position, and free var name. these will be united inside var_dict.
    rel_pos_var_triple = [(relation, pos, free_var) for (relation, pair_list) in free_var_positions.items() for
                          (pos, free_var) in pair_list]

    var_dict = {var_from_set: [(relation, free_var_pos) for (relation, free_var_pos, var_from_triple) in rel_pos_var_triple if var_from_set == var_from_triple]
                for var_from_set in free_var_set}

    return var_dict

::: {.callout-note collapse="true"}

##### Example

In [None]:
term_list_1 = ["X", "abc", "Y", "def", "Z"]
type_list_1 = [DataTypes.free_var_name, DataTypes.string, DataTypes.free_var_name, DataTypes.integer, DataTypes.free_var_name]

relation_1 = Relation("relation_1",term_list_1, type_list_1)

term_list_2 = ["X", "Y"]
type_list_2 = [DataTypes.free_var_name, DataTypes.free_var_name]

relation_2 = Relation("relation_2",term_list_2, type_list_2)
relations_set = {relation_1,relation_2}
get_free_var_to_relations_dict(relations_set)

{'Z': [(Relation(relation_1(X, "abc", Y, def, Z)), 4)],
 'Y': [(Relation(relation_2(X, Y)), 1),
  (Relation(relation_1(X, "abc", Y, def, Z)), 2)],
 'X': [(Relation(relation_2(X, Y)), 0),
  (Relation(relation_1(X, "abc", Y, def, Z)), 0)]}

:::

In [None]:
#| export
def check_properly_typed_term_list(term_list: Sequence, # the term list to be type checked
                                    type_list: Sequence, # the types of the terms in term_list
                                   correct_type_list: Sequence, # a list of the types that the terms must have to pass the type check
                                     symbol_table: SymbolTableBase # a symbol table (used to get the types of variables)
                                     ) -> bool: # True if the type check passed, else False
    """
    Checks if the term list is properly typed.
    the term list could include free variables, this method will assume their actual type is correct.
    """
    if len(term_list) != len(type_list) or len(term_list) != len(correct_type_list):
        raise Exception("the length of term_list, type_list and correct_type_list should be the same")
        

    # perform the type check
    for term, term_type, correct_type in zip(term_list, type_list, correct_type_list):

        if term_type is DataTypes.var_name:
            # current term is a variable, get its type from the symbol table
            term_type = symbol_table.get_variable_type(term)

        if term_type is not DataTypes.free_var_name and term_type != correct_type:
            # the term is a literal that is not properly typed, the type check failed
            return False

    # all variables are properly typed, the type check succeeded
    return True

#| hide
##### TEST

In [None]:
#| hide
symbol_table = SymbolTable()

term_list = ["x", "y", "z"]
type_list = [DataTypes.var_name, DataTypes.var_name, DataTypes.free_var_name]
correct_type_list = [DataTypes.integer, DataTypes.string, DataTypes.free_var_name]

symbol_table.set_var_value_and_type("x", None, DataTypes.integer)
symbol_table.set_var_value_and_type("y", None, DataTypes.string)

result = check_properly_typed_term_list(term_list, type_list, correct_type_list, symbol_table)
assert result is True

In [None]:
#| hide
symbol_table = SymbolTable()

term_list = ["x", "y", "z"]
type_list = [DataTypes.var_name, DataTypes.var_name, DataTypes.free_var_name]
correct_type_list = [DataTypes.string, DataTypes.integer, DataTypes.free_var_name]

symbol_table.set_var_value_and_type("x", None, DataTypes.integer)
symbol_table.set_var_value_and_type("y", None, DataTypes.string)

result = check_properly_typed_term_list(term_list, type_list, correct_type_list, symbol_table)
assert result is False

In [None]:
#| hide
symbol_table = SymbolTable()

term_list = ["x", "y", "z"]
type_list = [DataTypes.var_name, DataTypes.var_name]
correct_type_list = [DataTypes.integer, DataTypes.string]

symbol_table.set_var_value_and_type("x", None, DataTypes.integer)
symbol_table.set_var_value_and_type("y", None, DataTypes.string)

with pytest.raises(Exception, match="the length of term_list, type_list and correct_type_list should be the same"):
    check_properly_typed_term_list(term_list, type_list, correct_type_list, symbol_table)

#| hide
#### check_properly_typed_relation

In [None]:
#| export
@no_type_check
def check_properly_typed_relation(relation: Union[Relation, IERelation] # the relation to be checked
                                  , symbol_table: SymbolTableBase # a symbol table (to check the types of regular variables)
                                  ) -> bool: # true if the relation is properly typed, else false
    """
    Checks if a relation is properly typed, this check ignores free variables.
    """

    if isinstance(relation, Relation):
        # get the schema of the relation
        relation_schema = symbol_table.get_relation_schema(relation.relation_name)
        # check if the relation's term list is properly typed
        relation_is_properly_typed = check_properly_typed_term_list(
            relation.term_list, relation.type_list, relation_schema, symbol_table)

    elif isinstance(relation, IERelation):

        # get the input and output schemas of the ie function
        ie_func_name = relation.relation_name
        ie_func_data = symbol_table.get_ie_func_data(ie_func_name)
        input_schema = ie_func_data.get_input_types()
        output_arity = len(relation.output_term_list) + len(relation.input_term_list)
        output_schema = ie_func_data.get_output_types(output_arity)

        # perform the type check on both the input and output term lists
        # both of them need to be properly typed for the check to pass
        input_type_check_passed = check_properly_typed_term_list(
            relation.input_term_list, relation.input_type_list, input_schema, symbol_table)
        output_type_check_passed = check_properly_typed_term_list(
            relation.input_term_list + relation.output_term_list,
            relation.input_type_list + relation.output_type_list, output_schema, symbol_table)
        relation_is_properly_typed = input_type_check_passed and output_type_check_passed

    else:
        raise Exception(f'unexpected relation type: {type(relation)}')

    return relation_is_properly_typed

#| hide
##### TEST

In [None]:
#| hide
symbol_table = SymbolTable()

symbol_table.add_relation_schema("relation1", [DataTypes.integer, DataTypes.string], is_rule=False)

relation = Relation("relation1", ["x", "y"], [DataTypes.free_var_name, DataTypes.free_var_name])
result = check_properly_typed_relation(relation, symbol_table)
assert result is True

In [None]:
#| hide
symbol_table = SymbolTable()

symbol_table.add_relation_schema("ie_relation1", [DataTypes.integer, DataTypes.string], is_rule=True)

ie_relation = IERelation("ie_relation1", ["x"], [DataTypes.free_var_name], ["y"], [DataTypes.free_var_name])
with pytest.raises(ValueError):
    check_properly_typed_relation(ie_relation, symbol_table)


In [None]:
#| hide
symbol_table = SymbolTable()

symbol_table.add_relation_schema("ie_relation1", [DataTypes.integer, DataTypes.string], is_rule=True)

ie_relation = IERelation("ie_relation1", ["x"], [DataTypes.free_var_name], ["y"], [DataTypes.string])
with pytest.raises(ValueError):
    check_properly_typed_relation(ie_relation, symbol_table)

In [None]:
#| hide
symbol_table = SymbolTable()

symbol_table.add_relation_schema("ie_relation1", [DataTypes.integer, DataTypes.string], is_rule=True)

with pytest.raises(Exception):
    IERelation("ie_relation1", ["x"], [DataTypes.free_var_name], ["y"], [DataTypes.free_var_name, DataTypes.integer])

In [None]:
#| hide
symbol_table = SymbolTable()

with pytest.raises(Exception, match="unexpected relation type:"):
    check_properly_typed_relation("unexpected_relation", symbol_table)

#| hide
#### type_check_rule_free_vars

In [None]:
#| export
def type_check_rule_free_vars_aux(term_list: Sequence, # the term list of a rule body relation
                                   type_list: Sequence, # the types of the terms in term_list
                                     correct_type_list: Sequence, # a list of the types that the terms in the term list should have
                                  free_var_to_type: Dict, # a mapping of free variables to their type (those that are currently known)
                                    # a set of the free variables that are found to have conflicting types
                                    conflicted_free_vars: Set # this function adds conflicting free variables that it finds to this set
                                    ) -> None:
    """
    A helper function for the method `type_check_rule_free_vars`
    performs the free variables type checking on term_list.
    """

    if len(term_list) != len(type_list) or len(term_list) != len(correct_type_list):
        raise Exception("the length of term_list, type_list and correct_type_list should be the same")

    for term, term_type, correct_type in zip(term_list, type_list, correct_type_list):
        if term_type is DataTypes.free_var_name:
            # found a free variable, check for conflicting types
            free_var = term
            assert isinstance(free_var, str), "a free_var must be of type str"
            if free_var in free_var_to_type:
                # free var already has a type, make sure there's no conflict with the expected type.
                free_var_type = free_var_to_type[free_var]
                if free_var_type != correct_type:
                    # found a conflicted free var, add it to the conflicted free vars set
                    conflicted_free_vars.add(free_var)
            else:
                # free var does not currently have a type, map it to the correct type
                free_var_to_type[free_var] = correct_type

In [None]:
#| export
def type_check_rule_free_vars(rule: Rule, # The rule to be checked
                               symbol_table: SymbolTableBase # a symbol table (used to get the schema of the relation)
                                # a tuple (free_var_to_type, conflicted_free_vars) where
                                # free_var_to_type: a mapping from a free variable to its type
                                # conflicted_free_vars: a set of all the conflicted free variables
                               ) -> Tuple[Dict[str, DataTypes], Set[str]]:
    """
    Free variables in rules get their type from the relations in the rule body. <br>
    it is possible for a free variable to be expected to be more than one type (meaning it has conflicting types). <br>
    for each free variable in the rule body relations, this method will check for its type and will check if it
    has conflicting types
    """

    free_var_to_type: Dict[str, DataTypes] = {}
    conflicted_free_vars: Set[str] = set()

    for relation, relation_type in zip(rule.body_relation_list, rule.body_relation_type_list):

        if isinstance(relation, Relation):
            # get the schema for the relation
            relation_schema = symbol_table.get_relation_schema(relation.relation_name)
            # perform the free variable type checking
            type_check_rule_free_vars_aux(relation.term_list, relation.type_list, relation_schema,
                                          free_var_to_type, conflicted_free_vars)

        elif isinstance(relation, IERelation):
            # get the input and output schema of the ie function
            ie_func_name = relation.relation_name
            ie_func_data = symbol_table.get_ie_func_data(ie_func_name)
            input_schema = ie_func_data.get_input_types()
            output_arity = len(relation.output_term_list) + len(relation.input_term_list)
            output_schema = ie_func_data.get_output_types(output_arity)

            # perform the free variable type checking on both the input and output term lists of the ie relation
            type_check_rule_free_vars_aux(relation.input_term_list, relation.input_type_list,
                                          input_schema, free_var_to_type, conflicted_free_vars)
            type_check_rule_free_vars_aux(relation.input_term_list + relation.output_term_list,
                                          relation.input_type_list + relation.output_type_list,
                                          output_schema, free_var_to_type, conflicted_free_vars)

        else:
            raise Exception(f'unexpected relation type: {relation_type}')

    return free_var_to_type, conflicted_free_vars

::: {.callout-note}
this function updates `free_var_to_type`'s mapping if it finds new free variables in term_list
:::

#| hide
##### TEST

In [None]:
#| hide
def doSomething(x,y):
    yield x
rule = Rule(
    "rule2",
    [IERelation("doSomething", ["x","y"], [DataTypes.free_var_name,DataTypes.free_var_name], ["x"], [DataTypes.free_var_name])],
    ["IERelation"]
)

symbol_table = SymbolTable()
symbol_table.register_ie_function(doSomething,"doSomething", [DataTypes.integer, DataTypes.string], [DataTypes.integer])
free_var_to_type, conflicted_free_vars = type_check_rule_free_vars(rule, symbol_table)
assert free_var_to_type == {"x": DataTypes.integer, "y": DataTypes.string}
assert conflicted_free_vars == set()

In [None]:
#| hide
rule = Rule(
    "rule1",
    [Relation("Person", ["x", "y"], [DataTypes.free_var_name, DataTypes.free_var_name])],
    ["Relation"]
)
symbol_table = SymbolTable()
symbol_table.add_relation_schema("Person", [DataTypes.string, DataTypes.integer], False)
free_var_to_type, conflicted_free_vars = type_check_rule_free_vars(rule, symbol_table)
assert free_var_to_type == {"x": DataTypes.string, "y": DataTypes.integer}
assert conflicted_free_vars == set()

In [None]:
#| hide
def doSomething(x,y):
    yield x

rule = Rule(
    "rule3",
    [
        Relation("Person", ["x", "y"], [DataTypes.free_var_name, DataTypes.free_var_name]),
        IERelation("doSomething", ["x", "y"], [DataTypes.integer, DataTypes.string], ["x"], [DataTypes.integer])
    ],
    ["Relation", "IERelation"]
)
symbol_table = SymbolTable()
symbol_table.add_relation_schema("Person", [DataTypes.string, DataTypes.integer], False)
symbol_table.register_ie_function(doSomething, "doSomething", [DataTypes.string, DataTypes.integer], [DataTypes.string])
free_var_to_type, conflicted_free_vars = type_check_rule_free_vars(rule, symbol_table)
assert free_var_to_type == {"x": DataTypes.string, "y": DataTypes.integer}
assert conflicted_free_vars == set()

#| hide
#### rule_to_relation_name

In [None]:
#| export
def rule_to_relation_name(rule: str # a string that represents a rule
                          ) -> str: # the name of the rule relation
    """
    Extracts the relation name from the rule string.
    """

    return rule.strip().split('(')[0]

#| hide
#### string_to_span

In [None]:
#| export
def string_to_span(string_of_span: str # str represenation of a `Span` object
                   ) -> Optional[Span]: # `Span` object initialized based on the `string_of_span` it received as input 
    span_match = re.match(SPAN_PATTERN, string_of_span)
    if not span_match:
        return None
    start, end = int(span_match.group(SPAN_GROUP1)), int(span_match.group(SPAN_GROUP2))
    return Span(span_start=start, span_end=end)

#| hide
#### extract_one_relation

In [None]:
#| export
def extract_one_relation(func: Callable) -> Callable:
    """
    This decorator is used by engine operators that expect to get exactly one input relation but actually get a list of relations.
    """

    @functools.wraps(func)
    def wrapper(ref: Any, input_relations: Any, *args: Any, **kwargs: Any) -> Any:
        """
        Flattens the relations list.
        """
        if isinstance(input_relations, Relation):
            return func(ref, input_relations, *args, **kwargs)

        assert len(input_relations) == 1
        return func(ref, input_relations[0], *args, **kwargs)

    return wrapper

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()
     