In [1]:
from dj.construction.extract import *

In [2]:
from dj.utils import get_session

In [3]:
session = next(get_session())

In [4]:
from typing import Optional

# pylint: disable=too-many-lines
import pytest
from sqlmodel import Session, select

from dj.construction.extract import (
    ColumnDependencies,
    CompoundBuildException,
    extract_dependencies,
    extract_dependencies_from_query,
    get_dj_node,
    make_name,
)
from dj.errors import DJError, DJException, ErrorCode
from dj.models import Column
from dj.models.node import Node, NodeType
from dj.sql.parsing.ast import Alias, BinaryOp, BinaryOpKind
from dj.sql.parsing.ast import Column as ASTColumn
from dj.sql.parsing.ast import (
    From,
    Join,
    JoinKind,
    Name,
    Namespace,
    Select,
    String,
    Table,
)
from dj.sql.parsing.backends.sqloxide import parse
from dj.typing import ColumnType

In [5]:
from dj.sql.parsing import ast

In [6]:
from dj.typing import ColumnType

In [7]:
from functools import singledispatch

In [8]:
from dj.sql.functions import function_registry

In [9]:
from dj.sql.parsing.backends.exceptions import DJParseException

In [10]:
@singledispatch
def get_type_of_expression(expression: ast.Expression)->ColumnType:
    raise NotImplementedError(f"Cannot get type of expression {e}")

@get_type_of_expression.register
def _(expression: ast.Alias):
    return get_type_of_expression(expression.child)
    
@get_type_of_expression.register
def _(expression: ast.Column):
    #column has already determined/stated its type
    if expression.type:
        return expression.type
    
    # column was derived from some other expression we can get the type of
    if expression.expression:
        return get_type_of_expression(expression.expression)
    
    # column is from a table expression we can look through
    if table_pos_alias:=expression.table:
        if isinstance(table, Alias):
            table = table_pos_alias.child
        else:
            table = table_pos_alias
        if isinstance(table, ast.Table):
            if table.dj_node:
                for col in table.dj_node.columns:
                    if col.name==expression.name.name:
                        expression.set_type(col.type)
                        return col.type
            else:
                raise DJParseException(f"Cannot resolve type of column {expression}. "
                                      "column's table does not have a DJ Node."
                                      )
        else:
            raise DJParseException(f"Cannot resolve type of column {expression}. "
                                  "DJ does not currently traverse subqueries for type information. "
                                  "Consider extraction first."
                                  )
        #else:#if subquery
        #currently don't even bother checking subqueries.
        #the extract will have built it for us in crucial cases
    raise DJParseException(f"Cannot resolve type of column {expression}.")
    
@get_type_of_expression.register
def _(expression: ast.String):
    return ColumnType.STR

@get_type_of_expression.register
def _(expression: ast.Number):
    if isinstance(expression.value, int):
        return ColumnType.INT
    return ColumnType.FLOAT

@get_type_of_expression.register
def _(expression: ast.Boolean):
    return ColumnType.BOOL

@get_type_of_expression.register
def _(expression: ast.Wildcard):
    return ColumnType.WILDCARD

@get_type_of_expression.register
def _(expression: ast.Function):
    name = expression.name.name.upper()
    dj_func = function_registry[name]
    return dj_func.infer_type(*(get_type_of_expression(exp) for exp in expression.args))

In [11]:
f = ast.Function(Name("Count"), args = [ast.Wildcard()])

get_type_of_expression(f)

<ColumnType.INT: 'INT'>

In [12]:
tree = parse(
    """
    select country as cntry_alias, revenue from purchases_over_a_grand where revenue>1000.0
    """
)

extract_dependencies_from_select(session, tree.select)

SelectDependencies(tables=[(Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None), Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)]))], columns=ColumnDependencies(projection=[(Column(name=Name(name='country', quote_style=''), namespace=None), Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)), (Column(name=Name(name='revenue', quote_style=''), namespace=None), Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None))], group_by=[], filters=[(Column(name

In [13]:
[(c, c.type) for c in list(tree.find_all(ast.Column))]

[(Column(name=Name(name='country', quote_style=''), namespace=None),
  <ColumnType.STR: 'STR'>),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None),
  <ColumnType.FLOAT: 'FLOAT'>),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None),
  <ColumnType.FLOAT: 'FLOAT'>)]

In [14]:
[(c, c.expression) for c in list(tree.find_all(ast.Column))]

[(Column(name=Name(name='country', quote_style=''), namespace=None), None),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None), None),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None), None)]

In [15]:
[(exp, get_type_of_expression(exp)) for exp in tree.select.projection]

[(Alias(name=Name(name='cntry_alias', quote_style=''), namespace=None, child=Column(name=Name(name='country', quote_style=''), namespace=None)),
  <ColumnType.STR: 'STR'>),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None),
  <ColumnType.FLOAT: 'FLOAT'>)]

In [16]:
tree = parse(
    """
    select Sum(revenue) from purchases_over_a_grand where revenue>1000.0 group by country
    """
)

extract_dependencies_from_select(session, tree.select)

