Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions api/auth/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from pydantic import BaseModel
from api.extensions import db


class DatabaseUnavailableError(Exception):
"""Raised when the database backend cannot be reached during auth operations."""

# Get secret key for sessions
SECRET_KEY = os.getenv("FASTAPI_SECRET_KEY")
if not SECRET_KEY:
Expand Down Expand Up @@ -70,7 +74,7 @@ async def _get_user_info(api_token: str) -> Optional[Dict[str, Any]]:
return None
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error fetching user info: %s", e)
return None
raise DatabaseUnavailableError(str(e)) from e


async def delete_user_token(api_token: str):
Expand Down Expand Up @@ -226,19 +230,51 @@ def get_token(request: Request) -> Optional[str]:
return None


def _validate_from_session(request: Request, api_token: str) -> Tuple[Optional[Dict[str, Any]], bool]:
"""
Validate user from session data when the database is unavailable.

Only succeeds when the session contains user info AND the session's stored
api_token exactly matches the provided *api_token* argument. The token
comparison prevents an attacker from re-using a stale session after a
different token has been issued (e.g. after logout and re-login).
Returns (None, False) if the session token does not match or is absent,
ensuring no unauthorised access is granted via the session fallback path.
"""
try:
session = getattr(request, "session", None)
if not session:
return None, False
session_user = session.get("user_info")
session_token = session.get("api_token")
if session_user and session_token and session_token == api_token:
logging.info("Session fallback authentication succeeded")
return session_user, True
return None, False
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error during session fallback authentication: %s", e)
return None, False


async def validate_user(request: Request) -> Tuple[Optional[Dict[str, Any]], bool]:
"""
Helper function to validate token.
Returns (user_info, is_authenticated).
Includes refresh handling for Google.
Falls back to session-based authentication when the database is unavailable.
"""
try:
api_token = get_token(request)

if not api_token:
return None, False

db_info = await _get_user_info(api_token)
try:
db_info = await _get_user_info(api_token)
except DatabaseUnavailableError:
logging.warning(
"Database unavailable during user validation, falling back to session authentication"
)
return _validate_from_session(request, api_token)

if db_info:
return db_info, True
Expand Down
33 changes: 33 additions & 0 deletions api/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ async def email_login(request: Request, login_data: EmailLoginRequest) -> JSONRe

# Call the registered handler (await if async)
await handler('email', user_data, api_token)

# Store user info in session as a backup so login survives transient
# DB unavailability; validated against api_token on each request.
_store_session_backup(request, user_data, api_token)

response = JSONResponse({"success": True}, status_code=200)

