Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 142 additions & 3 deletions datajunction-server/datajunction_server/api/graphql/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,35 @@
"""

import json
from typing import Any, Optional
import logging
from typing import Any, Optional, Tuple

from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import joinedload, load_only, noload, selectinload
from strawberry.dataloader import DataLoader
from starlette.requests import Request

from datajunction_server.api.graphql.resolvers.nodes import load_node_options
from datajunction_server.construction.build_v3.loaders import find_upstream_node_names
from datajunction_server.database.collection import Collection as DBCollection
from datajunction_server.database.namespace import NodeNamespace
from datajunction_server.database.node import Node as DBNode
from datajunction_server.database.node import (
Node as DBNode,
NodeRevision as DBNodeRevision,
)
from datajunction_server.internal.namespaces import (
get_parent_namespaces,
resolve_git_info_from_map,
)
from datajunction_server.sql.decompose import (
MetricComponent,
MetricComponentExtractor,
)
from datajunction_server.sql.parsing import ast
from datajunction_server.utils import session_context

logger = logging.getLogger(__name__)


async def batch_load_nodes(
keys: list[tuple[str, dict[str, Any] | None]],
Expand Down Expand Up @@ -236,3 +248,130 @@ def create_git_info_loader(
return DataLoader(
load_fn=lambda keys: batch_load_git_info(keys, request),
)


async def batch_load_extracted_measures(
node_revision_ids: list[int],
request: Request,
) -> list[Optional[Tuple[list[MetricComponent], "ast.Query"]]]:
"""
Batch-extract metric components for multiple metric node revisions.

Opens ONE reader session and pre-populates the caches that
MetricComponentExtractor.extract() accepts:

- ``nodes_cache``: name -> Node (with current.query loaded), covering
every metric in the batch plus their full upstream metric chain.
This replaces the per-metric ``_load_metric_data`` that runs 2 DB
queries per call.
- ``parent_map``: child_name -> [parent_names], built from one recursive
CTE. Replaces the per-base-metric "is this parent derived?" check
inside the extractor's recursion.

With both caches populated, ``extract()`` takes the cached path at
``_build_metric_data_from_cache`` and fires zero DB queries (even for
derived metrics that recurse into parents). Nets hundreds to thousands
of queries down to three per request:

1. map nr_id -> node name
2. upstream CTE (all ancestor metric names)
3. bulk-load Node objects for the upstream set
"""
async with session_context(
request,
session_label="graphql extracted_measures batch",
) as session:
# 1) nr_id -> name, so we can kick off the upstream walk and later
# look up each batch entry's metric_node.
nr_stmt = (
select(DBNodeRevision.id, DBNode.name)
.join(DBNode, DBNodeRevision.node_id == DBNode.id)
.where(DBNodeRevision.id.in_(node_revision_ids))
)
nr_rows = (await session.execute(nr_stmt)).all()
nr_to_name: dict[int, str] = {row.id: row.name for row in nr_rows}

# 2) Walk the upstream graph once; pulls every ancestor that's a
# parent of any batch metric, and produces parent_map as a side effect.
all_names, parent_map = await find_upstream_node_names(
session,
list(nr_to_name.values()),
)

# 3) Bulk-load Node objects for the full ancestor set. We only need
# name, type, and current.query — everything else is noloaded so this
# is a narrow query. noload(Node.created_by/Node.tags) are safe because
# MetricComponentExtractor never reads them.
nodes_cache: dict[str, DBNode] = {}
if all_names:
node_stmt = (
select(DBNode)
.where(DBNode.name.in_(all_names))
.options(
load_only(
DBNode.name,
DBNode.type,
DBNode.current_version,
),
noload(DBNode.created_by),
noload(DBNode.tags),
joinedload(DBNode.current).options(
noload(DBNodeRevision.created_by),
load_only(
DBNodeRevision.id,
DBNodeRevision.name,
DBNodeRevision.query,
),
),
)
)
nodes_cache = {
n.name: n
for n in (await session.execute(node_stmt)).unique().scalars().all()
}

# 4) Per nr_id: invoke extract() with the cache trio so the extractor
# takes the zero-DB-query path through _build_metric_data_from_cache,
# including its recursion.
results: list[Optional[Tuple[list[MetricComponent], "ast.Query"]]] = []
for nr_id in node_revision_ids:
name = nr_to_name.get(nr_id)
metric_node = nodes_cache.get(name) if name else None
if metric_node is None or metric_node.current is None:
results.append(None)
continue
try:
extractor = MetricComponentExtractor(nr_id)
components, derived_ast = await extractor.extract(
session,
nodes_cache=nodes_cache,
parent_map=parent_map,
metric_node=metric_node,
)
results.append((components, derived_ast))
except Exception as exc: # pragma: no cover
logger.warning(
"extracted_measures extraction failed for nr_id=%s: %s",
nr_id,
exc,
)
results.append(None)
return results


def create_extracted_measures_loader(
request: Request,
) -> DataLoader[int, Optional[Tuple[list[MetricComponent], "ast.Query"]]]:
"""
Create a DataLoader that batches MetricComponentExtractor.extract() calls
across a GraphQL request. Keys are NodeRevision ids.