SelectDependencies(tables=[(Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None), Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)]))], columns=ColumnDependencies(projection=[(Column(name=Name(name='revenue', quote_style=''), namespace=None), Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None))], group_by=[(Column(name=Name(name='country', quote_style=''), namespace=None), Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None))], filters=[(Column(name=N

In [17]:
[(c, c.type) for c in list(tree.find_all(ast.Column))]

[(Column(name=Name(name='country', quote_style=''), namespace=None),
  <ColumnType.STR: 'STR'>),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None),
  <ColumnType.FLOAT: 'FLOAT'>),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None),
  <ColumnType.FLOAT: 'FLOAT'>)]

In [18]:
[(c, c.expression) for c in list(tree.find_all(ast.Column))]

[(Column(name=Name(name='country', quote_style=''), namespace=None), None),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None), None),
 (Column(name=Name(name='revenue', quote_style=''), namespace=None), None)]

In [19]:
[(exp, get_type_of_expression(exp)) for exp in tree.select.projection]

[(Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]),
  <ColumnType.FLOAT: 'FLOAT'>)]

In [20]:
tree = parse(
    """
    select Count(*) from purchases_over_a_grand where revenue>1000.0 group by country
    """
)

extract_dependencies_from_select(session, tree.select)

SelectDependencies(tables=[(Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None), Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)]))], columns=ColumnDependencies(projection=[], group_by=[(Column(name=Name(name='country', quote_style=''), namespace=None), Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None))], filters=[(Column(name=Name(name='revenue', quote_style=''), namespace=None), Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None))], ons=[]), su

In [21]:
[(exp, get_type_of_expression(exp)) for exp in tree.select.projection]

[(Function(name=Name(name='Count', quote_style=''), namespace=Namespace(names=[]), args=[Wildcard()]),
  <ColumnType.INT: 'INT'>)]

In [23]:
tree = parse(
    """
    select sum_revenue from (select Sum(revenue) as sum_revenue from purchases_over_a_grand where revenue>1000.0 group by country)
    """
)

extract_dependencies_from_select(session, tree.select)

SelectDependencies(tables=[], columns=ColumnDependencies(projection=[(Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), Select(from_=From(tables=[Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)], joins=[]), group_by=[Column(name=Name(name='country', quote_style=''), namespace=None)], having=None, projection=[Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]))], where=BinaryOp(op=BinaryOpKind.Gt, left=Column(name=Name(name='revenue', quote_style=''), namespace=None), right=Number(value=1000.0)), limit=None, distinct=False))], group_by=[], filters=[], ons=[]), subqueries=[(Select(from_=From(tables=[Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)], joins=[]), group_by=[Column(name=Name(name='country', quote_style=''), namespace=Non

In [24]:
tree.select.projection[0].expression

Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]))

In [25]:
# you can see the types of all expression in the top projection
[(exp, get_type_of_expression(exp)) for exp in tree.select.projection]

[(Column(name=Name(name='sum_revenue', quote_style=''), namespace=None),
  <ColumnType.FLOAT: 'FLOAT'>)]

In [31]:
# any Table with a dj_node is a dependency

tables = tree.find_all(ast.Table)
deps = {t.dj_node for t in tables if t.dj_node}
danglers = {make_name(t.namespace, t.name.name) for t in tables if not t.dj_node}

In [32]:
deps

{Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])}

In [33]:
danglers

set()

In [38]:
CompoundBuildException().reset()
CompoundBuildException().set_raise(False)

In [39]:
tree = parse(
    """
    select sum_revenue from (select Sum(revenue) as sum_revenue from purchases_over_a_grand where revenue>1000.0 group by country) as a, oops
    """
)

extract_dependencies_from_select(session, tree.select)

SelectDependencies(tables=[], columns=ColumnDependencies(projection=[(Column(name=Name(name='sum_revenue', quote_style=''), namespace=None), Select(from_=From(tables=[Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)], joins=[]), group_by=[Column(name=Name(name='country', quote_style=''), namespace=None)], having=None, projection=[Alias(name=Name(name='sum_revenue', quote_style=''), namespace=None, child=Function(name=Name(name='Sum', quote_style=''), namespace=Namespace(names=[]), args=[Column(name=Name(name='revenue', quote_style=''), namespace=None)]))], where=BinaryOp(op=BinaryOpKind.Gt, left=Column(name=Name(name='revenue', quote_style=''), namespace=None), right=Number(value=1000.0)), limit=None, distinct=False))], group_by=[], filters=[], ons=[]), subqueries=[(Select(from_=From(tables=[Table(name=Name(name='purchases_over_a_grand', quote_style=''), namespace=None)], joins=[]), group_by=[Column(name=Name(name='country', quote_style=''), namespace=Non

In [46]:
# any Table with a dj_node is a dependency

deps = {t.dj_node for t in tree.find_all(ast.Table) if t.dj_node}
danglers = {make_name(t.namespace, t.name.name) for t in tree.find_all(ast.Table) if not t.dj_node}

In [47]:
deps

{Node(name='purchases_over_a_grand', type=<NodeType.TRANSFORM: 'transform'>, id=23, query='select country, revenue from revenue where revenue>1000.0', created_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301204), updated_at=datetime.datetime(2023, 1, 12, 14, 19, 12, 301210), description='', columns=[Column(id=129, name='country', dimension_column=None, dimension_id=None, type=<ColumnType.STR: 'STR'>), Column(id=130, name='revenue', dimension_column=None, dimension_id=None, type=<ColumnType.FLOAT: 'FLOAT'>)])}

In [48]:
danglers

{'oops'}