diff --git a/datajunction-server/datajunction_server/api/graphql/main.py b/datajunction-server/datajunction_server/api/graphql/main.py index ed7ae720e..cce89d632 100644 --- a/datajunction-server/datajunction_server/api/graphql/main.py +++ b/datajunction-server/datajunction_server/api/graphql/main.py @@ -109,19 +109,21 @@ async def wrapper(*args, **kwargs): async def get_context( request: Request, background_tasks: BackgroundTasks, - db_session=Depends(get_session), cache=Depends(get_cache), _auth=Depends(DJHTTPBearer(auto_error=False)), ): """ Provides the context for graphql requests """ - # Attach test session to request.state so DataLoaders (which use - # session_context()) can reuse it in tests. resolver_session() does NOT - # use this — it checks dependency_overrides instead, so it always - # creates independent sessions in production. - if not hasattr(request.state, "test_session"): # pragma: no branch - request.state.test_session = db_session + # In tests, get_session is overridden via dependency_overrides to return a + # shared session. Attach it to request.state so DataLoaders (which use + # session_context()) reuse the same transaction. In production we must + # NOT attach a shared session — concurrent DataLoaders on one AsyncSession + # raise "concurrent operations are not permitted" — so each DataLoader + # opens its own session via session_context(). + override = request.app.dependency_overrides.get(get_session) + if override is not None: + request.state.test_session = override() return { "node_loader": create_node_by_name_loader(request), diff --git a/datajunction-server/tests/api/graphql/test_main.py b/datajunction-server/tests/api/graphql/test_main.py index d22262d4e..bb9d20d4a 100644 --- a/datajunction-server/tests/api/graphql/test_main.py +++ b/datajunction-server/tests/api/graphql/test_main.py @@ -9,13 +9,13 @@ @pytest.mark.asyncio @patch("datajunction_server.api.graphql.main.create_node_by_name_loader") @patch("datajunction_server.api.graphql.main.get_settings") -async def test_get_context(mock_get_settings, mock_create_loader): - """Test get_context returns expected keys without a shared session.""" +async def test_get_context_production(mock_get_settings, mock_create_loader): + """Test get_context in production: no test_session attached.""" mock_request = MagicMock() mock_request.state = MagicMock(spec=[]) + mock_request.app.dependency_overrides = {} mock_background_tasks = MagicMock() - mock_db_session = AsyncMock() mock_cache = MagicMock() mock_settings = MagicMock() mock_loader = MagicMock() @@ -26,15 +26,41 @@ async def test_get_context(mock_get_settings, mock_create_loader): context = await get_context( request=mock_request, background_tasks=mock_background_tasks, - db_session=mock_db_session, cache=mock_cache, ) - # No shared "session" in context — resolvers create their own via - # resolver_session() to avoid concurrent-session crashes. + # No shared "session" in context — DataLoaders open their own via + # session_context() to avoid concurrent-session crashes. assert "session" not in context + assert not hasattr(mock_request.state, "test_session") assert context["node_loader"] == mock_loader assert context["settings"] == mock_settings assert context["request"] == mock_request assert context["background_tasks"] == mock_background_tasks assert context["cache"] == mock_cache + + +@pytest.mark.asyncio +@patch("datajunction_server.api.graphql.main.create_node_by_name_loader") +@patch("datajunction_server.api.graphql.main.get_settings") +async def test_get_context_test_override(mock_get_settings, mock_create_loader): + """Test get_context in tests: shared session is attached from override.""" + from datajunction_server.utils import get_session + + mock_request = MagicMock() + mock_request.state = MagicMock(spec=[]) + mock_test_session = AsyncMock() + mock_request.app.dependency_overrides = {get_session: lambda: mock_test_session} + + mock_background_tasks = MagicMock() + mock_cache = MagicMock() + mock_get_settings.return_value = MagicMock() + mock_create_loader.return_value = MagicMock() + + await get_context( + request=mock_request, + background_tasks=mock_background_tasks, + cache=mock_cache, + ) + + assert mock_request.state.test_session is mock_test_session