In [1]:
import pytest
from sqlmodel import Session, select

from dj.construction.build import (
    ColumnDependencies,
    CompoundBuildException,
    InvalidSQLException,
    MissingColumnException,
    NodeTypeException,
    UnknownNodeException,
    extract_dependencies,
    extract_dependencies_from_query,
    extract_dependencies_from_select,
    get_dj_node,
    make_name,
)
from dj.models import Column
from dj.models.node import Node, NodeType
from dj.sql.parsing.ast import Alias, BinaryOp, BinaryOpKind
from dj.sql.parsing.ast import Column as ASTColumn
from dj.sql.parsing.ast import (
    From,
    Join,
    JoinKind,
    Name,
    Namespace,
    Query,
    Select,
    String,
    Table,
)
from dj.sql.parsing.backends.sqloxide import parse
from dj.typing import ColumnType

In [2]:
query = parse(
    """
SELECT r2.id as matched_id
FROM dbt.transform.customer_agg r1
LEFT JOIN dbt.transform.customer_agg r2
ON r1.purchase_id = r2.id
WHERE r2.id is not null
""",
    "hive",
)

In [3]:
from dj.utils import get_session
session = next(get_session())

In [11]:
%pdb

Automatic pdb calling has been turned ON


In [4]:
query_dependencies = extract_dependencies_from_query(session, query)

InvalidSQLException: Namespace `r1` has no column `purchase_id`. `r1.purchase_id` from `r1.purchase_id = r2.id`

In [2]:
query="""WITH customer_total_return AS
(
           SELECT     sr_customer_sk     AS ctr_customer_sk,
                      sr_store_sk        AS ctr_store_sk,
                      Sum(sr_return_amt) AS ctr_total_return
           FROM       store_returns
           INNER JOIN date_dim
           ON         store_returns.sr_customer_sk=date_dim.d_date_sk
           WHERE      sr_returned_date_sk = d_date_sk
           AND        d_year = 2001
           GROUP BY   sr_customer_sk,
                      sr_store_sk)
SELECT     c_customer_id
FROM       customer_total_return ctr1
INNER JOIN store
ON         s_store_sk = ctr1.ctr_store_sk
AND        s_state = 'TN'
AND        ctr1.ctr_customer_sk = c_customer_sk
LEFT JOIN customer
ON         s_store_sk = ctr1.ctr_store_sk
AND        s_state = 'TN'
AND        ctr1.ctr_customer_sk = c_customer_sk
WHERE      ctr1.ctr_total_return >
           (
                  SELECT DISTINCT Avg(ctr_total_return) * 1.2
                  FROM   customer_total_return ctr2
                  WHERE  ctr1.ctr_store_sk = ctr2.ctr_store_sk)
AND        s_store_sk = ctr1.ctr_store_sk
OR         s_state <> 'TN'
AND        ctr1.ctr_customer_sk = c_customer_sk
LIMIT 100
"""

In [4]:
parse(str(parse(query)))

Query(select=Select(from_=From(tables=[Alias(name=Name(name='ctr1', quote_style=''), namespace=None, child=Table(name=Name(name='customer_total_return', quote_style=''), namespace=None))], joins=[Join(kind=JoinKind.Inner, table=Table(name=Name(name='store', quote_style=''), namespace=None), on=BinaryOp(op=BinaryOpKind.And, left=BinaryOp(op=BinaryOpKind.And, left=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='s_store_sk', quote_style=''), namespace=None), right=Column(name=Name(name='ctr_store_sk', quote_style=''), namespace=Namespace(names=[Name(name='ctr1', quote_style='')]))), right=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='s_state', quote_style=''), namespace=None), right=String(value='TN'))), right=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='ctr_customer_sk', quote_style=''), namespace=Namespace(names=[Name(name='ctr1', quote_style='')])), right=Column(name=Name(name='c_customer_sk', quote_style=''), namespace=None)))), Join(kind=JoinKind.LeftO

In [6]:
print(str(parse(query).ctes[0]))

> [0;32m/code/dj/sql/parsing/ast.py[0m(736)[0;36m__str__[0;34m()[0m
[0;32m    734 [0;31m        [0mselect[0m [0;34m=[0m [0;34m" "[0m[0;34m.[0m[0mjoin[0m[0;34m([0m[0mparts[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    735 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 736 [0;31m        [0;32mif[0m [0msubselect[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    737 [0;31m            [0;32mreturn[0m [0;34m"("[0m [0;34m+[0m [0mselect[0m [0;34m+[0m [0;34m")"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    738 [0;31m        [0;32mreturn[0m [0mselect[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> c
(SELECT  sr_customer_sk AS ctr_customer_sk,
	sr_store_sk AS ctr_store_sk,
	Sum(sr_return_amt) AS ctr_total_return 
 FROM store_returns
INNER JOIN date_dim
        ON store_returns.sr_customer_sk = date_dim.d_date_sk 
 WHERE  sr_returned_

In [None]:
dependencies = query_dependencies.select