In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
from dj.construction.compile import compile_query, compile_select, compile_node
from dj.construction.extract import extract_dependencies
from dj.construction.build import build_query
from dj.construction.exceptions import CompoundBuildException
from dj.sql.parsing.backends.sqloxide import parse
from dj.sql.parsing import ast
from dj.models.node import Node, NodeType
from dj.models.column import Column, ColumnType
from dj.models.database import Database
from dj.sql.dag import get_cheapest_online_database
from dj.utils import get_session

session = next(get_session())

<IPython.core.display.Javascript object>

In [3]:
from typing import Set, List, Dict, Optional, Tuple
from functools import reduce, lru_cache
from sqlalchemy import select
from sqlmodel import Session

<IPython.core.display.Javascript object>

In [4]:
query = """
    select Count(id) as count_id from dbt.source.jaffle_shop.orders group by dbt.dimension.customers.first_name
    """

<IPython.core.display.Javascript object>

In [5]:
node = list(session.exec(select(Node).filter(Node.type == NodeType.TRANSFORM)))[0][0]

<IPython.core.display.Javascript object>

In [6]:
def get_node_materialized_databases(
    node: Node,
    columns: Set[str],
) -> Set[Database]:
    """
    Return all the databases where the node is explicitly materialized
    """
    tables = [
        table
        for table in node.tables
        if columns <= {column.name for column in table.columns}
    ]
    return {table.database for table in tables}

BuildPlan = Tuple[ast.Query, Dict[Node, Tuple[Set[Database], "BuildPlan"]]]

def generate_build_plan(node: Node, dialect: Optional[str] = None) -> BuildPlan:
    """
    Returns the list of optimal databases in order of decreasing goodness

    gooder firster
    """

    tree, deps, _ = extract_dependencies(session, node.query, dialect)
    databases = {}
    for node, tables in deps.items():
        columns = {col.name.name for table in tables for col in table.columns}

        node_mat_dbs = get_node_materialized_databases(node, columns)
        build_plan = None
        if node.type != NodeType.SOURCE:
            build_plan = generate_build_plan(node, dialect)
        databases[node] = (node_mat_dbs, build_plan)

    return tree, databases

def _level_database(bp: BuildPlan, levels: List[List[Set[Database]]], level: int = 0):
    if levels is None:
        levels = []
    tree, sbp = bp
    nodes = set((node for node, _ in sbp.items()))
    dbi = reduce(lambda a, b: a & b, (dbs for _, (dbs, _) in sbp.items()))

    while level >= len(levels):
        levels.append([])
    levels[level].append(dbi)

    for _, (_, ssbp) in sbp.items():
        if ssbp:
            level_database(ssbp, levels, level + 1)
    return levels

async def optimize_level_by_cost(bp: BuildPlan) -> Tuple[int, Database]:
    levels = []
    _level_database(bp, levels)
    some_db = False
    cheapest_levels = []
    for level in levels:
        try:
            cheapest_levels.append(
                await get_cheapest_online_database(reduce(lambda a, b: a & b, level))
            )
            some_db = True
        except Exception as exc:
            if "No active database found" in str(exc):
                cheapest_levels.append(None)
            else:
                raise exc

    if not some_db:
        raise Exception("No database found")

    return sorted(
        ((i, cl) for i, cl in enumerate(cheapest_levels)),
        key=lambda icl: icl[1].cost if icl[1] else float("-inf"),
    )[0]

async def optimize_level_by_database_id(bp: BuildPlan, database_id: int) -> Tuple[int, Database]:
    levels = []
    _level_database(bp, levels)
    some_db = False
    combined_levels = [reduce(lambda a, b: a & b, level) for level in levels]
    for i, level in enumerate(combined_levels):
        for database in level:
            if database.id == database_id and await database.do_ping():
                return i, database
    raise Exception(f"The requested database with id {database_id} cannot run this query.")


<IPython.core.display.Javascript object>

In [8]:
bp = generate_build_plan(node)

<IPython.core.display.Javascript object>

In [9]:
await optimize_level_by_cost(bp)

(0,
 Database(async_=False, extra_params={'connect_args': {'sslmode': 'prefer'}}, updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 239735), description='A Postgres database', read_only=False, uuid=UUID('f8f6a72f-adca-4f99-9791-a18bd0fb30b8'), name='postgres', created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 237177), id=2, URI='postgresql://username:FoolishPassword@postgres_examples:5432/examples', cost=10.0))

