In [1]:
from dj.construction.extract import extract_dependencies
from dj.construction.inference import get_type_of_expression
from dj.sql.parsing import ast

In [34]:
from dj.utils import get_session
from sqlmodel import select
from dj import models
session = next(get_session())

In [3]:
tree, deps, danglers = extract_dependencies(session, """
    select Sum(revenue) from purchases_over_a_grand where revenue>1000.0 group by country
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None)]

Projection Exp types  [(Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]), <ColumnType.FLOAT: 'FLOAT'>)]



({Node(created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), id=23, description='', name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, query='select country, revenue from revenue where revenue>1000.0', columns=[Column(type=<ColumnType.STR: 'STR'>, dimension_id=None, name='country', id=129, dimension_column=None), Column(type=<ColumnType.FLOAT: 'FLOAT'>, dimension_id=None, name='revenue', id=130, dimension_column=None)])},
 set())

In [4]:
from dj.sql.parsing.backends.sqloxide import parse

from dj.construction.compile import compile_query, get_dj_node

from dj.construction.build import build_query

In [None]:
select(Node).filter(Node.name == node_name)

In [46]:
list(session.exec(select(models.Node).filter(models.Node.type == models.node.NodeType.DIMENSION)))

[Node(created_at=datetime.datetime(2022, 12, 27, 18, 42, 38, 434993), updated_at=datetime.datetime(2022, 12, 27, 18, 42, 38, 434997), id=5, description='User dimension', name='basic.dimension.users', type=<NodeType.DIMENSION: 'dimension'>, query='SELECT id,\n       full_name,\n       age,\n       country,\n       gender,\n       preferred_language,\n       secret_number\nFROM basic.source.users'),
 Node(created_at=datetime.datetime(2022, 12, 27, 18, 42, 38, 475059), updated_at=datetime.datetime(2022, 12, 27, 18, 42, 38, 475063), id=6, description='User dimension', name='dbt.dimension.customers', type=<NodeType.DIMENSION: 'dimension'>, query='SELECT id,\n       first_name,\n       last_name\nFROM dbt.source.jaffle_shop.customers', columns=[Column(type=<ColumnType.INT: 'INT'>, dimension_id=None, name='id', id=38, dimension_column=None), Column(type=<ColumnType.STR: 'STR'>, dimension_id=None, name='first_name', id=39, dimension_column=None), Column(type=<ColumnType.STR: 'STR'>, dimension_

In [20]:
from copy import deepcopy

In [57]:
query = parse("""
    select Count(id) from dbt.source.jaffle_shop.orders group by dbt.dimension.customers.first_name
    """)

In [58]:
print(str(query))

SELECT  Count(id) 
 FROM dbt.source.jaffle_shop.orders
 
 GROUP BY  dbt.dimension.customers.first_name


In [59]:
q2 = deepcopy(query)

In [60]:
print(str(q2))#precompile

SELECT  Count(id) 
 FROM dbt.source.jaffle_shop.orders
 
 GROUP BY  dbt.dimension.customers.first_name


In [61]:
query_deps = compile_query(session, query)

In [62]:
q3 = deepcopy(query)

In [63]:
print(str(q3))#post compile
# notice there is no ambiguity in column references

SELECT  Count(dbt.source.jaffle_shop.orders.id) 
 FROM dbt.source.jaffle_shop.orders
 
 GROUP BY  dimension.customers.dbt.first_name


In [67]:
from itertools import chain

In [28]:
build_query(session, query, query_deps)

In [68]:
# you can see there is a column with a table that references a dj node that is a dimension
# the column's .table, not being found in tables in the FROM would mean it would need to be joined
[c.table.dj_node.type for c in chain(*(exp.find_all(ast.Column) for exp in q3.select.group_by))]

[<NodeType.DIMENSION: 'dimension'>]

In [48]:
# after compile the table knows that several columns were referenced from it
q3.select.from_.tables[0].columns

{Column(name=Name(name='country', quote_style=''), namespace=None),
 Column(name=Name(name='revenue', quote_style=''), namespace=None)}

In [32]:
print(str(query))

SELECT  Sum((SELECT  country,
	revenue 
 FROM revenue
 
 WHERE  revenue > 1000.0 
) AS _purchases_over_a_grand.revenue) 
 FROM (SELECT  country,
	revenue 
 FROM revenue
 
 WHERE  revenue > 1000.0 
) AS _purchases_over_a_grand
 
 WHERE  purchases_over_a_grand.revenue > 1000.0 
 GROUP BY  (SELECT  country,
	revenue 
 FROM revenue
 
 WHERE  revenue > 1000.0 
) AS _purchases_over_a_grand.country


In [15]:
query.select.where.left._table

Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)

In [7]:
from_=tree.select.projection[0]
to=ast.Alias(ast.Name('sum_revenue'), child=from_)

In [8]:
tree.select.replace(from_, to)

Select(from_=From(tables=[Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)], joins=[]), group_by=[Column(name=Name(name='country', quote_style=''), namespace=None)], having=None, projection=[Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]))], where=BinaryOp(op=BinaryOpKind.Gt, left=Column(name=Name(name='revenue', quote_style=''), namespace=None), right=Number(value=1000.0)), limit=None, distinct=False)

In [2]:
print(str(tree))

NameError: name 'tree' is not defined

In [4]:
def name_exp(exp: ast.Expression, default: str):
    if isinstance(exp, ast.Alias):
        return make_name(exp.namespace, exp.name.name)
    return default

In [5]:
{name_exp(exp, f"_col{i}"): get_type_of_expression(exp) for i, exp in enumerate(tree.select.projection)}

{'_col0': <ColumnType.FLOAT: 'FLOAT'>}

In [6]:
tree, deps, danglers = extract_dependencies(session,     """
    select Count(*) from purchases_over_a_grand where revenue>1000.0 group by country
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None)]

Projection Exp types  [(Function(name=Name(name='Count', quote_style=''), namespace=Namespace(names=[]), args=[Wildcard()]), <ColumnType.INT: 'INT'>)]



({Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])},
 set())

In [7]:
tree, deps, danglers = extract_dependencies(session, """
    select sum_revenue from (select Sum(revenue) as sum_revenue from purchases_over_a_grand where revenue>1000.0 group by country)
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), None)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)])))]

Projection Exp types  [(Column(name=Name(name='sum_revenue', q

({Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])},
 set())

In [8]:
tree, deps, danglers = extract_dependencies(session, """
    select sum_revenue from (select Sum(revenue) as sum_revenue from purchases_over_a_grand where revenue>1000.0 group by country) as a, oops
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), None)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)])))]

Projection Exp types  [(Column(name=Name(name='sum_revenue', q

({Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])},
 {'oops'})