From 689b7d7d5a6835366dc9d9e238c7639148a444f7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Mar 2026 19:12:41 +0000 Subject: [PATCH 1/2] Initial plan From c26b2c4489bc13c95c8645b3db39e9f0c9c1dfb0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Mar 2026 19:25:44 +0000 Subject: [PATCH 2/2] Fix login blocked by DB unavailability using session fallback auth Co-authored-by: gkorland <753206+gkorland@users.noreply.github.com> --- api/auth/user_management.py | 42 +++++- api/routes/auth.py | 33 +++++ tests/test_session_auth_fallback.py | 198 ++++++++++++++++++++++++++++ 3 files changed, 270 insertions(+), 3 deletions(-) create mode 100644 tests/test_session_auth_fallback.py diff --git a/api/auth/user_management.py b/api/auth/user_management.py index 3b5cd00f..a9b21885 100644 --- a/api/auth/user_management.py +++ b/api/auth/user_management.py @@ -11,6 +11,10 @@ from pydantic import BaseModel from api.extensions import db + +class DatabaseUnavailableError(Exception): + """Raised when the database backend cannot be reached during auth operations.""" + # Get secret key for sessions SECRET_KEY = os.getenv("FASTAPI_SECRET_KEY") if not SECRET_KEY: @@ -70,7 +74,7 @@ async def _get_user_info(api_token: str) -> Optional[Dict[str, Any]]: return None except Exception as e: # pylint: disable=broad-exception-caught logging.error("Error fetching user info: %s", e) - return None + raise DatabaseUnavailableError(str(e)) from e async def delete_user_token(api_token: str): @@ -226,11 +230,37 @@ def get_token(request: Request) -> Optional[str]: return None +def _validate_from_session(request: Request, api_token: str) -> Tuple[Optional[Dict[str, Any]], bool]: + """ + Validate user from session data when the database is unavailable. + + Only succeeds when the session contains user info AND the session's stored + api_token exactly matches the provided *api_token* argument. The token + comparison prevents an attacker from re-using a stale session after a + different token has been issued (e.g. after logout and re-login). + Returns (None, False) if the session token does not match or is absent, + ensuring no unauthorised access is granted via the session fallback path. + """ + try: + session = getattr(request, "session", None) + if not session: + return None, False + session_user = session.get("user_info") + session_token = session.get("api_token") + if session_user and session_token and session_token == api_token: + logging.info("Session fallback authentication succeeded") + return session_user, True + return None, False + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Error during session fallback authentication: %s", e) + return None, False + + async def validate_user(request: Request) -> Tuple[Optional[Dict[str, Any]], bool]: """ Helper function to validate token. Returns (user_info, is_authenticated). - Includes refresh handling for Google. + Falls back to session-based authentication when the database is unavailable. """ try: api_token = get_token(request) @@ -238,7 +268,13 @@ async def validate_user(request: Request) -> Tuple[Optional[Dict[str, Any]], boo if not api_token: return None, False - db_info = await _get_user_info(api_token) + try: + db_info = await _get_user_info(api_token) + except DatabaseUnavailableError: + logging.warning( + "Database unavailable during user validation, falling back to session authentication" + ) + return _validate_from_session(request, api_token) if db_info: return db_info, True diff --git a/api/routes/auth.py b/api/routes/auth.py index 01613eb7..f3b1b0b8 100644 --- a/api/routes/auth.py +++ b/api/routes/auth.py @@ -349,6 +349,11 @@ async def email_login(request: Request, login_data: EmailLoginRequest) -> JSONRe # Call the registered handler (await if async) await handler('email', user_data, api_token) + + # Store user info in session as a backup so login survives transient + # DB unavailability; validated against api_token on each request. + _store_session_backup(request, user_data, api_token) + response = JSONResponse({"success": True}, status_code=200) response.set_cookie( @@ -382,6 +387,22 @@ def _get_provider_client(request: Request, provider: str): raise HTTPException(status_code=500, detail=f"OAuth provider {provider} not configured") return client + +def _store_session_backup(request: Request, user_data: dict, api_token: str) -> None: + """Store user info and api_token in the session as a backup. + + This allows auth validation to succeed even when FalkorDB is temporarily + unavailable, by falling back to the signed session cookie. The api_token + is stored alongside the user info so that the fallback path can verify the + token has not changed (e.g. after an explicit logout). + """ + request.session["user_info"] = { + "email": user_data.get("email"), + "name": user_data.get("name"), + "picture": user_data.get("picture"), + } + request.session["api_token"] = api_token + def _build_callback_url(request: Request, path: str) -> str: """Build absolute callback URL, honoring OAUTH_BASE_URL if provided.""" base_override = os.getenv("OAUTH_BASE_URL") @@ -499,6 +520,10 @@ async def google_authorized(request: Request) -> RedirectResponse: # Call the registered handler (await if async) await handler('google', user_data, api_token) + # Store user info in session as a backup so login survives transient + # DB unavailability; validated against api_token on each request. + _store_session_backup(request, user_data, api_token) + redirect = RedirectResponse(url="/", status_code=302) redirect.set_cookie( key="api_token", @@ -603,6 +628,10 @@ async def github_authorized(request: Request) -> RedirectResponse: # Call the registered handler (await if async) await handler('github', user_data, api_token) + # Store user info in session as a backup so login survives transient + # DB unavailability; validated against api_token on each request. + _store_session_backup(request, user_data, api_token) + redirect = RedirectResponse(url="/", status_code=302) redirect.set_cookie( key="api_token", @@ -674,6 +703,10 @@ async def logout(request: Request): - GET: For direct navigation (bookmarks, links, old clients) - POST: For programmatic logout from the app """ + # Clear session-based auth backup on every logout path + request.session.pop("user_info", None) + request.session.pop("api_token", None) + # For GET requests, redirect to home page if request.method == "GET": response = RedirectResponse(url="/", status_code=302) diff --git a/tests/test_session_auth_fallback.py b/tests/test_session_auth_fallback.py new file mode 100644 index 00000000..650697b3 --- /dev/null +++ b/tests/test_session_auth_fallback.py @@ -0,0 +1,198 @@ +""" +Tests for session-based authentication fallback. + +These tests verify that login is not blocked when FalkorDB is unavailable, +and that the session is used as a backup for auth validation. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Mock FalkorDB at module level so api.extensions does not attempt a real +# Redis connection when the test suite is collected. +# --------------------------------------------------------------------------- +_extensions_mock = MagicMock() +sys.modules.setdefault("api.extensions", _extensions_mock) + +# pylint: disable=wrong-import-position +from api.auth.user_management import ( # noqa: E402 + DatabaseUnavailableError, + _validate_from_session, + validate_user, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_request(session_data: dict | None = None, api_token_cookie: str | None = None): + """Build a minimal mock Request with session and cookies.""" + request = MagicMock() + request.session = session_data or {} + request.cookies = {"api_token": api_token_cookie} if api_token_cookie else {} + # Simulate no Authorization header and no query param so get_token only + # looks at the cookie. + request.headers.get = lambda key, default=None: default + request.query_params.get = lambda key, default=None: default + return request + + +# --------------------------------------------------------------------------- +# _validate_from_session +# --------------------------------------------------------------------------- + +class TestValidateFromSession: + """Unit tests for the _validate_from_session helper.""" + + def test_returns_user_when_token_matches(self): + """User info is returned when the session token matches the cookie token.""" + user_info = {"email": "user@example.com", "name": "Test User", "picture": None} + request = _make_request( + session_data={"user_info": user_info, "api_token": "tok123"}, + ) + result_user, authenticated = _validate_from_session(request, "tok123") + assert authenticated is True + assert result_user == user_info + + def test_returns_none_when_token_mismatches(self): + """No user is returned when the session token does NOT match the request token.""" + user_info = {"email": "user@example.com", "name": "Test User", "picture": None} + request = _make_request( + session_data={"user_info": user_info, "api_token": "different-token"}, + ) + result_user, authenticated = _validate_from_session(request, "tok123") + assert authenticated is False + assert result_user is None + + def test_returns_none_when_session_is_empty(self): + """No user is returned when the session contains no auth data.""" + request = _make_request(session_data={}) + result_user, authenticated = _validate_from_session(request, "tok123") + assert authenticated is False + assert result_user is None + + def test_returns_none_when_session_missing_api_token(self): + """No user is returned when the session has user_info but no api_token.""" + user_info = {"email": "user@example.com", "name": "Test User", "picture": None} + request = _make_request(session_data={"user_info": user_info}) + result_user, authenticated = _validate_from_session(request, "tok123") + assert authenticated is False + assert result_user is None + + def test_returns_none_when_no_session_attribute(self): + """No user is returned when the request object has no session.""" + request = MagicMock(spec=[]) # No attributes at all + result_user, authenticated = _validate_from_session(request, "tok123") + assert authenticated is False + assert result_user is None + + +# --------------------------------------------------------------------------- +# validate_user – DB unavailable scenario +# --------------------------------------------------------------------------- + +class TestValidateUserSessionFallback: + """Tests for validate_user falling back to session when FalkorDB is down.""" + + @pytest.mark.asyncio + async def test_falls_back_to_session_when_db_unavailable(self): + """validate_user returns session user when FalkorDB raises DatabaseUnavailableError.""" + user_info = {"email": "user@example.com", "name": "Test User", "picture": None} + request = _make_request( + session_data={"user_info": user_info, "api_token": "good-token"}, + api_token_cookie="good-token", + ) + + with patch( + "api.auth.user_management._get_user_info", + new_callable=AsyncMock, + side_effect=DatabaseUnavailableError("DB is down"), + ): + result_user, authenticated = await validate_user(request) + + assert authenticated is True + assert result_user == user_info + + @pytest.mark.asyncio + async def test_db_success_returns_db_user(self): + """validate_user returns DB user when FalkorDB is reachable.""" + db_user = {"email": "db@example.com", "name": "DB User", "picture": None} + request = _make_request(api_token_cookie="some-token") + + with patch( + "api.auth.user_management._get_user_info", + new_callable=AsyncMock, + return_value=db_user, + ): + result_user, authenticated = await validate_user(request) + + assert authenticated is True + assert result_user == db_user + + @pytest.mark.asyncio + async def test_not_authenticated_when_db_down_and_no_session(self): + """validate_user returns not-authenticated when DB is down and no session backup.""" + request = _make_request( + session_data={}, # No session backup + api_token_cookie="some-token", + ) + + with patch( + "api.auth.user_management._get_user_info", + new_callable=AsyncMock, + side_effect=DatabaseUnavailableError("DB is down"), + ): + result_user, authenticated = await validate_user(request) + + assert authenticated is False + assert result_user is None + + @pytest.mark.asyncio + async def test_not_authenticated_when_db_down_and_token_mismatch(self): + """Session fallback does not authenticate when the session token differs from cookie.""" + user_info = {"email": "user@example.com", "name": "Test User", "picture": None} + request = _make_request( + session_data={"user_info": user_info, "api_token": "old-token"}, + api_token_cookie="new-token", # Mismatch + ) + + with patch( + "api.auth.user_management._get_user_info", + new_callable=AsyncMock, + side_effect=DatabaseUnavailableError("DB is down"), + ): + result_user, authenticated = await validate_user(request) + + assert authenticated is False + assert result_user is None + + @pytest.mark.asyncio + async def test_not_authenticated_when_no_token(self): + """validate_user returns not-authenticated when no token is present at all.""" + request = _make_request() # No cookies, no headers + + result_user, authenticated = await validate_user(request) + + assert authenticated is False + assert result_user is None + + +# --------------------------------------------------------------------------- +# DatabaseUnavailableError is exported correctly +# --------------------------------------------------------------------------- + +class TestDatabaseUnavailableError: + """Ensure DatabaseUnavailableError can be imported and is an Exception subclass.""" + + def test_is_exception(self): + """DatabaseUnavailableError must derive from Exception.""" + assert issubclass(DatabaseUnavailableError, Exception) + + def test_can_be_raised_and_caught(self): + """DatabaseUnavailableError can be raised and caught.""" + with pytest.raises(DatabaseUnavailableError, match="db error"): + raise DatabaseUnavailableError("db error")