<IPython.core.display.Javascript object>

In [10]:
await optimize_level_by_database_id(bp, 3)

(0,
 Database(async_=False, extra_params={'catalog': {'comments': 'https://docs.google.com/spreadsheets/d/1SkEZOipqjXQnxHLMr2kZ7Tbn7OiHSgO99gOCS5jTQJs/edit#gid=1811447072', 'users': 'https://docs.google.com/spreadsheets/d/1SkEZOipqjXQnxHLMr2kZ7Tbn7OiHSgO99gOCS5jTQJs/edit#gid=0'}}, updated_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 267317), description='A Google Sheets connector', read_only=True, uuid=UUID('af47f3f9-99f6-456b-9d17-1d6641f667b7'), name='gsheets', created_at=datetime.datetime(2022, 12, 27, 18, 42, 36, 266959), id=3, URI='gsheets://', cost=100.0))

<IPython.core.display.Javascript object>

In [None]:
"""Functions to add to an ast DJ node queries"""

from dataclasses import dataclass, field
from itertools import chain
from functools import reduce
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast

from sqlmodel import Session

from dj.errors import DJError, DJException, ErrorCode
from dj.models.node import Node, NodeType
from dj.models.column import Column
from dj.models.database import Database
from dj.sql.parsing import ast
from dj.sql.parsing.backends.sqloxide import parse
from dj.construction.compile import compile_query
from functools import reduce
from string import ascii_letters, digits

ACCEPTABLE_CHARS = set(ascii_letters + digits + "_")
LOOKUP_CHARS = {
    ".": "DOT",
    "'": "QUOTE",
    '"': "DQUOTE",
    "`": "BTICK",
    "!": "EXCL",
    "@": "AT",
    "#": "HASH",
    "$": "DOLLAR",
    "%": "PERC",
    "^": "CARAT",
    "&": "AMP",
    "*": "STAR",
    "(": "LPAREN",
    ")": "RPAREN",
    "[": "LBRACK",
    "]": "RBRACK",
    "-": "MINUS",
    "+": "PLUS",
    "=": "EQ",
}


def amenable_name(name: str) -> str:
    """Takes a string and makes it have only alphanumerics/_"""
    ret = []
    cont = []
    for c in name:
        if c in ACCEPTABLE_CHARS:
            cont.append(c)
        elif c in LOOKUP_CHARS:
            ret.append("".join(cont))
            ret.append(LOOKUP_CHARS[c])
            cont = []
        else:
            ret.append("".join(cont))
            ret.append("_")
            cont = []

    return "_".join(ret) + "_" + "".join(cont)


