From 18e765fb04014787920271aa19412215eba0e779 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 20 Apr 2026 06:48:12 -0700 Subject: [PATCH 1/5] Unify validation between bulk and single node validators --- .../internal/deployment/orchestrator.py | 38 +- .../internal/deployment/type_inference.py | 273 +++++++++-- .../internal/deployment/utils.py | 94 ++-- .../datajunction_server/internal/nodes.py | 13 +- .../internal/validation.py | 349 +++++++++++++++ .../internal/deployment/orchestration_test.py | 117 +++++ .../deployment/test_type_inference.py | 422 +++++++++++++++++- .../tests/internal/node_validation_test.py | 111 +++++ 8 files changed, 1337 insertions(+), 80 deletions(-) diff --git a/datajunction-server/datajunction_server/internal/deployment/orchestrator.py b/datajunction-server/datajunction_server/internal/deployment/orchestrator.py index 59e1cd768..ead9935af 100644 --- a/datajunction-server/datajunction_server/internal/deployment/orchestrator.py +++ b/datajunction-server/datajunction_server/internal/deployment/orchestrator.py @@ -24,7 +24,7 @@ from datajunction_server.database.history import History from datajunction_server.database.metricmetadata import MetricMetadata from datajunction_server.database.namespace import NodeNamespace -from datajunction_server.database.node import NodeRelationship +from datajunction_server.database.node import MissingParent, NodeRelationship from datajunction_server.database.partition import Partition from datajunction_server.database.tag import Tag from datajunction_server.database.user import User, OAuthProvider @@ -37,6 +37,7 @@ ErrorCode, ) from datajunction_server.internal.deployment.utils import ( + classify_parents, extract_node_graph, topological_levels, DeploymentContext, @@ -3122,6 +3123,24 @@ def _create_or_update_node( new_node.tags = tags # type: ignore return new_node + @staticmethod + def _classify_parents( + spec: NodeSpec, + dep_names: list[str], + dependency_nodes: dict[str, Node], + ) -> tuple[list[Node], list[str]]: + """Split a spec's dependency names into resolved parents and missing names. + + Thin wrapper around the shared utils.classify_parents helper — derives + the is_derived_metric flag from the spec so callers don't have to. + """ + is_derived_metric = ( + spec.node_type == NodeType.METRIC + and spec.query_ast is not None + and spec.query_ast.select.from_ is None + ) + return classify_parents(is_derived_metric, dep_names, dependency_nodes) + async def _create_node_revision( self, new_node: Node, @@ -3132,11 +3151,11 @@ async def _create_node_revision( """Create node revision with inferred columns and dependencies""" existing = self.registry.nodes.get(result.spec.rendered_name) old_node_revision = existing.current if existing else None - parents = [ - dependency_nodes.get(parent) - for parent in node_graph.get(result.spec.rendered_name, []) - if parent in dependency_nodes - ] + parents, missing_parent_names = self._classify_parents( + result.spec, + node_graph.get(result.spec.rendered_name, []), + dependency_nodes, + ) if result.spec.node_type != NodeType.SOURCE: # Pick the first parent with a non-virtual catalog to assign as the # catalog inherited from source parents. @@ -3175,11 +3194,8 @@ async def _create_node_revision( node=new_node, catalog=catalog, status=result.status, - parents=[ - dependency_nodes.get(parent) - for parent in node_graph.get(result.spec.rendered_name, []) - if parent in dependency_nodes - ], + parents=parents, + missing_parents=[MissingParent(name=n) for n in missing_parent_names], created_by_id=self.context.current_user.id, custom_metadata=result.spec.custom_metadata, # Initialize to empty so _deploy_links can append without triggering a diff --git a/datajunction-server/datajunction_server/internal/deployment/type_inference.py b/datajunction-server/datajunction_server/internal/deployment/type_inference.py index 36b55bde4..ebc6c4af4 100644 --- a/datajunction-server/datajunction_server/internal/deployment/type_inference.py +++ b/datajunction-server/datajunction_server/internal/deployment/type_inference.py @@ -202,7 +202,10 @@ def _resolve_query( def _resolve_inline_table(query: ast.Query) -> OutputColumns: """Resolve columns for a VALUES expression with explicit column aliases. - Infers types from the first row of values when available. + For each column, scans all rows for the first non-null literal and uses + its type. If every row has NULL at that position, falls back to NullType + rather than UnknownType (the column is genuinely nullable-only, not just + unresolvable). """ select = query.select if not isinstance(select, ast.InlineTable): # pragma: no cover @@ -212,21 +215,33 @@ def _resolve_inline_table(query: ast.Query) -> OutputColumns: [col.alias_or_name.name for col in select._columns] if select._columns else [] ) - first_row = select.values[0] if select.values else [] + def _literal_type(val: ast.Node) -> ColumnType | None: + if isinstance(val, ast.Null): + return None # skip — check later rows for a typed value + if isinstance(val, ast.Number): + return val.type + if isinstance(val, ast.String): + return StringType() + if isinstance(val, ast.Boolean): + return BooleanType() + return UnknownType() + result: OutputColumns = [] for i, name in enumerate(col_names): - if i < len(first_row): - val = first_row[i] - if isinstance(val, ast.Number): - result.append((name, val.type)) - elif isinstance(val, ast.String): - result.append((name, StringType())) - elif isinstance(val, ast.Boolean): - result.append((name, BooleanType())) - else: - result.append((name, UnknownType())) - else: - result.append((name, UnknownType())) + col_type: ColumnType = UnknownType() + saw_only_null = True + for row in select.values: + if i >= len(row): + continue + inferred = _literal_type(row[i]) + if inferred is None: + continue # NULL — try next row + saw_only_null = False + col_type = inferred + break + if saw_only_null and select.values: + col_type = NullType() + result.append((name, col_type)) return result @@ -243,22 +258,29 @@ def _build_table_scope( """ Build a mapping of table alias/name → {column_name: column_type} for all tables available in this query's scope. Returns (scope, errors). + + Derived-metric detection triggers only when there is neither a FROM clause + nor any lateral views — otherwise we'd miss table-generating LATERAL VIEW + EXPLODE(sequence(...)) constructs used as FROMless series generators. """ - if select.from_ is None: + if select.from_ is None and not select.lateral_views: return {"__derived__": _build_derived_scope(parent_columns_map)}, [] scope: TableScope = {} errors: list[str] = [] - for relation in select.from_.relations: - tables, errs = _collect_tables_from_relation( - relation, - parent_columns_map, - cte_registry, + if select.from_ is not None: + for relation in select.from_.relations: + tables, errs = _collect_tables_from_relation( + relation, + parent_columns_map, + cte_registry, + ) + scope.update(tables) + errors.extend(errs) + for idx, view in enumerate(select.lateral_views): + scope.update( + _collect_lateral_view_columns(view, scope, idx=idx, errors=errors), ) - scope.update(tables) - errors.extend(errs) - for view in select.lateral_views: - scope.update(_collect_lateral_view_columns(view, scope)) return scope, errors @@ -281,8 +303,14 @@ def _collect_tables_from_relation( node: ast.Node, parent_columns_map: ParentColumnsMap, cte_registry: dict[str, OutputColumns], + outer_scope: TableScope | None = None, ) -> tuple[TableScope, list[str]]: - """Collect table scopes from a FROM relation. Returns (tables, errors).""" + """Collect table scopes from a FROM relation. Returns (tables, errors). + + ``outer_scope`` carries columns already in scope from earlier parts of the + FROM clause. Needed so JOIN right sides like ``CROSS JOIN UNNEST(t.arr)`` + can see the left table's columns when resolving their argument. + """ result: TableScope = {} errors: list[str] = [] @@ -291,15 +319,19 @@ def _collect_tables_from_relation( node.primary, parent_columns_map, cte_registry, + outer_scope=outer_scope, ) result.update(tables) errors.extend(errs) for ext in node.extensions: if isinstance(ext, ast.Join): # pragma: no branch + # Combine accumulated scope + any outer scope for visibility. + combined_outer = {**(outer_scope or {}), **result} tables, errs = _collect_tables_from_relation( ext.right, parent_columns_map, cte_registry, + outer_scope=combined_outer, ) result.update(tables) errors.extend(errs) @@ -325,13 +357,54 @@ def _collect_tables_from_relation( errors.extend(sub_errors) elif isinstance(node, ast.InlineTable): - alias = node.alias.name if node.alias else "__inline__" + # `VALUES (…) tab(c1, c2)` stores the `tab` alias on node.name, not + # node.alias. Accept both so the outer query can reference `tab.c1`. + if node.alias is not None: + alias = node.alias.name + elif node.name is not None and node.name.name: + alias = node.name.name + else: + alias = "__inline__" # pragma: no cover inline_columns = _resolve_inline_table(ast.Query(select=node)) # type: ignore[arg-type] result[alias] = {name: typ for name, typ in inline_columns} elif isinstance(node, ast.FunctionTableExpression): # pragma: no branch + # CROSS JOIN UNNEST(arr) AS t(x) / LATERAL table-function forms. + # Resolve element types from the function's argument so downstream + # refs to t.x get a concrete type instead of UnknownType. Merge the + # outer scope (i.e., sibling tables already in the FROM chain) so + # `UNNEST(sibling.arr)` can reach back into the previous table. alias = node.alias.name if node.alias else "__func_table__" - func_cols = {col.name.name: UnknownType() for col in (node.column_list or [])} + col_list = node.column_list or [] + resolution_scope = {**(outer_scope or {}), **result} + element_types = _resolve_lateral_element_types( + node, + resolution_scope, + errors=errors, + ) + func_name = ( + node.name.name.upper() if hasattr(node, "name") and node.name else "" + ) + is_posexplode = "POS" in func_name + # Struct-unpacking: UNNEST(array>) AS t(c1, c2) positions + # each alias against a struct field. + if ( + not is_posexplode + and len(col_list) > 1 + and len(element_types) == 1 + and isinstance(element_types[0], StructType) + ): + element_types = [f.type for f in element_types[0].fields] + func_cols: dict[str, ColumnType] = {} + for i, col in enumerate(col_list): + if is_posexplode and i == 0: + func_cols[col.name.name] = IntegerType() + elif is_posexplode and i == 1 and element_types: + func_cols[col.name.name] = element_types[0] + elif not is_posexplode and i < len(element_types): + func_cols[col.name.name] = element_types[i] + else: + func_cols[col.name.name] = UnknownType() if func_cols: # pragma: no branch result[alias] = func_cols @@ -341,22 +414,41 @@ def _collect_tables_from_relation( def _collect_lateral_view_columns( view: ast.LateralView, from_scope: TableScope, + idx: int = 0, + errors: list[str] | None = None, ) -> TableScope: """Collect columns from a LATERAL VIEW (e.g., EXPLODE) expression. Resolves element types from the source column's ListType/MapType when possible, falls back to UnknownType otherwise. + + ``idx`` distinguishes multiple anonymous lateral views in the same SELECT + (each default-aliased to ``__lateral__`` would otherwise collide and + overwrite prior columns). + + ``errors`` receives any element-type-resolution errors (e.g., the EXPLODE + argument references a nonexistent column) so callers can surface them. """ func = view.func - alias = func.alias.name if func.alias else "__lateral__" + alias = func.alias.name if func.alias else f"__lateral_{idx}__" col_list = func.column_list or [] if not col_list: return {} - element_types = _resolve_lateral_element_types(func, from_scope) + element_types = _resolve_lateral_element_types(func, from_scope, errors=errors) func_name = func.name.name.upper() if hasattr(func, "name") and func.name else "" is_posexplode = "POS" in func_name + # Struct-unpacking: EXPLODE(array>) AS (c1, c2) aliases the + # struct fields positionally. + if ( + not is_posexplode + and len(col_list) > 1 + and len(element_types) == 1 + and isinstance(element_types[0], StructType) + ): + element_types = [f.type for f in element_types[0].fields] + lateral_cols: dict[str, ColumnType] = {} for i, col in enumerate(col_list): if is_posexplode and i == 0: @@ -374,8 +466,15 @@ def _collect_lateral_view_columns( def _resolve_lateral_element_types( func: ast.FunctionTableExpression, from_scope: TableScope, + errors: list[str] | None = None, ) -> list[ColumnType]: - """Resolve element types for an EXPLODE/UNNEST function argument.""" + """Resolve element types for an EXPLODE/UNNEST function argument. + + When ``errors`` is provided and type resolution of the argument fails + (e.g., the referenced column doesn't exist), the specific error message is + appended so callers can surface the real cause rather than a downstream + "Unable to infer type" coming from the resulting UnknownType columns. + """ if not func.args: return [] # pragma: no cover @@ -389,9 +488,18 @@ def _resolve_lateral_element_types( scope = TypeScope(tables=from_scope, parent_map={}) try: col_type = _resolve_expr_type(arg, scope) - except (TypeResolutionError, Exception): + except TypeResolutionError as exc: + if errors is not None: + errors.append(str(exc)) + return [] + except Exception: # pragma: no cover return [] + if errors is not None: + # Also surface any errors that accumulated inside the throwaway scope + # (e.g., unresolved sub-refs) rather than silently dropping them. + errors.extend(scope.errors) + if isinstance(col_type, ListType): return [col_type.element.type] if isinstance(col_type, MapType): @@ -421,6 +529,14 @@ def _resolve_projection_expr( ) return _resolve_wildcard(table_alias, scope.tables) + # Projection-based table-generating function: `explode(x) AS (c1, c2)` in the + # SELECT list. Expands into one output column per name in column_list, typed + # from the exploded element's ListType/MapType. Without this, the alias list + # would silently disappear and downstream refs to c1/c2 would resolve against + # nothing. + if isinstance(expr, ast.FunctionTableExpression): + return _resolve_projection_function_table(expr, scope) + # Unwrap Alias(child=..., alias="name") → resolve the child, use the alias as name if isinstance(expr, ast.Alias): output_name = expr.alias.name if expr.alias else _get_output_name(expr.child) @@ -450,6 +566,64 @@ def _resolve_projection_expr( return [(output_name, col_type)] +def _resolve_projection_function_table( + func: ast.FunctionTableExpression, + scope: TypeScope, +) -> OutputColumns: + """Expand a table-generating function used in a SELECT projection. + + Supports Spark's ``explode(x) AS (c1, c2)`` / ``posexplode(x) AS (i, c)`` + forms. Element types come from the input's ListType/MapType; names come + from ``func.column_list``. Mirrors the LATERAL VIEW handler so projection + and lateral-view EXPLODE behave the same. + + Struct-unpacking: ``EXPLODE(array>) AS (c1, c2)`` aliases the + struct fields positionally — c1 gets field a's type, c2 gets field b's. + """ + # `EXPLODE(x) AS (c1, c2)` (parenthesized alias list) populates column_list. + # `EXPLODE(x) AS c` (single-name form) leaves column_list empty and puts + # the name on func.alias. Handle both. + col_names: list[str] + if func.column_list: + col_names = [c.name.name for c in func.column_list] + elif func.alias is not None: + col_names = [func.alias.name] + else: + return [] # pragma: no cover + + element_types = _resolve_lateral_element_types( + func, + scope.tables, + errors=scope.errors, + ) + func_name = func.name.name.upper() if hasattr(func, "name") and func.name else "" + is_posexplode = "POS" in func_name + + # If the element is a single struct and the user aliased N columns, unpack + # the struct fields positionally. Matches legacy compile behavior for + # EXPLODE(array>) AS (c1, c2, ...) in a projection. + if ( + not is_posexplode + and len(col_names) > 1 + and len(element_types) == 1 + and isinstance(element_types[0], StructType) + ): + struct_type = element_types[0] + element_types = [f.type for f in struct_type.fields] + + output: OutputColumns = [] + for i, out_name in enumerate(col_names): + if is_posexplode and i == 0: + output.append((out_name, IntegerType())) + elif is_posexplode and i == 1 and element_types: + output.append((out_name, element_types[0])) + elif not is_posexplode and i < len(element_types): + output.append((out_name, element_types[i])) + else: + output.append((out_name, UnknownType())) + return output + + def _resolve_wildcard( table_alias: Optional[str], tables: TableScope, @@ -644,10 +818,18 @@ def _resolve_column_type( full_id = col.identifier() if full_id in derived: return derived[full_id] - if col.namespace: # pragma: no branch + if col.namespace: result = _resolve_dj_node_column(col, scope.parent_map) if result is not None: return result # pragma: no cover + # Multi-part namespaced refs in a derived metric that don't match + # a parent metric or attribute are likely dim-attribute references + # resolved via dim links at query build time (e.g., + # `common.dimensions.time.date.dateint` used inside a window + # function OVER (ORDER BY ...)). Fall back to UnknownType rather + # than raising so legitimate dim-link refs don't get falsely + # rejected — deployment enforces strictly via its own call. + return UnknownType() raise TypeResolutionError( f"Column `{col}` not found in derived metric scope.", ) @@ -700,8 +882,27 @@ def _resolve_column_type( if result is not None: return result - # Not resolvable - return UnknownType since it may be a dimension - # ref that gets resolved via dimension links at query time. + # Nothing matched. Fall through to UnknownType — the ref may be a + # legit dim-attribute reference resolved at query build time. + # + # Only surface a specific namespace-not-found error for SINGLE-segment + # namespaces. Those are the classic typo/wrong-alias shapes + # (e.g., `tenure.tenure` in an outer query where `tenure` was a + # subquery alias). Multi-segment namespaces like + # `common.dimensions.xp.allocation_day.days_since_allocation` are + # almost always real DJ dim-node paths. + if len(col.namespace) == 1: + all_parts = [n.name for n in col.namespace] + [col.name.name] + any_prefix_is_parent = any( + ".".join(all_parts[:i]) in scope.parent_map + for i in range(1, len(all_parts)) + ) + if not any_prefix_is_parent: + scope.errors.append( + f"Column `{col}` references namespace `{table_alias}` " + f"which is not a table alias in this scope, a struct " + f"column, or a known DJ node.", + ) return UnknownType() # Unqualified column - search all tables @@ -851,8 +1052,10 @@ def _resolve_expr_type( ) if isinstance(expr, ast.Subscript): - # Subscript: col['key'] (map or struct access) or col[0] (array access). + # Subscript: col['key'] (map/struct access) or col[0] (array access). base_type = _resolve_expr_type(expr.expr, scope) + if isinstance(base_type, ListType): + return base_type.element.type if isinstance(base_type, MapType): return base_type.value.type if isinstance(base_type, StructType) and isinstance(expr.index, ast.String): diff --git a/datajunction-server/datajunction_server/internal/deployment/utils.py b/datajunction-server/datajunction_server/internal/deployment/utils.py index 43d2319e9..ae48dbb94 100644 --- a/datajunction-server/datajunction_server/internal/deployment/utils.py +++ b/datajunction-server/datajunction_server/internal/deployment/utils.py @@ -2,10 +2,11 @@ from collections import defaultdict from dataclasses import dataclass +from typing import Iterable from datajunction_server.internal.caching.interface import Cache from datajunction_server.service_clients import QueryServiceClient from datajunction_server.database.user import User -from datajunction_server.database.node import NodeRevision +from datajunction_server.database.node import Node, NodeRevision from datajunction_server.models.deployment import ( NodeSpec, CubeSpec, @@ -13,6 +14,7 @@ MetricSpec, TransformSpec, ) +from datajunction_server.models.node_type import NodeType from datajunction_server.utils import SEPARATOR from datajunction_server.sql.parsing import ast from datajunction_server.sql.parsing.ast import fast_parse_mode @@ -23,6 +25,66 @@ logger = logging.getLogger(__name__) +def extract_upstream_candidates( + query_ast: ast.Query, + is_metric: bool, +) -> set[str]: + """Scan a parsed query for upstream node name candidates. + + Returns ast.Table references (excluding CTE names). For derived metrics + (MetricSpec/MetricRevision with no FROM clause), also returns namespaced + column identifiers plus their parent paths — these are speculative prefix + candidates that callers must resolve against the DB to distinguish real + parents from dim-attribute refs that aren't SQL parents. + + Shared between the deployment path (extract_node_graph) and the + single-node path (validate_node_data) so both see identical candidates. + """ + cte_names = {cte.alias_or_name.identifier() for cte in query_ast.ctes} + tables = { + t.name.identifier() + for t in query_ast.find_all(ast.Table) + if t.name.identifier() not in cte_names + } + + if is_metric and not tables and query_ast.select.from_ is None: + for col in query_ast.find_all(ast.Column): + col_identifier = col.identifier() + if SEPARATOR in col_identifier: + tables.add(col_identifier) + parent_path = col_identifier.rsplit(SEPARATOR, 1)[0] + if SEPARATOR in parent_path: + tables.add(parent_path) + + return tables + + +def classify_parents( + is_derived_metric: bool, + dep_names: Iterable[str], + dependency_nodes: dict[str, Node], +) -> tuple[list[Node], list[str]]: + """Split a node's dependency names into resolved parents and missing names. + + For derived metrics the dep name set includes speculative namespace-prefix + candidates for dim-attribute refs like `ns.dim.col`, which aren't real + parents. Keep only METRIC-typed resolutions, and skip MissingParent + emission since unresolved prefixes aren't genuine missing references. + """ + resolved: list[Node] = [] + missing: list[str] = [] + for name in dep_names: + node = dependency_nodes.get(name) + if node is None: + if not is_derived_metric: + missing.append(name) + continue + if is_derived_metric and node.type != NodeType.METRIC: + continue + resolved.append(node) + return resolved, missing + + def extract_node_graph(nodes: list[NodeSpec]) -> dict[str, list[str]]: """ Extract the node graph from a list of nodes. @@ -63,31 +125,11 @@ def _find_upstreams_for_node(node: NodeSpec) -> tuple[str, list[str], ast.Query node.rendered_name, ) query_ast = parse(query_str) - cte_names = [cte.alias_or_name.identifier() for cte in query_ast.ctes] - tables = { - t.name.identifier() - for t in query_ast.find_all(ast.Table) - if t.name.identifier() not in cte_names - } - - # For derived metrics (no FROM clause), look for metric references - # in Column nodes. E.g., SELECT default.metric_a / default.metric_b - if ( - isinstance(node, MetricSpec) - and not tables - and query_ast.select.from_ is None - ): - for col in query_ast.find_all(ast.Column): - col_identifier = col.identifier() - if SEPARATOR in col_identifier: # pragma: no branch - # Add full identifier (might be a metric node) - tables.add(col_identifier) - # Also add parent path (might be dimension.column) - parent_path = col_identifier.rsplit(SEPARATOR, 1)[0] - if SEPARATOR in parent_path: # Only if there's still a namespace - tables.add(parent_path) - - return node.rendered_name, sorted(list(tables)), query_ast + candidates = extract_upstream_candidates( + query_ast, + is_metric=isinstance(node, MetricSpec), + ) + return node.rendered_name, sorted(candidates), query_ast if isinstance(node, CubeSpec): dimension_nodes = [dim.rsplit(".", 1)[0] for dim in node.rendered_dimensions] return node.rendered_name, node.rendered_metrics + dimension_nodes, None diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index 2258c249f..120b5afe4 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -56,7 +56,11 @@ schedule_materialization_jobs_bg, ) from datajunction_server.internal.history import ActivityType, EntityType -from datajunction_server.internal.validation import NodeValidator, validate_node_data +from datajunction_server.internal.validation import ( + NodeValidator, + validate_node_data, + validate_node_data_v2, +) from datajunction_server.models.attribute import ( AttributeTypeIdentifier, ColumnAttributes, @@ -3111,8 +3115,11 @@ async def revalidate_node( errors=errors, ) - # Revalidate all other node types - node_validator = await validate_node_data(current_node_revision, session) + # Revalidate all other node types. + # NOTE: uses validate_node_data_v2 so POST /nodes/{name}/validate/ exercises + # the new validator. Every other call site (create flows, propagation, + # downstream reference resolution) stays on legacy until cutover. + node_validator = await validate_node_data_v2(current_node_revision, session) # Compile and save query AST if update_query_ast and background_tasks: diff --git a/datajunction-server/datajunction_server/internal/validation.py b/datajunction-server/datajunction_server/internal/validation.py index c5d44f306..76d509951 100644 --- a/datajunction-server/datajunction_server/internal/validation.py +++ b/datajunction-server/datajunction_server/internal/validation.py @@ -15,6 +15,11 @@ DJInvalidMetricQueryException, ErrorCode, ) +from datajunction_server.internal.deployment.type_inference import validate_node_query +from datajunction_server.internal.deployment.utils import ( + classify_parents, + extract_upstream_candidates, +) from datajunction_server.models.base import labelize from datajunction_server.models.node import NodeRevisionBase, NodeStatus from datajunction_server.models.node_type import NodeType @@ -416,6 +421,350 @@ async def validate_node_data( return node_validator +# --------------------------------------------------------------------------- +# New validate_node_data — shares primitives with the deployment path. +# --------------------------------------------------------------------------- + + +def _format_query_for_validation(validated_node: NodeRevision) -> str: + """Apply metric-aliasing when needed so the parsed AST matches the shape + validate_node_query expects (mirrors the legacy path).""" + if validated_node.type == NodeType.METRIC: + return NodeRevision.format_metric_alias( + validated_node.query, # type: ignore + validated_node.name, + ) + return validated_node.query # type: ignore + + +# _map_validation_error intentionally removed — see internal/deployment/validation.py +# bulk_validate_node_data which wraps every validate_node_query error with +# ErrorCode.TYPE_INFERENCE and does no further filtering. The single-node path +# does the same below so its error surface matches deployment exactly. + + +def _build_columns_from_output( + output_columns: list, + query_ast: ast.Query, + validated_node: NodeRevision, +) -> tuple[list[Column], dict[str, str]]: + """Build Column objects for every AST projection item. + + validate_node_query drops columns whose type resolution raised (e.g., unresolved + dim-attribute refs) — legacy compile() produced all of them via DB lookups. + To preserve the legacy column list we walk the AST projection directly for + ordering and names, then look up types from output_columns where available. + Missing types fall back to UnknownType so downstream consumers still see the + column entry. + """ + from datajunction_server.sql.parsing.types import UnknownType + + try: + column_mapping = {col.name: col for col in validated_node.columns} + except MissingGreenlet: # pragma: no cover + column_mapping = {} # pragma: no cover + + types_by_name = {name: col_type for name, col_type in output_columns} + + columns: list[Column] = [] + type_inference_failures: dict[str, str] = {} + for idx, expr in enumerate(query_ast.select.projection): # type: ignore[attr-defined] + col_name = expr.alias_or_name.name # type: ignore[union-attr] + col_type = types_by_name.get(col_name, UnknownType()) + existing = column_mapping.get(col_name) + columns.append( + Column( + name=col_name.lower() + if validated_node.type != NodeType.METRIC + else col_name, + display_name=existing.display_name + if existing and existing.display_name + else labelize(col_name), + type=col_type, + attributes=[ + ColumnAttribute( + attribute_type_id=attr.attribute_type_id, + attribute_type=attr.attribute_type, + ) + for attr in existing.attributes + ] + if existing + else [], + dimension=existing.dimension if existing else None, + order=idx, + ), + ) + return columns, type_inference_failures + + +@timed( + "dj.node_validation.v2.ms", + lambda data, session: {"node_type": str(data.type)}, +) +async def validate_node_data_v2( + data: Union[NodeRevisionBase, NodeRevision], + session: AsyncSession, +) -> NodeValidator: + """ + New node validator — shares primitives (extract_upstream_candidates, + classify_parents, validate_node_query) with the deployment path so behavior + matches bulk deployment. + + Not wired to any caller yet. Exposed alongside legacy validate_node_data + for shadow-mode diff scripts to compare results. Once divergences are + triaged and fixed, swap names and delete the legacy impl. + """ + node_validator = NodeValidator() + + # Wrap NodeRevisionBase into a NodeRevision for internal consistency + if isinstance(data, NodeRevision): + validated_node = data + else: + node = Node(name=data.name, type=data.type) + validated_node = NodeRevision(**data.model_dump()) + validated_node.node = node + + # --- Step 1: parse (metric-aliased if needed) --- + try: + formatted_query = _format_query_for_validation(validated_node) + query_ast = parse(formatted_query) # type: ignore + except (DJParseException, ValueError, SqlSyntaxError) as exc: + node_validator.status = NodeStatus.INVALID + node_validator.errors.append( + DJError(code=ErrorCode.INVALID_SQL_QUERY, message=str(exc)), + ) + return node_validator + + # Stable col{n} names for unnamed projections (matches legacy behavior) + query_ast.select.add_aliases_to_unnamed_columns() + + # --- Step 2: extract upstream candidates (SHARED with deployment) --- + is_metric = validated_node.type == NodeType.METRIC + candidates = extract_upstream_candidates(query_ast, is_metric=is_metric) + + # --- Step 3: bulk load parents (ONE DB query instead of per-table compile) --- + # Use the default load options so dependencies can be serialized by API + # response models that access .availability, .materializations, .parents, + # .dimension_links (legacy compile implicitly traversed these). + dep_nodes: Dict[str, Node] = {} + if candidates: + loaded = await Node.get_by_names(session, sorted(candidates)) + dep_nodes = {n.name: n for n in loaded} + + # --- Step 4: classify parents (SHARED with deployment) --- + is_derived_metric = is_metric and query_ast.select.from_ is None + parents, missing = classify_parents( + is_derived_metric, + candidates, + dep_nodes, + ) + + node_validator.dependencies_map = { + parent.current: [] for parent in parents if parent.current + } + node_validator.missing_parents_map = {name: [] for name in missing} + + # --- Step 5: re-parse exotic column types before type inference --- + _reparse_parent_column_types(node_validator.dependencies_map) + + # --- Step 6: SQL validation + type inference --- + parent_columns_map: dict[str, dict] = {} + for parent in parents: + if parent.current and parent.current.columns: + parent_columns_map[parent.name] = { + col.name: col.type for col in parent.current.columns + } + # Include missing parents with empty columns so validate_node_query doesn't + # choke on unknown-table errors for refs we've already tracked. + for name in missing: + parent_columns_map.setdefault(name, {}) + + validation = validate_node_query( + formatted_query, + parent_columns_map, + pre_parsed=query_ast, + ) + + # --- Step 7: surface every validate_node_query error directly. Same code + # and no filtering as bulk_validate_node_data does for deployment — this is + # the parity point. If tightening or codes need to diverge later, revisit. + if validation.errors: + for msg in validation.errors: + node_validator.errors.append( + DJError(code=ErrorCode.TYPE_INFERENCE, message=msg), + ) + node_validator.status = NodeStatus.INVALID + + # --- Step 8: build columns from output_columns --- + columns, type_inference_failures = _build_columns_from_output( + validation.output_columns, + query_ast, + validated_node, + ) + node_validator.columns = columns + if type_inference_failures: + node_validator.status = NodeStatus.INVALID + + # --- Step 9: metric-specific checks (cross-fact shared-dim + MISSING_COLUMNS) --- + if is_metric and node_validator.dependencies_map: + metric_parents = [ + parent + for parent in node_validator.dependencies_map.keys() + if parent.type == NodeType.METRIC + ] + non_metric_parents = [ + parent + for parent in node_validator.dependencies_map.keys() + if parent.type != NodeType.METRIC + ] + + if metric_parents and len(metric_parents) > 1: + # Cross-fact derived metric: all base metrics must share >=1 dimension + from datajunction_server.sql.dag import get_dimensions + + all_dimension_sets: List[Set[str]] = [] + for base_metric in metric_parents: + dims = await get_dimensions( + session, + base_metric.node, + with_attributes=True, + ) + all_dimension_sets.append({d.name for d in dims}) + + if all_dimension_sets: # pragma: no branch + shared = all_dimension_sets[0] + for ds in all_dimension_sets[1:]: + shared = shared & ds + if not shared: + names = [m.name for m in metric_parents] + node_validator.status = NodeStatus.INVALID + node_validator.errors.append( + DJError( + code=ErrorCode.INVALID_PARENT, + message=( + f"Cannot create derived metric from base metrics with no shared " + f"dimensions. The following metrics have no dimensions in common: " + f"{', '.join(names)}. Cross-fact derived metrics require " + f"at least one shared dimension for joining." + ), + ), + ) + elif not metric_parents: + # Standard metric: SELECT cols must exist on non-metric parents + all_available_columns = { + col.name + for upstream_node in non_metric_parents + for col in upstream_node.columns + } + metric_expression = query_ast.select.projection[0] + referenced_columns = metric_expression.find_all(ast.Column) + missing_columns = [ + col.alias_or_name.name + for col in referenced_columns + if not col.namespace + and col.alias_or_name.name not in all_available_columns + ] + if missing_columns: + node_validator.status = NodeStatus.INVALID + node_validator.errors.append( + DJError( + code=ErrorCode.MISSING_COLUMNS, + message=( + f"Metric definition references missing columns: " + f"{', '.join(missing_columns)}" + ), + ), + ) + + # --- Step 10: invalid-parent check --- + invalid_parents = { + parent.name + for parent in node_validator.dependencies_map + if parent.type != NodeType.SOURCE and parent.status == NodeStatus.INVALID + } + if invalid_parents: + node_validator.errors.append( + DJError( + code=ErrorCode.INVALID_PARENT, + message=f"References invalid parent node(s) {','.join(invalid_parents)}", + ), + ) + node_validator.status = NodeStatus.INVALID + + # --- Step 11: required dimensions --- + try: + parent_columns = [ + col + for parent in node_validator.dependencies_map.keys() + for col in parent.columns + ] + required_dim_strings = [ + col.full_name() if isinstance(col, Column) else col + for col in validated_node.required_dimensions + ] + ( + invalid_required_dimensions, + matched_bound_columns, + ) = await find_required_dimensions( + session, + required_dim_strings, + parent_columns, + ) + node_validator.required_dimensions = matched_bound_columns + except MissingGreenlet: # pragma: no cover + invalid_required_dimensions = set() + node_validator.required_dimensions = [] + + # --- Step 12: final error assembly for missing parents, type-inference, and + # invalid required dims (matches legacy code shapes) --- + if ( + node_validator.missing_parents_map + or type_inference_failures + or invalid_required_dimensions + ): + node_validator.status = NodeStatus.INVALID + if node_validator.missing_parents_map: + node_validator.errors.append( + DJError( + code=ErrorCode.MISSING_PARENT, + message=( + f"Node definition contains references to nodes that do not " + f"exist: {','.join(node_validator.missing_parents_map.keys())}" + ), + debug={ + "missing_parents": list( + node_validator.missing_parents_map.keys(), + ), + }, + ), + ) + for column, message in type_inference_failures.items(): + node_validator.errors.append( + DJError( + code=ErrorCode.TYPE_INFERENCE, + message=message, + debug={"columns": [column]}, + ), + ) + if invalid_required_dimensions: + node_validator.errors.append( + DJError( + code=ErrorCode.INVALID_COLUMN, + message=( + "Node definition contains references to columns as " + "required dimensions that are not on parent nodes." + ), + debug={ + "invalid_required_dimensions": list( + invalid_required_dimensions, + ), + }, + ), + ) + + return node_validator + + def validate_metric_query(query_ast: ast.Query, name: str) -> None: """ Validate a metric query. diff --git a/datajunction-server/tests/internal/deployment/orchestration_test.py b/datajunction-server/tests/internal/deployment/orchestration_test.py index e4dd78535..3da1a0089 100644 --- a/datajunction-server/tests/internal/deployment/orchestration_test.py +++ b/datajunction-server/tests/internal/deployment/orchestration_test.py @@ -2137,6 +2137,123 @@ async def test_execute_deployment_plan_dry_run_savepoint_rollback( assert downstream == [] +class TestClassifyParents: + """Unit tests for DeploymentOrchestrator._classify_parents. + + Covers parity gaps with the single-node validation path: + - derived metrics drop non-metric resolutions and never emit MissingParent + - regular queries emit MissingParent for unresolved ast.Table names + - regular queries keep all resolved Node objects as parents regardless of type + """ + + @staticmethod + def _make_node(name: str, node_type): + from datajunction_server.models.node_type import NodeType as _NT + + node = MagicMock() + node.name = name + node.type = node_type if isinstance(node_type, _NT) else _NT(node_type) + return node + + def test_regular_query_unresolved_names_become_missing_parents(self): + """Transform whose ast.Table ref is absent from dependency_nodes emits a + MissingParent entry, preserving the single-node path's behavior.""" + spec = TransformSpec( + name="t", + namespace="default", + query="SELECT a FROM default.missing_src", + ) + resolved, missing = DeploymentOrchestrator._classify_parents( + spec, + dep_names=["default.missing_src"], + dependency_nodes={}, + ) + assert resolved == [] + assert missing == ["default.missing_src"] + + def test_regular_query_resolved_parents_kept_regardless_of_type(self): + """A transform that references a source returns the source Node as a parent.""" + from datajunction_server.models.node_type import NodeType as _NT + + src = self._make_node("default.src", _NT.SOURCE) + spec = TransformSpec( + name="t", + namespace="default", + query="SELECT a FROM default.src", + ) + resolved, missing = DeploymentOrchestrator._classify_parents( + spec, + dep_names=["default.src"], + dependency_nodes={"default.src": src}, + ) + assert resolved == [src] + assert missing == [] + + def test_derived_metric_drops_non_metric_resolutions(self): + """Derived metric that references a dim-attribute like `ns.dim.col` must + not store the dim as a parent — dim-attribute refs aren't SQL parents.""" + from datajunction_server.models.node_type import NodeType as _NT + + metric_a = self._make_node("default.metric_a", _NT.METRIC) + dim_x = self._make_node("default.dim_x", _NT.DIMENSION) + spec = MetricSpec( + name="derived", + namespace="default", + query="SELECT default.metric_a * 2 + default.dim_x.year", + ) + # extract_node_graph emits both full id and parent-path candidates + dep_names = [ + "default.metric_a", + "default.dim_x.year", + "default.dim_x", + ] + dependency_nodes = { + "default.metric_a": metric_a, + "default.dim_x": dim_x, + } + resolved, missing = DeploymentOrchestrator._classify_parents( + spec, + dep_names, + dependency_nodes, + ) + assert resolved == [metric_a] + # Derived metrics never emit MissingParent rows — speculative prefix + # candidates aren't genuine references. + assert missing == [] + + def test_derived_metric_never_emits_missing_parents(self): + """Even when a derived metric's candidate resolves to nothing, no + MissingParent is recorded (candidates are speculative prefix expansions).""" + spec = MetricSpec( + name="derived", + namespace="default", + query="SELECT default.metric_a / default.metric_b", + ) + resolved, missing = DeploymentOrchestrator._classify_parents( + spec, + dep_names=["default.metric_a", "default.metric_b"], + dependency_nodes={}, + ) + assert resolved == [] + assert missing == [] + + def test_metric_with_from_clause_treated_as_regular(self): + """A metric whose query has a FROM clause is NOT a derived metric and + follows the regular-query path (emit MissingParent for unresolved refs).""" + spec = MetricSpec( + name="std_metric", + namespace="default", + query="SELECT SUM(amount) FROM default.missing_fact", + ) + resolved, missing = DeploymentOrchestrator._classify_parents( + spec, + dep_names=["default.missing_fact"], + dependency_nodes={}, + ) + assert resolved == [] + assert missing == ["default.missing_fact"] + + class TestCreateOrUpdateDimensionJoinLink: """Tests for create_or_update_dimension_join_link idempotency.""" diff --git a/datajunction-server/tests/internal/deployment/test_type_inference.py b/datajunction-server/tests/internal/deployment/test_type_inference.py index 6567ce184..c612ca679 100644 --- a/datajunction-server/tests/internal/deployment/test_type_inference.py +++ b/datajunction-server/tests/internal/deployment/test_type_inference.py @@ -1588,14 +1588,16 @@ def test_values_in_subquery(self): assert isinstance(result.output_columns[0][1], IntegerType) def test_values_with_null(self): + """NULL in the first row for a column doesn't poison its type — + _resolve_inline_table scans later rows for a typed literal.""" result = validate_node_query( "SELECT id, val FROM (VALUES (1, NULL), (2, 'b')) AS t(id, val)", {}, ) assert result.output_columns[0][0] == "id" assert isinstance(result.output_columns[0][1], IntegerType) - assert result.output_columns[1] == ("val", UnknownType()) - assert any("Unable to infer type for column `val`" in e for e in result.errors) + assert result.output_columns[1] == ("val", StringType()) + assert not result.errors, result.errors def test_values_with_boolean(self): result = validate_node_query( @@ -1792,9 +1794,14 @@ def test_wrong_table_alias(self): ) assert isinstance(result.output_columns[0][1], UnknownType) - def test_derived_metric_dim_not_in_map(self): + def test_derived_metric_dim_not_in_map_is_permissive(self): + """Derived metrics can reference dim attributes that aren't parents — + they resolve at query build time via dim links. The in-memory + inferrer can't know that, so it types them as UnknownType rather than + raising. Matches the regular-query branch's dim-link permissiveness. + """ result = validate_node_query( - "SELECT default.nonexistent_dim.col", + "SELECT default.some_dim.col", _col_map( ( "default.total_revenue", @@ -1802,7 +1809,10 @@ def test_derived_metric_dim_not_in_map(self): ), ), ) - assert any("not found" in e for e in result.errors) + # No "not found" error — the ref is allowed to stay untyped. + assert not any( + "not found in derived metric scope" in e for e in result.errors + ), result.errors def test_multiple_errors_collected(self): """Multiple bad columns → all errors collected, not just the first.""" @@ -2005,3 +2015,405 @@ def test_unknown_type_always_changed(self): ) is True ) + + +# --------------------------------------------------------------------------- +# EXPLODE struct unpacking, FROMless lateral views, unresolved-namespace hints +# --------------------------------------------------------------------------- + + +class TestExplodeStructUnpacking: + """EXPLODE(array>) AS (c1, c2) should alias the struct fields + positionally, matching Spark behavior.""" + + @staticmethod + def _cells_type(): + from datajunction_server.sql.parsing.types import ( + ListType, + NestedField, + StructType, + ) + + return ListType( + element_type=StructType( + NestedField(name="cell_id", field_type=StringType()), + NestedField(name="cell_name", field_type=StringType()), + ), + ) + + def test_projection_explode_struct_array_unpacks_fields(self): + """SELECT test_id, EXPLODE(cells) AS (cell_id, cell_name) FROM t + where cells is array> + → (test_id: bigint, cell_id: string, cell_name: string).""" + result = validate_node_query( + "SELECT test_id, EXPLODE(cells) AS (cell_id, cell_name) FROM t.src", + _col_map( + ( + "t.src", + [ + ("test_id", BigIntType()), + ("cells", self._cells_type()), + ], + ), + ), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("test_id", BigIntType()), + ("cell_id", StringType()), + ("cell_name", StringType()), + ] + + def test_projection_explode_scalar_array_single_column(self): + """Regression: EXPLODE(array) AS x produces one int column. + The struct-unpacking branch must NOT kick in for a scalar element type.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT id, EXPLODE(nums) AS n FROM t.src", + _col_map( + ( + "t.src", + [ + ("id", IntegerType()), + ("nums", ListType(element_type=IntegerType())), + ], + ), + ), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("id", IntegerType()), + ("n", IntegerType()), + ] + + def test_lateral_view_posexplode_struct_keeps_struct_intact(self): + """POSEXPLODE does NOT struct-unpack — first alias is pos (int), second + is the struct element as-is. (LATERAL VIEW form; the projection form + for POSEXPLODE with a parenthesized alias list isn't parseable.)""" + result = validate_node_query( + "SELECT pos, tag FROM t.src LATERAL VIEW POSEXPLODE(cells) v AS pos, tag", + _col_map(("t.src", [("cells", self._cells_type())])), + ) + assert not result.errors, result.errors + assert result.output_columns[0] == ("pos", IntegerType()) + # Second col is the struct element, not one of its fields + assert result.output_columns[1][0] == "tag" + from datajunction_server.sql.parsing.types import StructType + + assert isinstance(result.output_columns[1][1], StructType) + + def test_lateral_view_explode_struct_array_unpacks_fields(self): + """Same struct-unpacking for LATERAL VIEW form.""" + result = validate_node_query( + "SELECT cell_id, cell_name FROM t.src " + "LATERAL VIEW EXPLODE(cells) v AS cell_id, cell_name", + _col_map(("t.src", [("cells", self._cells_type())])), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("cell_id", StringType()), + ("cell_name", StringType()), + ] + + def test_projection_explode_struct_in_anonymous_subquery(self): + """Composition: anonymous subquery with projection-EXPLODE whose + struct-unpacked columns are referenced from the outer projection.""" + result = validate_node_query( + "SELECT CAST(test_id AS BIGINT) AS test_id, cell_id, cell_name " + "FROM ( SELECT test_id, EXPLODE(cells) AS (cell_id, cell_name) " + " FROM t.src )", + _col_map( + ( + "t.src", + [ + ("test_id", IntegerType()), + ("cells", self._cells_type()), + ], + ), + ), + ) + assert not result.errors, result.errors + assert [n for n, _ in result.output_columns] == [ + "test_id", + "cell_id", + "cell_name", + ] + # CAST gives bigint; struct-unpacked fields keep their string type + assert result.output_columns[0] == ("test_id", BigIntType()) + assert result.output_columns[1] == ("cell_id", StringType()) + assert result.output_columns[2] == ("cell_name", StringType()) + + +class TestMultipleAnonymousLateralViews: + def test_two_anonymous_lateral_views_do_not_collide(self): + """Two LATERAL VIEW EXPLODE(sequence(...)) in the same SELECT each + default to __lateral__; the new __lateral_{idx}__ fallback keeps both + sets of exploded columns reachable from the outer projection.""" + result = validate_node_query( + "SELECT window_start, window_end FROM (SELECT 1 AS d) t " + "LATERAL VIEW EXPLODE(SEQUENCE(1, 97)) AS window_start " + "LATERAL VIEW EXPLODE(SEQUENCE(2, 98)) AS window_end", + {}, + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("window_start", IntegerType()), + ("window_end", IntegerType()), + ] + + +class TestFromlessLateralView: + def test_fromless_query_with_lateral_view_resolves_exploded_columns(self): + """Query with LATERAL VIEW but no FROM clause — used as a series + generator in Spark — should still be processed, not routed to the + __derived__ scope.""" + result = validate_node_query( + "SELECT CAST(CONCAT(window_start, '-', window_end) AS string) AS obs_window, " + "window_start AS obs_window_start, " + "window_end AS obs_window_end " + "LATERAL VIEW EXPLODE(SEQUENCE(1, 97)) AS window_start " + "LATERAL VIEW EXPLODE(SEQUENCE(2, 98)) AS window_end " + "WHERE window_start < window_end AND window_start = 1", + {}, + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("obs_window", StringType()), + ("obs_window_start", IntegerType()), + ("obs_window_end", IntegerType()), + ] + + def test_fromless_query_without_lateral_view_stays_derived_metric(self): + """Regression: a true derived metric (no FROM, no lateral views) must + still go through the __derived__ scope so bare metric-name refs + resolve against parent_map.""" + result = validate_node_query( + "SELECT default.metric_a / default.metric_b AS ratio", + _col_map( + ("default.metric_a", [("metric_a", DoubleType())]), + ("default.metric_b", [("metric_b", DoubleType())]), + ), + ) + assert not result.errors, result.errors + assert [n for n, _ in result.output_columns] == ["ratio"] + + def test_inline_table_with_column_aliases_uses_explicit_alias(self): + """CROSS JOIN VALUES (...) tab(c1, c2) — the `tab` alias lands on + InlineTable.name rather than InlineTable.alias. v2 must honor it so + outer refs like `tab.c2` resolve instead of hitting the `__inline__` + fallback.""" + result = validate_node_query( + "SELECT tab.window_end AS tenure FROM source.s a " + "CROSS JOIN VALUES ('1-7', 1, 7), ('1-14', 1, 14) " + "tab(label, window_start, window_end)", + _col_map(("source.s", [("id", IntegerType())])), + ) + assert not result.errors, result.errors + assert result.output_columns == [("tenure", IntegerType())] + + def test_inline_table_scans_past_null_for_column_type(self): + """VALUES(NULL, 'a'), ('x', 'b') — first-row NULL for col0 must not + poison the column type. Scan rows until a typed literal is found.""" + result = validate_node_query( + "SELECT app_start_type, label FROM VALUES " + "(NULL, 'From Nflx'), ('COLD', 'COLD'), ('WARM', 'WARM') " + "AS t (app_start_type, label)", + {}, + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("app_start_type", StringType()), + ("label", StringType()), + ] + + def test_inline_table_all_null_column_is_null_type(self): + """When every row has NULL for a given column, the type is NullType + (not UnknownType — genuinely nullable-only).""" + result = validate_node_query( + "SELECT x, y FROM VALUES (NULL, 1), (NULL, 2) AS t(x, y)", + {}, + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("x", NullType()), + ("y", IntegerType()), + ] + + def test_lateral_view_explode_from_json_map_resolves_key_value(self): + """LATERAL VIEW EXPLODE(from_json(string_col, 'MAP')) — the + function dispatch for from_json returns a MapType, EXPLODE of a map gives + (key, value) pairs, and the AS list aliases them. All columns resolve + cleanly.""" + result = validate_node_query( + "SELECT ntl.test_id, ntl.evidence_map FROM source.t F " + "LATERAL VIEW EXPLODE(from_json(xpEvidenceMap, 'MAP')) " + "ntl AS test_id, evidence_map", + _col_map(("source.t", [("xpEvidenceMap", StringType())])), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("test_id", StringType()), + ("evidence_map", StringType()), + ] + + def test_cross_join_unnest_resolves_element_type_from_sibling_table(self): + """`FROM t CROSS JOIN UNNEST(t.arr) AS u(x)` — UNNEST is the right side + of a JOIN and references a column on the left-side table. The + resolution must see the left table in scope (outer-scope propagation), + and the element type of the array must flow into u.x.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT id, x FROM t.src CROSS JOIN UNNEST(vals) AS u(x)", + _col_map( + ( + "t.src", + [ + ("id", IntegerType()), + ("vals", ListType(element_type=IntegerType())), + ], + ), + ), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("id", IntegerType()), + ("x", IntegerType()), + ] + + def test_cross_join_unnest_map_subscript_resolves_to_element_type(self): + """UNNEST(map_col['key']) where map_col: map>. + The subscript returns array, UNNEST returns rows of int.""" + result = validate_node_query( + "SELECT AVG(x) AS avg_x FROM t.src " + "CROSS JOIN UNNEST(m['home']) AS u(x) " + "GROUP BY id", + _col_map( + ( + "t.src", + [ + ("id", IntegerType()), + # map> + ( + "m", + __import__( + "datajunction_server.sql.parsing.backends.antlr4", + fromlist=["parse_rule"], + ).parse_rule("map>", "dataType"), + ), + ], + ), + ), + ) + assert not result.errors, result.errors + assert result.output_columns[0][0] == "avg_x" + # AVG of int is double + from datajunction_server.sql.parsing.types import DoubleType + + assert isinstance(result.output_columns[0][1], DoubleType) + + def test_lateral_view_explode_missing_source_column_surfaces_real_error(self): + """When EXPLODE's argument references a nonexistent column, the + 'Column X not found' error should be surfaced alongside (and before + the user sees) the downstream 'Unable to infer type' noise.""" + result = validate_node_query( + "SELECT ntl.test_id FROM source.t F " + "LATERAL VIEW EXPLODE(from_json(xpEvidenceMap, 'MAP')) " + "ntl AS test_id, evidence_map", + _col_map(("source.t", [("xpEvidence", StringType())])), # renamed col + ) + assert any( + "Column `xpEvidenceMap` not found in any table" in e for e in result.errors + ), result.errors + + def test_get_json_object_on_exploded_value_resolves_to_string(self): + """get_json_object(map_value_col, '$.path.x') should resolve cleanly to + string when applied to a string (the exploded map value).""" + result = validate_node_query( + "SELECT get_json_object(ntl.evidence_map, '$.67291.alloc_cell') AS cell " + "FROM source.t F " + "LATERAL VIEW EXPLODE(from_json(xpEvidenceMap, 'MAP')) " + "ntl AS test_id, evidence_map", + _col_map(("source.t", [("xpEvidenceMap", StringType())])), + ) + assert not result.errors, result.errors + assert result.output_columns == [("cell", StringType())] + + def test_subscript_on_list_type_returns_element_type(self): + """arr[1] where arr: array should resolve to string. Previously + the Subscript handler only covered map and struct, falling through to + UnknownType for arrays.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT arr[1] AS first FROM t.src", + _col_map(("t.src", [("arr", ListType(element_type=StringType()))])), + ) + assert not result.errors, result.errors + assert result.output_columns == [("first", StringType())] + + def test_derived_metric_dim_attribute_window_ref_is_permissive(self): + """Derived metrics often reference dim attributes inside window + functions (`OVER (ORDER BY common.dimensions.time.date.dateint)`). + These aren't parents — they resolve via dim links at query build time. + v2 should type them as UnknownType and not reject.""" + result = validate_node_query( + "SELECT AVG(demo.metrics.main.avg_dl) OVER (" + " ORDER BY common.dimensions.time.date.dateint " + " ROWS BETWEEN 6 PRECEDING AND CURRENT ROW" + ") AS trailing_avg", + _col_map( + ("demo.metrics.main.avg_dl", [("avg_dl", DoubleType())]), + ), + ) + assert not result.errors, result.errors + assert [n for n, _ in result.output_columns] == ["trailing_avg"] + + +class TestUnresolvedNamespaceDiagnostic: + def test_bogus_namespace_emits_specific_error(self): + """A namespace that matches no table alias, struct column, or known + parent should surface a descriptive error rather than only the generic + 'Unable to infer type'.""" + result = validate_node_query( + "SELECT a.id, tenure.tenure FROM (SELECT 1 AS id) AS a", + {}, + ) + # Must produce the namespace-specific message (not only the generic + # Unable-to-infer-type that used to be the only signal). + assert any( + "namespace `tenure`" in msg and "not a table alias" in msg + for msg in result.errors + ), result.errors + + def test_known_parent_prefix_does_not_trigger_namespace_error(self): + """When some prefix of the column path matches a known parent in + parent_map, the namespace-error heuristic should stay quiet — the + reference may be a legit dim-link that resolves at query build time.""" + result = validate_node_query( + "SELECT src.orders.country.name FROM src.orders o", + _col_map(("src.orders", [("id", IntegerType())])), + ) + # No namespace-specific error — src.orders IS a known parent. + assert not any( + "not a table alias in this scope" in msg for msg in result.errors + ), result.errors + + def test_multi_segment_dim_attribute_ref_does_not_trigger_namespace_error(self): + """Long dim-attribute paths like + `common.dimensions.xp.allocation_day.days_since_allocation` inside a + metric's CASE expression aren't parents, and their prefixes don't + match parent_map either. But they're real DJ dim-node paths, not + typos. Only single-segment namespaces (`tenure.tenure`-style typos) + should be flagged.""" + result = validate_node_query( + "SELECT COUNT(DISTINCT CASE WHEN view_secs >= 360 " + "THEN common.dimensions.xp.allocation_day.days_since_allocation " + "ELSE NULL END) AS m " + "FROM users.foo.playback", + _col_map(("users.foo.playback", [("view_secs", IntegerType())])), + ) + assert not any("references namespace" in msg for msg in result.errors), ( + result.errors + ) diff --git a/datajunction-server/tests/internal/node_validation_test.py b/datajunction-server/tests/internal/node_validation_test.py index 64f65773e..1f18244c3 100644 --- a/datajunction-server/tests/internal/node_validation_test.py +++ b/datajunction-server/tests/internal/node_validation_test.py @@ -748,3 +748,114 @@ def test_already_parsed_type_is_skipped(self): def test_empty_map_is_noop(self): """An empty dependencies_map doesn't raise.""" _reparse_parent_column_types({}) + + +# --------------------------------------------------------------------------- +# validate_node_data_v2 smoke tests +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def v2_parent_source(session: AsyncSession, user: User) -> Node: + """Source node with typed columns for use as a parent in v2 tests.""" + node = Node( + name="test.v2_parent", + type=NodeType.SOURCE, + created_by_id=user.id, + current_version="v1.0", + ) + revision = NodeRevision( + name="test.v2_parent", + display_name="v2 parent", + type=NodeType.SOURCE, + query=None, + status=NodeStatus.VALID, + version="v1.0", + node=node, + columns=[ + Column(name="id", type=ct.BigIntType(), order=0), + Column(name="is_winning_bid", type=ct.BooleanType(), order=1), + ], + created_by_id=user.id, + ) + session.add(node) + session.add(revision) + await session.commit() + return node + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_returns_valid_for_simple_transform( + session: AsyncSession, + user: User, + v2_parent_source: Node, +): + """Happy path: simple transform selecting a typed column resolves cleanly.""" + from datajunction_server.internal.validation import validate_node_data_v2 + + data = NodeRevisionBase( + name="test.v2_child_valid", + display_name="v2 child", + type=NodeType.TRANSFORM, + query="SELECT id FROM test.v2_parent", + mode="published", + ) + validator = await validate_node_data_v2(data, session) + + assert validator.status == NodeStatus.VALID, validator.errors + assert [c.name for c in validator.columns] == ["id"] + assert not validator.missing_parents_map + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_flags_missing_parent( + session: AsyncSession, + user: User, +): + """A query referencing a non-existent parent surfaces the missing parent + via both `missing_parents_map` and an error on the validator.""" + from datajunction_server.errors import ErrorCode + from datajunction_server.internal.validation import validate_node_data_v2 + + data = NodeRevisionBase( + name="test.v2_missing_parent_child", + display_name="v2 missing parent child", + type=NodeType.TRANSFORM, + query="SELECT id FROM test.does_not_exist", + mode="published", + ) + validator = await validate_node_data_v2(data, session) + + assert validator.status == NodeStatus.INVALID + assert "test.does_not_exist" in validator.missing_parents_map + assert any(err.code == ErrorCode.MISSING_PARENT for err in validator.errors), [ + (e.code, e.message) for e in validator.errors + ] + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_flags_sum_boolean( + session: AsyncSession, + user: User, + v2_parent_source: Node, +): + """SUM(boolean) is not a valid Spark aggregation — v2 should surface a + TYPE_INFERENCE error. Locks in the un-suppression of + 'Unable to infer type' error strings at the single-node boundary.""" + from datajunction_server.errors import ErrorCode + from datajunction_server.internal.validation import validate_node_data_v2 + + data = NodeRevisionBase( + name="test.v2_sum_boolean", + display_name="sum boolean metric", + type=NodeType.METRIC, + query="SELECT SUM(is_winning_bid) FROM test.v2_parent", + mode="published", + ) + validator = await validate_node_data_v2(data, session) + + assert validator.status == NodeStatus.INVALID + assert any( + err.code == ErrorCode.TYPE_INFERENCE and "Unable to infer type" in err.message + for err in validator.errors + ), [(e.code, e.message) for e in validator.errors] From aa8aea46026924e81067800ee81a673812b4f5ae Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Tue, 21 Apr 2026 02:22:36 -0700 Subject: [PATCH 2/5] Fix invalid sql in fixtures --- datajunction-clients/python/tests/examples.py | 2 +- datajunction-server/tests/api/nodes_test.py | 8 ++++++-- datajunction-server/tests/examples.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/datajunction-clients/python/tests/examples.py b/datajunction-clients/python/tests/examples.py index 149ba90d9..682e59221 100644 --- a/datajunction-clients/python/tests/examples.py +++ b/datajunction-clients/python/tests/examples.py @@ -372,7 +372,7 @@ FROM default.hard_hats hh LEFT JOIN default.hard_hat_state hhs ON hh.hard_hat_id = hhs.hard_hat_id - WHERE hh.state_id = 'NY' + WHERE hhs.state_id = 'NY' """, "mode": "published", "name": "default.local_hard_hats", diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index 011489f19..ac1414010 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -4899,11 +4899,15 @@ async def test_revalidating_existing_nodes(self, client_with_roads: AsyncClient) ) for node in (await client_with_roads.get("/nodes/")).json(): if node.startswith("default."): - status = ( + response = ( await client_with_roads.post( f"/nodes/{node}/validate/", ) - ).json()["status"] + ).json() + status = response["status"] + print("node", node, "status", response) + assert status == "valid" + print("node", node, "status", status) assert status == "valid" # Confirm that they still show as valid server-side for node in (await client_with_roads.get("/nodes/")).json(): diff --git a/datajunction-server/tests/examples.py b/datajunction-server/tests/examples.py index c519061ad..ed70bb84c 100644 --- a/datajunction-server/tests/examples.py +++ b/datajunction-server/tests/examples.py @@ -446,7 +446,7 @@ FROM default.hard_hats hh LEFT JOIN default.hard_hat_state hhs ON hh.hard_hat_id = hhs.hard_hat_id - WHERE hh.state_id = 'NY' + WHERE hhs.state_id = 'NY' """, "mode": "published", "name": "default.local_hard_hats", @@ -1224,7 +1224,7 @@ FROM foo.bar.hard_hats hh LEFT JOIN foo.bar.hard_hat_state hhs ON hh.hard_hat_id = hhs.hard_hat_id - WHERE hh.state_id = 'NY' + WHERE hhs.state_id = 'NY' """, "mode": "published", "name": "foo.bar.local_hard_hats", @@ -1544,7 +1544,7 @@ "query": ( "SELECT payment_id, payment_amount, customer_id, account_type " "FROM default.revenue WHERE " - "large_revenue_payments_and_business_only > 1000000 " + "payment_amount > 1000000 " "AND account_type='BUSINESS'" ), "description": "Only large revenue payments from business accounts", @@ -1558,7 +1558,7 @@ "query": ( "SELECT payment_id, payment_amount, customer_id, account_type " "FROM default.revenue WHERE " - "large_revenue_payments_and_business_only > 1000000 " + "payment_amount > 1000000 " "AND account_type='BUSINESS'" ), "description": "Only large revenue payments from business accounts 1", From 688ca1f545ec3dbd5a34025b3ad83f4892a05c39 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Tue, 21 Apr 2026 03:42:00 -0700 Subject: [PATCH 3/5] Fix --- .../internal/validation.py | 80 ++++----- .../deployment/test_type_inference.py | 96 +++++++++++ .../tests/internal/node_validation_test.py | 158 ++++++++++++++++++ 3 files changed, 284 insertions(+), 50 deletions(-) diff --git a/datajunction-server/datajunction_server/internal/validation.py b/datajunction-server/datajunction_server/internal/validation.py index 76d509951..c2ad0ef7a 100644 --- a/datajunction-server/datajunction_server/internal/validation.py +++ b/datajunction-server/datajunction_server/internal/validation.py @@ -447,7 +447,7 @@ def _build_columns_from_output( output_columns: list, query_ast: ast.Query, validated_node: NodeRevision, -) -> tuple[list[Column], dict[str, str]]: +) -> list[Column]: """Build Column objects for every AST projection item. validate_node_query drops columns whose type resolution raised (e.g., unresolved @@ -459,15 +459,10 @@ def _build_columns_from_output( """ from datajunction_server.sql.parsing.types import UnknownType - try: - column_mapping = {col.name: col for col in validated_node.columns} - except MissingGreenlet: # pragma: no cover - column_mapping = {} # pragma: no cover - + column_mapping = {col.name: col for col in validated_node.columns} types_by_name = {name: col_type for name, col_type in output_columns} columns: list[Column] = [] - type_inference_failures: dict[str, str] = {} for idx, expr in enumerate(query_ast.select.projection): # type: ignore[attr-defined] col_name = expr.alias_or_name.name # type: ignore[union-attr] col_type = types_by_name.get(col_name, UnknownType()) @@ -494,7 +489,7 @@ def _build_columns_from_output( order=idx, ), ) - return columns, type_inference_failures + return columns @timed( @@ -596,14 +591,11 @@ async def validate_node_data_v2( node_validator.status = NodeStatus.INVALID # --- Step 8: build columns from output_columns --- - columns, type_inference_failures = _build_columns_from_output( + node_validator.columns = _build_columns_from_output( validation.output_columns, query_ast, validated_node, ) - node_validator.columns = columns - if type_inference_failures: - node_validator.status = NodeStatus.INVALID # --- Step 9: metric-specific checks (cross-fact shared-dim + MISSING_COLUMNS) --- if is_metric and node_validator.dependencies_map: @@ -691,37 +683,33 @@ async def validate_node_data_v2( ) node_validator.status = NodeStatus.INVALID - # --- Step 11: required dimensions --- - try: - parent_columns = [ - col - for parent in node_validator.dependencies_map.keys() - for col in parent.columns - ] - required_dim_strings = [ - col.full_name() if isinstance(col, Column) else col - for col in validated_node.required_dimensions - ] - ( - invalid_required_dimensions, - matched_bound_columns, - ) = await find_required_dimensions( - session, - required_dim_strings, - parent_columns, - ) - node_validator.required_dimensions = matched_bound_columns - except MissingGreenlet: # pragma: no cover - invalid_required_dimensions = set() - node_validator.required_dimensions = [] + # --- Step 11: required dimensions. Parents come from Node.get_by_names's + # default_load_options which eagerly loads NodeRevision.columns, so the + # `parent.columns` access here doesn't trigger a lazy load. Let any real + # MissingGreenlet propagate rather than silently swallowing it — that + # would hide a genuine eager-load regression. + parent_columns = [ + col + for parent in node_validator.dependencies_map.keys() + for col in parent.columns + ] + required_dim_strings = [ + col.full_name() if isinstance(col, Column) else col + for col in validated_node.required_dimensions + ] + ( + invalid_required_dimensions, + matched_bound_columns, + ) = await find_required_dimensions( + session, + required_dim_strings, + parent_columns, + ) + node_validator.required_dimensions = matched_bound_columns - # --- Step 12: final error assembly for missing parents, type-inference, and - # invalid required dims (matches legacy code shapes) --- - if ( - node_validator.missing_parents_map - or type_inference_failures - or invalid_required_dimensions - ): + # --- Step 12: final error assembly for missing parents + invalid required + # dims (matches legacy code shapes). + if node_validator.missing_parents_map or invalid_required_dimensions: node_validator.status = NodeStatus.INVALID if node_validator.missing_parents_map: node_validator.errors.append( @@ -738,14 +726,6 @@ async def validate_node_data_v2( }, ), ) - for column, message in type_inference_failures.items(): - node_validator.errors.append( - DJError( - code=ErrorCode.TYPE_INFERENCE, - message=message, - debug={"columns": [column]}, - ), - ) if invalid_required_dimensions: node_validator.errors.append( DJError( diff --git a/datajunction-server/tests/internal/deployment/test_type_inference.py b/datajunction-server/tests/internal/deployment/test_type_inference.py index c612ca679..b2d752dea 100644 --- a/datajunction-server/tests/internal/deployment/test_type_inference.py +++ b/datajunction-server/tests/internal/deployment/test_type_inference.py @@ -2417,3 +2417,99 @@ def test_multi_segment_dim_attribute_ref_does_not_trigger_namespace_error(self): assert not any("references namespace" in msg for msg in result.errors), ( result.errors ) + + +class TestCoverageGaps: + """Tests targeting specific uncovered branches in type_inference.py.""" + + def test_inline_table_explicit_alias_on_node_alias(self): + """VALUES (…) AS t(...) form can park the alias on node.alias + rather than node.name. Exercises the `node.alias is not None` + branch of the InlineTable handler (line 363).""" + result = validate_node_query( + "SELECT t.x FROM (VALUES (1), (2)) AS t(x)", + {}, + ) + assert not result.errors, result.errors + assert result.output_columns == [("x", IntegerType())] + + def test_cross_join_unnest_struct_array_unpacks_fields_in_from(self): + """Struct-unpacking in the FROM-clause FunctionTableExpression branch + (line 397): `CROSS JOIN UNNEST(array>) AS t(c1, c2)` + expands positionally to c1: a_type, c2: b_type.""" + from datajunction_server.sql.parsing.backends.antlr4 import parse_rule + + cells = parse_rule( + "array>", + "dataType", + ) + result = validate_node_query( + "SELECT t.cell_id, t.cell_name FROM src.s " + "CROSS JOIN UNNEST(cells) AS t(cell_id, cell_name)", + _col_map(("src.s", [("cells", cells)])), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("cell_id", StringType()), + ("cell_name", StringType()), + ] + + def test_lateral_view_explode_with_too_many_aliases_fills_unknown(self): + """LATERAL VIEW EXPLODE(scalar_array) AS a, b, c — 3 aliases but + element is a single scalar. The extras (line 623) fall back to + UnknownType.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT v.a, v.b, v.c FROM src.s LATERAL VIEW EXPLODE(nums) v AS a, b, c", + _col_map( + ("src.s", [("nums", ListType(element_type=IntegerType()))]), + ), + ) + assert result.output_columns[0] == ("a", IntegerType()) + assert isinstance(result.output_columns[1][1], UnknownType) + assert isinstance(result.output_columns[2][1], UnknownType) + + def test_projection_explode_scalar_array_with_extra_aliases_unknown(self): + """Projection `EXPLODE(arr) AS (a, b, c)` on a scalar-array column — + only one element type, extras (line 623) fall back to UnknownType.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT EXPLODE(nums) AS (a, b, c) FROM src.s", + _col_map( + ("src.s", [("nums", ListType(element_type=IntegerType()))]), + ), + ) + assert result.output_columns[0] == ("a", IntegerType()) + assert isinstance(result.output_columns[1][1], UnknownType) + assert isinstance(result.output_columns[2][1], UnknownType) + + def test_lateral_element_types_propagates_typeresolution_error(self): + """When EXPLODE's argument references a nonexistent column, the + TypeResolutionError message flows into `errors` (lines 492→494) + so the real root cause surfaces, not just 'Unable to infer type'.""" + result = validate_node_query( + "SELECT v.x FROM src.s LATERAL VIEW EXPLODE(missing_col) v AS x", + _col_map(("src.s", [("id", IntegerType())])), + ) + assert any( + "Column `missing_col` not found in any table" in msg + for msg in result.errors + ), result.errors + + def test_single_segment_namespace_matching_parent_is_silent(self): + """Single-segment namespace whose name matches a parent_map key (not + a FROM table) should NOT trigger the 'references namespace' error — + the `if not any_prefix_is_parent` guard (line 900) short-circuits and + the fallback just returns UnknownType (line 906).""" + result = validate_node_query( + "SELECT src.foo_missing FROM other.t", + _col_map( + ("other.t", [("id", IntegerType())]), + ("src", [("real_col", IntegerType())]), + ), + ) + assert not any("references namespace" in msg for msg in result.errors), ( + result.errors + ) diff --git a/datajunction-server/tests/internal/node_validation_test.py b/datajunction-server/tests/internal/node_validation_test.py index 1f18244c3..b22cdf47a 100644 --- a/datajunction-server/tests/internal/node_validation_test.py +++ b/datajunction-server/tests/internal/node_validation_test.py @@ -859,3 +859,161 @@ async def test_validate_node_data_v2_flags_sum_boolean( err.code == ErrorCode.TYPE_INFERENCE and "Unable to infer type" in err.message for err in validator.errors ), [(e.code, e.message) for e in validator.errors] + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_surfaces_parse_errors( + session: AsyncSession, + user: User, +): + """Malformed SQL returns INVALID_SQL_QUERY and doesn't crash validation.""" + from datajunction_server.errors import ErrorCode + from datajunction_server.internal.validation import validate_node_data_v2 + + data = NodeRevisionBase( + name="test.v2_bad_sql", + display_name="bad sql", + type=NodeType.TRANSFORM, + query="SELECT ))", # guaranteed parser failure + mode="published", + ) + validator = await validate_node_data_v2(data, session) + assert validator.status == NodeStatus.INVALID + assert any(err.code == ErrorCode.INVALID_SQL_QUERY for err in validator.errors), [ + (e.code, e.message) for e in validator.errors + ] + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_no_candidates_skips_db_load( + session: AsyncSession, + user: User, +): + """A derived metric that references no namespaced candidates (pure literal + arithmetic) should skip the bulk Node.get_by_names load entirely.""" + from datajunction_server.internal.validation import validate_node_data_v2 + + data = NodeRevisionBase( + name="test.v2_literal_metric", + display_name="literal metric", + type=NodeType.METRIC, + query="SELECT 1 + 1", + mode="published", + ) + validator = await validate_node_data_v2(data, session) + # Status may be valid or invalid depending on metric-query constraints; + # the important coverage is that validation completed without a DB load. + assert validator.missing_parents_map == {} + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_skips_parents_without_columns( + session: AsyncSession, + user: User, +): + """Parent nodes whose current revision has no columns are skipped when + building parent_columns_map (line 573→572).""" + from datajunction_server.internal.validation import validate_node_data_v2 + + parent = Node( + name="test.v2_empty_parent", + type=NodeType.SOURCE, + created_by_id=user.id, + current_version="v1.0", + ) + revision = NodeRevision( + name="test.v2_empty_parent", + display_name="empty parent", + type=NodeType.SOURCE, + query=None, + status=NodeStatus.VALID, + version="v1.0", + node=parent, + columns=[], # intentionally empty + created_by_id=user.id, + ) + session.add(parent) + session.add(revision) + await session.commit() + + data = NodeRevisionBase( + name="test.v2_empty_parent_child", + display_name="child of empty parent", + type=NodeType.TRANSFORM, + query="SELECT * FROM test.v2_empty_parent", + mode="published", + ) + validator = await validate_node_data_v2(data, session) + # Validation should not crash; the empty-columns parent is just skipped. + assert validator.status in (NodeStatus.VALID, NodeStatus.INVALID) + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_flags_invalid_required_dimensions( + session: AsyncSession, + user: User, +): + """required_dimensions pointing at a column that doesn't exist on the + referenced dim node surfaces an INVALID_COLUMN error.""" + from datajunction_server.errors import ErrorCode + from datajunction_server.internal.validation import validate_node_data_v2 + + source = Node( + name="test.v2_req_dim_parent", + type=NodeType.SOURCE, + created_by_id=user.id, + current_version="v1.0", + ) + source_rev = NodeRevision( + name="test.v2_req_dim_parent", + display_name="req dim parent", + type=NodeType.SOURCE, + query=None, + status=NodeStatus.VALID, + version="v1.0", + node=source, + columns=[Column(name="id", type=ct.BigIntType(), order=0)], + created_by_id=user.id, + ) + dim = Node( + name="test.v2_dim_tiny", + type=NodeType.DIMENSION, + created_by_id=user.id, + current_version="v1.0", + ) + dim_rev = NodeRevision( + name="test.v2_dim_tiny", + display_name="tiny dim", + type=NodeType.DIMENSION, + query="SELECT 1 AS id", + status=NodeStatus.VALID, + version="v1.0", + node=dim, + columns=[Column(name="id", type=ct.BigIntType(), order=0)], + created_by_id=user.id, + ) + session.add_all([source, source_rev, dim, dim_rev]) + await session.commit() + + # Construct a transient NodeRevision carrying the string-form required + # dimension (NodeRevisionBase doesn't have the field — it's on + # MetricNodeFields via CreateNode). v2 resolves these strings against + # the DB and flags any that don't match a real column. + child = NodeRevision( + name="test.v2_req_dim_child", + display_name="req dim child", + type=NodeType.METRIC, + query="SELECT COUNT(id) FROM test.v2_req_dim_parent", + status=NodeStatus.VALID, + required_dimensions=["test.v2_dim_tiny.ghost_col"], # type: ignore[list-item] + ) + validator = await validate_node_data_v2(child, session) + assert validator.status == NodeStatus.INVALID + assert any( + err.code == ErrorCode.INVALID_COLUMN and "required dimensions" in err.message + for err in validator.errors + ), [(e.code, e.message) for e in validator.errors] + assert any( + err.code == ErrorCode.INVALID_COLUMN and "required dimensions" in err.message + for err in validator.errors + ), [(e.code, e.message) for e in validator.errors] From dc827d9f02cf6baf5cb5e1f178656015aedeb986 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Tue, 21 Apr 2026 04:46:02 -0700 Subject: [PATCH 4/5] Fix tests --- .../internal/deployment/type_inference.py | 28 +++-- .../tests/api/nodes_update_test.py | 8 +- .../deployment/test_type_inference.py | 103 +++++++++++++++--- .../tests/internal/node_validation_test.py | 2 +- 4 files changed, 110 insertions(+), 31 deletions(-) diff --git a/datajunction-server/datajunction_server/internal/deployment/type_inference.py b/datajunction-server/datajunction_server/internal/deployment/type_inference.py index ebc6c4af4..3f02b9b96 100644 --- a/datajunction-server/datajunction_server/internal/deployment/type_inference.py +++ b/datajunction-server/datajunction_server/internal/deployment/type_inference.py @@ -357,14 +357,13 @@ def _collect_tables_from_relation( errors.extend(sub_errors) elif isinstance(node, ast.InlineTable): - # `VALUES (…) tab(c1, c2)` stores the `tab` alias on node.name, not - # node.alias. Accept both so the outer query can reference `tab.c1`. - if node.alias is not None: - alias = node.alias.name - elif node.name is not None and node.name.name: - alias = node.name.name - else: - alias = "__inline__" # pragma: no cover + # DJ's parser reaches this branch only for the non-parenthesized + # `VALUES (…) tab(c1, c2)` form, where the `tab` alias is parked on + # node.name. Parenthesized `(VALUES (…)) AS tab(c1, c2)` comes in as + # an ast.Query wrapping an InlineTable and is handled above. If that + # invariant ever changes, the name-access below raises rather than + # silently mis-aliasing — easier to diagnose than a defensive default. + alias = node.name.name inline_columns = _resolve_inline_table(ast.Query(select=node)) # type: ignore[arg-type] result[alias] = {name: typ for name, typ in inline_columns} @@ -613,10 +612,15 @@ def _resolve_projection_function_table( output: OutputColumns = [] for i, out_name in enumerate(col_names): - if is_posexplode and i == 0: - output.append((out_name, IntegerType())) - elif is_posexplode and i == 1 and element_types: - output.append((out_name, element_types[0])) + # Projection-form POSEXPLODE with a parenthesized alias list + # (`POSEXPLODE(arr) AS (pos, val)`) isn't accepted by DJ's grammar — + # the parser rejects the identifier-list alias on a non-Table + # expression. These two branches are kept for symmetry with the + # LATERAL VIEW form but aren't reachable in practice. + if is_posexplode and i == 0: # pragma: no cover + output.append((out_name, IntegerType())) # pragma: no cover + elif is_posexplode and i == 1 and element_types: # pragma: no cover + output.append((out_name, element_types[0])) # pragma: no cover elif not is_posexplode and i < len(element_types): output.append((out_name, element_types[i])) else: diff --git a/datajunction-server/tests/api/nodes_update_test.py b/datajunction-server/tests/api/nodes_update_test.py index 42b223591..44d721353 100644 --- a/datajunction-server/tests/api/nodes_update_test.py +++ b/datajunction-server/tests/api/nodes_update_test.py @@ -62,7 +62,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.national_level_agg", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -85,7 +85,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.regional_level_agg", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -107,7 +107,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.avg_repair_price", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -131,7 +131,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.regional_repair_efficiency", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, diff --git a/datajunction-server/tests/internal/deployment/test_type_inference.py b/datajunction-server/tests/internal/deployment/test_type_inference.py index b2d752dea..268cb0ba5 100644 --- a/datajunction-server/tests/internal/deployment/test_type_inference.py +++ b/datajunction-server/tests/internal/deployment/test_type_inference.py @@ -5,6 +5,8 @@ and returns the output column names + types without any DB calls. """ +import pytest + from datajunction_server.internal.deployment.type_inference import ( columns_signature_changed, validate_node_query, @@ -2422,10 +2424,12 @@ def test_multi_segment_dim_attribute_ref_does_not_trigger_namespace_error(self): class TestCoverageGaps: """Tests targeting specific uncovered branches in type_inference.py.""" - def test_inline_table_explicit_alias_on_node_alias(self): - """VALUES (…) AS t(...) form can park the alias on node.alias - rather than node.name. Exercises the `node.alias is not None` - branch of the InlineTable handler (line 363).""" + def test_inline_table_parens_with_alias_list_resolves_columns(self): + """(VALUES (1), (2)) AS t(x) — parenthesized-VALUES + alias list + currently parses with both node.alias and node.name None, so the + handler falls into the __inline__ default-alias branch. The test + still covers the end-to-end "outer query sees the aliased column" + behavior for this common shape.""" result = validate_node_query( "SELECT t.x FROM (VALUES (1), (2)) AS t(x)", {}, @@ -2433,10 +2437,82 @@ def test_inline_table_explicit_alias_on_node_alias(self): assert not result.errors, result.errors assert result.output_columns == [("x", IntegerType())] + @pytest.mark.xfail( + reason=( + "Spark accepts `SELECT POSEXPLODE(arr) AS (pos, val) FROM t` as " + "an inline generator function, but DJ's grammar rejects the " + "identifier-list alias on a non-Table expression. When the " + "grammar is relaxed for known table-valued functions, the " + "projection-form POSEXPLODE branches in " + "_resolve_projection_function_table become reachable and this " + "test flips to pass." + ), + strict=True, + ) + def test_projection_posexplode_alias_list_parses(self): + """Spark inline-generator POSEXPLODE with explicit (pos, val) alias + list in the SELECT projection.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT POSEXPLODE(arr) AS (pos, val) FROM src.s", + _col_map( + ("src.s", [("arr", ListType(element_type=IntegerType()))]), + ), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("pos", IntegerType()), + ("val", IntegerType()), + ] + + def test_lateral_element_types_surfaces_nested_scope_errors(self): + """When EXPLODE's argument is a Function wrapping an unresolved + sub-reference (e.g., IF(missing_col > 0, arr, arr)), the outer + function still types successfully but the inner TypeResolutionError + lands in scope.errors. Those are surfaced to the caller's errors + list so the real cause doesn't get swallowed.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT v.x FROM src.s " + "LATERAL VIEW EXPLODE(IF(missing_col > 0, arr, arr)) v AS x", + _col_map( + ("src.s", [("arr", ListType(element_type=IntegerType()))]), + ), + ) + assert any("missing_col" in msg for msg in result.errors), result.errors + + def test_cross_join_posexplode_yields_pos_and_element(self): + """`CROSS JOIN POSEXPLODE(arr) AS u(p, v)` — the FROM-clause + POSEXPLODE form. First alias is the integer position; second is + the element type.""" + from datajunction_server.sql.parsing.types import ListType + + result = validate_node_query( + "SELECT t.id, u.p, u.v FROM src.s t " + "CROSS JOIN POSEXPLODE(t.arr) AS u(p, v)", + _col_map( + ( + "src.s", + [ + ("id", IntegerType()), + ("arr", ListType(element_type=IntegerType())), + ], + ), + ), + ) + assert not result.errors, result.errors + assert result.output_columns == [ + ("id", IntegerType()), + ("p", IntegerType()), + ("v", IntegerType()), + ] + def test_cross_join_unnest_struct_array_unpacks_fields_in_from(self): - """Struct-unpacking in the FROM-clause FunctionTableExpression branch - (line 397): `CROSS JOIN UNNEST(array>) AS t(c1, c2)` - expands positionally to c1: a_type, c2: b_type.""" + """Struct-unpacking in the FROM-clause FunctionTableExpression handler: + `CROSS JOIN UNNEST(array>) AS t(c1, c2)` expands + positionally to c1: a_type, c2: b_type.""" from datajunction_server.sql.parsing.backends.antlr4 import parse_rule cells = parse_rule( @@ -2456,8 +2532,7 @@ def test_cross_join_unnest_struct_array_unpacks_fields_in_from(self): def test_lateral_view_explode_with_too_many_aliases_fills_unknown(self): """LATERAL VIEW EXPLODE(scalar_array) AS a, b, c — 3 aliases but - element is a single scalar. The extras (line 623) fall back to - UnknownType.""" + element is a single scalar. The extras fall back to UnknownType.""" from datajunction_server.sql.parsing.types import ListType result = validate_node_query( @@ -2472,7 +2547,7 @@ def test_lateral_view_explode_with_too_many_aliases_fills_unknown(self): def test_projection_explode_scalar_array_with_extra_aliases_unknown(self): """Projection `EXPLODE(arr) AS (a, b, c)` on a scalar-array column — - only one element type, extras (line 623) fall back to UnknownType.""" + only one element type, extras fall back to UnknownType.""" from datajunction_server.sql.parsing.types import ListType result = validate_node_query( @@ -2487,8 +2562,8 @@ def test_projection_explode_scalar_array_with_extra_aliases_unknown(self): def test_lateral_element_types_propagates_typeresolution_error(self): """When EXPLODE's argument references a nonexistent column, the - TypeResolutionError message flows into `errors` (lines 492→494) - so the real root cause surfaces, not just 'Unable to infer type'.""" + TypeResolutionError message flows into the caller's errors list so + the real root cause surfaces, not just 'Unable to infer type'.""" result = validate_node_query( "SELECT v.x FROM src.s LATERAL VIEW EXPLODE(missing_col) v AS x", _col_map(("src.s", [("id", IntegerType())])), @@ -2501,8 +2576,8 @@ def test_lateral_element_types_propagates_typeresolution_error(self): def test_single_segment_namespace_matching_parent_is_silent(self): """Single-segment namespace whose name matches a parent_map key (not a FROM table) should NOT trigger the 'references namespace' error — - the `if not any_prefix_is_parent` guard (line 900) short-circuits and - the fallback just returns UnknownType (line 906).""" + the prefix-is-parent guard short-circuits and we fall through to + UnknownType.""" result = validate_node_query( "SELECT src.foo_missing FROM other.t", _col_map( diff --git a/datajunction-server/tests/internal/node_validation_test.py b/datajunction-server/tests/internal/node_validation_test.py index b22cdf47a..df9d3f740 100644 --- a/datajunction-server/tests/internal/node_validation_test.py +++ b/datajunction-server/tests/internal/node_validation_test.py @@ -912,7 +912,7 @@ async def test_validate_node_data_v2_skips_parents_without_columns( user: User, ): """Parent nodes whose current revision has no columns are skipped when - building parent_columns_map (line 573→572).""" + building the parent_columns_map rather than creating empty entries.""" from datajunction_server.internal.validation import validate_node_data_v2 parent = Node( From 539363b242fd5f35a142a3f46645a240444f79a9 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Tue, 21 Apr 2026 12:38:01 -0700 Subject: [PATCH 5/5] Fix tests --- .../datajunction_server/database/node.py | 7 +- .../internal/deployment/type_inference.py | 30 ++--- .../tests/internal/node_validation_test.py | 113 ++++++++++++++++++ 3 files changed, 132 insertions(+), 18 deletions(-) diff --git a/datajunction-server/datajunction_server/database/node.py b/datajunction-server/datajunction_server/database/node.py index 9553e0869..0d503308e 100644 --- a/datajunction-server/datajunction_server/database/node.py +++ b/datajunction-server/datajunction_server/database/node.py @@ -1320,7 +1320,12 @@ def default_load_options(cls): ), joinedload(DimensionLink.node_revision), ), - selectinload(NodeRevision.required_dimensions), + selectinload(NodeRevision.required_dimensions).options( + # Column.node_revision back-ref is read by Column.full_name() + # during required-dimensions resolution; preload it so + # accessing it in async context doesn't trip MissingGreenlet. + joinedload(Column.node_revision).load_only(NodeRevision.name), + ), selectinload(NodeRevision.availability), # Load created_by for API responses (but noload in /sql/ endpoint's custom options) selectinload(NodeRevision.created_by), diff --git a/datajunction-server/datajunction_server/internal/deployment/type_inference.py b/datajunction-server/datajunction_server/internal/deployment/type_inference.py index 3f02b9b96..59e69b919 100644 --- a/datajunction-server/datajunction_server/internal/deployment/type_inference.py +++ b/datajunction-server/datajunction_server/internal/deployment/type_inference.py @@ -279,7 +279,7 @@ def _build_table_scope( errors.extend(errs) for idx, view in enumerate(select.lateral_views): scope.update( - _collect_lateral_view_columns(view, scope, idx=idx, errors=errors), + _collect_lateral_view_columns(view, scope, errors, idx=idx), ) return scope, errors @@ -413,20 +413,20 @@ def _collect_tables_from_relation( def _collect_lateral_view_columns( view: ast.LateralView, from_scope: TableScope, + errors: list[str], idx: int = 0, - errors: list[str] | None = None, ) -> TableScope: """Collect columns from a LATERAL VIEW (e.g., EXPLODE) expression. Resolves element types from the source column's ListType/MapType when possible, falls back to UnknownType otherwise. + ``errors`` receives any element-type-resolution errors (e.g., the EXPLODE + argument references a nonexistent column) so callers can surface them. + ``idx`` distinguishes multiple anonymous lateral views in the same SELECT (each default-aliased to ``__lateral__`` would otherwise collide and overwrite prior columns). - - ``errors`` receives any element-type-resolution errors (e.g., the EXPLODE - argument references a nonexistent column) so callers can surface them. """ func = view.func alias = func.alias.name if func.alias else f"__lateral_{idx}__" @@ -434,7 +434,7 @@ def _collect_lateral_view_columns( if not col_list: return {} - element_types = _resolve_lateral_element_types(func, from_scope, errors=errors) + element_types = _resolve_lateral_element_types(func, from_scope, errors) func_name = func.name.name.upper() if hasattr(func, "name") and func.name else "" is_posexplode = "POS" in func_name @@ -465,14 +465,14 @@ def _collect_lateral_view_columns( def _resolve_lateral_element_types( func: ast.FunctionTableExpression, from_scope: TableScope, - errors: list[str] | None = None, + errors: list[str], ) -> list[ColumnType]: """Resolve element types for an EXPLODE/UNNEST function argument. - When ``errors`` is provided and type resolution of the argument fails - (e.g., the referenced column doesn't exist), the specific error message is - appended so callers can surface the real cause rather than a downstream - "Unable to infer type" coming from the resulting UnknownType columns. + The ``errors`` list receives any resolution failures — both the specific + TypeResolutionError on the argument itself and any errors accumulated + inside the throwaway scope (e.g., unresolved sub-refs). All current + callers thread their own error list through, so it's required. """ if not func.args: return [] # pragma: no cover @@ -488,16 +488,12 @@ def _resolve_lateral_element_types( try: col_type = _resolve_expr_type(arg, scope) except TypeResolutionError as exc: - if errors is not None: - errors.append(str(exc)) + errors.append(str(exc)) return [] except Exception: # pragma: no cover return [] - if errors is not None: - # Also surface any errors that accumulated inside the throwaway scope - # (e.g., unresolved sub-refs) rather than silently dropping them. - errors.extend(scope.errors) + errors.extend(scope.errors) if isinstance(col_type, ListType): return [col_type.element.type] diff --git a/datajunction-server/tests/internal/node_validation_test.py b/datajunction-server/tests/internal/node_validation_test.py index df9d3f740..5711a6ea7 100644 --- a/datajunction-server/tests/internal/node_validation_test.py +++ b/datajunction-server/tests/internal/node_validation_test.py @@ -1017,3 +1017,116 @@ async def test_validate_node_data_v2_flags_invalid_required_dimensions( err.code == ErrorCode.INVALID_COLUMN and "required dimensions" in err.message for err in validator.errors ), [(e.code, e.message) for e in validator.errors] + + +@pytest.mark.asyncio +async def test_validate_node_data_v2_cross_fact_metrics_no_shared_dims( + session: AsyncSession, + user: User, +): + """A derived metric summing two base metrics whose underlying sources + share no dimensions surfaces an INVALID_PARENT error. Covers the + cross-fact safety check in validate_node_data_v2.""" + from datajunction_server.errors import ErrorCode + from datajunction_server.internal.validation import validate_node_data_v2 + + # Two independent source nodes — no shared dim. + src_a = Node( + name="test.v2_src_a", + type=NodeType.SOURCE, + created_by_id=user.id, + current_version="v1.0", + ) + src_a_rev = NodeRevision( + name="test.v2_src_a", + display_name="src a", + type=NodeType.SOURCE, + query=None, + status=NodeStatus.VALID, + version="v1.0", + node=src_a, + columns=[Column(name="amount", type=ct.DoubleType(), order=0)], + created_by_id=user.id, + ) + src_b = Node( + name="test.v2_src_b", + type=NodeType.SOURCE, + created_by_id=user.id, + current_version="v1.0", + ) + src_b_rev = NodeRevision( + name="test.v2_src_b", + display_name="src b", + type=NodeType.SOURCE, + query=None, + status=NodeStatus.VALID, + version="v1.0", + node=src_b, + columns=[Column(name="cost", type=ct.DoubleType(), order=0)], + created_by_id=user.id, + ) + # Two metrics, one per source. + metric_a = Node( + name="test.v2_metric_a", + type=NodeType.METRIC, + created_by_id=user.id, + current_version="v1.0", + ) + metric_a_rev = NodeRevision( + name="test.v2_metric_a", + display_name="metric a", + type=NodeType.METRIC, + query="SELECT SUM(amount) FROM test.v2_src_a", + status=NodeStatus.VALID, + version="v1.0", + node=metric_a, + columns=[Column(name="test_DOT_v2_metric_a", type=ct.DoubleType(), order=0)], + parents=[src_a], + created_by_id=user.id, + ) + metric_b = Node( + name="test.v2_metric_b", + type=NodeType.METRIC, + created_by_id=user.id, + current_version="v1.0", + ) + metric_b_rev = NodeRevision( + name="test.v2_metric_b", + display_name="metric b", + type=NodeType.METRIC, + query="SELECT SUM(cost) FROM test.v2_src_b", + status=NodeStatus.VALID, + version="v1.0", + node=metric_b, + columns=[Column(name="test_DOT_v2_metric_b", type=ct.DoubleType(), order=0)], + parents=[src_b], + created_by_id=user.id, + ) + session.add_all( + [ + src_a, + src_a_rev, + src_b, + src_b_rev, + metric_a, + metric_a_rev, + metric_b, + metric_b_rev, + ], + ) + await session.commit() + + # Derived metric referencing both base metrics — no shared dim. + data = NodeRevisionBase( + name="test.v2_derived_cross_fact", + display_name="derived cross fact", + type=NodeType.METRIC, + query="SELECT test.v2_metric_a + test.v2_metric_b", + mode="published", + ) + validator = await validate_node_data_v2(data, session) + assert validator.status == NodeStatus.INVALID + assert any( + err.code == ErrorCode.INVALID_PARENT and "no shared" in err.message.lower() + for err in validator.errors + ), [(e.code, e.message) for e in validator.errors]