From c8dbfc63a892f8e90c2374cc905742cf77378eb8 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Thu, 23 Apr 2026 16:45:17 -0700 Subject: [PATCH 1/2] Make sure that deployments also populate the frozen measures and derived query fields --- .../internal/deployment/orchestrator.py | 61 +++++ .../datajunction_server/internal/nodes.py | 210 ++++++++++++--- .../tests/api/deployments_test.py | 59 ++++ .../nodes/derive_frozen_measures_test.py | 253 ++++++++++++++++++ 4 files changed, 549 insertions(+), 34 deletions(-) create mode 100644 datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py diff --git a/datajunction-server/datajunction_server/internal/deployment/orchestrator.py b/datajunction-server/datajunction_server/internal/deployment/orchestrator.py index 131a97806..b61ee9ce0 100644 --- a/datajunction-server/datajunction_server/internal/deployment/orchestrator.py +++ b/datajunction-server/datajunction_server/internal/deployment/orchestrator.py @@ -54,6 +54,7 @@ from datajunction_server.internal.history import EntityType from datajunction_server.sql.dag import get_metric_parents_map from datajunction_server.internal.nodes import ( + derive_frozen_measures_bulk, hard_delete_node, ) from datajunction_server.models.base import labelize @@ -1189,6 +1190,14 @@ async def _execute_deployment_plan(self, plan: DeploymentPlan) -> list: self.deployed_results.extend(deployed_cubes) await self._update_deployment_status() + # Derive frozen measures for deployed metrics inline (not + # background) so derived_expression and FrozenMeasure rows + # are atomic with the rest of the deployment: committed or + # rolled back together, and dry-run exercises derivation. + with timer.phase("derive measures") as p: + derived = await self._derive_measures_for_deployed_metrics() + p.append(f"{derived} metrics") + # Run impact propagation before deletions so deleted nodes' # children are still reachable via NodeRelationship. changed_names = { @@ -1232,6 +1241,58 @@ async def _execute_deployment_plan(self, plan: DeploymentPlan) -> list: await self.session.commit() return downstream + async def _derive_measures_for_deployed_metrics(self) -> int: + """Run ``derive_frozen_measures`` inline for every metric that was + created or updated in this deployment. + + Single-node create/update schedules this as a FastAPI background task; + bulk deployment has no such background-task machinery available for + a durable post-commit gap (timeouts, restarts, disconnected clients + drop the task). Running inline inside the orchestrator's SAVEPOINT + makes derivation atomic with the rest of the deployment: + + * wet-run: committed with the deployment + * dry-run: rolled back along with the SAVEPOINT + + Order matters: a derived metric's extractor reads its base metrics' + already-derived measures. ``self.deployed_results`` is appended in the + order ``_deploy_nodes`` processes topological levels, so iterating + it in append order produces correct base-before-derived ordering + within a single deployment. + """ + metric_spec_names = { + spec.rendered_name + for spec in self.deployment_spec.nodes + if spec.node_type == NodeType.METRIC + } + touched_metric_names = [ + r.name + for r in self.deployed_results + if r.deploy_type == DeploymentResult.Type.NODE + and r.name in metric_spec_names + and r.status != DeploymentResult.Status.SKIPPED + and r.operation + in ( + DeploymentResult.Operation.CREATE, + DeploymentResult.Operation.UPDATE, + ) + ] + if not touched_metric_names: + return 0 + + nodes = await Node.get_by_names( + self.session, + touched_metric_names, + options=[joinedload(Node.current)], + ) + # Iterate in the deployed order (base metrics before derived metrics) + # so a derived metric's extractor sees its upstream base metrics' + # already-derived measures in-session. + name_to_node = {n.name: n for n in nodes} + revision_ids = [name_to_node[name].current.id for name in touched_metric_names] + await derive_frozen_measures_bulk(self.session, revision_ids) + return len(touched_metric_names) + async def _deploy_nodes( self, plan: DeploymentPlan, diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index 6737581f0..9f05d7c9a 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -40,6 +40,7 @@ from datajunction_server.database.node import ( MissingParent, Node, + NodeMissingParents, NodeRelationship, NodeRevision, ) @@ -663,42 +664,188 @@ async def derive_frozen_measures(node_revision_id: int) -> list[FrozenMeasure]: For base metrics: extracts aggregation components from the metric query. For derived metrics: collects components from referenced base metrics. + + Used by the background-task call sites (``create_a_node_revision``, + ``revalidate_node``) where the HTTP response has already returned and + the derivation runs asynchronously. Opens its own session and commits + on completion. The deployment path uses ``derive_frozen_measures_bulk`` + instead to batch DB work across all metrics in a single transaction. """ async with session_context() as session: - node_revision = cast( - NodeRevision, - await NodeRevision.get_by_id( - session=session, - node_revision_id=node_revision_id, - options=[ - joinedload(NodeRevision.parents).joinedload(Node.current), - ], - ), - ) - if not node_revision: - return [] # pragma: no cover + result = await _derive_frozen_measures_impl(node_revision_id, session) + await session.commit() + return result - frozen_measures: list[FrozenMeasure] = [] - if not node_revision.parents: - return frozen_measures # pragma: no cover - # Extract components using the node revision ID - # The extractor will automatically detect base vs derived metrics - extractor = MetricComponentExtractor(node_revision.id) - measures, derived_sql = await extractor.extract(session) +async def _derive_frozen_measures_impl( + node_revision_id: int, + session: AsyncSession, +) -> list[FrozenMeasure]: + """Core derivation logic. Does not commit — caller is responsible.""" + node_revision = cast( + NodeRevision, + await NodeRevision.get_by_id( + session=session, + node_revision_id=node_revision_id, + options=[ + joinedload(NodeRevision.parents).joinedload(Node.current), + ], + ), + ) + if not node_revision: + return [] # pragma: no cover - node_revision.derived_expression = str(derived_sql) + frozen_measures: list[FrozenMeasure] = [] + if not node_revision.parents: + return frozen_measures # pragma: no cover - # Use the first direct parent for the frozen measure upstream_revision_id - await session.refresh(node_revision.parents[0], ["current"]) - upstream_revision_id = node_revision.parents[0].current.id + # Extract components using the node revision ID. + # The extractor auto-detects base vs derived metrics. + extractor = MetricComponentExtractor(node_revision.id) + measures, derived_sql = await extractor.extract(session) - for measure in measures: - frozen_measure = await FrozenMeasure.get_by_name( - session=session, + node_revision.derived_expression = str(derived_sql) + + # Use the first direct parent for the frozen measure upstream_revision_id + await session.refresh(node_revision.parents[0], ["current"]) + upstream_revision_id = node_revision.parents[0].current.id + + for measure in measures: + frozen_measure = await FrozenMeasure.get_by_name( + session=session, + name=measure.name, + ) + if not frozen_measure and measure.aggregation: + frozen_measure = FrozenMeasure( name=measure.name, + upstream_revision_id=upstream_revision_id, + expression=measure.expression, + aggregation=measure.aggregation, + rule=measure.rule, + used_by_node_revisions=[], ) - if not frozen_measure and measure.aggregation: + session.add(frozen_measure) + if frozen_measure: + frozen_measure.used_by_node_revisions.append(node_revision) + frozen_measures.append(frozen_measure) + return frozen_measures + + +async def derive_frozen_measures_bulk( + session: AsyncSession, + node_revision_ids: list[int], +) -> None: + """Batched equivalent of `derive_frozen_measures` for many metric revisions. + + The per-revision entry point issues N × (2 + avg_measures) DB queries for a + deployment of N metrics: one `NodeRevision.get_by_id`, one + `_load_metric_data`, and one `FrozenMeasure.get_by_name` per measure. This + bulk path replaces all of that with: + + * one query to load all target revisions + their parent chain eagerly, + * zero-DB `MetricComponentExtractor.extract` calls (via the extractor's + `nodes_cache` / `parent_map` / `parsed_query_cache` params), and + * one batch `SELECT` against FrozenMeasure.name IN (...). + + Caller owns the transaction and commit; used by the deployment + orchestrator inside its SAVEPOINT so derivation is atomic with the rest + of the deployment (and rolled back naturally on dry-run). + """ + if not node_revision_ids: + return + + # 1. Load target revisions with parents + parent.current + grandparent chain. + # Two selectinload levels cover the common metric-parent chain depths + # (base metrics: depth 0; single-level derived: depth 1). Deeper chains + # are expanded iteratively below. + revisions = ( + ( + await session.execute( + select(NodeRevision) + .where(NodeRevision.id.in_(node_revision_ids)) + .options( + joinedload(NodeRevision.node), + selectinload(NodeRevision.parents) + .joinedload(Node.current) + .options( + selectinload(NodeRevision.parents).joinedload(Node.current), + ), + ), + ) + ) + .unique() + .scalars() + .all() + ) + if not revisions: # pragma: no cover + return + + # 2. Build the caches MetricComponentExtractor.extract expects. The + # two-level selectinload above is enough for the common depth-2 + # metric chain (deployed derived metric → base metric → source); + # deeper chains rely on SQLAlchemy's session identity map to + # populate any already-known intermediates. + nodes_cache: dict[str, Node] = {} + parent_map: dict[str, list[str]] = {} + for rev in revisions: + nodes_cache[rev.node.name] = rev.node + parent_map[rev.node.name] = [p.name for p in rev.parents] + for parent in rev.parents: + nodes_cache.setdefault(parent.name, parent) + if ( + parent.current + and parent.type == NodeType.METRIC + and parent.name not in parent_map + ): + parent_map[parent.name] = [gp.name for gp in parent.current.parents] + for grandparent in parent.current.parents: + nodes_cache.setdefault(grandparent.name, grandparent) + + parsed_query_cache: dict[str, ast.Query] = {} + + # 3. Per-metric extract with caches — zero DB calls in this loop. + extraction_results: list[tuple[NodeRevision, list]] = [] + for rev in revisions: + extractor = MetricComponentExtractor(rev.id) + measures, derived_sql = await extractor.extract( + session, + nodes_cache=nodes_cache, + parent_map=parent_map, + metric_node=rev.node, + parsed_query_cache=parsed_query_cache, + ) + rev.derived_expression = str(derived_sql) + extraction_results.append((rev, measures)) + + # 4. Batch-fetch all existing FrozenMeasures matching any extracted name. + all_measure_names = {m.name for _, measures in extraction_results for m in measures} + fm_by_name: dict[str, FrozenMeasure] = {} + if all_measure_names: + existing_fms = ( + ( + await session.execute( + select(FrozenMeasure).where( + FrozenMeasure.name.in_(list(all_measure_names)), + ), + ) + ) + .scalars() + .all() + ) + fm_by_name = {fm.name: fm for fm in existing_fms} + + # 5. Link / create in-memory. New FrozenMeasures added to fm_by_name so a + # later metric in the same deployment referencing the same measure + # reuses the freshly-added row instead of creating a duplicate. + for rev, measures in extraction_results: + if not rev.parents: + continue # pragma: no cover + upstream_revision_id = rev.parents[0].current.id + for measure in measures: + frozen_measure = fm_by_name.get(measure.name) + if frozen_measure is None: + if not measure.aggregation: + continue frozen_measure = FrozenMeasure( name=measure.name, upstream_revision_id=upstream_revision_id, @@ -708,11 +855,8 @@ async def derive_frozen_measures(node_revision_id: int) -> list[FrozenMeasure]: used_by_node_revisions=[], ) session.add(frozen_measure) - if frozen_measure: - frozen_measure.used_by_node_revisions.append(node_revision) - frozen_measures.append(frozen_measure) - await session.commit() - return frozen_measures + fm_by_name[measure.name] = frozen_measure + frozen_measure.used_by_node_revisions.append(rev) async def save_node( @@ -2759,8 +2903,6 @@ async def delete_orphaned_missing_parents(session: AsyncSession) -> None: This should be called after operations that remove node references (like hard delete or deactivate) to clean up orphaned MissingParents. """ - from datajunction_server.database.node import NodeMissingParents - orphaned_missing_parents = ( ( await session.execute( diff --git a/datajunction-server/tests/api/deployments_test.py b/datajunction-server/tests/api/deployments_test.py index f3ddade8e..bcdf68668 100644 --- a/datajunction-server/tests/api/deployments_test.py +++ b/datajunction-server/tests/api/deployments_test.py @@ -1777,6 +1777,65 @@ async def test_deploy_metric_with_update( "status": "success", } + @pytest.mark.asyncio + async def test_deploy_populates_derived_expression_and_measures( + self, + session, + client, + default_hard_hats, + default_hard_hat, + default_us_states, + default_us_state, + default_avg_length_of_employment, + ): + """ + Deploying a metric must populate NodeRevision.derived_expression and + the associated FrozenMeasure rows inline (not via background task) so + the deployment result is atomically consistent: when the deployment + reports success, measure derivation has already happened. + + Regression test for the pre-cutover gap where single-node create used + a FastAPI BackgroundTask but bulk deployment had no equivalent path, + leaving metrics with derived_expression = NULL after deployment. + """ + from sqlalchemy.orm import joinedload, selectinload + from datajunction_server.database.node import Node, NodeRevision + + namespace = "derive_measures" + data = await deploy_and_wait( + client, + DeploymentSpec( + namespace=namespace, + nodes=[ + default_hard_hats, + default_hard_hat, + default_us_states, + default_us_state, + default_avg_length_of_employment, + ], + ), + ) + assert data["status"] == "success" + + metric = await Node.get_by_name( + session, + f"{namespace}.default.avg_length_of_employment", + options=[ + joinedload(Node.current).options( + selectinload(NodeRevision.frozen_measures), + ), + ], + ) + assert metric is not None + assert metric.current is not None + assert metric.current.derived_expression is not None, ( + "derived_expression should be populated inline by the orchestrator — " + "a NULL value means the bulk-deployment path is skipping derivation." + ) + assert len(metric.current.frozen_measures) > 0, ( + "at least one FrozenMeasure row should be linked after deployment" + ) + @pytest.mark.asyncio async def test_deploy_cube_with_update( self, diff --git a/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py b/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py new file mode 100644 index 000000000..22a64784c --- /dev/null +++ b/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py @@ -0,0 +1,253 @@ +""" +Unit tests for ``derive_frozen_measures_bulk`` — the batched derivation path +used by the deployment orchestrator. Exercises the cache-construction branches +(derived-metric expansion, deep-chain iterative expansion) and edge cases +(empty list, shared measures across metrics) that aren't reachable through +the per-metric ``derive_frozen_measures`` entry point. +""" + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +import datajunction_server.sql.parsing.types as ct +from datajunction_server.database.column import Column +from datajunction_server.database.measure import FrozenMeasure +from datajunction_server.database.node import Node, NodeRevision +from datajunction_server.database.user import OAuthProvider, User +from datajunction_server.internal.nodes import derive_frozen_measures_bulk +from datajunction_server.models.node import NodeStatus +from datajunction_server.models.node_type import NodeType + + +@pytest_asyncio.fixture +async def user(session: AsyncSession) -> User: + u = User(username="derive_fm_test_user", oauth_provider=OAuthProvider.BASIC) + session.add(u) + await session.commit() + return u + + +async def _make_source( + session: AsyncSession, + user: User, + name: str, + columns: list[Column], +) -> Node: + node = Node( + name=name, + type=NodeType.SOURCE, + current_version="v1.0", + created_by_id=user.id, + ) + rev = NodeRevision( + node=node, + name=name, + type=NodeType.SOURCE, + version="v1.0", + query=None, + status=NodeStatus.VALID, + columns=columns, + created_by_id=user.id, + ) + session.add_all([node, rev]) + await session.commit() + return node + + +async def _make_metric( + session: AsyncSession, + user: User, + name: str, + query: str, + parents: list[Node], +) -> Node: + node = Node( + name=name, + type=NodeType.METRIC, + current_version="v1.0", + created_by_id=user.id, + ) + rev = NodeRevision( + node=node, + name=name, + type=NodeType.METRIC, + version="v1.0", + query=query, + status=NodeStatus.VALID, + parents=parents, + created_by_id=user.id, + ) + session.add_all([node, rev]) + await session.commit() + return node + + +@pytest.mark.asyncio +async def test_empty_list_is_noop(session: AsyncSession): + """An empty input returns immediately without touching the session.""" + before = (await session.execute(select(FrozenMeasure))).scalars().all() + await derive_frozen_measures_bulk(session, []) + after = (await session.execute(select(FrozenMeasure))).scalars().all() + assert len(after) == len(before) + + +@pytest.mark.asyncio +async def test_base_metric_populates_derived_expression_and_measure( + session: AsyncSession, + user: User, +): + """Single base metric → one FrozenMeasure keyed on its aggregation, plus + a derived_expression on the metric revision.""" + src = await _make_source( + session, + user, + "src", + [Column(name="amount", type=ct.DoubleType(), order=0)], + ) + metric = await _make_metric( + session, + user, + "m.total_amount", + "SELECT SUM(amount) FROM src", + [src], + ) + await session.refresh(metric, ["current"]) + await derive_frozen_measures_bulk(session, [metric.current.id]) + await session.commit() + + await session.refresh(metric.current, ["frozen_measures"]) + assert metric.current.derived_expression is not None + assert len(metric.current.frozen_measures) >= 1 + assert any(fm.aggregation == "SUM" for fm in metric.current.frozen_measures) + + +@pytest.mark.asyncio +async def test_derived_metric_expands_parent_cache( + session: AsyncSession, + user: User, +): + """A derived metric deployed on its own (parent metric pre-existing) + exercises the first-loop cache expansion that walks from the deployed + revision into its metric parents and their grandparent sources — the + lines that populate parent_map entries for parent metrics and seed + nodes_cache with the grandparent sources.""" + src = await _make_source( + session, + user, + "src2", + [Column(name="amount", type=ct.DoubleType(), order=0)], + ) + base = await _make_metric( + session, + user, + "m.base_total", + "SELECT SUM(amount) FROM src2", + [src], + ) + derived = await _make_metric( + session, + user, + "m.derived_double", + "SELECT m.base_total * 2", + [base], + ) + await session.refresh(base, ["current"]) + await session.refresh(derived, ["current"]) + # Pre-derive the base so the derived metric's extract can resolve it. + await derive_frozen_measures_bulk(session, [base.current.id]) + # Now derive the derived metric in isolation — its parent chain must be + # reconstructed from the eager load + cache-build loop. + await derive_frozen_measures_bulk(session, [derived.current.id]) + await session.commit() + + await session.refresh(derived.current, ["frozen_measures"]) + assert derived.current.derived_expression is not None + # The derived metric inherits its base's components. + assert len(derived.current.frozen_measures) >= 1 + + +@pytest.mark.asyncio +async def test_deep_derived_chain_resolves( + session: AsyncSession, + user: User, +): + """A 3-level derived metric chain (C derived from B derived from A + derived from source S). The deployed metric C's parents are fully + resolved through the 2-level eager load plus SQLAlchemy's session + identity map (which already has A cached from the earlier bulk + derive). Confirms depth-3 derivation works end-to-end.""" + src = await _make_source( + session, + user, + "src3", + [Column(name="amount", type=ct.DoubleType(), order=0)], + ) + a = await _make_metric( + session, + user, + "m.a_total", + "SELECT SUM(amount) FROM src3", + [src], + ) + b = await _make_metric(session, user, "m.b_scaled", "SELECT m.a_total * 2", [a]) + c = await _make_metric(session, user, "m.c_offset", "SELECT m.b_scaled + 1", [b]) + await session.refresh(a, ["current"]) + await session.refresh(b, ["current"]) + await session.refresh(c, ["current"]) + # Derive A and B first so the chain has resolved measures for C's extractor. + await derive_frozen_measures_bulk(session, [a.current.id]) + await derive_frozen_measures_bulk(session, [b.current.id]) + # Deriving C on its own must iteratively expand the parent chain. + await derive_frozen_measures_bulk(session, [c.current.id]) + await session.commit() + + await session.refresh(c.current, ["frozen_measures"]) + assert c.current.derived_expression is not None + assert len(c.current.frozen_measures) >= 1 + + +@pytest.mark.asyncio +async def test_two_metrics_in_one_batch_both_get_measures( + session: AsyncSession, + user: User, +): + """Batch derivation of multiple metrics in a single call wires each + metric's FrozenMeasure links independently — both get populated, and + the batch-fetch/link logic iterates each (rev, measures) pair.""" + src = await _make_source( + session, + user, + "src4", + [ + Column(name="amount", type=ct.DoubleType(), order=0), + Column(name="cost", type=ct.DoubleType(), order=1), + ], + ) + m1 = await _make_metric( + session, + user, + "m.sum_a", + "SELECT SUM(amount) FROM src4", + [src], + ) + m2 = await _make_metric( + session, + user, + "m.sum_c", + "SELECT SUM(cost) FROM src4", + [src], + ) + await session.refresh(m1, ["current"]) + await session.refresh(m2, ["current"]) + + await derive_frozen_measures_bulk(session, [m1.current.id, m2.current.id]) + await session.commit() + + await session.refresh(m1.current, ["frozen_measures"]) + await session.refresh(m2.current, ["frozen_measures"]) + assert m1.current.derived_expression is not None + assert m2.current.derived_expression is not None + assert len(m1.current.frozen_measures) >= 1 + assert len(m2.current.frozen_measures) >= 1 From 0a58501746491a8121269a7a08d5b1a6f55d9cac Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Fri, 24 Apr 2026 03:26:17 -0700 Subject: [PATCH 2/2] Fix --- .../nodes/derive_frozen_measures_test.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py b/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py index 22a64784c..292be4521 100644 --- a/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py +++ b/datajunction-server/tests/internal/nodes/derive_frozen_measures_test.py @@ -208,6 +208,51 @@ async def test_deep_derived_chain_resolves( assert len(c.current.frozen_measures) >= 1 +@pytest.mark.asyncio +async def test_distinct_aggregation_component_is_skipped( + session: AsyncSession, + user: User, +): + """``MetricComponentExtractor`` emits ``aggregation=None`` for DISTINCT + aggregations because they can't be pre-aggregated (see + ``sql/decompose.py``). The link/create loop must not create a + FrozenMeasure row in that case — it should ``continue`` past the + component. This test deploys a ``COUNT(DISTINCT ...)`` metric and + confirms derivation completes without crashing and without creating + rogue rows for the distinct component.""" + src = await _make_source( + session, + user, + "src_distinct", + [Column(name="user_id", type=ct.IntegerType(), order=0)], + ) + metric = await _make_metric( + session, + user, + "m.unique_users", + "SELECT COUNT(DISTINCT user_id) FROM src_distinct", + [src], + ) + await session.refresh(metric, ["current"]) + before_ids = { + fm.id for fm in (await session.execute(select(FrozenMeasure))).scalars().all() + } + + await derive_frozen_measures_bulk(session, [metric.current.id]) + await session.commit() + + after = (await session.execute(select(FrozenMeasure))).scalars().all() + new_fms = [fm for fm in after if fm.id not in before_ids] + # Every new FrozenMeasure row must have an aggregation set — the + # DISTINCT component should have hit the `continue` path. + assert all(fm.aggregation for fm in new_fms), ( + f"non-aggregation FrozenMeasure leaked through: " + f"{[(fm.name, fm.aggregation) for fm in new_fms if not fm.aggregation]}" + ) + # Derivation still completed for the metric itself. + assert metric.current.derived_expression is not None + + @pytest.mark.asyncio async def test_two_metrics_in_one_batch_both_get_measures( session: AsyncSession,