# flake8: noqa: C901
def build_select(
    session: Session,
    select: ast.Select,
    build_plan: BuildPlan,
    build_plan_depth: int,
    database: Database,
    dialect: Optional[str] = None,    
) -> ast.Select:
    """transforms a select ast by replacing dj node references with their asts"""
    dimension_columns: Dict[Node, List[ast.Column]] = {}
    tables: Dict[Node, List[ast.Table]] = {}

    for table in select.find_all(ast.Table):
        if node := table.dj_node:
            if node.type!=NodeType.SOURCE:
                tables[node] = tables.get(node) or []
                tables[node].append(table)
        else:
            raise Exception(f"To build, Table {table} requires a pointer to a DJ Node.")

    for col in select.find_all(ast.Column):
        if isinstance(col.table, ast.Table):
            if node := col.table.dj_node:
                if node.type == NodeType.DIMENSION:
                    dimension_columns[node] = dimension_columns.get(node) or []
                    dimension_columns[node].append(col)

    for dim_node, dim_cols in dimension_columns.items():
        if dim_node not in tables:  # need to join dimension
            alias = amenable_name(dim_node.name)
            join_info: Dict[str, Tuple[Node, List[Column]]] = {}
            for table_node in tables:
                join_dim_cols = [
                    col for col in table_node.columns if col.dimension == dim_node
                ]
                join_info[table_node] = join_dim_cols
            dim_query = build_node(session, dim_node, dialect)

            if dim_query.ctes:  # will have to build ctes in as subqueries to the select
                raise Exception("DJ does not currently support ctes here.")

            dim_select = dim_query.select
            dim_ast = ast.Alias(ast.Name(alias), child=dim_select)
            for dim_col in dim_cols:
                dim_col.add_table(dim_select)

            joins: List[ast.Join] = []

            for table_node, cols in join_info.items():
                ast_tables = tables[table_node]
                for table in ast_tables:
                    on = []
                    for col in cols:
                        on.append(
                            ast.BinaryOp(
                                ast.BinaryOpKind.Eq,
                                ast.Column(ast.Name(col.name), _table=table),
                                ast.Column(
                                    ast.Name(col.dimension_column), _table=dim_ast
                                ),
                            ),
                        )
                joins.append(
                    ast.Join(
                        ast.JoinKind.LeftOuter,
                        dim_ast,
                        reduce(
                            lambda l, r: ast.BinaryOp(ast.BinaryOpKind.And, l, r), on
                        ),
                    ),
                )

            select.from_.joins += joins

    for node, tbls in tables.items():
        node_query = build_node(session, node, dialect)

        if (
            node_query.ctes
        ):  # will have to build ctes in as subqueries to the select
            raise Exception("DJ does not currently support ctes here.")

        node_select = node_query.select
        node_ast = ast.Alias(ast.Name(alias), child=node_select)
        for tbl in tbls:
            select.replace(tbl, node_ast)


    #sources
    sources: Dict[Node, List[ast.Table]] = {}

    for table in select.find_all(ast.Table):
        if node := table.dj_node:
            if node.type==NodeType.SOURCE:
                tables[node] = tables.get(node) or []
                tables[node].append(table)
        else:
            raise Exception(f"To build, Table {table} requires a pointer to a DJ Node.")

    source_dbs_intersect = []
    source_db_costs = {}
    for node, tbls in sources.items():
        db_names = set()
        for db_table in node.tables:
            name = db_table.table
            source_db_costs[name] = source_db_costs.get(name, 0) + db_table.cost
            db_names.add(name)
        source_dbs_intersect.append(db_names)

    source_intersect = reduce(lambda a, b: a&b, source_dbs_intersect)

    if not source_intersect:
        raise Exception("Sources do not share a common database.")
    return select


def build_query(
    session: Session,
    query: ast.Query,
    build_plan: BuildPlan,
    build_plan_depth: int,
    database: Database,
    dialect: Optional[str] = None, 
) -> ast.Query:
    """transforms a query ast by replacing dj node references with their asts"""
    select = query.to_select()
    build_select(session, select, dialect)
    for i, exp in enumerate(select.projection):
        if not isinstance(exp, ast.Named):
            name = f"_col{i}"
            aliased = ast.Alias(ast.Name(name), child=exp)
            select.replace(exp, aliased)
    return query


await def build_node(
    session: Session,
    node: Node,
    dialect: Optional[str] = None,
    database_id: Optional[int] = None
) -> ast.Query:
    """transforms a query ast by replacing dj node references with their asts"""
    query, build_plan = generate_build_plan(node, dialect)
    if database_id is not None:
        build_plan_depth, database = optimize_level_by_database_id(build_plan, database_id)
    else:
        build_plan_depth, database = optimize_level_by_cost(build_plan)
    
    query = parse(node.query, dialect)
    return build_query(session, query, build_plan, build_plan_depth, database, dialect)


In [4]:
query = parse(query)

In [5]:
compiled = compile_query(session, query)

In [7]:
list(compiled.find_all(ast.Table, True))

IndexError: list index out of range

In [None]:
def level_database(bp: BuildPlan, levels: List[List[Set[Database]]]):
    tree, sbp = bp
    nodes = set(ssbp for node, _ in sbp)
    ssbps = (ssbp for _, (_, ssbp) in sbp)
    dbi = reduce(lambda a, b: a&b, (dbs for _, (dbs, _) in sbp))
    dbs = sorted(dbi, key=lambda db: db.cost)
    db_select = dbs[0] if dbs else None
    levels.append((nodes, tree, db_select))
    
        
        

In [8]:
CompoundBuildException().errors

[]

In [None]:
def get_computable_databases(query: Query)->Set[Database]:
    tables: Dict[Node, Set[str]] = {}

    for table in query.find_all(ast.Table):
        if node := table.dj_node:
            tables[node] = tables.get(node) or []
            tables[node].append(table)
        else:
            raise Exception(f"To build, Table {table} requires a pointer to a DJ Node.")
    

