Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions datajunction-server/datajunction_server/api/graphql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ async def get_context(
"""
Provides the context for graphql requests
"""
# Attach test session to request.state so DataLoaders can use it
# This ensures DataLoaders use the same test session in tests
# Attach test session to request.state so DataLoaders (which use
# session_context()) can reuse it in tests. resolver_session() does NOT
# use this — it checks dependency_overrides instead, so it always
# creates independent sessions in production.
if not hasattr(request.state, "test_session"):
request.state.test_session = db_session

Expand Down
27 changes: 25 additions & 2 deletions datajunction-server/datajunction_server/api/graphql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.types import Info

from datajunction_server.utils import session_context
from datajunction_server.utils import get_session

CURSOR_SEPARATOR = "-"

Expand All @@ -23,13 +23,36 @@ async def resolver_session(info: Info) -> AsyncIterator[AsyncSession]:
concurrent resolvers causes ``InvalidCachedStatementError`` /
``isce`` errors.

In tests, ``request.state.test_session`` holds a shared session that
all resolvers must use (so they see the same transaction's data).
Tests run resolvers sequentially so the shared session is safe.

In production, ``test_session`` is not set, so we create a fresh
session per resolver via ``get_session()``.

Usage::

async with resolver_session(info) as session:
...
"""
async with session_context(info.context["request"]) as session:
request = info.context["request"]

# In tests, get_session is overridden via dependency_overrides to return a
# shared test session. We call the override directly so resolvers see the
# same transaction. The override is a sync function returning the session.
app = request.app
override = app.dependency_overrides.get(get_session)
if override is not None:
yield override()
return

# Production: create a genuinely independent session per resolver.
gen = get_session(request, session_label="graphql_resolver")
session = await gen.__anext__()
try:
yield session
finally:
await gen.aclose() # type: ignore


def convert_camel_case(name):
Expand Down
58 changes: 5 additions & 53 deletions datajunction-server/tests/api/graphql/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
@pytest.mark.asyncio
@patch("datajunction_server.api.graphql.main.create_node_by_name_loader")
@patch("datajunction_server.api.graphql.main.get_settings")
async def test_get_context_without_test_session(mock_get_settings, mock_create_loader):
"""Test get_context when request.state doesn't have test_session attribute"""
# Create a mock request without test_session
async def test_get_context(mock_get_settings, mock_create_loader):
"""Test get_context returns expected keys without a shared session."""
mock_request = MagicMock()
mock_request.state = MagicMock(spec=[]) # Empty spec means no attributes initially
mock_request.state = MagicMock(spec=[])

mock_background_tasks = MagicMock()
mock_db_session = AsyncMock()
Expand All @@ -24,62 +23,15 @@ async def test_get_context_without_test_session(mock_get_settings, mock_create_l
mock_get_settings.return_value = mock_settings
mock_create_loader.return_value = mock_loader

# Call get_context
context = await get_context(
request=mock_request,
background_tasks=mock_background_tasks,
db_session=mock_db_session,
cache=mock_cache,
)

# Verify test_session was set on request.state
assert hasattr(mock_request.state, "test_session")
assert mock_request.state.test_session == mock_db_session

# Verify context contains expected keys (no shared "session" — resolvers
# create their own via resolver_session())
assert "session" not in context
assert context["node_loader"] == mock_loader
assert context["settings"] == mock_settings
assert context["request"] == mock_request
assert context["background_tasks"] == mock_background_tasks
assert context["cache"] == mock_cache


@pytest.mark.asyncio
@patch("datajunction_server.api.graphql.main.create_node_by_name_loader")
@patch("datajunction_server.api.graphql.main.get_settings")
async def test_get_context_with_existing_test_session(
mock_get_settings,
mock_create_loader,
):
"""Test get_context when request.state already has test_session attribute"""
# Create a mock request WITH test_session already set
mock_request = MagicMock()
existing_test_session = AsyncMock()
mock_request.state.test_session = existing_test_session

mock_background_tasks = MagicMock()
mock_db_session = AsyncMock()
mock_cache = MagicMock()
mock_settings = MagicMock()
mock_loader = MagicMock()

mock_get_settings.return_value = mock_settings
mock_create_loader.return_value = mock_loader

# Call get_context
context = await get_context(
request=mock_request,
background_tasks=mock_background_tasks,
db_session=mock_db_session,
cache=mock_cache,
)

# Verify test_session was NOT overwritten - should still be the existing one
assert mock_request.state.test_session == existing_test_session

# Verify context keys (no shared "session")
# No shared "session" in context — resolvers create their own via
# resolver_session() to avoid concurrent-session crashes.
assert "session" not in context
assert context["node_loader"] == mock_loader
assert context["settings"] == mock_settings
Expand Down
Loading