Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def create_extracted_measures_loader(
returning N metrics, Strawberry batches all per-node loader.load(nr_id)
calls within a single event-loop tick into one
batch_load_extracted_measures(ids) — collapsing N sessions + N extractions
into 1 session + N extractions sharing parsed_query_cache.
into 1 session + N extractions sharing nodes_cache and parent_map.
"""
return DataLoader(
load_fn=lambda keys: batch_load_extracted_measures(keys, request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ async def extracted_measures(

Uses the request-scoped extracted_measures_loader so that a GraphQL
query returning N metrics batches all extractions into a single
session + shared parsed_query_cache (instead of opening N independent
sessions via resolver_session).
session + shared nodes_cache/parent_map (instead of opening N
independent sessions via resolver_session).
"""
if root.type != NodeType.METRIC:
return None
Expand Down
5 changes: 1 addition & 4 deletions datajunction-server/datajunction_server/internal/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ async def derive_frozen_measures_bulk(

* one query to load all target revisions + their parent chain eagerly,
* zero-DB `MetricComponentExtractor.extract` calls (via the extractor's
`nodes_cache` / `parent_map` / `parsed_query_cache` params), and
`nodes_cache` / `parent_map` params), and
* one batch `SELECT` against FrozenMeasure.name IN (...).

Caller owns the transaction and commit; used by the deployment
Expand Down Expand Up @@ -807,8 +807,6 @@ async def derive_frozen_measures_bulk(
for grandparent in parent.current.parents:
nodes_cache.setdefault(grandparent.name, grandparent)

parsed_query_cache: dict[str, ast.Query] = {}

# 3. Per-metric extract with caches — zero DB calls in this loop.
extraction_results: list[tuple[NodeRevision, list]] = []
for rev in revisions:
Expand All @@ -818,7 +816,6 @@ async def derive_frozen_measures_bulk(
nodes_cache=nodes_cache,
parent_map=parent_map,
metric_node=rev.node,
parsed_query_cache=parsed_query_cache,
)
rev.derived_expression = str(derived_sql)
extraction_results.append((rev, measures))
Expand Down
20 changes: 2 additions & 18 deletions datajunction-server/datajunction_server/sql/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import hashlib
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import cast

Expand Down Expand Up @@ -696,7 +695,6 @@ async def extract(
nodes_cache: dict[str, "Node"] | None = None,
parent_map: dict[str, list[str]] | None = None,
metric_node: "Node | None" = None,
parsed_query_cache: dict[str, ast.Query] | None = None,
_visited: set[str] | None = None,
) -> tuple[list[MetricComponent], ast.Query]:
"""
Expand All @@ -714,8 +712,6 @@ async def extract(
Required if nodes_cache is provided.
metric_node: Optional metric Node object.
Required if nodes_cache is provided.
parsed_query_cache: Optional dict of query_string -> parsed AST.
Used to avoid re-parsing the same query multiple times.
"""
# Use cache if available, otherwise query DB
if (
Expand All @@ -731,18 +727,7 @@ async def extract(
else:
metric_data = await self._load_metric_data(session)

# Helper to parse with cache
def cached_parse(query: str) -> ast.Query:
if parsed_query_cache is not None:
if query not in parsed_query_cache: # pragma: no cover
parsed_query_cache[query] = parse(query)

# Return a deep copy to avoid AST mutation issues
return deepcopy(parsed_query_cache[query]) # pragma: no cover
return parse(query)

# Parse queries (pure computation, no DB)
query_ast = cached_parse(metric_data.query)
query_ast = parse(metric_data.query)

# Initialize visited set for cycle detection
if _visited is None:
Expand Down Expand Up @@ -820,12 +805,11 @@ def cached_parse(query: str) -> ast.Query:
nodes_cache=nodes_cache,
parent_map=parent_map,
metric_node=parent_node,
parsed_query_cache=parsed_query_cache,
_visited=_visited,
)
else:
# True base metric - decompose aggregations
base_ast = cached_parse(base_metric.query)
base_ast = parse(base_metric.query)
base_components, derived_ast = self._extract_base(base_ast)

for comp in base_components:
Expand Down
Loading