Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 131 additions & 33 deletions datajunction-server/datajunction_server/construction/build_v3/cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import re
from copy import deepcopy
from typing import Optional

Expand Down Expand Up @@ -951,20 +950,65 @@ def _resolve_pushdown_filters_for_cte(
)
if rewritten is None:
continue
results.append(parse_filter(rewritten))
results.append(rewritten)
return results


def _cte_has_set_operation(cte_query: ast.Query) -> bool:
"""Detect whether the CTE body is a UNION / INTERSECT / EXCEPT.

Projection inspection only sees the first arm, so a set operation with
asymmetric arms can't be safely pushed into. Callers should skip
pushdown when this returns True.
"""
return bool(cte_query.select and cte_query.select.set_op)


def _build_name(parts: list[str]) -> ast.Name:
"""Build a nested ``ast.Name`` matching a dotted reference.

``["o", "order_date"]`` becomes ``Name("order_date", namespace=Name("o"))``
which serializes as ``o.order_date``.
"""
if len(parts) == 1:
return ast.Name(parts[0])
return ast.Name(parts[-1], namespace=_build_name(parts[:-1]))


def _column_from_qualified(qualified: str) -> ast.Column:
"""Build an ``ast.Column`` node from a dotted reference like ``o.order_date``."""
return ast.Column(name=_build_name(qualified.split(SEPARATOR)))


def _resolve_pushdown_form(
output_col: str,
cte_output_cols: set[str],
projection_map: dict[str, str | None],
) -> str | None:
"""Determine the WHERE-safe form for a filter column in a specific CTE.

Returns the WHERE-safe reference as a string (qualified or bare), or
``None`` when the filter cannot be safely pushed into this CTE — either
because the column isn't exposed at all, or because the projection is a
non-column expression that can't be inlined into WHERE.
"""
if output_col not in cte_output_cols:
return None
if output_col in projection_map:
return projection_map[output_col] # may be None: unsafe projection
return output_col # pruned from CTE SELECT; falls through to bare name


def _rewrite_filter_for_cte(
filter_str: str,
filter_column_aliases: dict[str, str],
cte_output_cols: set[str],
cte_query: ast.Query,
) -> str | None:
) -> ast.Expression | None:
"""Rewrite a dimension filter for injection into a specific CTE.

Resolves each dimension reference (e.g., ``v3.product.category``) to the
form that's safe in the CTE's WHERE clause. Three cases:
form that's safe in the CTE's WHERE clause. Three projection cases:

1. CTE projects the column as a simple (possibly aliased) column: replace
with the underlying qualified form (e.g., ``p.category``). This is the
Expand All @@ -976,40 +1020,94 @@ def _rewrite_filter_for_cte(
3. CTE projects the column via a non-column expression (e.g.
``SUM(x) AS y``): skip — inlining is unsafe.

Returns the rewritten filter string, or None if no dim_ref applies.
Multi-predicate handling: a single filter may reference several dim refs
(``a.x = 1 OR b.y = 2``). All matching refs are rewritten, but if ANY
ref's column isn't exposed by this CTE, the whole filter is skipped —
pushing a partial OR-predicate into the wrong CTE produces invalid SQL.

Returns the rewritten filter AST, or None when the filter can't be
safely pushed into this CTE.
"""
# Set-operation CTEs can't be safely pushed into via the first arm alone.
if _cte_has_set_operation(cte_query):
return None

projection_map = _build_cte_projection_map(cte_query)
filter_ast = parse_filter(filter_str)

# First pass: plan the rewrites by walking the AST. Role-qualified refs
# appear as Subscript(Column(base), Column/Lambda(role)) and are handled
# whole; plain Column refs are handled individually. Collect rewrites
# into buffers so we can bail out atomically if any ref can't be pushed.
subscript_rewrites: list[tuple[ast.Subscript, ast.Column]] = []
column_rewrites: list[tuple[ast.Column, ast.Column]] = []
# Columns that are children of a rewritten Subscript — exclude from the
# plain-Column pass so we don't double-process them.
handled_col_ids: set[int] = set()

# Users can write filters like v3.date.date_id[order] >= 20240101 where
# [order] is a role, a disambiguator when the same dimension is linked to
# the fact multiple times. A filter like that would be stored in the AST as:
# BinaryOp(
# >=,
# Subscript(
# expr=Column("v3.date.date_id"),
# index=Column("order"),
# ),
# Literal(20240101),
# )
for subscript in filter_ast.find_all(ast.Subscript):
# Skip ones whose target isn't a Column as these are real SQL array subscripts, not role refs
if not isinstance(subscript.expr, ast.Column):
continue # pragma: no cover
# Reconstruct the original role-qualified form: base = "v3.date.date_id", role = "order"
base = get_column_full_name(subscript.expr)
role = extract_subscript_role(subscript)
if not role:
continue # pragma: no cover
full_name = f"{base}[{role}]"

rewritten = filter_str
matched = False
for dim_ref, output_col in sorted(
filter_column_aliases.items(),
key=lambda x: -len(x[0]),
):
if dim_ref not in rewritten:
# Look up in the filter alias map. Prefers the role-specific key over the fallback.
output_col = filter_column_aliases.get(
full_name,
) or filter_column_aliases.get(base)
if output_col is None:
continue # pragma: no cover
form = _resolve_pushdown_form(output_col, cte_output_cols, projection_map)
if form is None:
return None
replacement = _column_from_qualified(form)
subscript_rewrites.append((subscript, replacement))

# Safety checks and queue for subscript to column swap
handled_col_ids.add(id(subscript.expr))
if isinstance(subscript.index, ast.Column):
handled_col_ids.add(id(subscript.index))

# Handle dim refs that don't have a role qualifier
for col in filter_ast.find_all(ast.Column):
# Skip columns accounted for in the subscript pass
if id(col) in handled_col_ids:
continue
if output_col not in cte_output_cols:
full_name = get_column_full_name(col)
if not full_name or full_name not in filter_column_aliases:
continue
if output_col in projection_map:
qualified = projection_map[output_col]
if qualified is None:
# Non-column expression under this alias — unsafe to inline.
continue
else:
# Column pruned from the CTE's SELECT; fall through to the bare
# name, which the underlying source still exposes.
qualified = output_col
# Word-boundary-safe replacement so a shorter dim_ref doesn't clobber
# a longer similarly-prefixed one (e.g. ``fact.orders.order_date`` must
# not match ``fact.orders.order_date_extended``). SQL identifier chars
# are ASCII word chars plus underscore; dots terminate identifiers so
# we only guard against alnum/underscore on either side.
pattern = r"(?<![A-Za-z0-9_])" + re.escape(dim_ref) + r"(?![A-Za-z0-9_])"
rewritten = re.sub(pattern, qualified, rewritten)
matched = True
break # One dim ref per filter (filters are single predicates)

return rewritten if matched else None
output_col = filter_column_aliases[full_name]
form = _resolve_pushdown_form(output_col, cte_output_cols, projection_map)
if form is None:
return None
column_rewrites.append((col, _column_from_qualified(form)))

if not subscript_rewrites and not column_rewrites:
return None

# Second pass: apply. Safe to mutate now that every ref has been validated.
for subscript, replacement in subscript_rewrites:
subscript.swap(replacement)
for col, replacement in column_rewrites:
col.swap(replacement)

return filter_ast


def _build_cte_projection_map(cte_query: ast.Query) -> dict[str, str | None]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def parse_dimension_ref(dim_ref: str) -> DimensionRef:
- "v3.customer.name" -> node=v3.customer, col=name, role=None
- "v3.customer.name[order]" -> node=v3.customer, col=name, role=order
- "v3.date.month[customer->registration]" -> node=v3.date, col=month, role=customer->registration

A reference without an extractable node (e.g. a bare ``status``) is
rejected — DJ can't route the reference to a CTE without knowing the
owning node.
"""
from datajunction_server.errors import DJInvalidInputException

# Extract role if present
role = None
if "[" in dim_ref:
Expand All @@ -50,12 +56,13 @@ def parse_dimension_ref(dim_ref: str) -> DimensionRef:

# Split into node and column
parts = dim_part.rsplit(SEPARATOR, 1)
if len(parts) == 2:
node_name, column_name = parts
else: # pragma: no cover
# Assume single part is column name on current node
node_name = ""
column_name = parts[0]
if len(parts) != 2:
raise DJInvalidInputException(
f"Reference `{dim_ref}` is not fully qualified. Use the "
f"`node.column` form (e.g., `v3.order_details.status`) so DJ "
f"can route the reference to the correct node.",
)
node_name, column_name = parts

return DimensionRef(node_name=node_name, column_name=column_name, role=role)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ def add_dimensions_from_filters(ctx: "BuildContext") -> None:
# so we don't also add the role-less version in the Column pass below.
subscript_handled_refs: set[str] = set()

# Role markers inside subscripts are parsed as Column nodes (e.g. the
# `to` in `v3.location.country[to]`) but are not real filter column
# references — identify them so the second pass doesn't treat them
# as bare column refs and reject them.
role_marker_ids: set[int] = set()

# First pass: handle Subscript nodes for role-qualified dimension refs.
# SQL like "v3.location.country[customer->home]" is parsed as
# Subscript(Column(v3.location.country), Lambda(customer->home)).
Expand All @@ -352,6 +358,17 @@ def add_dimensions_from_filters(ctx: "BuildContext") -> None:
# Mark this base ref as handled so the Column pass skips it
subscript_handled_refs.add(base_col_ref)

# Mark any Column nodes used as the role marker so we don't raise
# on them as bare refs.
if isinstance(subscript.index, ast.Column):
role_marker_ids.add(id(subscript.index))
for inner_col in (
subscript.index.find_all(ast.Column)
if hasattr(subscript.index, "find_all")
else []
):
role_marker_ids.add(id(inner_col))

if full_name in existing_dims:
continue

Expand Down Expand Up @@ -382,10 +399,13 @@ def add_dimensions_from_filters(ctx: "BuildContext") -> None:
# Second pass: handle regular Column references.
# Skip columns that were already added via the subscript pass above.
for col in filter_ast.find_all(ast.Column):
full_name = get_column_full_name(col)
if not full_name or SEPARATOR not in full_name:
# Simple column name (e.g., "status") - will be resolved from parent node
# Role markers inside subscripts (e.g. `to` in `dim.col[to]`) are
# parsed as Columns but don't refer to real data columns.
if id(col) in role_marker_ids:
continue
full_name = get_column_full_name(col)
if not full_name:
continue # pragma: no cover

# Skip if already handled as a role-qualified subscript ref
if full_name in subscript_handled_refs:
Expand Down
Loading
Loading