In [6]:
"""
Tests for building nodes and extracting dependencies
"""
from typing import Optional

# pylint: disable=too-many-lines
import pytest
from sqlmodel import Session, select

from dj.construction.extract import (
    extract_dependencies,
    extract_dependencies_from_query,
)
from dj.errors import DJError, DJException, ErrorCode
from dj.models import Column
from dj.models.node import Node, NodeType
from dj.sql.parsing.ast import Alias, BinaryOp, BinaryOpKind
from dj.sql.parsing.ast import Column as ASTColumn
from dj.sql.parsing.ast import (
    From,
    Join,
    JoinKind,
    Name,
    Namespace,
    Select,
    String,
    Table,
)
from dj.sql.parsing.backends.sqloxide import parse
from dj.typing import ColumnType

In [8]:
from dj.utils import get_session
session = next(get_session())

In [9]:
purchases = Node(
    name="purchases",
    type=NodeType.SOURCE,
    columns=[
        Column(name="transaction_id", type=ColumnType.INT),
        Column(name="transaction_time", type=ColumnType.DATETIME),
        Column(name="transaction_amount", type=ColumnType.FLOAT),
        Column(name="customer_id", type=ColumnType.INT),
        Column(name="system_version", type=ColumnType.STR),
    ],
)
customer_events = Node(
    name="customer_events",
    type=NodeType.SOURCE,
    columns=[
        Column(name="event_id", type=ColumnType.INT),
        Column(name="event_time", type=ColumnType.DATETIME),
        Column(name="event_type", type=ColumnType.STR),
        Column(name="customer_id", type=ColumnType.INT),
        Column(name="message", type=ColumnType.STR),
    ],
)
returns = Node(
    name="returns",
    type=NodeType.SOURCE,
    columns=[
        Column(name="transaction_id", type=ColumnType.INT),
        Column(name="transaction_time", type=ColumnType.DATETIME),
        Column(name="purchase_transaction_id", type=ColumnType.INT),
    ],
)

eligible_purchases = Node(
    name="eligible_purchases",
    query="""
            SELECT transaction_id, transaction_time, transaction_amount, customer_id, system_version
            FROM purchases
            WHERE transaction_amount > 100.0
        """,
    type=NodeType.TRANSFORM,
    columns=[
        Column(name="transaction_id", type=ColumnType.INT),
        Column(name="transaction_time", type=ColumnType.DATETIME),
        Column(name="transaction_amount", type=ColumnType.FLOAT),
        Column(name="customer_id", type=ColumnType.INT),
        Column(name="system_version", type=ColumnType.STR),
    ],
)

eligible_purchases_new_system = Node(
    name="eligible_purchases_new_system",
    query="""
            SELECT transaction_id, transaction_time, transaction_amount, customer_id
            FROM eligible_purchases
            WHERE transaction_amount > 100.0
            AND system_version = 'v2'
        """,
    type=NodeType.TRANSFORM,
    columns=[
        Column(name="transaction_id", type=ColumnType.INT),
        Column(name="transaction_time", type=ColumnType.DATETIME),
        Column(name="transaction_amount", type=ColumnType.FLOAT),
        Column(name="customer_id", type=ColumnType.INT),
        Column(name="system_version", type=ColumnType.STR),
    ],
)

returned_transactions = Node(
    name="returned_transactions",
    query="""
            SELECT transaction_id, transaction_time, transaction_amount, customer_id
            FROM purchases p
            LEFT JOIN returns r
            ON p.transaction_id = r.purchase_transaction_id
            WHERE r.purchase_transaction_id is not null
        """,
    type=NodeType.TRANSFORM,
    columns=[
        Column(name="transaction_id", type=ColumnType.INT),
        Column(name="transaction_time", type=ColumnType.DATETIME),
        Column(name="transaction_amount", type=ColumnType.FLOAT),
        Column(name="customer_id", type=ColumnType.INT),
    ],
)

event_type = Node(
    name="event_type",
    type=NodeType.DIMENSION,
    query="SELECT DISTINCT event_type FROM customer_events",
    columns=[
        Column(name="event_type", type=ColumnType.STR),
    ],
)
event_type_id = Node(
    name="event_type_id",
    type=NodeType.DIMENSION,
    query="SELECT DISTINCT event_id, event_type FROM customer_events",
    columns=[
        Column(name="event_id", type=ColumnType.INT),
        Column(name="event_type", type=ColumnType.STR),
    ],
)
customer_events2 = Node(
    name="customer_events2",
    type=NodeType.SOURCE,
    columns=[
        Column(
            name="event_id",
            type=ColumnType.INT,
            dimension=event_type_id,
            dimension_column="event_id",
        ),
        Column(name="event_time", type=ColumnType.DATETIME),
        Column(name="event_type", type=ColumnType.STR),
        Column(name="customer_id", type=ColumnType.INT),
        Column(name="message", type=ColumnType.STR),
    ],
)

session.add(purchases)
session.add(customer_events)
session.add(returns)
session.add(eligible_purchases)
session.add(eligible_purchases_new_system)
session.add(returned_transactions)
session.add(event_type)
session.add(customer_events2)
session.add(event_type_id)

