Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions datajunction-server/datajunction_server/api/graphql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,21 @@ async def wrapper(*args, **kwargs):
async def get_context(
request: Request,
background_tasks: BackgroundTasks,
db_session=Depends(get_session),
cache=Depends(get_cache),
_auth=Depends(DJHTTPBearer(auto_error=False)),
):
"""
Provides the context for graphql requests
"""
# Attach test session to request.state so DataLoaders (which use
# session_context()) can reuse it in tests. resolver_session() does NOT
# use this — it checks dependency_overrides instead, so it always
# creates independent sessions in production.
if not hasattr(request.state, "test_session"): # pragma: no branch
request.state.test_session = db_session
# In tests, get_session is overridden via dependency_overrides to return a
# shared session. Attach it to request.state so DataLoaders (which use
# session_context()) reuse the same transaction. In production we must
# NOT attach a shared session — concurrent DataLoaders on one AsyncSession
# raise "concurrent operations are not permitted" — so each DataLoader
# opens its own session via session_context().
override = request.app.dependency_overrides.get(get_session)
if override is not None:
request.state.test_session = override()

return {
"node_loader": create_node_by_name_loader(request),
Expand Down
38 changes: 32 additions & 6 deletions datajunction-server/tests/api/graphql/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
@pytest.mark.asyncio
@patch("datajunction_server.api.graphql.main.create_node_by_name_loader")
@patch("datajunction_server.api.graphql.main.get_settings")
async def test_get_context(mock_get_settings, mock_create_loader):
"""Test get_context returns expected keys without a shared session."""
async def test_get_context_production(mock_get_settings, mock_create_loader):
"""Test get_context in production: no test_session attached."""
mock_request = MagicMock()
mock_request.state = MagicMock(spec=[])
mock_request.app.dependency_overrides = {}

mock_background_tasks = MagicMock()
mock_db_session = AsyncMock()
mock_cache = MagicMock()
mock_settings = MagicMock()
mock_loader = MagicMock()
Expand All @@ -26,15 +26,41 @@ async def test_get_context(mock_get_settings, mock_create_loader):
context = await get_context(
request=mock_request,
background_tasks=mock_background_tasks,
db_session=mock_db_session,
cache=mock_cache,
)

# No shared "session" in context — resolvers create their own via
# resolver_session() to avoid concurrent-session crashes.
# No shared "session" in context — DataLoaders open their own via
# session_context() to avoid concurrent-session crashes.
assert "session" not in context
assert not hasattr(mock_request.state, "test_session")
assert context["node_loader"] == mock_loader
assert context["settings"] == mock_settings
assert context["request"] == mock_request
assert context["background_tasks"] == mock_background_tasks
assert context["cache"] == mock_cache


@pytest.mark.asyncio
@patch("datajunction_server.api.graphql.main.create_node_by_name_loader")
@patch("datajunction_server.api.graphql.main.get_settings")
async def test_get_context_test_override(mock_get_settings, mock_create_loader):
"""Test get_context in tests: shared session is attached from override."""
from datajunction_server.utils import get_session

mock_request = MagicMock()
mock_request.state = MagicMock(spec=[])
mock_test_session = AsyncMock()
mock_request.app.dependency_overrides = {get_session: lambda: mock_test_session}

mock_background_tasks = MagicMock()
mock_cache = MagicMock()
mock_get_settings.return_value = MagicMock()
mock_create_loader.return_value = MagicMock()

await get_context(
request=mock_request,
background_tasks=mock_background_tasks,
cache=mock_cache,
)

assert mock_request.state.test_session is mock_test_session
Loading