In [5]:
built = build_query(session, query)

TypeError: reduce() of empty iterable with no initial value

In [6]:
print(str(built))

SELECT  Count(dbt.source.jaffle_shop.orders.id) 
 FROM dbt.source.jaffle_shop.orders
LEFT JOIN (SELECT  dbt.source.jaffle_shop.customers.id,
	dbt.source.jaffle_shop.customers.first_name,
	dbt.source.jaffle_shop.customers.last_name 
 FROM dbt.source.jaffle_shop.customers
 
) AS dbt_DOT_dimension_DOT_customers
        ON dbt.source.jaffle_shop.orders.user_id = dbt_DOT_dimension_DOT_customers.id 
 GROUP BY  first_name


In [7]:
d={}
d['x']=d.get('x', 0)+1

In [8]:
d

{'x': 1}

In [1]:
from dj.construction.extract import extract_dependencies
from dj.construction.inference import get_type_of_expression
from dj.sql.parsing import ast

In [2]:

from sqlmodel import select
from dj import models


In [3]:
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast

In [4]:
from functools import reduce

In [5]:
from dj.models.node import Node, NodeType

In [6]:
from string import ascii_letters, digits

ACCEPTABLE_CHARS = set(ascii_letters + digits + "_")
LOOKUP_CHARS = {
    ".": "DOT",
    "'": "QUOTE",
    '"': "DQUOTE",
    "`": "BTICK",
    "!": "EXCL",
    "@": "AT",
    "#": "HASH",
    "$": "DOLLAR",
    "%": "PERC",
    "^": "CARAT",
    "&": "AMP",
    "*": "STAR",
    "(": "LPAREN",
    ")": "RPAREN",
    "[": "LBRACK",
    "]": "RBRACK",
    "-": "MINUS",
    "+": "PLUS",
    "=": "EQ",
}


def amenable_name(name: str) -> str:
    """Takes a string and makes it have only alphanumerics/_"""
    ret = []
    cont = []
    for c in name:
        if c in ACCEPTABLE_CHARS:
            cont.append(c)
        elif c in LOOKUP_CHARS:
            ret.append("".join(cont))
            ret.append(LOOKUP_CHARS[c])
            cont = []
        else:
            ret.append("".join(cont))
            ret.append("_")
            cont = []

    return "_".join(ret) + "_" + "".join(cont)

