diff --git a/datajunction-server/datajunction_server/api/graphql/dataloaders.py b/datajunction-server/datajunction_server/api/graphql/dataloaders.py index 223b5fd56..60ec9a566 100644 --- a/datajunction-server/datajunction_server/api/graphql/dataloaders.py +++ b/datajunction-server/datajunction_server/api/graphql/dataloaders.py @@ -3,23 +3,35 @@ """ import json -from typing import Any, Optional +import logging +from typing import Any, Optional, Tuple from sqlalchemy import select -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import joinedload, load_only, noload, selectinload from strawberry.dataloader import DataLoader from starlette.requests import Request from datajunction_server.api.graphql.resolvers.nodes import load_node_options +from datajunction_server.construction.build_v3.loaders import find_upstream_node_names from datajunction_server.database.collection import Collection as DBCollection from datajunction_server.database.namespace import NodeNamespace -from datajunction_server.database.node import Node as DBNode +from datajunction_server.database.node import ( + Node as DBNode, + NodeRevision as DBNodeRevision, +) from datajunction_server.internal.namespaces import ( get_parent_namespaces, resolve_git_info_from_map, ) +from datajunction_server.sql.decompose import ( + MetricComponent, + MetricComponentExtractor, +) +from datajunction_server.sql.parsing import ast from datajunction_server.utils import session_context +logger = logging.getLogger(__name__) + async def batch_load_nodes( keys: list[tuple[str, dict[str, Any] | None]], @@ -236,3 +248,130 @@ def create_git_info_loader( return DataLoader( load_fn=lambda keys: batch_load_git_info(keys, request), ) + + +async def batch_load_extracted_measures( + node_revision_ids: list[int], + request: Request, +) -> list[Optional[Tuple[list[MetricComponent], "ast.Query"]]]: + """ + Batch-extract metric components for multiple metric node revisions. + + Opens ONE reader session and pre-populates the caches that + MetricComponentExtractor.extract() accepts: + + - ``nodes_cache``: name -> Node (with current.query loaded), covering + every metric in the batch plus their full upstream metric chain. + This replaces the per-metric ``_load_metric_data`` that runs 2 DB + queries per call. + - ``parent_map``: child_name -> [parent_names], built from one recursive + CTE. Replaces the per-base-metric "is this parent derived?" check + inside the extractor's recursion. + + With both caches populated, ``extract()`` takes the cached path at + ``_build_metric_data_from_cache`` and fires zero DB queries (even for + derived metrics that recurse into parents). Nets hundreds to thousands + of queries down to three per request: + + 1. map nr_id -> node name + 2. upstream CTE (all ancestor metric names) + 3. bulk-load Node objects for the upstream set + """ + async with session_context( + request, + session_label="graphql extracted_measures batch", + ) as session: + # 1) nr_id -> name, so we can kick off the upstream walk and later + # look up each batch entry's metric_node. + nr_stmt = ( + select(DBNodeRevision.id, DBNode.name) + .join(DBNode, DBNodeRevision.node_id == DBNode.id) + .where(DBNodeRevision.id.in_(node_revision_ids)) + ) + nr_rows = (await session.execute(nr_stmt)).all() + nr_to_name: dict[int, str] = {row.id: row.name for row in nr_rows} + + # 2) Walk the upstream graph once; pulls every ancestor that's a + # parent of any batch metric, and produces parent_map as a side effect. + all_names, parent_map = await find_upstream_node_names( + session, + list(nr_to_name.values()), + ) + + # 3) Bulk-load Node objects for the full ancestor set. We only need + # name, type, and current.query — everything else is noloaded so this + # is a narrow query. noload(Node.created_by/Node.tags) are safe because + # MetricComponentExtractor never reads them. + nodes_cache: dict[str, DBNode] = {} + if all_names: + node_stmt = ( + select(DBNode) + .where(DBNode.name.in_(all_names)) + .options( + load_only( + DBNode.name, + DBNode.type, + DBNode.current_version, + ), + noload(DBNode.created_by), + noload(DBNode.tags), + joinedload(DBNode.current).options( + noload(DBNodeRevision.created_by), + load_only( + DBNodeRevision.id, + DBNodeRevision.name, + DBNodeRevision.query, + ), + ), + ) + ) + nodes_cache = { + n.name: n + for n in (await session.execute(node_stmt)).unique().scalars().all() + } + + # 4) Per nr_id: invoke extract() with the cache trio so the extractor + # takes the zero-DB-query path through _build_metric_data_from_cache, + # including its recursion. + results: list[Optional[Tuple[list[MetricComponent], "ast.Query"]]] = [] + for nr_id in node_revision_ids: + name = nr_to_name.get(nr_id) + metric_node = nodes_cache.get(name) if name else None + if metric_node is None or metric_node.current is None: + results.append(None) + continue + try: + extractor = MetricComponentExtractor(nr_id) + components, derived_ast = await extractor.extract( + session, + nodes_cache=nodes_cache, + parent_map=parent_map, + metric_node=metric_node, + ) + results.append((components, derived_ast)) + except Exception as exc: # pragma: no cover + logger.warning( + "extracted_measures extraction failed for nr_id=%s: %s", + nr_id, + exc, + ) + results.append(None) + return results + + +def create_extracted_measures_loader( + request: Request, +) -> DataLoader[int, Optional[Tuple[list[MetricComponent], "ast.Query"]]]: + """ + Create a DataLoader that batches MetricComponentExtractor.extract() calls + across a GraphQL request. Keys are NodeRevision ids. + + For a query like `findNodes(nodeTypes: [METRIC]) { current { extractedMeasures {...} } }` + returning N metrics, Strawberry batches all per-node loader.load(nr_id) + calls within a single event-loop tick into one + batch_load_extracted_measures(ids) — collapsing N sessions + N extractions + into 1 session + N extractions sharing parsed_query_cache. + """ + return DataLoader( + load_fn=lambda keys: batch_load_extracted_measures(keys, request), + ) diff --git a/datajunction-server/datajunction_server/api/graphql/main.py b/datajunction-server/datajunction_server/api/graphql/main.py index 823f9e5f3..ed7ae720e 100644 --- a/datajunction-server/datajunction_server/api/graphql/main.py +++ b/datajunction-server/datajunction_server/api/graphql/main.py @@ -15,6 +15,7 @@ from datajunction_server.internal.access.authentication.http import DJHTTPBearer from datajunction_server.api.graphql.dataloaders import ( create_collection_nodes_loader, + create_extracted_measures_loader, create_git_info_loader, create_node_by_name_loader, ) @@ -126,6 +127,7 @@ async def get_context( "node_loader": create_node_by_name_loader(request), "collection_nodes_loader": create_collection_nodes_loader(request), "git_info_loader": create_git_info_loader(request), + "extracted_measures_loader": create_extracted_measures_loader(request), "settings": get_settings(), "request": request, "background_tasks": background_tasks, diff --git a/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py b/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py index e56751081..9cbce2ba8 100644 --- a/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py +++ b/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py @@ -421,16 +421,19 @@ def load_node_revision_options(node_revision_fields): # Handle columns if "columns" in node_revision_fields or "primary_key" in node_revision_fields: - # Full columns with all relationships needed for columns/primary_key queries - options.append( - selectinload(DBNodeRevision.columns).options( - joinedload(Column.attributes).joinedload( + columns_sub = node_revision_fields.get("columns") + requested_col_sub: dict = columns_sub if isinstance(columns_sub, dict) else {} + needs_attributes = ( + "attributes" in requested_col_sub or "primary_key" in node_revision_fields + ) + col_opts: list = [] + if needs_attributes: + col_opts.append( + selectinload(Column.attributes).selectinload( ColumnAttribute.attribute_type, ), - joinedload(Column.dimension), - joinedload(Column.partition), - ), - ) + ) + options.append(selectinload(DBNodeRevision.columns).options(*col_opts)) elif is_cube_request and not all_name_only: # Minimal ORM columns for cube element resolution options.append( diff --git a/datajunction-server/datajunction_server/api/graphql/scalars/node.py b/datajunction-server/datajunction_server/api/graphql/scalars/node.py index 938814887..acd2d332b 100644 --- a/datajunction-server/datajunction_server/api/graphql/scalars/node.py +++ b/datajunction-server/datajunction_server/api/graphql/scalars/node.py @@ -10,6 +10,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute, set_committed_value from datajunction_server.api.graphql.scalars import BigInt +from datajunction_server.api.graphql.utils import extract_fields from datajunction_server.api.graphql.scalars.availabilitystate import ( AvailabilityState, PartitionAvailability, @@ -45,7 +46,6 @@ from datajunction_server.models.node import ( GitRepositoryInfo as PydanticGitRepositoryInfo, ) -from datajunction_server.sql.decompose import MetricComponentExtractor from datajunction_server.sql.parsing.backends.antlr4 import ast, parse NodeType = strawberry.enum(NodeType_) @@ -325,15 +325,35 @@ def primary_key(self, root: "DBNodeRevision") -> list[str]: return [col.name for col in root.primary_key()] @strawberry.field - def metric_metadata(self, root: "DBNodeRevision") -> MetricMetadata | None: + def metric_metadata( + self, + root: "DBNodeRevision", + info: Info, + ) -> MetricMetadata | None: """ Metric metadata """ if root.type != NodeType.METRIC: return None - query_ast = parse(root.query) - functions = [func.function() for func in query_ast.find_all(ast.Function)] + # Parsing the metric SQL + walking its AST for `expression` and + # `incompatible_druid_functions` is ANTLR-heavy. Skip it entirely + # when the client didn't request either sub-field. + requested = extract_fields(info) + needs_ast = ( + "expression" in requested or "incompatible_druid_functions" in requested + ) + expression: str | None = None + incompatible: set[str] = set() + if needs_ast: + query_ast = parse(root.query) + functions = [func.function() for func in query_ast.find_all(ast.Function)] + expression = str(query_ast.select.projection[0]) + incompatible = { + func.__name__.upper() + for func in functions + if Dialect.DRUID not in func.dialects + } return MetricMetadata( # type: ignore direction=root.metric_metadata.direction if root.metric_metadata else None, unit=root.metric_metadata.unit.value @@ -348,12 +368,8 @@ def metric_metadata(self, root: "DBNodeRevision") -> MetricMetadata | None: max_decimal_exponent=root.metric_metadata.max_decimal_exponent if root.metric_metadata else None, - expression=str(query_ast.select.projection[0]), - incompatible_druid_functions={ - func.__name__.upper() - for func in functions - if Dialect.DRUID not in func.dialects - }, + expression=expression or "", + incompatible_druid_functions=incompatible, ) @strawberry.field @@ -373,23 +389,51 @@ async def extracted_measures( info: Info, ) -> DecomposedMetric | None: """ - A list of metric components for a metric node + A list of metric components for a metric node. + + Uses the request-scoped extracted_measures_loader so that a GraphQL + query returning N metrics batches all extractions into a single + session + shared parsed_query_cache (instead of opening N independent + sessions via resolver_session). """ if root.type != NodeType.METRIC: return None - from datajunction_server.api.graphql.utils import resolver_session - async with resolver_session(info) as session: - extractor = MetricComponentExtractor(root.id) - components, derived_ast = await extractor.extract(session) - # The derived_expression is the combiner (how to combine merged components) - combiner_expr = str(derived_ast.select.projection[0]) + # Fast path: derive_frozen_measures (internal/nodes.py:291, background + # task on node create/update) persists `str(derived_sql)` to the + # NodeRevision.derived_expression column. When the fragment only reads + # `derivedQuery`, we return the cached scalar and skip extract(). The + # GQL `derivedExpression` field is an alias for `combiner`, not for + # the column of the same name — so we have to fall through to the + # full path whenever it's requested. + requested = extract_fields(info) + needs_full_extract = ( + "components" in requested + or "combiner" in requested + or "derived_expression" in requested + ) + + if not needs_full_extract and root.derived_expression: return DecomposedMetric( # type: ignore - components=components, - combiner=combiner_expr, - derived_query=str(derived_ast), + components=[], + combiner="", + derived_query=root.derived_expression, ) + # Full path: DataLoader batch + extract(). + loader = info.context["extracted_measures_loader"] + result = await loader.load(root.id) + if result is None: + return None + components, derived_ast = result + # The derived_expression is the combiner (how to combine merged components) + combiner_expr = str(derived_ast.select.projection[0]) + return DecomposedMetric( # type: ignore + components=components, + combiner=combiner_expr, + derived_query=str(derived_ast), + ) + # Only cubes will have these fields @strawberry.field def cube_metrics(self, root: "DBNodeRevision", info: Info) -> List["NodeRevision"]: # type: ignore[return-value] diff --git a/datajunction-server/datajunction_server/database/node.py b/datajunction-server/datajunction_server/database/node.py index e4ff4b7cc..454b5fbae 100644 --- a/datajunction-server/datajunction_server/database/node.py +++ b/datajunction-server/datajunction_server/database/node.py @@ -691,7 +691,20 @@ async def get_cube_by_name( joinedload(Node.tags), ] - statement = select(Node).where(Node.name == name).options(*options) + # Force loader options to re-apply to NodeRevisions already in the + # identity map. Without this, a prior query that hydrated a metric's + # NodeRevision without eager-loading `.node` leaves the cached + # instance sticky, and the nested selectinload(Column.node_revision) + # .selectinload(NodeRevision.node) silently skips it — producing a + # sync lazy-load when `cube_node_metrics` touches `.node` from a + # hybrid_property (MissingGreenlet on 3.11, where fixture ordering + # happens to trigger this). + statement = ( + select(Node) + .where(Node.name == name) + .options(*options) + .execution_options(populate_existing=True) + ) result = await session.execute(statement) node = result.unique().scalar_one_or_none() return node diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index 6a360d7bc..6737581f0 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -37,7 +37,12 @@ from datajunction_server.database.dimensionlink import DimensionLink from datajunction_server.database.history import History from datajunction_server.database.metricmetadata import MetricMetadata -from datajunction_server.database.node import MissingParent, Node, NodeRevision +from datajunction_server.database.node import ( + MissingParent, + Node, + NodeRelationship, + NodeRevision, +) from datajunction_server.database.partition import Partition from datajunction_server.database.user import User from datajunction_server.database.measure import FrozenMeasure @@ -2984,11 +2989,20 @@ async def activate_node( missing_parent and missing_parent in downstream.current.missing_parents ): downstream.current.missing_parents.remove(missing_parent) - # Compare by id, not Python identity: the cached parent collection may - # contain a different ORM object instance for the same node, and an - # identity-based `not in` check would wrongly append, producing a - # duplicate NodeRelationship insert. - if node.id not in {p.id for p in downstream.current.parents}: + # Query NodeRelationship directly rather than trusting the in-memory + # `downstream.current.parents` collection. On Python 3.11 we've seen + # the selectinload'd collection miss the row that already exists in + # the DB (autoflush ordering across the prior deactivate + this + # activate), which caused duplicate-key violations at flush time. + existing = await session.execute( + select(NodeRelationship.parent_id) + .where( + NodeRelationship.parent_id == node.id, + NodeRelationship.child_id == downstream.current.id, + ) + .limit(1), + ) + if existing.first() is None: downstream.current.parents.append(node) _logger.info( diff --git a/datajunction-server/scripts/backfill_derived_expression.py b/datajunction-server/scripts/backfill_derived_expression.py new file mode 100644 index 000000000..9752a5ef5 --- /dev/null +++ b/datajunction-server/scripts/backfill_derived_expression.py @@ -0,0 +1,115 @@ +""" +One-off backfill for NodeRevision.derived_expression. + +Most GraphQL queries that request metric's extractedMeasures can take a fast +path that reads NodeRevision.derived_expression directly instead of re-running +MetricComponentExtractor.extract(). New metrics get this populated via the +`derive_frozen_measures` background task triggered on node create/update +(internal/nodes.py:291), but metrics created before that was added have +derived_expression = NULL and force the slow extract() path. + +Run this once after deploying the fast-path resolver to populate the column +for all existing metrics. + + python scripts/backfill_derived_expression.py + # or: python scripts/backfill_derived_expression.py --batch-size 50 --dry-run + +Idempotent: skips rows that already have derived_expression set. Safe to +re-run. Failures on individual metrics are logged and skipped so one bad +metric doesn't block the rest. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +from sqlalchemy import select + +from datajunction_server.database.node import Node, NodeRevision +from datajunction_server.internal.nodes import derive_frozen_measures +from datajunction_server.models.node_type import NodeType +from datajunction_server.utils import session_context + +logger = logging.getLogger(__name__) + + +async def backfill(batch_size: int = 100, dry_run: bool = False) -> None: + async with session_context() as session: + stmt = ( + select(NodeRevision.id, Node.name) + .join(Node, NodeRevision.node_id == Node.id) + .where( + Node.type == NodeType.METRIC, + Node.current_version == NodeRevision.version, + Node.deactivated_at.is_(None), + NodeRevision.derived_expression.is_(None), + ) + ) + targets = [(row.id, row.name) for row in (await session.execute(stmt)).all()] + + logger.info( + "Found %d metric revisions missing derived_expression", + len(targets), + ) + if dry_run: + for nr_id, name in targets[:20]: + logger.info(" would backfill: nr_id=%d name=%s", nr_id, name) + if len(targets) > 20: + logger.info(" ... and %d more", len(targets) - 20) + return + + total = len(targets) + done = 0 + failed = 0 + for i in range(0, total, batch_size): + batch = targets[i : i + batch_size] + for nr_id, name in batch: + try: + # derive_frozen_measures opens its own session_context and + # commits inside; safe to call in a loop. + await derive_frozen_measures(nr_id) + done += 1 + except Exception as exc: + failed += 1 + logger.warning( + "backfill failed for nr_id=%d name=%s: %s", + nr_id, + name, + exc, + ) + logger.info( + "progress: %d/%d (done=%d failed=%d)", + i + len(batch), + total, + done, + failed, + ) + + logger.info("backfill complete: done=%d failed=%d", done, failed) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Progress log interval (default: 100)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Count target rows without writing", + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s %(name)s: %(message)s", + ) + asyncio.run(backfill(batch_size=args.batch_size, dry_run=args.dry_run)) + + +if __name__ == "__main__": + main() diff --git a/datajunction-server/tests/api/graphql/find_nodes_test.py b/datajunction-server/tests/api/graphql/find_nodes_test.py index 758acd421..2c1a2a8aa 100644 --- a/datajunction-server/tests/api/graphql/find_nodes_test.py +++ b/datajunction-server/tests/api/graphql/find_nodes_test.py @@ -733,6 +733,93 @@ async def test_find_metric( }, ] + # Fast path: when the fragment only requests `derivedQuery`, the resolver + # should read `derived_expression` directly off the row and skip the + # DataLoader + extract() entirely. Patching the batch loader to raise + # guarantees the fast path was taken. `derivedExpression` is an alias for + # `combiner` and would force the full path, so it is not requested here. + fast_path_query = """ + { + findNodes(names: ["default.regional_repair_efficiency"]) { + current { + extractedMeasures { + derivedQuery + } + } + } + } + """ + with mock.patch( + "datajunction_server.api.graphql.dataloaders.batch_load_extracted_measures", + side_effect=AssertionError("fast path should skip the DataLoader"), + ): + fast_response = await client_with_roads.post( + "/graphql", + json={"query": fast_path_query}, + ) + assert fast_response.status_code == 200 + fast_extracted = fast_response.json()["data"]["findNodes"][0]["current"][ + "extractedMeasures" + ] + full_extracted = data["data"]["findNodes"][0]["current"]["extractedMeasures"] + assert fast_extracted["derivedQuery"] == full_extracted["derivedQuery"] + + # AST skip for metricMetadata: when neither `expression` nor + # `incompatibleDruidFunctions` is requested, parse(root.query) must not run. + metadata_only_query = """ + { + findNodes(names: ["default.regional_repair_efficiency"]) { + current { + metricMetadata { + direction + unit { name } + } + } + } + } + """ + with mock.patch( + "datajunction_server.api.graphql.scalars.node.parse", + side_effect=AssertionError("AST parse should be skipped"), + ): + metadata_response = await client_with_roads.post( + "/graphql", + json={"query": metadata_only_query}, + ) + assert metadata_response.status_code == 200 + metadata = metadata_response.json()["data"]["findNodes"][0]["current"][ + "metricMetadata" + ] + assert metadata == {"direction": None, "unit": None} + + # None path: when the DataLoader can't produce a result for an nr_id + # (e.g. extract() raised inside the batch loader), the resolver returns + # null for `extractedMeasures`. + full_query_with_components = """ + { + findNodes(names: ["default.regional_repair_efficiency"]) { + current { + extractedMeasures { + components { name } + } + } + } + } + """ + with mock.patch( + "datajunction_server.api.graphql.dataloaders.batch_load_extracted_measures", + return_value=[None], + ): + none_response = await client_with_roads.post( + "/graphql", + json={"query": full_query_with_components}, + ) + assert none_response.status_code == 200 + assert ( + none_response.json()["data"]["findNodes"][0]["current"]["extractedMeasures"] + is None + ) + @pytest.mark.asyncio async def test_find_cubes( diff --git a/datajunction-server/tests/api/graphql/test_dataloaders.py b/datajunction-server/tests/api/graphql/test_dataloaders.py index 801f07a96..63e33cd3e 100644 --- a/datajunction-server/tests/api/graphql/test_dataloaders.py +++ b/datajunction-server/tests/api/graphql/test_dataloaders.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from datajunction_server.api.graphql.dataloaders import ( + batch_load_extracted_measures, batch_load_git_info, batch_load_nodes, batch_load_nodes_by_name_only, @@ -410,3 +411,35 @@ def test_create_git_info_loader(mock_dataloader): assert "load_fn" in call_kwargs assert callable(call_kwargs["load_fn"]) assert result == mock_loader_instance + + +@pytest.mark.asyncio +@patch("datajunction_server.api.graphql.dataloaders.find_upstream_node_names") +@patch("datajunction_server.api.graphql.dataloaders.session_context") +async def test_batch_load_extracted_measures_missing_ids( + mock_session_context, + mock_find_upstream, +): + """Batch loader returns None for nr_ids that don't resolve to a metric node. + + Exercises the `if all_names:` False branch (no ancestor set, so the bulk + Node load is skipped) and the `metric_node is None` early-continue that + appends None to the results — both defensive paths for nr_ids that + vanished between GraphQL resolution and the batch query. + """ + mock_session = AsyncMock() + mock_session_context.return_value.__aenter__.return_value = mock_session + + # nr_stmt returns no rows -> nr_to_name is empty + nr_result = MagicMock() + nr_result.all.return_value = [] + mock_session.execute.return_value = nr_result + + # find_upstream_node_names returns empty when called with [] + mock_find_upstream.return_value = (set(), {}) + + result = await batch_load_extracted_measures([999999], MagicMock()) + + assert result == [None] + # Only the nr_stmt query ran; the bulk Node load was skipped. + assert mock_session.execute.call_count == 1