For a query like `findNodes(nodeTypes: [METRIC]) { current { extractedMeasures {...} } }`
returning N metrics, Strawberry batches all per-node loader.load(nr_id)
calls within a single event-loop tick into one
batch_load_extracted_measures(ids) — collapsing N sessions + N extractions
into 1 session + N extractions sharing parsed_query_cache.
"""
return DataLoader(
load_fn=lambda keys: batch_load_extracted_measures(keys, request),
)
2 changes: 2 additions & 0 deletions datajunction-server/datajunction_server/api/graphql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datajunction_server.internal.access.authentication.http import DJHTTPBearer
from datajunction_server.api.graphql.dataloaders import (
create_collection_nodes_loader,
create_extracted_measures_loader,
create_git_info_loader,
create_node_by_name_loader,
)
Expand Down Expand Up @@ -126,6 +127,7 @@ async def get_context(
"node_loader": create_node_by_name_loader(request),
"collection_nodes_loader": create_collection_nodes_loader(request),
"git_info_loader": create_git_info_loader(request),
"extracted_measures_loader": create_extracted_measures_loader(request),
"settings": get_settings(),
"request": request,
"background_tasks": background_tasks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,19 @@ def load_node_revision_options(node_revision_fields):

# Handle columns
if "columns" in node_revision_fields or "primary_key" in node_revision_fields:
# Full columns with all relationships needed for columns/primary_key queries
options.append(
selectinload(DBNodeRevision.columns).options(
joinedload(Column.attributes).joinedload(
columns_sub = node_revision_fields.get("columns")
requested_col_sub: dict = columns_sub if isinstance(columns_sub, dict) else {}
needs_attributes = (
"attributes" in requested_col_sub or "primary_key" in node_revision_fields
)
col_opts: list = []
if needs_attributes:
col_opts.append(
selectinload(Column.attributes).selectinload(
ColumnAttribute.attribute_type,
),
joinedload(Column.dimension),
joinedload(Column.partition),
),
)
)
options.append(selectinload(DBNodeRevision.columns).options(*col_opts))
elif is_cube_request and not all_name_only:
# Minimal ORM columns for cube element resolution
options.append(
Expand Down
84 changes: 64 additions & 20 deletions datajunction-server/datajunction_server/api/graphql/scalars/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.orm.attributes import InstrumentedAttribute, set_committed_value

from datajunction_server.api.graphql.scalars import BigInt
from datajunction_server.api.graphql.utils import extract_fields
from datajunction_server.api.graphql.scalars.availabilitystate import (
AvailabilityState,
PartitionAvailability,
Expand Down Expand Up @@ -45,7 +46,6 @@
from datajunction_server.models.node import (
GitRepositoryInfo as PydanticGitRepositoryInfo,
)
from datajunction_server.sql.decompose import MetricComponentExtractor
from datajunction_server.sql.parsing.backends.antlr4 import ast, parse

NodeType = strawberry.enum(NodeType_)
Expand Down Expand Up @@ -325,15 +325,35 @@ def primary_key(self, root: "DBNodeRevision") -> list[str]:
return [col.name for col in root.primary_key()]

@strawberry.field
def metric_metadata(self, root: "DBNodeRevision") -> MetricMetadata | None:
def metric_metadata(
self,
root: "DBNodeRevision",
info: Info,
) -> MetricMetadata | None:
"""
Metric metadata
"""
if root.type != NodeType.METRIC:
return None

query_ast = parse(root.query)
functions = [func.function() for func in query_ast.find_all(ast.Function)]
# Parsing the metric SQL + walking its AST for `expression` and
# `incompatible_druid_functions` is ANTLR-heavy. Skip it entirely
# when the client didn't request either sub-field.
requested = extract_fields(info)
needs_ast = (
"expression" in requested or "incompatible_druid_functions" in requested
)
expression: str | None = None
incompatible: set[str] = set()
if needs_ast:
query_ast = parse(root.query)
functions = [func.function() for func in query_ast.find_all(ast.Function)]
expression = str(query_ast.select.projection[0])
incompatible = {
func.__name__.upper()
for func in functions
if Dialect.DRUID not in func.dialects
}
return MetricMetadata( # type: ignore
direction=root.metric_metadata.direction if root.metric_metadata else None,
unit=root.metric_metadata.unit.value
Expand All @@ -348,12 +368,8 @@ def metric_metadata(self, root: "DBNodeRevision") -> MetricMetadata | None:
max_decimal_exponent=root.metric_metadata.max_decimal_exponent
if root.metric_metadata
else None,
expression=str(query_ast.select.projection[0]),
incompatible_druid_functions={
func.__name__.upper()
for func in functions
if Dialect.DRUID not in func.dialects
},
expression=expression or "",
incompatible_druid_functions=incompatible,
)

@strawberry.field
Expand All @@ -373,23 +389,51 @@ async def extracted_measures(
info: Info,
) -> DecomposedMetric | None:
"""
A list of metric components for a metric node
A list of metric components for a metric node.

