diff --git a/datajunction-server/datajunction_server/api/graphql/main.py b/datajunction-server/datajunction_server/api/graphql/main.py index cea00179c..eb68e5331 100644 --- a/datajunction-server/datajunction_server/api/graphql/main.py +++ b/datajunction-server/datajunction_server/api/graphql/main.py @@ -121,7 +121,6 @@ async def get_context( request.state.test_session = db_session return { - "session": db_session, # Keep for backward compatibility with existing code "node_loader": create_node_by_name_loader(request), "collection_nodes_loader": create_collection_nodes_loader(request), "git_info_loader": create_git_info_loader(request), diff --git a/datajunction-server/datajunction_server/api/graphql/queries/catalogs.py b/datajunction-server/datajunction_server/api/graphql/queries/catalogs.py index 65e49fb04..236a14d91 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/catalogs.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/catalogs.py @@ -8,6 +8,7 @@ from strawberry.types import Info from datajunction_server.api.graphql.scalars.catalog_engine import Catalog +from datajunction_server.api.graphql.utils import resolver_session from datajunction_server.database.catalog import Catalog as DBCatalog @@ -18,8 +19,8 @@ async def list_catalogs( """ List all available catalogs """ - session = info.context["session"] # type: ignore - return [ - Catalog.from_pydantic(catalog) # type: ignore - for catalog in (await session.execute(select(DBCatalog))).scalars().all() - ] + async with resolver_session(info) as session: + return [ + Catalog.from_pydantic(catalog) # type: ignore + for catalog in (await session.execute(select(DBCatalog))).scalars().all() + ] diff --git a/datajunction-server/datajunction_server/api/graphql/queries/collections.py b/datajunction-server/datajunction_server/api/graphql/queries/collections.py index c2c3f2cdf..b5865c0a5 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/collections.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/collections.py @@ -11,6 +11,7 @@ from strawberry.types import Info from datajunction_server.api.graphql.scalars.collection import Collection +from datajunction_server.api.graphql.utils import resolver_session from datajunction_server.database.collection import Collection as DBCollection from datajunction_server.database.collection import CollectionNodes from datajunction_server.database.user import User @@ -39,51 +40,50 @@ async def list_collections( """ List collections, optionally filtered by fragment, creator, or limit. """ - session = info.context["session"] - - # Subquery to count nodes per collection - node_count_subquery = ( - select( - CollectionNodes.collection_id, - func.count(CollectionNodes.node_id).label("node_count"), + async with resolver_session(info) as session: + # Subquery to count nodes per collection + node_count_subquery = ( + select( + CollectionNodes.collection_id, + func.count(CollectionNodes.node_id).label("node_count"), + ) + .group_by(CollectionNodes.collection_id) + .subquery() ) - .group_by(CollectionNodes.collection_id) - .subquery() - ) - statement = ( - select(DBCollection, node_count_subquery.c.node_count) - .outerjoin( - node_count_subquery, - DBCollection.id == node_count_subquery.c.collection_id, - ) - .where(is_(DBCollection.deactivated_at, None)) - .options( - joinedload(DBCollection.created_by), # Eager load creator + statement = ( + select(DBCollection, node_count_subquery.c.node_count) + .outerjoin( + node_count_subquery, + DBCollection.id == node_count_subquery.c.collection_id, + ) + .where(is_(DBCollection.deactivated_at, None)) + .options( + joinedload(DBCollection.created_by), # Eager load creator + ) ) - ) - # Filter by fragment (search in name or description) - if fragment: - statement = statement.where( - (DBCollection.name.ilike(f"%{fragment}%")) - | (DBCollection.description.ilike(f"%{fragment}%")), - ) + # Filter by fragment (search in name or description) + if fragment: + statement = statement.where( + (DBCollection.name.ilike(f"%{fragment}%")) + | (DBCollection.description.ilike(f"%{fragment}%")), + ) - if created_by: - statement = statement.join( - User, - DBCollection.created_by_id == User.id, - ).where(User.username == created_by) + if created_by: + statement = statement.join( + User, + DBCollection.created_by_id == User.id, + ).where(User.username == created_by) - statement = statement.order_by(DBCollection.created_at.desc()) + statement = statement.order_by(DBCollection.created_at.desc()) - if limit and limit > 0: # pragma: no branch - statement = statement.limit(limit) + if limit and limit > 0: # pragma: no branch + statement = statement.limit(limit) - result = await session.execute(statement) + result = await session.execute(statement) - return [ - Collection.from_db_collection(collection, node_count or 0) - for collection, node_count in result.unique().all() - ] + return [ + Collection.from_db_collection(collection, node_count or 0) + for collection, node_count in result.unique().all() + ] diff --git a/datajunction-server/datajunction_server/api/graphql/queries/dag.py b/datajunction-server/datajunction_server/api/graphql/queries/dag.py index a609b3dc4..709ef15b3 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/dag.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/dag.py @@ -10,14 +10,14 @@ from datajunction_server.database.node import Node from datajunction_server.api.graphql.resolvers.nodes import load_node_options from datajunction_server.api.graphql.scalars.node import DimensionAttribute -from datajunction_server.api.graphql.utils import extract_fields +from datajunction_server.api.graphql.utils import extract_fields, resolver_session from datajunction_server.sql.dag import ( get_common_dimensions, get_downstream_nodes, get_upstream_nodes, ) from datajunction_server.models.node_type import NodeType -from datajunction_server.utils import SEPARATOR, session_context +from datajunction_server.utils import SEPARATOR from sqlalchemy.orm import joinedload @@ -34,12 +34,7 @@ async def common_dimensions( """ Return a list of common dimensions for a set of nodes. """ - request = info.context["request"] - - # Use a fresh independent session for all operations to avoid concurrent - # session conflicts when this resolver runs alongside other resolvers - # (e.g., findNode + commonDimensions in the same GraphQL query) - async with session_context(request) as dims_session: + async with resolver_session(info) as dims_session: # Load nodes in the independent session dims_nodes = await Node.find_by( dims_session, @@ -116,25 +111,24 @@ async def downstream_nodes( Note: Unlike upstreams, downstreams uses per-node queries because the fanout threshold check and BFS fallback work better with single nodes. """ - session = info.context["session"] - - # Build load options based on requested GraphQL fields - fields = extract_fields(info) - options = load_node_options(fields) - - all_downstreams: dict[int, Node] = {} - for node_name in node_names: - downstreams = await get_downstream_nodes( - session, - node_name=node_name, - node_type=node_type, - include_deactivated=include_deactivated, - options=options, - ) - for node in downstreams: - if node.id not in all_downstreams: # pragma: no cover - all_downstreams[node.id] = node - return list(all_downstreams.values()) + async with resolver_session(info) as session: + # Build load options based on requested GraphQL fields + fields = extract_fields(info) + options = load_node_options(fields) + + all_downstreams: dict[int, Node] = {} + for node_name in node_names: + downstreams = await get_downstream_nodes( + session, + node_name=node_name, + node_type=node_type, + include_deactivated=include_deactivated, + options=options, + ) + for node in downstreams: + if node.id not in all_downstreams: # pragma: no cover + all_downstreams[node.id] = node + return list(all_downstreams.values()) async def upstream_nodes( @@ -163,16 +157,15 @@ async def upstream_nodes( Return a list of upstream nodes for one or more nodes. Results are deduplicated by node ID. """ - session = info.context["session"] - - # Build load options based on requested GraphQL fields - fields = extract_fields(info) - options = load_node_options(fields) - - return await get_upstream_nodes( # type: ignore - session, - node_name=node_names, - node_type=node_type, - include_deactivated=include_deactivated, - options=options, - ) + async with resolver_session(info) as session: + # Build load options based on requested GraphQL fields + fields = extract_fields(info) + options = load_node_options(fields) + + return await get_upstream_nodes( # type: ignore + session, + node_name=node_names, + node_type=node_type, + include_deactivated=include_deactivated, + options=options, + ) diff --git a/datajunction-server/datajunction_server/api/graphql/queries/engines.py b/datajunction-server/datajunction_server/api/graphql/queries/engines.py index a673b9726..a0620581b 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/engines.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/engines.py @@ -9,6 +9,7 @@ from datajunction_server.models.dialect import DialectRegistry from datajunction_server.api.graphql.scalars.catalog_engine import Engine, DialectInfo +from datajunction_server.api.graphql.utils import resolver_session from datajunction_server.database.engine import Engine as DBEngine @@ -19,11 +20,11 @@ async def list_engines( """ List all available engines """ - session = info.context["session"] # type: ignore - return [ - Engine.from_pydantic(engine) # type: ignore #pylint: disable=E1101 - for engine in (await session.execute(select(DBEngine))).scalars().all() - ] + async with resolver_session(info) as session: + return [ + Engine.from_pydantic(engine) # type: ignore #pylint: disable=E1101 + for engine in (await session.execute(select(DBEngine))).scalars().all() + ] async def list_dialects( diff --git a/datajunction-server/datajunction_server/api/graphql/queries/namespaces.py b/datajunction-server/datajunction_server/api/graphql/queries/namespaces.py index 0c7839247..b960ff4b3 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/namespaces.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/namespaces.py @@ -12,6 +12,7 @@ GitRootConfig, Namespace, ) +from datajunction_server.api.graphql.utils import resolver_session from datajunction_server.database.namespace import NodeNamespace from datajunction_server.database.node import Node @@ -27,46 +28,50 @@ async def list_namespaces( For branch namespaces, git is a GitBranchConfig with the root config embedded. For non-git namespaces, git is null. """ - session = info.context["session"] # type: ignore - statement = ( - select(NodeNamespace, func.count(Node.id).label("num_nodes")) - .join(Node, onclause=NodeNamespace.namespace == Node.namespace, isouter=True) - .where(NodeNamespace.deactivated_at.is_(None)) - .group_by(NodeNamespace.namespace) - ) - result = await session.execute(statement) - rows = result.all() + async with resolver_session(info) as session: + statement = ( + select(NodeNamespace, func.count(Node.id).label("num_nodes")) + .join( + Node, + onclause=NodeNamespace.namespace == Node.namespace, + isouter=True, + ) + .where(NodeNamespace.deactivated_at.is_(None)) + .group_by(NodeNamespace.namespace) + ) + result = await session.execute(statement) + rows = result.all() - # Build a map so branch namespaces can resolve their root config inline - ns_map = {ns.namespace: ns for ns, _ in rows} + # Build a map so branch namespaces can resolve their root config inline + ns_map = {ns.namespace: ns for ns, _ in rows} - namespaces = [] - for ns, num_nodes in rows: - git: Optional[Union[GitRootConfig, GitBranchConfig]] = None - if ns.github_repo_path: - git = GitRootConfig( # type: ignore - repo=ns.github_repo_path, - path=ns.git_path, - default_branch=ns.default_branch, - ) - elif ns.git_branch and ns.parent_namespace: - parent = ns_map.get(ns.parent_namespace) - if parent and parent.github_repo_path: # pragma: no branch - git = GitBranchConfig( # type: ignore - branch=ns.git_branch, - git_only=ns.git_only, - parent_namespace=ns.parent_namespace, - root=GitRootConfig( # type: ignore - repo=parent.github_repo_path, - path=parent.git_path, - default_branch=parent.default_branch, - ), + namespaces = [] + for ns, num_nodes in rows: + git: Optional[Union[GitRootConfig, GitBranchConfig]] = None + if ns.github_repo_path: + git = GitRootConfig( # type: ignore + repo=ns.github_repo_path, + path=ns.git_path, + default_branch=ns.default_branch, ) - namespaces.append( - Namespace( # type: ignore - namespace=ns.namespace, - num_nodes=num_nodes or 0, - git=git, - ), - ) - return namespaces + elif ns.git_branch and ns.parent_namespace: + parent = ns_map.get(ns.parent_namespace) + if parent and parent.github_repo_path: # pragma: no branch + git = GitBranchConfig( # type: ignore + branch=ns.git_branch, + git_only=ns.git_only, + parent_namespace=ns.parent_namespace, + root=GitRootConfig( # type: ignore + repo=parent.github_repo_path, + path=parent.git_path, + default_branch=parent.default_branch, + ), + ) + namespaces.append( + Namespace( # type: ignore + namespace=ns.namespace, + num_nodes=num_nodes or 0, + git=git, + ), + ) + return namespaces diff --git a/datajunction-server/datajunction_server/api/graphql/queries/sql.py b/datajunction-server/datajunction_server/api/graphql/queries/sql.py index 11180f3a7..78f1ff6b3 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/sql.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/sql.py @@ -16,6 +16,7 @@ resolve_metrics_and_dimensions, find_nodes_by, ) +from datajunction_server.api.graphql.utils import resolver_session from datajunction_server.utils import SEPARATOR from datajunction_server.sql.parsing.backends.antlr4 import parse, ast from datajunction_server.models.cube_materialization import Aggregability @@ -66,32 +67,32 @@ async def measures_sql( """ Get measures SQL for a set of metrics with dimensions and filters """ - session = info.context["session"] - metrics, dimensions = await resolve_metrics_and_dimensions(session, cube) - query_cache_manager = QueryCacheManager( - cache=info.context["cache"], - query_type=QueryBuildType.MEASURES, - ) - queries = await query_cache_manager.get_or_load( - info.context["background_tasks"], - info.context["request"], - QueryRequestParams( - nodes=metrics, - dimensions=dimensions, - filters=cube.filters, - engine_name=engine.name if engine else None, - engine_version=engine.version if engine else None, - orderby=cube.orderby, - query_params=query_parameters, - include_all_columns=include_all_columns, - preaggregate=preaggregate, - use_materialized=use_materialized, - ), - ) - return [ - await GeneratedSQL.from_pydantic(info, measures_query) - for measures_query in queries - ] + async with resolver_session(info) as session: + metrics, dimensions = await resolve_metrics_and_dimensions(session, cube) + query_cache_manager = QueryCacheManager( + cache=info.context["cache"], + query_type=QueryBuildType.MEASURES, + ) + queries = await query_cache_manager.get_or_load( + info.context["background_tasks"], + info.context["request"], + QueryRequestParams( + nodes=metrics, + dimensions=dimensions, + filters=cube.filters, + engine_name=engine.name if engine else None, + engine_version=engine.version if engine else None, + orderby=cube.orderby, + query_params=query_parameters, + include_all_columns=include_all_columns, + preaggregate=preaggregate, + use_materialized=use_materialized, + ), + ) + return [ + await GeneratedSQL.from_pydantic(info, measures_query) + for measures_query in queries + ] async def materialization_plan( @@ -103,89 +104,95 @@ async def materialization_plan( This constructs a `MaterializationPlan` by computing all the versioned entities (metrics, measures, dimensions, filters) required to materialize the cube. """ - session = info.context["session"] - metrics, dimensions = await resolve_metrics_and_dimensions(session, cube) - - metric_nodes = await get_metrics(session, metrics=metrics) - - # Extract dimension references from filters - filter_refs = [ - filter_dim.identifier() - for filter_expr in cube.filters or [] - for filter_dim in parse(f"SELECT 1 WHERE {filter_expr}").find_all(ast.Column) - ] - - # Resolve nodes for dimensions and filter references - all_ref_nodes = {dim.rsplit(SEPARATOR, 1)[0] for dim in dimensions + filter_refs} - nodes_lookup = { - node.name: node for node in await find_nodes_by(info, list(all_ref_nodes)) - } - - # Group the metrics by upstream node - grouped_metrics = await group_metrics_by_parent(session, metric_nodes) - units = [] - for upstream_node, metrics_in_group in grouped_metrics.items(): - # Ensure frozen measures are loaded - for metric in metrics_in_group: - await session.refresh(metric, ["frozen_measures"]) - - # Deduplicate and collect all frozen measures - measures = { - fm.name: MetricComponent( # type: ignore - name=fm.name, - expression=fm.expression, - rule=fm.rule, - aggregation=fm.aggregation, + async with resolver_session(info) as session: + metrics, dimensions = await resolve_metrics_and_dimensions(session, cube) + + metric_nodes = await get_metrics(session, metrics=metrics) + + # Extract dimension references from filters + filter_refs = [ + filter_dim.identifier() + for filter_expr in cube.filters or [] + for filter_dim in parse(f"SELECT 1 WHERE {filter_expr}").find_all( + ast.Column, ) - for metric in metrics_in_group - for fm in metric.frozen_measures - }.values() - - # Determine grain dimensions based on aggregability - limited_agg_measures = [ - m - for m in measures - if m.rule.type == Aggregability.LIMITED # type: ignore - ] - non_agg_measures = [ - m - for m in measures - if m.rule.type == Aggregability.NONE # type: ignore ] - if non_agg_measures: - grain_dimensions = [] # pragma: no cover - else: - grain_from_rules = [ - dim - for m in limited_agg_measures - for dim in m.rule.level # type: ignore + # Resolve nodes for dimensions and filter references + all_ref_nodes = { + dim.rsplit(SEPARATOR, 1)[0] for dim in dimensions + filter_refs + } + nodes_lookup = { + node.name: node for node in await find_nodes_by(info, list(all_ref_nodes)) + } + + # Group the metrics by upstream node + grouped_metrics = await group_metrics_by_parent(session, metric_nodes) + units = [] + for upstream_node, metrics_in_group in grouped_metrics.items(): + # Ensure frozen measures are loaded + for metric in metrics_in_group: + await session.refresh(metric, ["frozen_measures"]) + + # Deduplicate and collect all frozen measures + measures = { + fm.name: MetricComponent( # type: ignore + name=fm.name, + expression=fm.expression, + rule=fm.rule, + aggregation=fm.aggregation, + ) + for metric in metrics_in_group + for fm in metric.frozen_measures + }.values() + + # Determine grain dimensions based on aggregability + limited_agg_measures = [ + m + for m in measures + if m.rule.type == Aggregability.LIMITED # type: ignore ] - grain_from_dims = [ - nodes_lookup[dim.rsplit(SEPARATOR, 1)[0]] for dim in dimensions + non_agg_measures = [ + m + for m in measures + if m.rule.type == Aggregability.NONE # type: ignore ] - grain_dimensions = grain_from_rules + grain_from_dims - # Construct materialization unit - unit = MaterializationUnit( # type: ignore - upstream=VersionedRef( # type: ignore - name=upstream_node.name, - version=upstream_node.current_version, - ), - grain_dimensions=[ - VersionedRef(name=dim.name, version=dim.current_version) # type: ignore - for dim in grain_dimensions - ], - measures=list(measures), - filter_refs=[ - VersionedRef( # type: ignore - name=ref, - version=nodes_lookup[ref.rsplit(SEPARATOR, 1)[0]].current_version, - ) - for ref in filter_refs - ], - filters=cube.filters, - ) - units.append(unit) + if non_agg_measures: + grain_dimensions = [] # pragma: no cover + else: + grain_from_rules = [ + dim + for m in limited_agg_measures + for dim in m.rule.level # type: ignore + ] + grain_from_dims = [ + nodes_lookup[dim.rsplit(SEPARATOR, 1)[0]] for dim in dimensions + ] + grain_dimensions = grain_from_rules + grain_from_dims + + # Construct materialization unit + unit = MaterializationUnit( # type: ignore + upstream=VersionedRef( # type: ignore + name=upstream_node.name, + version=upstream_node.current_version, + ), + grain_dimensions=[ + VersionedRef(name=dim.name, version=dim.current_version) # type: ignore + for dim in grain_dimensions + ], + measures=list(measures), + filter_refs=[ + VersionedRef( # type: ignore + name=ref, + version=nodes_lookup[ + ref.rsplit(SEPARATOR, 1)[0] + ].current_version, + ) + for ref in filter_refs + ], + filters=cube.filters, + ) + units.append(unit) - return MaterializationPlan(units=units) # type: ignore + return MaterializationPlan(units=units) # type: ignore diff --git a/datajunction-server/datajunction_server/api/graphql/queries/tags.py b/datajunction-server/datajunction_server/api/graphql/queries/tags.py index e57ad9b7a..c73c03b0b 100644 --- a/datajunction-server/datajunction_server/api/graphql/queries/tags.py +++ b/datajunction-server/datajunction_server/api/graphql/queries/tags.py @@ -8,6 +8,7 @@ from strawberry.types import Info from datajunction_server.api.graphql.scalars.tag import Tag +from datajunction_server.api.graphql.utils import resolver_session from datajunction_server.database.tag import Tag as DBTag @@ -30,9 +31,9 @@ async def list_tags( """ Find available tags by the search parameters """ - session = info.context["session"] # type: ignore - db_tags = await DBTag.find_tags(session, tag_names, tag_types) - return [_to_graphql_tag(db_tag) for db_tag in db_tags] + async with resolver_session(info) as session: + db_tags = await DBTag.find_tags(session, tag_names, tag_types) + return [_to_graphql_tag(db_tag) for db_tag in db_tags] async def search_tags( @@ -54,9 +55,9 @@ async def search_tags( """ if not search or not search.strip(): return [] - session = info.context["session"] # type: ignore - db_tags = await DBTag.search_tags(session, search.strip(), limit=limit) - return [_to_graphql_tag(db_tag) for db_tag in db_tags] + async with resolver_session(info) as session: + db_tags = await DBTag.search_tags(session, search.strip(), limit=limit) + return [_to_graphql_tag(db_tag) for db_tag in db_tags] async def list_tag_types( @@ -66,5 +67,5 @@ async def list_tag_types( """ List all tag types """ - session = info.context["session"] # type: ignore - return await DBTag.get_tag_types(session) + async with resolver_session(info) as session: + return await DBTag.get_tag_types(session) diff --git a/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py b/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py index 334e3cf1c..19229b897 100644 --- a/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py +++ b/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py @@ -14,7 +14,11 @@ from datajunction_server.errors import DJNodeNotFound from datajunction_server.api.graphql.scalars.node import NodeName, NodeSortField from datajunction_server.api.graphql.scalars.sql import CubeDefinition -from datajunction_server.api.graphql.utils import dedupe_append, extract_fields +from datajunction_server.api.graphql.utils import ( + dedupe_append, + extract_fields, + resolver_session, +) from datajunction_server.database.dimensionlink import DimensionLink from datajunction_server.database.node import Column, ColumnAttribute, CubeRelationship @@ -143,66 +147,66 @@ async def find_nodes_by( Finds nodes based on the search parameters. This function also tries to optimize the database query by only retrieving joined-in fields if they were requested. """ - session = info.context["session"] # type: ignore - fields = extract_fields(info) - node_fields = ( - fields["nodes"] - if "nodes" in fields - else fields["edges"]["node"] - if "edges" in fields - else fields - ) - options = load_node_options(node_fields) - - # Signal to cube resolvers whether the name-only fast path is active - current_fields = node_fields.get("current") or {} - is_cube_name_only = _is_cube_name_only_request(current_fields) - info.context["cube_name_only"] = is_cube_name_only # type: ignore - - # When include_team is set with an ownedBy filter, expand to the user's - # groups so nodes owned directly by the user OR by any of their groups - # are returned. No-op when ownedBy is not set. - owned_by_list: Optional[List[str]] = None - if owned_by: - owned_by_list = [owned_by] - if include_team: - groups = await get_group_membership_service().get_user_groups( - session, - owned_by, - ) - owned_by_list = list({owned_by, *groups}) - - result = await DBNode.find_by( - session, - names, - fragment, - node_types, - tags, - edited_by, - namespace, - limit, - before, - after, - order_by=order_by.column, - ascending=ascending, - options=options, - mode=mode, - owned_by=owned_by_list, - missing_description=missing_description, - missing_owner=missing_owner, - statuses=statuses, - has_materialization=has_materialization, - orphaned_dimension=orphaned_dimension, - dimensions=dimensions, - search=search, - ) + async with resolver_session(info) as session: + fields = extract_fields(info) + node_fields = ( + fields["nodes"] + if "nodes" in fields + else fields["edges"]["node"] + if "edges" in fields + else fields + ) + options = load_node_options(node_fields) + + # Signal to cube resolvers whether the name-only fast path is active + current_fields = node_fields.get("current") or {} + is_cube_name_only = _is_cube_name_only_request(current_fields) + info.context["cube_name_only"] = is_cube_name_only # type: ignore + + # When include_team is set with an ownedBy filter, expand to the user's + # groups so nodes owned directly by the user OR by any of their groups + # are returned. No-op when ownedBy is not set. + owned_by_list: Optional[List[str]] = None + if owned_by: + owned_by_list = [owned_by] + if include_team: + groups = await get_group_membership_service().get_user_groups( + session, + owned_by, + ) + owned_by_list = list({owned_by, *groups}) + + result = await DBNode.find_by( + session, + names, + fragment, + node_types, + tags, + edited_by, + namespace, + limit, + before, + after, + order_by=order_by.column, + ascending=ascending, + options=options, + mode=mode, + owned_by=owned_by_list, + missing_description=missing_description, + missing_owner=missing_owner, + statuses=statuses, + has_materialization=has_materialization, + orphaned_dimension=orphaned_dimension, + dimensions=dimensions, + search=search, + ) - # For the name-only cube path, fetch column data as raw tuples instead - # of ORM objects. This avoids hydrating ~20k Column instances. - if is_cube_name_only and result: - await _attach_raw_columns(session, result) + # For the name-only cube path, fetch column data as raw tuples instead + # of ORM objects. This avoids hydrating ~20k Column instances. + if is_cube_name_only and result: + await _attach_raw_columns(session, result) - return result + return result async def get_node_by_name( diff --git a/datajunction-server/datajunction_server/api/graphql/scalars/node.py b/datajunction-server/datajunction_server/api/graphql/scalars/node.py index 774af652b..0039de2ff 100644 --- a/datajunction-server/datajunction_server/api/graphql/scalars/node.py +++ b/datajunction-server/datajunction_server/api/graphql/scalars/node.py @@ -367,16 +367,18 @@ async def extracted_measures( """ if root.type != NodeType.METRIC: return None - session = info.context["session"] # type: ignore - extractor = MetricComponentExtractor(root.id) - components, derived_ast = await extractor.extract(session) - # The derived_expression is the combiner (how to combine merged components) - combiner_expr = str(derived_ast.select.projection[0]) - return DecomposedMetric( # type: ignore - components=components, - combiner=combiner_expr, - derived_query=str(derived_ast), - ) + from datajunction_server.api.graphql.utils import resolver_session + + async with resolver_session(info) as session: + extractor = MetricComponentExtractor(root.id) + components, derived_ast = await extractor.extract(session) + # The derived_expression is the combiner (how to combine merged components) + combiner_expr = str(derived_ast.select.projection[0]) + return DecomposedMetric( # type: ignore + components=components, + combiner=combiner_expr, + derived_query=str(derived_ast), + ) # Only cubes will have these fields @strawberry.field diff --git a/datajunction-server/datajunction_server/api/graphql/scalars/sql.py b/datajunction-server/datajunction_server/api/graphql/scalars/sql.py index 22c0ae8ef..37b0d6515 100644 --- a/datajunction-server/datajunction_server/api/graphql/scalars/sql.py +++ b/datajunction-server/datajunction_server/api/graphql/scalars/sql.py @@ -84,28 +84,30 @@ async def from_pydantic(cls, info: Info, obj: GeneratedSQL_): Loads a strawberry GeneratedSQL from the original pydantic model. """ from datajunction_server.api.graphql.resolvers.nodes import get_node_by_name - - fields = extract_fields(info) - return GeneratedSQL( # type: ignore - node=await get_node_by_name( - session=info.context["session"], - fields=fields.get("node"), - name=obj.node.name, - ), - sql=obj.sql, - columns=[ - ColumnMetadata( # type: ignore - name=col.name, - type=col.type, - semantic_entity=SemanticEntity(name=col.semantic_entity), # type: ignore - semantic_type=SemanticType(col.semantic_type), - ) - for col in obj.columns # type: ignore - ], - dialect=obj.dialect, - upstream_tables=obj.upstream_tables, - errors=obj.errors, - ) + from datajunction_server.api.graphql.utils import resolver_session + + async with resolver_session(info) as session: + fields = extract_fields(info) + return GeneratedSQL( # type: ignore + node=await get_node_by_name( + session=session, + fields=fields.get("node"), + name=obj.node.name, + ), + sql=obj.sql, + columns=[ + ColumnMetadata( # type: ignore + name=col.name, + type=col.type, + semantic_entity=SemanticEntity(name=col.semantic_entity), # type: ignore + semantic_type=SemanticType(col.semantic_type), + ) + for col in obj.columns # type: ignore + ], + dialect=obj.dialect, + upstream_tables=obj.upstream_tables, + errors=obj.errors, + ) @strawberry.input diff --git a/datajunction-server/datajunction_server/api/graphql/scalars/tag.py b/datajunction-server/datajunction_server/api/graphql/scalars/tag.py index 728c09ed5..b4337cd61 100644 --- a/datajunction-server/datajunction_server/api/graphql/scalars/tag.py +++ b/datajunction-server/datajunction_server/api/graphql/scalars/tag.py @@ -5,7 +5,7 @@ from datajunction_server.api.graphql.resolvers.tags import get_nodes_by_tag from datajunction_server.api.graphql.scalars.node import Node, TagBase -from datajunction_server.api.graphql.utils import extract_fields +from datajunction_server.api.graphql.utils import extract_fields, resolver_session @strawberry.type @@ -19,9 +19,10 @@ async def nodes(self, info: Info) -> list[Node]: """ Lazy load the nodes with this tag. """ - fields = extract_fields(info) - return await get_nodes_by_tag( # type: ignore - session=info.context["session"], - fields=fields, - tag_name=self.name, - ) + async with resolver_session(info) as session: + fields = extract_fields(info) + return await get_nodes_by_tag( # type: ignore + session=session, + fields=fields, + tag_name=self.name, + ) diff --git a/datajunction-server/datajunction_server/api/graphql/utils.py b/datajunction-server/datajunction_server/api/graphql/utils.py index 898510193..6a4023039 100644 --- a/datajunction-server/datajunction_server/api/graphql/utils.py +++ b/datajunction-server/datajunction_server/api/graphql/utils.py @@ -1,13 +1,37 @@ """Utils for handling GraphQL queries.""" import re -from typing import Any, Dict, TypeVar +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Dict, TypeVar + +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.types import Info + +from datajunction_server.utils import session_context CURSOR_SEPARATOR = "-" T = TypeVar("T") +@asynccontextmanager +async def resolver_session(info: Info) -> AsyncIterator[AsyncSession]: + """Create an independent database session for a resolver. + + Each GraphQL resolver must use its own session because strawberry resolves + top-level fields concurrently. Sharing a single AsyncSession across + concurrent resolvers causes ``InvalidCachedStatementError`` / + ``isce`` errors. + + Usage:: + + async with resolver_session(info) as session: + ... + """ + async with session_context(info.context["request"]) as session: + yield session + + def convert_camel_case(name): """ Convert from camel case to snake case diff --git a/datajunction-server/tests/api/graphql/test_main.py b/datajunction-server/tests/api/graphql/test_main.py index f2bbb1e95..0b38d23d6 100644 --- a/datajunction-server/tests/api/graphql/test_main.py +++ b/datajunction-server/tests/api/graphql/test_main.py @@ -36,8 +36,9 @@ async def test_get_context_without_test_session(mock_get_settings, mock_create_l assert hasattr(mock_request.state, "test_session") assert mock_request.state.test_session == mock_db_session - # Verify context contains expected keys - assert context["session"] == mock_db_session + # Verify context contains expected keys (no shared "session" — resolvers + # create their own via resolver_session()) + assert "session" not in context assert context["node_loader"] == mock_loader assert context["settings"] == mock_settings assert context["request"] == mock_request @@ -78,8 +79,8 @@ async def test_get_context_with_existing_test_session( # Verify test_session was NOT overwritten - should still be the existing one assert mock_request.state.test_session == existing_test_session - # Verify context still uses the db_session passed in (not test_session) - assert context["session"] == mock_db_session + # Verify context keys (no shared "session") + assert "session" not in context assert context["node_loader"] == mock_loader assert context["settings"] == mock_settings assert context["request"] == mock_request