In [7]:
tree, deps, danglers = extract_dependencies(session, """
    select Sum(revenue) from purchases_over_a_grand where revenue>1000.0 group by country
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

RecursionError: maximum recursion depth exceeded in comparison

In [8]:
from dj.sql.parsing.backends.sqloxide import parse

from dj.construction.compile import compile_query, get_dj_node

# from dj.construction.build import build_query

In [9]:
from copy import deepcopy

In [10]:
from itertools import chain

In [27]:
query = parse("""
    select Count(id) from dbt.source.jaffle_shop.orders group by dbt.dimension.customers.first_name
    """)

In [28]:
print(str(query))

SELECT  Count(id) 
 FROM dbt.source.jaffle_shop.orders
 
 GROUP BY  dbt.dimension.customers.first_name


In [29]:
q2 = deepcopy(query)

In [30]:
print(str(q2))#precompile

SELECT  Count(id) 
 FROM dbt.source.jaffle_shop.orders
 
 GROUP BY  dbt.dimension.customers.first_name


In [31]:
query_deps = compile_query(session, query)

In [32]:
q3 = deepcopy(query)

In [33]:
print(str(q3))#post compile
# notice there is no ambiguity in column references

SELECT  Count(dbt.source.jaffle_shop.orders.id) 
 FROM dbt.source.jaffle_shop.orders
 
 GROUP BY  dimension.customers.dbt.first_name


In [34]:
# you can see there is a column with a table that references a dj node that is a dimension
# the column's .table, not being found in tables in the FROM would mean it would need to be joined
[c.table.dj_node.type for c in chain(*(exp.find_all(ast.Column) for exp in q3.select.group_by))]

[<NodeType.DIMENSION: 'dimension'>]

In [35]:
from dj.models.column import Column

In [36]:
select = query.select

In [37]:
dimension_columns: Dict[Node, List[ast.Column]] = {}
tables: Dict[Node, List[ast.Table]] = {}

for table in select.find_all(ast.Table):
    node = table.dj_node
    node_type = node.type
    tables[node] = (tables.get(node) or [])
    tables[node].append(table)
        
for col in select.find_all(ast.Column):
    if isinstance(col.table, ast.Table):
        if node:= col.table.dj_node:
            if node.type == NodeType.DIMENSION:
                dimension_columns[node] = (dimension_columns.get(node) or [])
                dimension_columns[node].append(col)

for dim_node, dim_cols in dimension_columns.items():
    if dim_node in tables:
        pass
    else:# need to join dimension
        alias = amenable_name(dim_node.name)
        join_info: Dict[str, Tuple[Node, List[Column]]] = {}
        for table_node in tables:
            join_dim_cols = [col for col in table_node.columns if col.dimension == dim_node]
            join_info[table_node] = join_dim_cols
        dim_query = parse(dim_node.query, dialect)
        
        
        if dim_query.ctes: #will have to build ctes in as subqueries to the select
            raise Exception("DJ does not currently support ctes here.")
        
        dim_select = dim_query.select
        dim_ast = ast.Alias(ast.Name(alias), child=dim_select)
        for dim_col in dim_cols:
            dim_col.add_table(dim_select)
        
        joins: List[ast.Join] = []

        for table_node, cols in join_info.items():
            ast_tables = tables[table_node]
            for table in ast_tables:
                on = []
                for col in cols:
                    on.append(
                        ast.BinaryOp(
                            ast.BinaryOpKind.Eq,
                            ast.Column(ast.Name(col.name), _table=table),
                            ast.Column(
                                ast.Name(col.dimension_column), _table = dim_ast
                            ),
                        ),
                    )
            joins.append(
                ast.Join(
                    ast.JoinKind.LeftOuter,
                    dim_ast,
                    reduce(lambda l, r: ast.BinaryOp(ast.BinaryOpKind.And, l, r), on),
                ),
            )
        
        select.from_.joins += joins
        
for node, tables in tables.items():
    if node.type==NodeType.TRANSFORM:
        tfm_query = parse(node.query, dialect)
        
        
        if tfm_query.ctes: #will have to build ctes in as subqueries to the select
            raise Exception("DJ does not currently support ctes here.")
        
        tfm_select = tfm_query.select
        tfm_ast = ast.Alias(ast.Name(alias), child=tfm_select)
        for table in tables:
            select.replace(table, tfm_ast)

In [41]:
print(str(query))

SELECT  Count(dbt.source.jaffle_shop.orders.id) 
 FROM dbt.source.jaffle_shop.orders
LEFT JOIN (SELECT  id,
	first_name,
	last_name 
 FROM dbt.source.jaffle_shop.customers
 
) AS dbt_DOT_dimension_DOT_customers
        ON dbt.source.jaffle_shop.orders.user_id = dbt_DOT_dimension_DOT_customers.id 
 GROUP BY  first_name


In [31]:
query.select.from_.joins[0].on.right.table.name.name

'dbt_DOT_dimension_DOT_customers'

<IPython.core.display.Javascript object>

In [57]:
print(str(query.select.from_.joins[0].on.right))

(SELECT  id,
	first_name,
	last_name 
 FROM dbt.source.jaffle_shop.customers
 
) AS dbt_DOT_dimension_DOT_customers.id


<IPython.core.display.Javascript object>

In [51]:
print(str((query.select.from_.joins[0].on)))

dbt.source.jaffle_shop.orders.user_id = (SELECT  id,
	first_name,
	last_name 
 FROM dbt.source.jaffle_shop.customers
 
) AS dbt_DOT_dimension_DOT_customers.id


<IPython.core.display.Javascript object>

In [37]:
Query(
    select=Select(
        from_=From(
            tables=[
                Table(
                    name=Name(name="orders", quote_style=""),
                    namespace=Namespace(
                        names=[
                            Name(name="dbt", quote_style=""),
                            Name(name="source", quote_style=""),
                            Name(name="jaffle_shop", quote_style=""),
                        ]
                    ),
                )
            ],
            joins=[
                Join(
                    kind=JoinKind.LeftOuter,
                    table=Alias(
                        name=Name(
                            name="dbt_DOT_dimension_DOT_customers", quote_style=""
                        ),
                        namespace=None,
                        child=Select(
                            from_=From(
                                tables=[
                                    Table(
                                        name=Name(name="customers", quote_style=""),
                                        namespace=Namespace(
                                            names=[
                                                Name(name="dbt", quote_style=""),
                                                Name(name="source", quote_style=""),
                                                Name(
                                                    name="jaffle_shop", quote_style=""
                                                ),
                                            ]
                                        ),
                                    )
                                ],
                                joins=[],
                            ),
                            group_by=[],
                            having=None,
                            projection=[
                                Column(
                                    name=Name(name="id", quote_style=""), namespace=None
                                ),
                                Column(
                                    name=Name(name="first_name", quote_style=""),
                                    namespace=None,
                                ),
                                Column(
                                    name=Name(name="last_name", quote_style=""),
                                    namespace=None,
                                ),
                            ],
                            where=None,
                            limit=None,
                            distinct=False,
                        ),
                    ),
                    on=BinaryOp(
                        op=BinaryOpKind.Eq,
                        left=Column(
                            name=Name(name="user_id", quote_style=""), namespace=None
                        ),
                        right=Column(
                            name=Name(name="id", quote_style=""), namespace=None
                        ),
                    ),
                )
            ],
        ),
        group_by=[Column(name=Name(name="first_name", quote_style=""), namespace=None)],
        having=None,
        projection=[
            Function(
                name=Name(name="Count", quote_style=""),
                namespace=Namespace(names=[]),
                args=[Column(name=Name(name="id", quote_style=""), namespace=None)],
            )
        ],
        where=None,
        limit=None,
        distinct=False,
    ),
    ctes=[],
)

Query(select=Select(from_=From(tables=[Table(name=Name(name='orders', quote_style=''), namespace=Namespace(names=[Name(name='dbt', quote_style=''), Name(name='source', quote_style=''), Name(name='jaffle_shop', quote_style='')]))], joins=[Join(kind=JoinKind.LeftOuter, table=Alias(name=Name(name='dbt_DOT_dimension_DOT_customers', quote_style=''), namespace=None, child=Select(from_=From(tables=[Table(name=Name(name='customers', quote_style=''), namespace=Namespace(names=[Name(name='dbt', quote_style=''), Name(name='source', quote_style=''), Name(name='jaffle_shop', quote_style='')]))], joins=[]), group_by=[], having=None, projection=[Column(name=Name(name='id', quote_style=''), namespace=None), Column(name=Name(name='first_name', quote_style=''), namespace=None), Column(name=Name(name='last_name', quote_style=''), namespace=None)], where=None, limit=None, distinct=False)), on=BinaryOp(op=BinaryOpKind.Eq, left=Column(name=Name(name='user_id', quote_style=''), namespace=None), right=Column(

<IPython.core.display.Javascript object>

In [33]:
print(str(query))

SELECT  Count(dbt.source.jaffle_shop.orders.id) 
 FROM dbt.source.jaffle_shop.orders
LEFT JOIN (SELECT  id,
	first_name,
	last_name 
 FROM dbt.source.jaffle_shop.customers
 
) AS dbt_DOT_dimension_DOT_customers
        ON dbt.source.jaffle_shop.orders.user_id = (SELECT  id,
	first_name,
	last_name 
 FROM dbt.source.jaffle_shop.customers
 
) AS dbt_DOT_dimension_DOT_customers.id 
 GROUP BY  dbt_DOT_dimension_DOT_customers.first_name


In [29]:
ast.JoinKind??

In [37]:
for dim_col in dimension_columns:
    dim_node = dim_col.table.dj_node
    if dim_node in tables_nodes: # dimension does not require joining
        ...
    else:# need to join dimension
        alias = amenable_name(dim_node.name)
        join_info: Dict[str, Tuple[Node, List[Column]]] = {}
        for table_node in tables_nodes:
            dim_cols = [col for col in table_node.columns if col.dimension == dim_node]
            join_info[table_node.name] = (
                table_node,
                dim_cols,
            )
        
        for col in dimension_columns
        
        dim_ast = ast.Alias(ast.Name(alias), child=parse(dim.query, dialect))
        
        joins: List[ast.Join] = []

        for tables, cols in join_info.values():
            for table in tables:
                on = []
                for col in cols:
                    on.append(
                        ast.BinaryOp(
                            ast.BinaryOpKind.Eq,
                            ast.Column(ast.Name(col.name), _table=table),
                            ast.Column(
                                ast.Name(col.dimension_column),
                                ast.Namespace([Name(alias)]),
                            ),
                        ),
                    )
            joins.append(
                ast.Join(
                    ast.JoinKind.LeftOuterdim,
                    dim_ast,
                    reduce(lambda l, r: ast.BinaryOp(ast.BinaryOpKind.And, l, r), on),
                ),
            )
        
        select.from_.joins += joins

NameError: name 'Alias' is not defined

In [48]:
# after compile the table knows that several columns were referenced from it
q3.select.from_.tables[0].columns

{Column(name=Name(name='country', quote_style=''), namespace=None),
 Column(name=Name(name='revenue', quote_style=''), namespace=None)}

In [32]:
print(str(query))

SELECT  Sum((SELECT  country,
	revenue 
 FROM revenue
 
 WHERE  revenue > 1000.0 
) AS _purchases_over_a_grand.revenue) 
 FROM (SELECT  country,
	revenue 
 FROM revenue
 
 WHERE  revenue > 1000.0 
) AS _purchases_over_a_grand
 
 WHERE  purchases_over_a_grand.revenue > 1000.0 
 GROUP BY  (SELECT  country,
	revenue 
 FROM revenue
 
 WHERE  revenue > 1000.0 
) AS _purchases_over_a_grand.country


In [15]:
query.select.where.left._table

Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)

In [7]:
from_=tree.select.projection[0]
to=ast.Alias(ast.Name('sum_revenue'), child=from_)

In [8]:
tree.select.replace(from_, to)

Select(from_=From(tables=[Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)], joins=[]), group_by=[Column(name=Name(name='country', quote_style=''), namespace=None)], having=None, projection=[Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]))], where=BinaryOp(op=BinaryOpKind.Gt, left=Column(name=Name(name='revenue', quote_style=''), namespace=None), right=Number(value=1000.0)), limit=None, distinct=False)

In [2]:
print(str(tree))

NameError: name 'tree' is not defined

In [4]:
def name_exp(exp: ast.Expression, default: str):
    if isinstance(exp, ast.Alias):
        return make_name(exp.namespace, exp.name.name)
    return default

In [5]:
{name_exp(exp, f"_col{i}"): get_type_of_expression(exp) for i, exp in enumerate(tree.select.projection)}

{'_col0': <ColumnType.FLOAT: 'FLOAT'>}

In [6]:
tree, deps, danglers = extract_dependencies(session,     """
    select Count(*) from purchases_over_a_grand where revenue>1000.0 group by country
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None)]

