In [1]:
from spanner_workbench.src.rgxlog_interpreter.src.rgxlog.engine.session import Session, queries_to_string
from spanner_workbench.src.rgxlog_interpreter.src.rgxlog.engine.datatypes.primitive_types import DataTypes
from spanner_workbench.src.rgxlog_interpreter.src.rgxlog.engine.utils.general_utils import QUERY_RESULT_PREFIX
from typing import Optional, Iterable, Dict, List

In [2]:
def split_to_tables(result: str) -> List[str]:
    """
    @param result: rgxlog's output.
    @return: List of strings, each string represents a table.
    """

    # in rgxlog's output, all tables are separated by two consecutive \n.
    return result.split("\n\n")

In [3]:
def table_to_query_free_vars_tuples(table: str) -> Iterable:
    """
    Parses the string table into a nicer format.

    @param table: the string that represents a table.
    @return: the clean format (see comments above return statements).
    """
    # split string into lines and ignore white spaces.
    # tuple[0] is always the print statement.
    tuples = [line.strip() for line in table.split("\n") if len(line.strip()) != 0]
    if len(tuples) < 2:
        raise ValueError("illegal output received: \n\"" + '\n'.join(tuples) + '"')
    # if table is empty (which means it contains one value of true/false) we return tuple.
    # tuple[0] is the print statement, tuple[1] is true/false.
    if tuples[1] in ["[()]", "[]"]:
        return tuples
    # if table is not empty, then: tuple[0] is the print statement, tuple[1] are the free vars and tuple[3:] contains
    # all the tuples inside the table .
    else:  # query   |free vars|     tuples
        return tuples[0], tuples[1], set(tuples[3:])

In [4]:
def compare_strings(expected: str, output: str) -> bool:
    """
    @param expected: expected output.
    @param output: actual output.
    @return: True if output and expected represent the same result, False otherwise.
    """
    expected = "\n".join([line.strip() for line in expected.splitlines()])
    output = "\n".join([line.strip() for line in output.splitlines()])

    expected_tables, output_tables = split_to_tables(expected), split_to_tables(output)
    # if there are different number of tables than false
    if len(expected_tables) != len(output_tables):
        return False

    # check that all the tables are equal
    for expected_table, output_table in zip(expected_tables, output_tables):
        if table_to_query_free_vars_tuples(expected_table) != table_to_query_free_vars_tuples(output_table):
            return False

    return True

In [5]:
def run_test(commands: str, expected_output: Optional[str] = None, functions_to_import: Iterable[Dict] = (),
             session: Optional[Session] = None) -> Session:
    """
    A function that executes a test.

    @param commands: the commands to run.
    @param expected_output: the expected output of the commands. if it has value of None, than we won't check the output.
    @param functions_to_import: an iterable of functions we want to import to the session.
    @param session: the session in which we run the commands.
    @return: the session it created or got as an argument.
    """
    # if session wasn't passed as an arg than we create it
    if session is None:
        session = Session()

    # import all ie functions
    for ie_function in functions_to_import:
        session.register(**ie_function)
    commands_result = session.run_commands(commands, print_results=True)

    if expected_output is not None:
        commands_result_string = queries_to_string(commands_result)
        assert compare_strings(expected_output, commands_result_string), "expected string != result string"

    return session

In [6]:
def test_issue_80_1() -> None:
    def which_century(year) -> Iterable[int]:
        yield int(year / 100) + 1

    in_out_types = [DataTypes.integer]

    which_century_dict = dict(ie_function=which_century,
                              ie_function_name='which_century',
                              in_rel=in_out_types,
                              out_rel=in_out_types)

    def which_era(cet) -> Iterable[str]:
        if 1 <= cet < 4:
            yield "Targerian Regime"
        elif 4 <= cet < 8:
            yield "Lanister Regime"
        elif 8 <= cet < 12:
            yield "Stark Regime"
        elif 12 <= cet < 16:
            yield "Barathion Regime"
        elif cet >= 16:
            yield "Long Winter"

    which_era_dict = dict(ie_function=which_era,
                          ie_function_name='which_era',
                          in_rel=[DataTypes.integer],
                          out_rel=[DataTypes.string])

    commands = """new event(str, int)
                        event("First Dragon", 250)
                        event("Mad king", 390)
                        event("Winter came", 1750)
                        event("Hodor", 999)
                        event("Joffery died", 799)
                        
                        new important_year(int)
                        important_year(999)
                        important_year(1750)
                        important_year(250)
                        
                        
                        important_events(EVE, Y) <- event(EVE, Y), important_year(Y)
                        
                        important_events_per_cet(EVE, CET) <- important_events(EVE, Y), which_century(Y) -> (CET)
                        ?important_events_per_cet(EVE, CET)
            """
    commands2 = """
                        important_events_per_era(EVE, ERA) <- important_events_per_cet(EVE, CET), which_era(CET) -> (ERA)
                        ?important_events_per_era(EVE, ERA)
            """
    expected_result = f"""{QUERY_RESULT_PREFIX}'important_events_per_cet(EVE, CET)':
                         EVE      |   CET
                    --------------+-------
                     First Dragon |     3
                     Winter came  |    18
                        Hodor     |    10
         """

    expected_result2 = f"""{QUERY_RESULT_PREFIX}'important_events_per_era(EVE, ERA)':
                         EVE      |       ERA
                    --------------+------------------
                        Hodor     |   Stark Regime
                     Winter came  |   Long Winter
                     First Dragon | Targerian Regime
        """

    session = run_test(commands, expected_result, [which_century_dict])

    run_test(commands2, expected_result2, [which_era_dict], session=session)

In [7]:
test_issue_80_1()

printing results for query 'important_events_per_cet(EVE, CET)':
     EVE      |   CET
--------------+-------
    Hodor     |    10
 Winter came  |    18
 First Dragon |     3

printing results for query 'important_events_per_era(EVE, ERA)':
     EVE      |       ERA
--------------+------------------
    Hodor     |   Stark Regime
 Winter came  |   Long Winter
 First Dragon | Targerian Regime

