diff --git a/datajunction-server/datajunction_server/api/graphql/dataloaders.py b/datajunction-server/datajunction_server/api/graphql/dataloaders.py index 60ec9a566..0d7a2df75 100644 --- a/datajunction-server/datajunction_server/api/graphql/dataloaders.py +++ b/datajunction-server/datajunction_server/api/graphql/dataloaders.py @@ -370,7 +370,7 @@ def create_extracted_measures_loader( returning N metrics, Strawberry batches all per-node loader.load(nr_id) calls within a single event-loop tick into one batch_load_extracted_measures(ids) — collapsing N sessions + N extractions - into 1 session + N extractions sharing parsed_query_cache. + into 1 session + N extractions sharing nodes_cache and parent_map. """ return DataLoader( load_fn=lambda keys: batch_load_extracted_measures(keys, request), diff --git a/datajunction-server/datajunction_server/api/graphql/scalars/node.py b/datajunction-server/datajunction_server/api/graphql/scalars/node.py index acd2d332b..59b53ec6c 100644 --- a/datajunction-server/datajunction_server/api/graphql/scalars/node.py +++ b/datajunction-server/datajunction_server/api/graphql/scalars/node.py @@ -393,8 +393,8 @@ async def extracted_measures( Uses the request-scoped extracted_measures_loader so that a GraphQL query returning N metrics batches all extractions into a single - session + shared parsed_query_cache (instead of opening N independent - sessions via resolver_session). + session + shared nodes_cache/parent_map (instead of opening N + independent sessions via resolver_session). """ if root.type != NodeType.METRIC: return None diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index abed7d771..0107a3387 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -750,7 +750,7 @@ async def derive_frozen_measures_bulk( * one query to load all target revisions + their parent chain eagerly, * zero-DB `MetricComponentExtractor.extract` calls (via the extractor's - `nodes_cache` / `parent_map` / `parsed_query_cache` params), and + `nodes_cache` / `parent_map` params), and * one batch `SELECT` against FrozenMeasure.name IN (...). Caller owns the transaction and commit; used by the deployment @@ -807,8 +807,6 @@ async def derive_frozen_measures_bulk( for grandparent in parent.current.parents: nodes_cache.setdefault(grandparent.name, grandparent) - parsed_query_cache: dict[str, ast.Query] = {} - # 3. Per-metric extract with caches — zero DB calls in this loop. extraction_results: list[tuple[NodeRevision, list]] = [] for rev in revisions: @@ -818,7 +816,6 @@ async def derive_frozen_measures_bulk( nodes_cache=nodes_cache, parent_map=parent_map, metric_node=rev.node, - parsed_query_cache=parsed_query_cache, ) rev.derived_expression = str(derived_sql) extraction_results.append((rev, measures)) diff --git a/datajunction-server/datajunction_server/sql/decompose.py b/datajunction-server/datajunction_server/sql/decompose.py index c83e10060..57012c5a3 100644 --- a/datajunction-server/datajunction_server/sql/decompose.py +++ b/datajunction-server/datajunction_server/sql/decompose.py @@ -2,7 +2,6 @@ import hashlib from abc import ABC, abstractmethod -from copy import deepcopy from dataclasses import dataclass from typing import cast @@ -696,7 +695,6 @@ async def extract( nodes_cache: dict[str, "Node"] | None = None, parent_map: dict[str, list[str]] | None = None, metric_node: "Node | None" = None, - parsed_query_cache: dict[str, ast.Query] | None = None, _visited: set[str] | None = None, ) -> tuple[list[MetricComponent], ast.Query]: """ @@ -714,8 +712,6 @@ async def extract( Required if nodes_cache is provided. metric_node: Optional metric Node object. Required if nodes_cache is provided. - parsed_query_cache: Optional dict of query_string -> parsed AST. - Used to avoid re-parsing the same query multiple times. """ # Use cache if available, otherwise query DB if ( @@ -731,18 +727,7 @@ async def extract( else: metric_data = await self._load_metric_data(session) - # Helper to parse with cache - def cached_parse(query: str) -> ast.Query: - if parsed_query_cache is not None: - if query not in parsed_query_cache: # pragma: no cover - parsed_query_cache[query] = parse(query) - - # Return a deep copy to avoid AST mutation issues - return deepcopy(parsed_query_cache[query]) # pragma: no cover - return parse(query) - - # Parse queries (pure computation, no DB) - query_ast = cached_parse(metric_data.query) + query_ast = parse(metric_data.query) # Initialize visited set for cycle detection if _visited is None: @@ -820,12 +805,11 @@ def cached_parse(query: str) -> ast.Query: nodes_cache=nodes_cache, parent_map=parent_map, metric_node=parent_node, - parsed_query_cache=parsed_query_cache, _visited=_visited, ) else: # True base metric - decompose aggregations - base_ast = cached_parse(base_metric.query) + base_ast = parse(base_metric.query) base_components, derived_ast = self._extract_base(base_ast) for comp in base_components: