In [1]:
import pytest
from sqlmodel import Session, select

from dj.construction.build import (
    ColumnDependencies,
    CompoundBuildException,
    InvalidSQLException,
    MissingColumnException,
    NodeTypeException,
    UnknownNodeException,
    extract_dependencies,
    extract_dependencies_from_query,
    extract_dependencies_from_select,
    get_dj_node,
    make_name,
)
from dj.models import Column
from dj.models.node import Node, NodeType
from dj.sql.parsing.ast import Alias, BinaryOp, BinaryOpKind
from dj.sql.parsing.ast import Column as ASTColumn
from dj.sql.parsing.ast import (
    From,
    Join,
    JoinKind,
    Name,
    Namespace,
    Query,
    Select,
    String,
    Table,
)
from dj.sql.parsing.backends.sqloxide import parse
from dj.typing import ColumnType

In [43]:
        purchases = Node(
            name="purchases",
            type=NodeType.SOURCE,
            columns=[
                Column(name="transaction_id", type=ColumnType.INT),
                Column(name="transaction_time", type=ColumnType.DATETIME),
                Column(name="transaction_amount", type=ColumnType.FLOAT),
                Column(name="customer_id", type=ColumnType.INT),
            ],
        )
        customer_events = Node(
            name="customer_events",
            type=NodeType.SOURCE,
            columns=[
                Column(name="event_id", type=ColumnType.INT),
                Column(name="event_time", type=ColumnType.DATETIME),
                Column(name="event_type", type=ColumnType.STR),
                Column(name="customer_id", type=ColumnType.INT),
                Column(name="message", type=ColumnType.STR),
            ],
        )
        returns = Node(
            name="returns",
            type=NodeType.SOURCE,
            columns=[
                Column(name="transaction_id", type=ColumnType.INT),
                Column(name="transaction_time", type=ColumnType.DATETIME),
                Column(name="purchase_transaction_id", type=ColumnType.INT),
            ],
        )

        eligible_purchases = Node(
            name="eligible_purchases",
            query="""
                    SELECT transaction_id, transaction_time, transaction_amount, customer_id
                    FROM purchases
                    WHERE transaction_amount > 100.0
                """,
            type=NodeType.TRANSFORM,
            columns=[
                Column(name="transaction_id", type=ColumnType.INT),
                Column(name="transaction_time", type=ColumnType.DATETIME),
                Column(name="transaction_amount", type=ColumnType.FLOAT),
                Column(name="customer_id", type=ColumnType.INT),
            ],
        )

        returned_transactions = Node(
            name="returned_transactions",
            query="""
                    SELECT transaction_id, transaction_time, transaction_amount, customer_id
                    FROM purchases p
                    LEFT JOIN returns r
                    ON p.transaction_id = r.purchase_transaction_id
                    WHERE r.purchase_transaction_id is not null
                """,
            type=NodeType.TRANSFORM,
            columns=[
                Column(name="transaction_id", type=ColumnType.INT),
                Column(name="transaction_time", type=ColumnType.DATETIME),
                Column(name="transaction_amount", type=ColumnType.FLOAT),
                Column(name="customer_id", type=ColumnType.INT),
            ],
        )

        event_type = Node(
            name="event_type",
            type=NodeType.DIMENSION,
            query="SELECT DISTINCT event_type FROM customer_events",
            columns=[
                Column(name="event_type", type=ColumnType.STR),
            ],
        )
        session.add(purchases)
        session.add(customer_events)
        session.add(returns)
        session.add(eligible_purchases)
        session.add(returned_transactions)
        session.add(event_type)

In [15]:
hash(Name('a'))

-771431139317362362

In [19]:
query = parse(
    """
SELECT r2.id as matched_id
FROM dbt.transform.customer_agg r1
LEFT JOIN dbt.transform.customer_agg r2
ON r1.id = r2.id
WHERE r2.id is not null
""",
    "hive",
)

In [23]:
%timeit hash(query)==hash(query)

1.03 ms ± 62.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
%timeit query.compare(query)

1.68 ms ± 79.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
parse(query)#.compare(parse(query))

TypeError: argument 'sql': 'Query' object cannot be converted to 'PyString'

In [26]:
from dj.utils import get_session
session = next(get_session())

In [32]:
%pdb

Automatic pdb calling has been turned OFF


In [33]:
query_dependencies = extract_dependencies_from_query(session, query)

In [16]:
from dj.sql.parsing.ast import flatten

In [18]:
list(flatten(['hello', ['world']]))

['hello', 'world']

In [15]:
query_dependencies.select.columns