In [1]:
from dj.construction.build import build_node
from dj.models.node import Node, NodeType
from dj.utils import get_session
from sqlalchemy import select

session = next(get_session())



In [2]:
node = list(session.exec(select(Node).filter(Node.type == NodeType.TRANSFORM)))[1][0]

node.name

IndexError: list index out of range

In [5]:
print(node.query)

NameError: name 'node' is not defined

In [5]:
node.query = """
SELECT c.id,
       c.first_name,
       c.last_name,
       COUNT(1) AS order_cnt
FROM dbt.source.jaffle_shop.orders o
JOIN dbt.source.jaffle_shop.customers c ON o.user_id = c.id
GROUP BY c.id,
         c.first_name,
         c.last_name
"""

<IPython.core.display.Javascript object>

In [6]:
query, db = await build_node(session, node)

print(db)
print()
print(str(query))

async_=False extra_params={'connect_args': {'sslmode': 'prefer'}} updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 239735) description='A Postgres database' read_only=False uuid=UUID('f8f6a72f-adca-4f99-9791-a18bd0fb30b8') name='postgres' created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 237177) id=2 URI='postgresql://username:FoolishPassword@postgres_examples:5432/examples' cost=10.0

SELECT  c.id,
	c.first_name,
	c.last_name,
	COUNT(1) AS order_cnt 
 FROM jaffle_shop.orders AS o
INNER JOIN jaffle_shop.customers AS c
        ON o.user_id = c.id 
 GROUP BY  c.id, c.first_name, c.last_name


<IPython.core.display.Javascript object>

In [7]:
query, db = await build_node(session, node, database_id=3)

print(db)
print()
print(str(query))

Exception: The requested database with id 3 cannot run this query.

<IPython.core.display.Javascript object>

In [8]:
query, db = await build_node(session, node, database_id=2)

print(db)
print()
print(str(query))

async_=False extra_params={'connect_args': {'sslmode': 'prefer'}} updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 239735) description='A Postgres database' read_only=False uuid=UUID('f8f6a72f-adca-4f99-9791-a18bd0fb30b8') name='postgres' created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 237177) id=2 URI='postgresql://username:FoolishPassword@postgres_examples:5432/examples' cost=10.0

SELECT  c.id,
	c.first_name,
	c.last_name,
	COUNT(1) AS order_cnt 
 FROM jaffle_shop.orders AS o
INNER JOIN jaffle_shop.customers AS c
        ON o.user_id = c.id 
 GROUP BY  c.id, c.first_name, c.last_name


<IPython.core.display.Javascript object>

In [9]:
node = list(session.exec(select(Node).filter(Node.type == NodeType.TRANSFORM)))[0][0]

node.name

'basic.transform.country_agg'

<IPython.core.display.Javascript object>

In [10]:
query, db = await build_node(session, node)

print(db)
print()
print(str(query))

async_=False extra_params={'connect_args': {'sslmode': 'prefer'}} updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 239735) description='A Postgres database' read_only=False uuid=UUID('f8f6a72f-adca-4f99-9791-a18bd0fb30b8') name='postgres' created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 237177) id=2 URI='postgresql://username:FoolishPassword@postgres_examples:5432/examples' cost=10.0

SELECT  basic.dim_users.country,
	COUNT(basic.dim_users.id) AS num_users 
 FROM basic.dim_users
 
 GROUP BY  1


<IPython.core.display.Javascript object>

In [11]:
query, db = await build_node(session, node, database_id=2)

print(db)
print()
print(str(query))

async_=False extra_params={'connect_args': {'sslmode': 'prefer'}} updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 239735) description='A Postgres database' read_only=False uuid=UUID('f8f6a72f-adca-4f99-9791-a18bd0fb30b8') name='postgres' created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 237177) id=2 URI='postgresql://username:FoolishPassword@postgres_examples:5432/examples' cost=10.0

SELECT  basic.dim_users.country,
	COUNT(basic.dim_users.id) AS num_users 
 FROM basic.dim_users
 
 GROUP BY  1


<IPython.core.display.Javascript object>

In [12]:
query, db = await build_node(session, node, database_id=3)

print(db)
print()
print(str(query))

async_=False extra_params={'catalog': {'comments': 'https://docs.google.com/spreadsheets/d/1SkEZOipqjXQnxHLMr2kZ7Tbn7OiHSgO99gOCS5jTQJs/edit#gid=1811447072', 'users': 'https://docs.google.com/spreadsheets/d/1SkEZOipqjXQnxHLMr2kZ7Tbn7OiHSgO99gOCS5jTQJs/edit#gid=0'}} updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 267317) description='A Google Sheets connector' read_only=True uuid=UUID('af47f3f9-99f6-456b-9d17-1d6641f667b7') name='gsheets' created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 266959) id=3 URI='gsheets://' cost=100.0

SELECT  users.country,
	COUNT(users.id) AS num_users 
 FROM users
 
 GROUP BY  1


<IPython.core.display.Javascript object>