Projection Exp types  [(Function(name=Name(name='Count', quote_style=''), namespace=Namespace(names=[]), args=[Wildcard()]), <ColumnType.INT: 'INT'>)]



({Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])},
 set())

In [7]:
tree, deps, danglers = extract_dependencies(session, """
    select sum_revenue from (select Sum(revenue) as sum_revenue from purchases_over_a_grand where revenue>1000.0 group by country)
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), None)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)])))]

Projection Exp types  [(Column(name=Name(name='sum_revenue', q

({Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])},
 set())

In [8]:
tree, deps, danglers = extract_dependencies(session, """
    select sum_revenue from (select Sum(revenue) as sum_revenue from purchases_over_a_grand where revenue>1000.0 group by country) as a, oops
    """, raise_=False)

print("All Column types ", [(c, c.type) for c in list(tree.find_all(ast.Column))])
print()
print("Column Expressions", [(c, c.expression) for c in list(tree.find_all(ast.Column))])
print()
print("Projection Exp types ", [(exp, get_type_of_expression(exp)) for exp in tree.select.projection])
print()
deps, danglers

All Column types  [(Column(name=Name(name='country', quote_style=''), namespace=None), <ColumnType.STR: 'STR'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='revenue', quote_style=''), namespace=None), <ColumnType.FLOAT: 'FLOAT'>), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), None)]

Column Expressions [(Column(name=Name(name='country', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='revenue', quote_style=''), namespace=None), None), (Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)])))]

Projection Exp types  [(Column(name=Name(name='sum_revenue', q

({Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])},
 {'oops'})