In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
from dj.sql.parsing.backends.sqloxide import parse
from dj.sql.parsing.backends.exceptions import DJParseException
from dj.construction.utils import make_name, get_dj_node, amenable_name
from dj.construction.build import build_ast_for_database
from dj.construction.compile import _compile_select_ast
from dj.sql.parsing import ast
from typing import Optional, List, Tuple
from sqlmodel import Session
from dj.models.node import Node, NodeType, NodeRevision
from dj.models.database import Database
from dj.errors import DJException



<IPython.core.display.Javascript object>

In [3]:
from sqlalchemy import select as built_select

<IPython.core.display.Javascript object>

In [4]:
from dj.utils import get_session

<IPython.core.display.Javascript object>

In [5]:
session = next(get_session())

<IPython.core.display.Javascript object>

In [6]:
query = """
SELECT Avg(n_comments),
       age
FROM   (SELECT basic.num_comments            AS n_comments,
               basic.dimension.users.country country,
               basic.dimension.users.age     AS age
        FROM   metrics
        GROUP  BY basic.dimension.users.country)
GROUP  BY age 
"""

<IPython.core.display.Javascript object>

In [7]:
async def build_dj_query(
    session: Session,
    query: str,
    dialect: Optional[str] = None,
    database_id: Optional[int] = None,
) -> Tuple[ast.Query, Database]:
    """
    Build a dj query
    """
    query_ast = parse(query, dialect)
    select = query_ast._to_select()

    for col in select.find_all(ast.Column):
        froms = []
        col_name = make_name(col.namespace, col.name.name)
        if metric_node := get_dj_node(
            session, col_name, {NodeType.METRIC}, raise_=False
        ):
            parent_select = col.get_nearest_parent_of_type(ast.Select)
            if not getattr(parent_select, "_validated", False):
                if len(parent_select.from_.tables) != 1 or parent_select.from_.joins:
                    raise DJParseException(
                        "Any SELECT referencing a Metric must source from a single unaliased Table named 'metrics.'."
                    )
                metrics_ref = parent_select.from_.tables[0]
                metrics_ref_name = make_name(
                    metrics_ref.namespace, metrics_ref.name.name
                )
                if metrics_ref_name != "metrics":
                    raise DJParseException(
                        "The name of the table for a Metric select must be 'metrics'."
                    )
                parent_select.from_ = ast.From([])
                parent_select._validated = True

            metric_name = amenable_name(metric_node.name)
            metric_select = parse(metric_node.query)._to_select()
            tables = metric_select.from_.tables + [
                join.table for join in metric_select.from_.joins
            ]
            for table in tables:
                if isinstance(table, ast.Select):
                    continue
                if isinstance(table, ast.Alias):
                    if isinstance(table.child, ast.Select):
                        continue
                    table = table.child
                table_name = make_name(table.namespace, table.name.name)
                if table_node := get_dj_node(
                    session,
                    table_name,
                    {NodeType.SOURCE, NodeType.TRANSFORM, NodeType.DIMENSION},
                    raise_=False,
                ):
                    metric_select.projection += [
                        ast.Column(ast.Name(col.name)) for col in table_node.columns
                    ]
                    froms.append(table.copy())

            metric_table_expression = ast.Alias(
                ast.Name(metric_name), None, metric_select
            )
            froms.append(metric_table_expression)
            metric_column = ast.Column(
                ast.Name(metric_node.columns[0].name), _table=metric_table_expression
            )
            parent_select.replace(col, metric_column)
            parent_select.from_.tables += froms

    for col in select.find_all(ast.Column):
        col_name = make_name(col.namespace)
        if get_dj_node(session, col_name, {NodeType.DIMENSION}, raise_=False):
            col.set_api_column(True)

    return await build_ast_for_database(
        session, query=ast.Query(select), dialect=dialect, database_id=database_id
    )

<IPython.core.display.Javascript object>

In [8]:
query_ast, db = await build_dj_query(session, query)

  results = super().execute(
  results = super().execute(
  results = super().execute(


<IPython.core.display.Javascript object>

In [9]:
print(query_ast)

SELECT  Avg(n_comments),
	age 
 FROM (SELECT  basic_DOT_num_comments.cnt AS n_comments,
	basic_DOT_dimension_DOT_users.country AS country,
	basic_DOT_dimension_DOT_users.age AS age 
 FROM comments, (SELECT  COUNT(1) AS cnt,
	comments.id,
	comments.user_id,
	comments.timestamp,
	comments.text 
 FROM comments
 
) AS basic_DOT_num_comments
LEFT JOIN (SELECT  users.id,
	users.full_name,
	users.age,
	users.country,
	users.gender,
	users.preferred_language,
	users.secret_number 
 FROM users
 
) AS basic_DOT_dimension_DOT_users
        ON comments.user_id = basic_DOT_dimension_DOT_users.id AND comments.user_id = basic_DOT_dimension_DOT_users.id 
 GROUP BY  basic_DOT_dimension_DOT_users.country)
 
 GROUP BY  age


<IPython.core.display.Javascript object>

In [10]:
from sqlalchemy.sql import text

<IPython.core.display.Javascript object>

In [11]:
import pandas as pd

<IPython.core.display.Javascript object>

In [12]:
with db.engine.connect() as conn:
    df = pd.read_sql(
        text(str(query_ast)),
        conn,
    )

<IPython.core.display.Javascript object>

In [13]:
df

Unnamed: 0,Avg(n_comments),age
0,21.0,10.0
1,21.0,15.0
2,21.0,20.0
3,21.0,25.0
4,21.0,27.0
5,21.0,29.0


<IPython.core.display.Javascript object>