Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ async def get_context(
request.state.test_session = db_session

return {
"session": db_session, # Keep for backward compatibility with existing code
"node_loader": create_node_by_name_loader(request),
"collection_nodes_loader": create_collection_nodes_loader(request),
"git_info_loader": create_git_info_loader(request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from strawberry.types import Info

from datajunction_server.api.graphql.scalars.catalog_engine import Catalog
from datajunction_server.api.graphql.utils import resolver_session
from datajunction_server.database.catalog import Catalog as DBCatalog


Expand All @@ -18,8 +19,8 @@ async def list_catalogs(
"""
List all available catalogs
"""
session = info.context["session"] # type: ignore
return [
Catalog.from_pydantic(catalog) # type: ignore
for catalog in (await session.execute(select(DBCatalog))).scalars().all()
]
async with resolver_session(info) as session:
return [
Catalog.from_pydantic(catalog) # type: ignore
for catalog in (await session.execute(select(DBCatalog))).scalars().all()
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from strawberry.types import Info

from datajunction_server.api.graphql.scalars.collection import Collection
from datajunction_server.api.graphql.utils import resolver_session
from datajunction_server.database.collection import Collection as DBCollection
from datajunction_server.database.collection import CollectionNodes
from datajunction_server.database.user import User
Expand Down Expand Up @@ -39,51 +40,50 @@ async def list_collections(
"""
List collections, optionally filtered by fragment, creator, or limit.
"""
session = info.context["session"]

# Subquery to count nodes per collection
node_count_subquery = (
select(
CollectionNodes.collection_id,
func.count(CollectionNodes.node_id).label("node_count"),
async with resolver_session(info) as session:
# Subquery to count nodes per collection
node_count_subquery = (
select(
CollectionNodes.collection_id,
func.count(CollectionNodes.node_id).label("node_count"),
)
.group_by(CollectionNodes.collection_id)
.subquery()
)
.group_by(CollectionNodes.collection_id)
.subquery()
)

statement = (
select(DBCollection, node_count_subquery.c.node_count)
.outerjoin(
node_count_subquery,
DBCollection.id == node_count_subquery.c.collection_id,
)
.where(is_(DBCollection.deactivated_at, None))
.options(
joinedload(DBCollection.created_by), # Eager load creator
statement = (
select(DBCollection, node_count_subquery.c.node_count)
.outerjoin(
node_count_subquery,
DBCollection.id == node_count_subquery.c.collection_id,
)
.where(is_(DBCollection.deactivated_at, None))
.options(
joinedload(DBCollection.created_by), # Eager load creator
)
)
)

# Filter by fragment (search in name or description)
if fragment:
statement = statement.where(
(DBCollection.name.ilike(f"%{fragment}%"))
| (DBCollection.description.ilike(f"%{fragment}%")),
)
# Filter by fragment (search in name or description)
if fragment:
statement = statement.where(
(DBCollection.name.ilike(f"%{fragment}%"))
| (DBCollection.description.ilike(f"%{fragment}%")),
)

if created_by:
statement = statement.join(
User,
DBCollection.created_by_id == User.id,
).where(User.username == created_by)
if created_by:
statement = statement.join(
User,
DBCollection.created_by_id == User.id,
).where(User.username == created_by)

statement = statement.order_by(DBCollection.created_at.desc())
statement = statement.order_by(DBCollection.created_at.desc())

if limit and limit > 0: # pragma: no branch
statement = statement.limit(limit)
if limit and limit > 0: # pragma: no branch
statement = statement.limit(limit)

result = await session.execute(statement)
result = await session.execute(statement)

return [
Collection.from_db_collection(collection, node_count or 0)
for collection, node_count in result.unique().all()
]
return [
Collection.from_db_collection(collection, node_count or 0)
for collection, node_count in result.unique().all()
]
73 changes: 33 additions & 40 deletions datajunction-server/datajunction_server/api/graphql/queries/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from datajunction_server.database.node import Node
from datajunction_server.api.graphql.resolvers.nodes import load_node_options
from datajunction_server.api.graphql.scalars.node import DimensionAttribute
from datajunction_server.api.graphql.utils import extract_fields
from datajunction_server.api.graphql.utils import extract_fields, resolver_session
from datajunction_server.sql.dag import (
get_common_dimensions,
get_downstream_nodes,
get_upstream_nodes,
)
from datajunction_server.models.node_type import NodeType
from datajunction_server.utils import SEPARATOR, session_context
from datajunction_server.utils import SEPARATOR
from sqlalchemy.orm import joinedload


Expand All @@ -34,12 +34,7 @@ async def common_dimensions(
"""
Return a list of common dimensions for a set of nodes.
"""
request = info.context["request"]

# Use a fresh independent session for all operations to avoid concurrent
# session conflicts when this resolver runs alongside other resolvers
# (e.g., findNode + commonDimensions in the same GraphQL query)
async with session_context(request) as dims_session:
async with resolver_session(info) as dims_session:
# Load nodes in the independent session
dims_nodes = await Node.find_by(
dims_session,
Expand Down Expand Up @@ -116,25 +111,24 @@ async def downstream_nodes(
Note: Unlike upstreams, downstreams uses per-node queries because the
fanout threshold check and BFS fallback work better with single nodes.
"""
session = info.context["session"]

# Build load options based on requested GraphQL fields
fields = extract_fields(info)
options = load_node_options(fields)

all_downstreams: dict[int, Node] = {}
for node_name in node_names:
downstreams = await get_downstream_nodes(
session,
node_name=node_name,
node_type=node_type,
include_deactivated=include_deactivated,
options=options,
)
for node in downstreams:
if node.id not in all_downstreams: # pragma: no cover
all_downstreams[node.id] = node
return list(all_downstreams.values())
async with resolver_session(info) as session:
# Build load options based on requested GraphQL fields
fields = extract_fields(info)
options = load_node_options(fields)

all_downstreams: dict[int, Node] = {}
for node_name in node_names:
downstreams = await get_downstream_nodes(
session,
node_name=node_name,
node_type=node_type,
include_deactivated=include_deactivated,
options=options,
)
for node in downstreams:
if node.id not in all_downstreams: # pragma: no cover
all_downstreams[node.id] = node
return list(all_downstreams.values())


async def upstream_nodes(
Expand Down Expand Up @@ -163,16 +157,15 @@ async def upstream_nodes(
Return a list of upstream nodes for one or more nodes.
Results are deduplicated by node ID.
"""
session = info.context["session"]

# Build load options based on requested GraphQL fields
fields = extract_fields(info)
options = load_node_options(fields)

return await get_upstream_nodes( # type: ignore
session,
node_name=node_names,
node_type=node_type,
include_deactivated=include_deactivated,
options=options,
)
async with resolver_session(info) as session:
# Build load options based on requested GraphQL fields
fields = extract_fields(info)
options = load_node_options(fields)

return await get_upstream_nodes( # type: ignore
session,
node_name=node_names,
node_type=node_type,
include_deactivated=include_deactivated,
options=options,
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from datajunction_server.models.dialect import DialectRegistry
from datajunction_server.api.graphql.scalars.catalog_engine import Engine, DialectInfo
from datajunction_server.api.graphql.utils import resolver_session
from datajunction_server.database.engine import Engine as DBEngine


Expand All @@ -19,11 +20,11 @@ async def list_engines(
"""
List all available engines
"""
session = info.context["session"] # type: ignore
return [
Engine.from_pydantic(engine) # type: ignore #pylint: disable=E1101
for engine in (await session.execute(select(DBEngine))).scalars().all()
]
async with resolver_session(info) as session:
return [
Engine.from_pydantic(engine) # type: ignore #pylint: disable=E1101
for engine in (await session.execute(select(DBEngine))).scalars().all()
]


async def list_dialects(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
GitRootConfig,
Namespace,
)
from datajunction_server.api.graphql.utils import resolver_session
from datajunction_server.database.namespace import NodeNamespace
from datajunction_server.database.node import Node

Expand All @@ -27,46 +28,50 @@ async def list_namespaces(
For branch namespaces, git is a GitBranchConfig with the root config embedded.
For non-git namespaces, git is null.
"""
session = info.context["session"] # type: ignore
statement = (
select(NodeNamespace, func.count(Node.id).label("num_nodes"))
.join(Node, onclause=NodeNamespace.namespace == Node.namespace, isouter=True)
.where(NodeNamespace.deactivated_at.is_(None))
.group_by(NodeNamespace.namespace)
)
result = await session.execute(statement)
rows = result.all()
async with resolver_session(info) as session:
statement = (
select(NodeNamespace, func.count(Node.id).label("num_nodes"))
.join(
Node,
onclause=NodeNamespace.namespace == Node.namespace,
isouter=True,
)
.where(NodeNamespace.deactivated_at.is_(None))
.group_by(NodeNamespace.namespace)
)
result = await session.execute(statement)
rows = result.all()

# Build a map so branch namespaces can resolve their root config inline
ns_map = {ns.namespace: ns for ns, _ in rows}
# Build a map so branch namespaces can resolve their root config inline
ns_map = {ns.namespace: ns for ns, _ in rows}

namespaces = []
for ns, num_nodes in rows:
git: Optional[Union[GitRootConfig, GitBranchConfig]] = None
if ns.github_repo_path:
git = GitRootConfig( # type: ignore
repo=ns.github_repo_path,
path=ns.git_path,
default_branch=ns.default_branch,
)
elif ns.git_branch and ns.parent_namespace:
parent = ns_map.get(ns.parent_namespace)
if parent and parent.github_repo_path: # pragma: no branch
git = GitBranchConfig( # type: ignore
branch=ns.git_branch,
git_only=ns.git_only,
parent_namespace=ns.parent_namespace,
root=GitRootConfig( # type: ignore
repo=parent.github_repo_path,
path=parent.git_path,
default_branch=parent.default_branch,
),
namespaces = []
for ns, num_nodes in rows:
git: Optional[Union[GitRootConfig, GitBranchConfig]] = None
if ns.github_repo_path:
git = GitRootConfig( # type: ignore
repo=ns.github_repo_path,
path=ns.git_path,
default_branch=ns.default_branch,
)
namespaces.append(
Namespace( # type: ignore
namespace=ns.namespace,
num_nodes=num_nodes or 0,
git=git,
),
)
return namespaces
elif ns.git_branch and ns.parent_namespace:
parent = ns_map.get(ns.parent_namespace)
if parent and parent.github_repo_path: # pragma: no branch
git = GitBranchConfig( # type: ignore
branch=ns.git_branch,
git_only=ns.git_only,
parent_namespace=ns.parent_namespace,
root=GitRootConfig( # type: ignore
repo=parent.github_repo_path,
path=parent.git_path,
default_branch=parent.default_branch,
),
)
namespaces.append(
Namespace( # type: ignore
namespace=ns.namespace,
num_nodes=num_nodes or 0,
git=git,
),
)
return namespaces
Loading
Loading