diff --git a/main.py b/main.py index 84d6c30..be2b73d 100644 --- a/main.py +++ b/main.py @@ -3,14 +3,15 @@ from contextlib import asynccontextmanager from dotenv import load_dotenv from fastapi import FastAPI, Request, Depends, status -from fastapi.responses import RedirectResponse +from fastapi.responses import RedirectResponse, Response from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException from routers.core import account, dashboard, organization, role, user, static_pages, invitation from utils.core.dependencies import ( - get_optional_user + get_optional_user, + get_user_from_request ) from exceptions.http_exceptions import ( AuthenticationError, @@ -91,13 +92,18 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens): # Handle PasswordValidationError by rendering the validation_error page @app.exception_handler(PasswordValidationError) -async def password_validation_exception_handler(request: Request, exc: PasswordValidationError): +async def password_validation_exception_handler( + request: Request, + exc: PasswordValidationError +) -> Response: + user = await get_user_from_request(request) return templates.TemplateResponse( request, "errors/validation_error.html", { "status_code": 422, - "errors": {"error": exc.detail} + "errors": {"error": exc.detail}, + "user": user }, status_code=422, ) @@ -105,7 +111,10 @@ async def password_validation_exception_handler(request: Request, exc: PasswordV # Handle RequestValidationError by rendering the validation_error page @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): +async def validation_exception_handler( + request: Request, + exc: RequestValidationError +): errors = {} # Map error types to user-friendly message templates @@ -129,26 +138,28 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE # For JSON body, it might be (body, field_name) # For array items, it might be (field_name, array_index) field_name = location[-2] if isinstance(location[-1], int) else location[-1] - + # Format the field name to be more user-friendly display_name = field_name.replace("_", " ").title() - + # Use mapped message if available, otherwise use FastAPI's message error_type = error.get("type", "") message_template = error_templates.get(error_type, error["msg"]) - + # For array items, append the index to the message if isinstance(location[-1], int): message_template = f"Item {location[-1] + 1}: {message_template}" - + errors[display_name] = message_template + user = await get_user_from_request(request) return templates.TemplateResponse( request, "errors/validation_error.html", { "status_code": 422, - "errors": errors + "errors": errors, + "user": user }, status_code=422, ) @@ -157,10 +168,11 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE # Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): + user = await get_user_from_request(request) return templates.TemplateResponse( request, "errors/error.html", - {"status_code": exc.status_code, "detail": exc.detail}, + {"status_code": exc.status_code, "detail": exc.detail, "user": user}, status_code=exc.status_code, ) @@ -170,13 +182,15 @@ async def http_exception_handler(request: Request, exc: StarletteHTTPException): async def general_exception_handler(request: Request, exc: Exception): # Log the error for debugging logger.error(f"Unhandled exception: {exc}", exc_info=True) + user = await get_user_from_request(request) return templates.TemplateResponse( request, "errors/error.html", { "status_code": 500, - "detail": "Internal Server Error" + "detail": "Internal Server Error", + "user": user }, status_code=500, ) diff --git a/utils/core/dependencies.py b/utils/core/dependencies.py index 3b1613e..ce6d456 100644 --- a/utils/core/dependencies.py +++ b/utils/core/dependencies.py @@ -1,4 +1,4 @@ -from fastapi import Depends, Form +from fastapi import Depends, Form, Request from pydantic import EmailStr from sqlmodel import Session, select from sqlalchemy.orm import selectinload @@ -321,4 +321,29 @@ def get_user_with_relations( ) ).one() - return eager_user \ No newline at end of file + return eager_user + + +async def get_user_from_request(request: Request) -> Optional[User]: + """ + Helper function to get user from request cookies in exception handlers. + Exception handlers can't use Depends(), so we manually extract tokens and get the user. + """ + access_token = request.cookies.get("access_token") + refresh_token = request.cookies.get("refresh_token") + tokens = (access_token, refresh_token) + + # Get a database session + engine = create_engine(get_connection_url()) + with Session(engine) as session: + user, new_access_token, new_refresh_token = get_user_from_tokens(tokens, session) + + # If we got new tokens, we'd normally raise NeedsNewTokens, but in an exception + # handler we can't do that easily. For now, just return the user. + # The tokens will be refreshed on the next request. + if user and new_access_token and new_refresh_token: + # Note: We can't easily set cookies here since we're in an exception handler. + # The user will need to make another request to get new tokens. + pass + + return user \ No newline at end of file