response.set_cookie(
Expand Down Expand Up @@ -382,6 +387,22 @@ def _get_provider_client(request: Request, provider: str):
raise HTTPException(status_code=500, detail=f"OAuth provider {provider} not configured")
return client


def _store_session_backup(request: Request, user_data: dict, api_token: str) -> None:
"""Store user info and api_token in the session as a backup.

This allows auth validation to succeed even when FalkorDB is temporarily
unavailable, by falling back to the signed session cookie. The api_token
is stored alongside the user info so that the fallback path can verify the
token has not changed (e.g. after an explicit logout).
"""
request.session["user_info"] = {
"email": user_data.get("email"),
"name": user_data.get("name"),
"picture": user_data.get("picture"),
}
request.session["api_token"] = api_token

def _build_callback_url(request: Request, path: str) -> str:
"""Build absolute callback URL, honoring OAUTH_BASE_URL if provided."""
base_override = os.getenv("OAUTH_BASE_URL")
Expand Down Expand Up @@ -499,6 +520,10 @@ async def google_authorized(request: Request) -> RedirectResponse:
# Call the registered handler (await if async)
await handler('google', user_data, api_token)

# Store user info in session as a backup so login survives transient
# DB unavailability; validated against api_token on each request.
_store_session_backup(request, user_data, api_token)

redirect = RedirectResponse(url="/", status_code=302)
redirect.set_cookie(
key="api_token",
Expand Down Expand Up @@ -603,6 +628,10 @@ async def github_authorized(request: Request) -> RedirectResponse:
# Call the registered handler (await if async)
await handler('github', user_data, api_token)

# Store user info in session as a backup so login survives transient
# DB unavailability; validated against api_token on each request.
_store_session_backup(request, user_data, api_token)

redirect = RedirectResponse(url="/", status_code=302)
redirect.set_cookie(
key="api_token",
Expand Down Expand Up @@ -674,6 +703,10 @@ async def logout(request: Request):
- GET: For direct navigation (bookmarks, links, old clients)
- POST: For programmatic logout from the app
"""
# Clear session-based auth backup on every logout path
request.session.pop("user_info", None)
request.session.pop("api_token", None)

# For GET requests, redirect to home page
if request.method == "GET":
response = RedirectResponse(url="/", status_code=302)
Expand Down
198 changes: 198 additions & 0 deletions tests/test_session_auth_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
Tests for session-based authentication fallback.

These tests verify that login is not blocked when FalkorDB is unavailable,
and that the session is used as a backup for auth validation.
"""

import sys
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

# ---------------------------------------------------------------------------
# Mock FalkorDB at module level so api.extensions does not attempt a real
# Redis connection when the test suite is collected.
# ---------------------------------------------------------------------------
_extensions_mock = MagicMock()
sys.modules.setdefault("api.extensions", _extensions_mock)

# pylint: disable=wrong-import-position
from api.auth.user_management import ( # noqa: E402
DatabaseUnavailableError,
_validate_from_session,
validate_user,
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_request(session_data: dict | None = None, api_token_cookie: str | None = None):
"""Build a minimal mock Request with session and cookies."""
request = MagicMock()
request.session = session_data or {}
request.cookies = {"api_token": api_token_cookie} if api_token_cookie else {}
# Simulate no Authorization header and no query param so get_token only
# looks at the cookie.
request.headers.get = lambda key, default=None: default
request.query_params.get = lambda key, default=None: default
return request


# ---------------------------------------------------------------------------
# _validate_from_session
# ---------------------------------------------------------------------------

class TestValidateFromSession:
"""Unit tests for the _validate_from_session helper."""

def test_returns_user_when_token_matches(self):
"""User info is returned when the session token matches the cookie token."""
user_info = {"email": "user@example.com", "name": "Test User", "picture": None}
request = _make_request(
session_data={"user_info": user_info, "api_token": "tok123"},
)
result_user, authenticated = _validate_from_session(request, "tok123")
assert authenticated is True
assert result_user == user_info

def test_returns_none_when_token_mismatches(self):
"""No user is returned when the session token does NOT match the request token."""
user_info = {"email": "user@example.com", "name": "Test User", "picture": None}
request = _make_request(
session_data={"user_info": user_info, "api_token": "different-token"},
)
result_user, authenticated = _validate_from_session(request, "tok123")
assert authenticated is False
assert result_user is None

def test_returns_none_when_session_is_empty(self):
"""No user is returned when the session contains no auth data."""
request = _make_request(session_data={})
result_user, authenticated = _validate_from_session(request, "tok123")
assert authenticated is False
assert result_user is None

def test_returns_none_when_session_missing_api_token(self):
"""No user is returned when the session has user_info but no api_token."""
user_info = {"email": "user@example.com", "name": "Test User", "picture": None}
request = _make_request(session_data={"user_info": user_info})
result_user, authenticated = _validate_from_session(request, "tok123")
assert authenticated is False
assert result_user is None

def test_returns_none_when_no_session_attribute(self):
"""No user is returned when the request object has no session."""
request = MagicMock(spec=[]) # No attributes at all
result_user, authenticated = _validate_from_session(request, "tok123")
assert authenticated is False
assert result_user is None


# ---------------------------------------------------------------------------
# validate_user – DB unavailable scenario
# ---------------------------------------------------------------------------

class TestValidateUserSessionFallback:
"""Tests for validate_user falling back to session when FalkorDB is down."""

@pytest.mark.asyncio
async def test_falls_back_to_session_when_db_unavailable(self):
"""validate_user returns session user when FalkorDB raises DatabaseUnavailableError."""
user_info = {"email": "user@example.com", "name": "Test User", "picture": None}
request = _make_request(
session_data={"user_info": user_info, "api_token": "good-token"},
api_token_cookie="good-token",
)

with patch(
"api.auth.user_management._get_user_info",
new_callable=AsyncMock,
side_effect=DatabaseUnavailableError("DB is down"),
):
result_user, authenticated = await validate_user(request)

assert authenticated is True
assert result_user == user_info

@pytest.mark.asyncio
async def test_db_success_returns_db_user(self):
"""validate_user returns DB user when FalkorDB is reachable."""
db_user = {"email": "db@example.com", "name": "DB User", "picture": None}
request = _make_request(api_token_cookie="some-token")

with patch(
"api.auth.user_management._get_user_info",
new_callable=AsyncMock,
return_value=db_user,
):
result_user, authenticated = await validate_user(request)

assert authenticated is True
assert result_user == db_user

@pytest.mark.asyncio
async def test_not_authenticated_when_db_down_and_no_session(self):
"""validate_user returns not-authenticated when DB is down and no session backup."""
request = _make_request(
session_data={}, # No session backup
api_token_cookie="some-token",
)

with patch(
"api.auth.user_management._get_user_info",
new_callable=AsyncMock,
side_effect=DatabaseUnavailableError("DB is down"),
):
result_user, authenticated = await validate_user(request)

assert authenticated is False
assert result_user is None

@pytest.mark.asyncio
async def test_not_authenticated_when_db_down_and_token_mismatch(self):
"""Session fallback does not authenticate when the session token differs from cookie."""
user_info = {"email": "user@example.com", "name": "Test User", "picture": None}
request = _make_request(
session_data={"user_info": user_info, "api_token": "old-token"},
api_token_cookie="new-token", # Mismatch
)

with patch(
"api.auth.user_management._get_user_info",
new_callable=AsyncMock,
side_effect=DatabaseUnavailableError("DB is down"),
):
result_user, authenticated = await validate_user(request)

assert authenticated is False
assert result_user is None

@pytest.mark.asyncio
async def test_not_authenticated_when_no_token(self):
"""validate_user returns not-authenticated when no token is present at all."""
request = _make_request() # No cookies, no headers

result_user, authenticated = await validate_user(request)

assert authenticated is False
assert result_user is None


# ---------------------------------------------------------------------------
# DatabaseUnavailableError is exported correctly
# ---------------------------------------------------------------------------

class TestDatabaseUnavailableError:
"""Ensure DatabaseUnavailableError can be imported and is an Exception subclass."""

def test_is_exception(self):
"""DatabaseUnavailableError must derive from Exception."""
assert issubclass(DatabaseUnavailableError, Exception)

def test_can_be_raised_and_caught(self):
"""DatabaseUnavailableError can be raised and caught."""
with pytest.raises(DatabaseUnavailableError, match="db error"):
raise DatabaseUnavailableError("db error")
Loading