ColumnDependencies(projection=[(Column(name=Name(name='id', quote_style=''), namespace=Namespace(names=[Name(name='r2', quote_style='')])), Table(name=Name(name='customer_agg', quote_style=''), namespace=Namespace(names=[Name(name='dbt', quote_style=''), Name(name='transform', quote_style='')])))], group_by=[], filters=[(Column(name=Name(name='id', quote_style=''), namespace=Namespace(names=[Name(name='r2', quote_style='')])), Table(name=Name(name='customer_agg', quote_style=''), namespace=Namespace(names=[Name(name='dbt', quote_style=''), Name(name='transform', quote_style='')])))], ons=[(Column(name=Name(name='id', quote_style=''), namespace=Namespace(names=[Name(name='r1', quote_style='')])), Table(name=Name(name='customer_agg', quote_style=''), namespace=Namespace(names=[Name(name='dbt', quote_style=''), Name(name='transform', quote_style='')]))), (Column(name=Name(name='id', quote_style=''), namespace=Namespace(names=[Name(name='r2', quote_style='')])), Table(name=Name(name='custo

In [2]:
query="""WITH customer_total_return AS
(
           SELECT     sr_customer_sk     AS ctr_customer_sk,
                      sr_store_sk        AS ctr_store_sk,
                      Sum(sr_return_amt) AS ctr_total_return
           FROM       store_returns
           INNER JOIN date_dim
           ON         store_returns.sr_customer_sk=date_dim.d_date_sk
           WHERE      sr_returned_date_sk = d_date_sk
           AND        d_year = 2001
           GROUP BY   sr_customer_sk,
                      sr_store_sk)
SELECT     c_customer_id
FROM       customer_total_return ctr1
INNER JOIN store
ON         s_store_sk = ctr1.ctr_store_sk
AND        s_state = 'TN'
AND        ctr1.ctr_customer_sk = c_customer_sk
LEFT JOIN customer
ON         s_store_sk = ctr1.ctr_store_sk
AND        s_state = 'TN'
AND        ctr1.ctr_customer_sk = c_customer_sk
WHERE      ctr1.ctr_total_return >
           (
                  SELECT DISTINCT Avg(ctr_total_return) * 1.2
                  FROM   customer_total_return ctr2
                  WHERE  ctr1.ctr_store_sk = ctr2.ctr_store_sk)
AND        s_store_sk = ctr1.ctr_store_sk
OR         s_state <> 'TN'
AND        ctr1.ctr_customer_sk = c_customer_sk
LIMIT 100
"""

In [4]:
parse(str(parse(query)))

Query(select=Select(from_=From(tables=[Alias(name=Name(name='ctr1', quote_style=''), namespace=None, child=Table(name=Name(name='customer_total_return', quote_style=''), namespace=None))], joins=[Join(kind=JoinKind.Inner, table=Table(name=Name(name='store', quote_style=''), namespace=None), on=BinaryOp(op=BinaryOpKind.And, left=BinaryOp(op=BinaryOpKind.And, left=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='s_store_sk', quote_style=''), namespace=None), right=Column(name=Name(name='ctr_store_sk', quote_style=''), namespace=Namespace(names=[Name(name='ctr1', quote_style='')]))), right=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='s_state', quote_style=''), namespace=None), right=String(value='TN'))), right=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='ctr_customer_sk', quote_style=''), namespace=Namespace(names=[Name(name='ctr1', quote_style='')])), right=Column(name=Name(name='c_customer_sk', quote_style=''), namespace=None)))), Join(kind=JoinKind.LeftO

In [6]:
print(str(parse(query).ctes[0]))

> [0;32m/code/dj/sql/parsing/ast.py[0m(736)[0;36m__str__[0;34m()[0m
[0;32m    734 [0;31m        [0mselect[0m [0;34m=[0m [0;34m" "[0m[0;34m.[0m[0mjoin[0m[0;34m([0m[0mparts[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    735 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 736 [0;31m        [0;32mif[0m [0msubselect[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    737 [0;31m            [0;32mreturn[0m [0;34m"("[0m [0;34m+[0m [0mselect[0m [0;34m+[0m [0;34m")"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    738 [0;31m        [0;32mreturn[0m [0mselect[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> c
(SELECT  sr_customer_sk AS ctr_customer_sk,
	sr_store_sk AS ctr_store_sk,
	Sum(sr_return_amt) AS ctr_total_return 
 FROM store_returns
INNER JOIN date_dim
        ON store_returns.sr_customer_sk = date_dim.d_date_sk 
 WHERE  sr_returned_

In [44]:
query = parse(
    """
SELECT r2.transaction_id as matched_id
FROM returns r1
LEFT JOIN returns r2
ON r1.purchase_transaction_id = r2.transaction_id
WHERE r2.transaction_id is not null
""",
    "hive",
)
query_dependencies = extract_dependencies_from_query(session, query)
dependencies = query_dependencies.select

UnknownNodeException: No NodeType.TRANSFORM or NodeType.DIMENSION or NodeType.SOURCE node `returns` exists. `returns`

In [40]:
hash(tuple(query_dependencies.select.columns.all_columns))

2507612997478913879

In [41]:
hash(tuple(query_dependencies.select.columns.all_columns))

2507612997478913879

In [None]:
dependencies = query_dependencies.select