# Utils

In [None]:
#| default_exp tests.utils

In [None]:
#| hide
from nbdev.showdoc import show_doc

%load_ext autoreload
%autoreload 2

In [None]:
#| export
#| output: false
import numpy as np
import tempfile
from pandas import DataFrame
from pathlib import Path
from typing import List, Optional, Iterable, Dict, no_type_check, Type
import pandas as pd

from graph_rewrite import draw
from spannerlib.utils import assert_df_equals
from spannerlib.span import Span
from spannerlib.session import Session

## Scaffolding

In [None]:
#| export

def test_session(
    queries,
    expected_outputs,# list of expected dfs
    ie_funcs=None,# List of [name,func,in_scheme,out_scheme]
    csvs=None,# List of [name,df]
    debug=False,
    ):

    sess=Session()

    # add data
    if csvs:
        for name,df in csvs:
            sess.import_rel(name,df)
    # add ies
    if ie_funcs:
        for name,func,in_scheme,out_scheme in ie_funcs:
            sess.register_function(name,func,in_scheme,out_scheme)
    
    if not isinstance(queries,list):
        queries = [queries]
    if not isinstance(expected_outputs,list):
        expected_outputs = [expected_outputs]
    for query,expected in zip(queries,expected_outputs):
        if debug and query == queries[-1]:
            print(query)
            q,root = sess.plan_query(query)
            draw(q)
            res = sess.execute_plan(q)
        else:
            res = sess.export(query)
        # used for debugging, we return the results of the first query without expected
        if expected is None:
            return query
        assert_df_equals(res,expected)
        

## Basic tests

In [None]:
# test assignment
test_session(
    """
        new Relation(int, int)
        x = 1
        y = 2
        z = y
        Relation(x, y)
        Relation(y, x)
        ?Relation(X, x)
    """,
    pd.DataFrame([[2]],columns=['X']),
)

In [None]:
# test copy table rule
test_session(
    """
        new B(int, int)
        B(1, 1)
        B(1, 2)
        B(2, 3)
        A(X, Y) <- B(X, Y)
        ?A(X, Y)
    """,
    pd.DataFrame(
        [[1,1],[1,2],[2,3]],
        columns=['X','Y']
    )
)

In [None]:
# test join 2 tables
test_session(
    """
        new B(int, int)
        new C(int, int)
        B(1, 1)
        B(1, 2)
        B(2, 3)
        C(2, 2)
        C(1, 1)
        D(X, Y, Z) <- B(X, Y), C(Y, Z)
        ?D(X, Y, Z)
    """,
    pd.DataFrame(
        [[1,2,2],[1,1,1]],
        columns=['X','Y','Z']
    )
)

In [None]:
# rel with same free var
test_session(
    """
        new B(int, int)
        B(1, 1)
        B(1, 2)
        B(2, 2)
        A(X) <- B(X, X)
        ?A(X)
    """,
    pd.DataFrame(
        [[1],[2]],
        columns=['X']
    )
)

In [None]:
# union with same vars
test_session(
    """
        new B(int, int)
        new C(int, int)
        B(1, 1)
        B(1, 2)
        B(2, 3)
        C(2, 2)
        C(1, 1)

        A(X, Y) <- B(X, Y)
        A(X, Y) <- C(X, Y)
        ?A(X, Y)
    """,
    pd.DataFrame(
        [[1,1],[1,2],[2,2],[2,3]],
        columns=['X','Y']
    )
)

In [None]:
# union with different vars
test_session(
    """
        new B(int, int)
        new C(int, int)
        B(1, 1)
        B(1, 2)
        B(2, 3)
        C(2, 2)
        C(1, 1)

        A(X, Y) <- B(X, Y)
        A(Z, W) <- C(Z, W)
        ?A(X, Y)
    """,
    pd.DataFrame(
        [[1,1],[1,2],[2,2],[2,3]],
        columns=['X','Y']
    )
)

In [None]:
# test project
test_session(
    """
            new B(int, int)
            B(1, 1)
            B(1, 2)

            A(X) <- B(X, Y)
            ?A(X)
    """,
    pd.DataFrame(
        [[1]],
        columns=['X']
    )
)

In [None]:
# add fact after rule
test_session(
    """
            new B(int, int)
            B(1, 1)
            A(X, Y) <- B(X, Y)
            B(1, 2)
            ?A(Z, W)
    """,
    pd.DataFrame(
        [[1,1],[1,2]],
        columns=['Z','W']
    )
)

In [None]:
# test data types
test_session(
    """
            new B(int, str, span)
            B(1, "2", [1, 2))
            ?B(X, Y, Z)
    """,
    pd.DataFrame(
        [[1,"2",Span(1,2)]],
        columns=['X','Y','Z']
    )
)

In [None]:
# join same relation
test_session(
    """
            new Parent(str, str)
            Parent("Sam", "Noah")
            Parent("Noah", "Austin")
            Parent("Austin", "Stephen")


            GrandParent(G, C) <- Parent(G, M), Parent(M, C)
            ?GrandParent(X, "Austin")
    """,
    pd.DataFrame(
        [['Sam']],
        columns=['X']
    )
)