Uses the request-scoped extracted_measures_loader so that a GraphQL
query returning N metrics batches all extractions into a single
session + shared parsed_query_cache (instead of opening N independent
sessions via resolver_session).
"""
if root.type != NodeType.METRIC:
return None
from datajunction_server.api.graphql.utils import resolver_session

async with resolver_session(info) as session:
extractor = MetricComponentExtractor(root.id)
components, derived_ast = await extractor.extract(session)
# The derived_expression is the combiner (how to combine merged components)
combiner_expr = str(derived_ast.select.projection[0])
# Fast path: derive_frozen_measures (internal/nodes.py:291, background
# task on node create/update) persists `str(derived_sql)` to the
# NodeRevision.derived_expression column. When the fragment only reads
# `derivedQuery`, we return the cached scalar and skip extract(). The
# GQL `derivedExpression` field is an alias for `combiner`, not for
# the column of the same name — so we have to fall through to the
# full path whenever it's requested.
requested = extract_fields(info)
needs_full_extract = (
"components" in requested
or "combiner" in requested
or "derived_expression" in requested
)

if not needs_full_extract and root.derived_expression:
return DecomposedMetric( # type: ignore
components=components,
combiner=combiner_expr,
derived_query=str(derived_ast),
components=[],
combiner="",
derived_query=root.derived_expression,
)

# Full path: DataLoader batch + extract().
loader = info.context["extracted_measures_loader"]
result = await loader.load(root.id)
if result is None:
return None
components, derived_ast = result
# The derived_expression is the combiner (how to combine merged components)
combiner_expr = str(derived_ast.select.projection[0])
return DecomposedMetric( # type: ignore
components=components,
combiner=combiner_expr,
derived_query=str(derived_ast),
)

# Only cubes will have these fields
@strawberry.field
def cube_metrics(self, root: "DBNodeRevision", info: Info) -> List["NodeRevision"]: # type: ignore[return-value]
Expand Down
15 changes: 14 additions & 1 deletion datajunction-server/datajunction_server/database/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,20 @@ async def get_cube_by_name(
joinedload(Node.tags),
]

statement = select(Node).where(Node.name == name).options(*options)
# Force loader options to re-apply to NodeRevisions already in the
# identity map. Without this, a prior query that hydrated a metric's
# NodeRevision without eager-loading `.node` leaves the cached
# instance sticky, and the nested selectinload(Column.node_revision)
# .selectinload(NodeRevision.node) silently skips it — producing a
# sync lazy-load when `cube_node_metrics` touches `.node` from a
# hybrid_property (MissingGreenlet on 3.11, where fixture ordering
# happens to trigger this).
statement = (
select(Node)
.where(Node.name == name)
.options(*options)
.execution_options(populate_existing=True)
)
result = await session.execute(statement)
node = result.unique().scalar_one_or_none()
return node
Expand Down
Loading
Loading