Skip to content

Make sure logout button renders in middleware exception handlers #141

@chriscarrollsmith

Description

@chriscarrollsmith

Here's how I fixed it in a fork of this template:

2 files changed
+46
-11
lines changed
Search within code

[main.py
+22
-8
Lines changed: 22 additions & 8 deletions
Original file line number Original file line Diff line number Diff line change
@@ -2,7 +2,7 @@
from typing import Optional
from typing import Optional
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status
from fastapi.responses import RedirectResponse
from fastapi.responses import RedirectResponse, Response
from fastapi.staticfiles import StaticFiles
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.templating import Jinja2Templates
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import RequestValidationError
@@ -17,7 +17,8 @@
NeedsNewTokens
NeedsNewTokens
)
)
from utils.core.dependencies import (
from utils.core.dependencies import (
get_optional_user
get_optional_user,
get_user_from_request
)
)
from utils.core.db import set_up_db
from utils.core.db import set_up_db
from utils.core.models import User
from utils.core.models import User
@@ -94,21 +95,29 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):

Handle PasswordValidationError by rendering the validation_error page

Handle PasswordValidationError by rendering the validation_error page

@app.exception_handler(PasswordValidationError)
@app.exception_handler(PasswordValidationError)
async def password_validation_exception_handler(request: Request, exc: PasswordValidationError):
async def password_validation_exception_handler(
request: Request,
exc: PasswordValidationError
) -> Response:
user = await get_user_from_request(request)
return templates.TemplateResponse(
return templates.TemplateResponse(
request,
request,
"errors/validation_error.html",
"errors/validation_error.html",
{
{
"status_code": 422,
"status_code": 422,
"errors": {"error": exc.detail}
"errors": {"error": exc.detail},
"user": user
},
},
status_code=422,
status_code=422,
)
)

Handle RequestValidationError by rendering the validation_error page

Handle RequestValidationError by rendering the validation_error page

@app.exception_handler(RequestValidationError)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
async def validation_exception_handler(
request: Request,
exc: RequestValidationError
):
errors = {}
errors = {}

# Map error types to user-friendly message templates
# Map error types to user-friendly message templates

@@ -146,12 +155,14 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE

    errors[display_name] = message_template
    errors[display_name] = message_template


user = await get_user_from_request(request)
return templates.TemplateResponse(
return templates.TemplateResponse(
    request,
    request,
    "errors/validation_error.html",
    "errors/validation_error.html",
    {
    {
        "status_code": 422,
        "status_code": 422,
        "errors": errors
        "errors": errors,
        "user": user
    },
    },
    status_code=422,
    status_code=422,
)
)

@@ -160,10 +171,11 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE

Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page

Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page

@app.exception_handler(StarletteHTTPException)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
user = await get_user_from_request(request)
return templates.TemplateResponse(
return templates.TemplateResponse(
request,
request,
"errors/error.html",
"errors/error.html",
{"status_code": exc.status_code, "detail": exc.detail},
{"status_code": exc.status_code, "detail": exc.detail, "user": user},
status_code=exc.status_code,
status_code=exc.status_code,
)
)

@@ -173,13 +185,15 @@ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
async def general_exception_handler(request: Request, exc: Exception):
async def general_exception_handler(request: Request, exc: Exception):
# Log the error for debugging
# Log the error for debugging
logger.error(f"Unhandled exception: {exc}", exc_info=True)
logger.error(f"Unhandled exception: {exc}", exc_info=True)
user = await get_user_from_request(request)

return templates.TemplateResponse(
return templates.TemplateResponse(
    request,
    request,
    "errors/error.html",
    "errors/error.html",
    {
    {
        "status_code": 500,
        "status_code": 500,
        "detail": "Internal Server Error"
        "detail": "Internal Server Error",
        "user": user
    },
    },
    status_code=500,
    status_code=500,
)
)

utils/core/dependencies.py
+24
-3
Lines changed: 24 additions & 3 deletions
Original file line number Original file line Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import Depends, Form
from fastapi import Depends, Form, Request
from pydantic import EmailStr
from pydantic import EmailStr
from sqlmodel import Session, select
from sqlmodel import Session, select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import selectinload
@@ -324,5 +324,26 @@ def get_user_with_relations(
return eager_user
return eager_user

if name == "main":
async def get_user_from_request(request: Request) -> Optional[User]:
print(verify_password("aec78e3958889799", "$2b$12$CfcKGTXNJb3wm76G3qUGYeWmXHgflpnUxgVXgjOqy/ElLov/WxDD6"))
"""
Helper function to get user from request cookies in exception handlers.
Exception handlers can't use Depends(), so we manually extract tokens and get the user.
"""
access_token = request.cookies.get("access_token")
refresh_token = request.cookies.get("refresh_token")
tokens = (access_token, refresh_token)

# Get a database session
engine = create_engine(get_connection_url())
with Session(engine) as session:
    user, new_access_token, new_refresh_token = get_user_from_tokens(tokens, session)
    
    # If we got new tokens, we'd normally raise NeedsNewTokens, but in an exception
    # handler we can't do that easily. For now, just return the user.
    # The tokens will be refreshed on the next request.
    if user and new_access_token and new_refresh_token:
        # Note: We can't easily set cookies here since we're in an exception handler.
        # The user will need to make another request to get new tokens.
        pass
    
    return user

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions