Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from datajunction_server.internal.history import EntityType
from datajunction_server.sql.dag import get_metric_parents_map
from datajunction_server.internal.nodes import (
derive_frozen_measures_bulk,
hard_delete_node,
)
from datajunction_server.models.base import labelize
Expand Down Expand Up @@ -1189,6 +1190,14 @@ async def _execute_deployment_plan(self, plan: DeploymentPlan) -> list:
self.deployed_results.extend(deployed_cubes)
await self._update_deployment_status()

# Derive frozen measures for deployed metrics inline (not
# background) so derived_expression and FrozenMeasure rows
# are atomic with the rest of the deployment: committed or
# rolled back together, and dry-run exercises derivation.
with timer.phase("derive measures") as p:
derived = await self._derive_measures_for_deployed_metrics()
p.append(f"{derived} metrics")

# Run impact propagation before deletions so deleted nodes'
# children are still reachable via NodeRelationship.
changed_names = {
Expand Down Expand Up @@ -1232,6 +1241,58 @@ async def _execute_deployment_plan(self, plan: DeploymentPlan) -> list:
await self.session.commit()
return downstream

async def _derive_measures_for_deployed_metrics(self) -> int:
"""Run ``derive_frozen_measures`` inline for every metric that was
created or updated in this deployment.

Single-node create/update schedules this as a FastAPI background task;
bulk deployment has no such background-task machinery available for
a durable post-commit gap (timeouts, restarts, disconnected clients
drop the task). Running inline inside the orchestrator's SAVEPOINT
makes derivation atomic with the rest of the deployment:

* wet-run: committed with the deployment
* dry-run: rolled back along with the SAVEPOINT

Order matters: a derived metric's extractor reads its base metrics'
already-derived measures. ``self.deployed_results`` is appended in the
order ``_deploy_nodes`` processes topological levels, so iterating
it in append order produces correct base-before-derived ordering
within a single deployment.
"""
metric_spec_names = {
spec.rendered_name
for spec in self.deployment_spec.nodes
if spec.node_type == NodeType.METRIC
}
touched_metric_names = [
r.name
for r in self.deployed_results
if r.deploy_type == DeploymentResult.Type.NODE
and r.name in metric_spec_names
and r.status != DeploymentResult.Status.SKIPPED
and r.operation
in (
DeploymentResult.Operation.CREATE,
DeploymentResult.Operation.UPDATE,
)
]
if not touched_metric_names:
return 0

nodes = await Node.get_by_names(
self.session,
touched_metric_names,
options=[joinedload(Node.current)],
)
# Iterate in the deployed order (base metrics before derived metrics)
# so a derived metric's extractor sees its upstream base metrics'
# already-derived measures in-session.
name_to_node = {n.name: n for n in nodes}
revision_ids = [name_to_node[name].current.id for name in touched_metric_names]
await derive_frozen_measures_bulk(self.session, revision_ids)
return len(touched_metric_names)

async def _deploy_nodes(
self,
plan: DeploymentPlan,
Expand Down
210 changes: 176 additions & 34 deletions datajunction-server/datajunction_server/internal/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from datajunction_server.database.node import (
MissingParent,
Node,
NodeMissingParents,
NodeRelationship,
NodeRevision,
)
Expand Down Expand Up @@ -663,42 +664,188 @@ async def derive_frozen_measures(node_revision_id: int) -> list[FrozenMeasure]:

For base metrics: extracts aggregation components from the metric query.
For derived metrics: collects components from referenced base metrics.

Used by the background-task call sites (``create_a_node_revision``,
``revalidate_node``) where the HTTP response has already returned and
the derivation runs asynchronously. Opens its own session and commits
on completion. The deployment path uses ``derive_frozen_measures_bulk``
instead to batch DB work across all metrics in a single transaction.
"""
async with session_context() as session:
node_revision = cast(
NodeRevision,
await NodeRevision.get_by_id(
session=session,
node_revision_id=node_revision_id,
options=[
joinedload(NodeRevision.parents).joinedload(Node.current),
],
),
)
if not node_revision:
return [] # pragma: no cover
result = await _derive_frozen_measures_impl(node_revision_id, session)
await session.commit()
return result

frozen_measures: list[FrozenMeasure] = []
if not node_revision.parents:
return frozen_measures # pragma: no cover

# Extract components using the node revision ID
# The extractor will automatically detect base vs derived metrics
extractor = MetricComponentExtractor(node_revision.id)
measures, derived_sql = await extractor.extract(session)
async def _derive_frozen_measures_impl(
node_revision_id: int,
session: AsyncSession,
) -> list[FrozenMeasure]:
"""Core derivation logic. Does not commit — caller is responsible."""
node_revision = cast(
NodeRevision,
await NodeRevision.get_by_id(
session=session,
node_revision_id=node_revision_id,
options=[
joinedload(NodeRevision.parents).joinedload(Node.current),
],
),
)
if not node_revision:
return [] # pragma: no cover

node_revision.derived_expression = str(derived_sql)
frozen_measures: list[FrozenMeasure] = []
if not node_revision.parents:
return frozen_measures # pragma: no cover

# Use the first direct parent for the frozen measure upstream_revision_id
await session.refresh(node_revision.parents[0], ["current"])
upstream_revision_id = node_revision.parents[0].current.id
# Extract components using the node revision ID.
# The extractor auto-detects base vs derived metrics.
extractor = MetricComponentExtractor(node_revision.id)
measures, derived_sql = await extractor.extract(session)

for measure in measures:
frozen_measure = await FrozenMeasure.get_by_name(
session=session,
node_revision.derived_expression = str(derived_sql)

# Use the first direct parent for the frozen measure upstream_revision_id
await session.refresh(node_revision.parents[0], ["current"])
upstream_revision_id = node_revision.parents[0].current.id

for measure in measures:
frozen_measure = await FrozenMeasure.get_by_name(
session=session,
name=measure.name,
)
if not frozen_measure and measure.aggregation:
frozen_measure = FrozenMeasure(
name=measure.name,
upstream_revision_id=upstream_revision_id,
expression=measure.expression,
aggregation=measure.aggregation,
rule=measure.rule,
used_by_node_revisions=[],
)
if not frozen_measure and measure.aggregation:
session.add(frozen_measure)
if frozen_measure:
frozen_measure.used_by_node_revisions.append(node_revision)
frozen_measures.append(frozen_measure)
return frozen_measures


async def derive_frozen_measures_bulk(
session: AsyncSession,
node_revision_ids: list[int],
) -> None:
"""Batched equivalent of `derive_frozen_measures` for many metric revisions.

The per-revision entry point issues N × (2 + avg_measures) DB queries for a
deployment of N metrics: one `NodeRevision.get_by_id`, one
`_load_metric_data`, and one `FrozenMeasure.get_by_name` per measure. This
bulk path replaces all of that with:

* one query to load all target revisions + their parent chain eagerly,
* zero-DB `MetricComponentExtractor.extract` calls (via the extractor's
`nodes_cache` / `parent_map` / `parsed_query_cache` params), and
* one batch `SELECT` against FrozenMeasure.name IN (...).

Caller owns the transaction and commit; used by the deployment
orchestrator inside its SAVEPOINT so derivation is atomic with the rest
of the deployment (and rolled back naturally on dry-run).
"""
if not node_revision_ids:
return

# 1. Load target revisions with parents + parent.current + grandparent chain.
# Two selectinload levels cover the common metric-parent chain depths
# (base metrics: depth 0; single-level derived: depth 1). Deeper chains
# are expanded iteratively below.
revisions = (
(
await session.execute(
select(NodeRevision)
.where(NodeRevision.id.in_(node_revision_ids))
.options(
joinedload(NodeRevision.node),
selectinload(NodeRevision.parents)
.joinedload(Node.current)
.options(
selectinload(NodeRevision.parents).joinedload(Node.current),
),
),
)
)
.unique()
.scalars()
.all()
)
if not revisions: # pragma: no cover
return

# 2. Build the caches MetricComponentExtractor.extract expects. The
# two-level selectinload above is enough for the common depth-2
# metric chain (deployed derived metric → base metric → source);
# deeper chains rely on SQLAlchemy's session identity map to
# populate any already-known intermediates.
nodes_cache: dict[str, Node] = {}
parent_map: dict[str, list[str]] = {}
for rev in revisions:
nodes_cache[rev.node.name] = rev.node
parent_map[rev.node.name] = [p.name for p in rev.parents]
for parent in rev.parents:
nodes_cache.setdefault(parent.name, parent)
if (
parent.current
and parent.type == NodeType.METRIC
and parent.name not in parent_map
):
parent_map[parent.name] = [gp.name for gp in parent.current.parents]
for grandparent in parent.current.parents:
nodes_cache.setdefault(grandparent.name, grandparent)

parsed_query_cache: dict[str, ast.Query] = {}

# 3. Per-metric extract with caches — zero DB calls in this loop.
extraction_results: list[tuple[NodeRevision, list]] = []
for rev in revisions:
extractor = MetricComponentExtractor(rev.id)
measures, derived_sql = await extractor.extract(
session,
nodes_cache=nodes_cache,
parent_map=parent_map,
metric_node=rev.node,
parsed_query_cache=parsed_query_cache,
)
rev.derived_expression = str(derived_sql)
extraction_results.append((rev, measures))

# 4. Batch-fetch all existing FrozenMeasures matching any extracted name.
all_measure_names = {m.name for _, measures in extraction_results for m in measures}
fm_by_name: dict[str, FrozenMeasure] = {}
if all_measure_names:
existing_fms = (
(
await session.execute(
select(FrozenMeasure).where(
FrozenMeasure.name.in_(list(all_measure_names)),
),
)
)
.scalars()
.all()
)
fm_by_name = {fm.name: fm for fm in existing_fms}

# 5. Link / create in-memory. New FrozenMeasures added to fm_by_name so a
# later metric in the same deployment referencing the same measure
# reuses the freshly-added row instead of creating a duplicate.
for rev, measures in extraction_results:
if not rev.parents:
continue # pragma: no cover
upstream_revision_id = rev.parents[0].current.id
for measure in measures:
frozen_measure = fm_by_name.get(measure.name)
if frozen_measure is None:
if not measure.aggregation:
continue
frozen_measure = FrozenMeasure(
name=measure.name,
upstream_revision_id=upstream_revision_id,
Expand All @@ -708,11 +855,8 @@ async def derive_frozen_measures(node_revision_id: int) -> list[FrozenMeasure]:
used_by_node_revisions=[],
)
session.add(frozen_measure)
if frozen_measure:
frozen_measure.used_by_node_revisions.append(node_revision)
frozen_measures.append(frozen_measure)
await session.commit()
return frozen_measures
fm_by_name[measure.name] = frozen_measure
frozen_measure.used_by_node_revisions.append(rev)


async def save_node(
Expand Down Expand Up @@ -2759,8 +2903,6 @@ async def delete_orphaned_missing_parents(session: AsyncSession) -> None:
This should be called after operations that remove node references
(like hard delete or deactivate) to clean up orphaned MissingParents.
"""
from datajunction_server.database.node import NodeMissingParents

orphaned_missing_parents = (
(
await session.execute(
Expand Down
59 changes: 59 additions & 0 deletions datajunction-server/tests/api/deployments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,65 @@ async def test_deploy_metric_with_update(
"status": "success",
}

@pytest.mark.asyncio
async def test_deploy_populates_derived_expression_and_measures(
self,
session,
client,
default_hard_hats,
default_hard_hat,
default_us_states,
default_us_state,
default_avg_length_of_employment,
):
"""
Deploying a metric must populate NodeRevision.derived_expression and
the associated FrozenMeasure rows inline (not via background task) so
the deployment result is atomically consistent: when the deployment
reports success, measure derivation has already happened.

Regression test for the pre-cutover gap where single-node create used
a FastAPI BackgroundTask but bulk deployment had no equivalent path,
leaving metrics with derived_expression = NULL after deployment.
"""
from sqlalchemy.orm import joinedload, selectinload
from datajunction_server.database.node import Node, NodeRevision

namespace = "derive_measures"
data = await deploy_and_wait(
client,
DeploymentSpec(
namespace=namespace,
nodes=[
default_hard_hats,
default_hard_hat,
default_us_states,
default_us_state,
default_avg_length_of_employment,
],
),
)
assert data["status"] == "success"

metric = await Node.get_by_name(
session,
f"{namespace}.default.avg_length_of_employment",
options=[
joinedload(Node.current).options(
selectinload(NodeRevision.frozen_measures),
),
],
)
assert metric is not None
assert metric.current is not None
assert metric.current.derived_expression is not None, (
"derived_expression should be populated inline by the orchestrator — "
"a NULL value means the bulk-deployment path is skipping derivation."
)
assert len(metric.current.frozen_measures) > 0, (
"at least one FrozenMeasure row should be linked after deployment"
)

@pytest.mark.asyncio
async def test_deploy_cube_with_update(
self,
Expand Down
Loading
Loading