From 62f90485584e1d3c831e03f64739d6d7cbb390a4 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Sun, 19 Apr 2026 02:37:20 -0700 Subject: [PATCH] Create a fresh session per resolver to avoid concurrent session conflicts --- .../datajunction_server/api/graphql/main.py | 6 +- .../datajunction_server/api/graphql/utils.py | 27 ++++++++- .../tests/api/graphql/test_main.py | 58 ++----------------- 3 files changed, 34 insertions(+), 57 deletions(-) diff --git a/datajunction-server/datajunction_server/api/graphql/main.py b/datajunction-server/datajunction_server/api/graphql/main.py index 158b08849..aa5c0ca32 100644 --- a/datajunction-server/datajunction_server/api/graphql/main.py +++ b/datajunction-server/datajunction_server/api/graphql/main.py @@ -115,8 +115,10 @@ async def get_context( """ Provides the context for graphql requests """ - # Attach test session to request.state so DataLoaders can use it - # This ensures DataLoaders use the same test session in tests + # Attach test session to request.state so DataLoaders (which use + # session_context()) can reuse it in tests. resolver_session() does NOT + # use this — it checks dependency_overrides instead, so it always + # creates independent sessions in production. if not hasattr(request.state, "test_session"): request.state.test_session = db_session diff --git a/datajunction-server/datajunction_server/api/graphql/utils.py b/datajunction-server/datajunction_server/api/graphql/utils.py index 6a4023039..f7642fdce 100644 --- a/datajunction-server/datajunction_server/api/graphql/utils.py +++ b/datajunction-server/datajunction_server/api/graphql/utils.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from strawberry.types import Info -from datajunction_server.utils import session_context +from datajunction_server.utils import get_session CURSOR_SEPARATOR = "-" @@ -23,13 +23,36 @@ async def resolver_session(info: Info) -> AsyncIterator[AsyncSession]: concurrent resolvers causes ``InvalidCachedStatementError`` / ``isce`` errors. + In tests, ``request.state.test_session`` holds a shared session that + all resolvers must use (so they see the same transaction's data). + Tests run resolvers sequentially so the shared session is safe. + + In production, ``test_session`` is not set, so we create a fresh + session per resolver via ``get_session()``. + Usage:: async with resolver_session(info) as session: ... """ - async with session_context(info.context["request"]) as session: + request = info.context["request"] + + # In tests, get_session is overridden via dependency_overrides to return a + # shared test session. We call the override directly so resolvers see the + # same transaction. The override is a sync function returning the session. + app = request.app + override = app.dependency_overrides.get(get_session) + if override is not None: + yield override() + return + + # Production: create a genuinely independent session per resolver. + gen = get_session(request, session_label="graphql_resolver") + session = await gen.__anext__() + try: yield session + finally: + await gen.aclose() # type: ignore def convert_camel_case(name): diff --git a/datajunction-server/tests/api/graphql/test_main.py b/datajunction-server/tests/api/graphql/test_main.py index 0b38d23d6..d22262d4e 100644 --- a/datajunction-server/tests/api/graphql/test_main.py +++ b/datajunction-server/tests/api/graphql/test_main.py @@ -9,11 +9,10 @@ @pytest.mark.asyncio @patch("datajunction_server.api.graphql.main.create_node_by_name_loader") @patch("datajunction_server.api.graphql.main.get_settings") -async def test_get_context_without_test_session(mock_get_settings, mock_create_loader): - """Test get_context when request.state doesn't have test_session attribute""" - # Create a mock request without test_session +async def test_get_context(mock_get_settings, mock_create_loader): + """Test get_context returns expected keys without a shared session.""" mock_request = MagicMock() - mock_request.state = MagicMock(spec=[]) # Empty spec means no attributes initially + mock_request.state = MagicMock(spec=[]) mock_background_tasks = MagicMock() mock_db_session = AsyncMock() @@ -24,7 +23,6 @@ async def test_get_context_without_test_session(mock_get_settings, mock_create_l mock_get_settings.return_value = mock_settings mock_create_loader.return_value = mock_loader - # Call get_context context = await get_context( request=mock_request, background_tasks=mock_background_tasks, @@ -32,54 +30,8 @@ async def test_get_context_without_test_session(mock_get_settings, mock_create_l cache=mock_cache, ) - # Verify test_session was set on request.state - assert hasattr(mock_request.state, "test_session") - assert mock_request.state.test_session == mock_db_session - - # Verify context contains expected keys (no shared "session" — resolvers - # create their own via resolver_session()) - assert "session" not in context - assert context["node_loader"] == mock_loader - assert context["settings"] == mock_settings - assert context["request"] == mock_request - assert context["background_tasks"] == mock_background_tasks - assert context["cache"] == mock_cache - - -@pytest.mark.asyncio -@patch("datajunction_server.api.graphql.main.create_node_by_name_loader") -@patch("datajunction_server.api.graphql.main.get_settings") -async def test_get_context_with_existing_test_session( - mock_get_settings, - mock_create_loader, -): - """Test get_context when request.state already has test_session attribute""" - # Create a mock request WITH test_session already set - mock_request = MagicMock() - existing_test_session = AsyncMock() - mock_request.state.test_session = existing_test_session - - mock_background_tasks = MagicMock() - mock_db_session = AsyncMock() - mock_cache = MagicMock() - mock_settings = MagicMock() - mock_loader = MagicMock() - - mock_get_settings.return_value = mock_settings - mock_create_loader.return_value = mock_loader - - # Call get_context - context = await get_context( - request=mock_request, - background_tasks=mock_background_tasks, - db_session=mock_db_session, - cache=mock_cache, - ) - - # Verify test_session was NOT overwritten - should still be the existing one - assert mock_request.state.test_session == existing_test_session - - # Verify context keys (no shared "session") + # No shared "session" in context — resolvers create their own via + # resolver_session() to avoid concurrent-session crashes. assert "session" not in context assert context["node_loader"] == mock_loader assert context["settings"] == mock_settings