Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ async def _get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK:
return await self.jwks.get_key(kid)

def validated_claims(
self, unvalidated: str, required_claims: dict[str, Any] | None = None
self, unvalidated: str, required_claims: dict[str, Any] | None = None, *, extra_leeway: float = 0
) -> dict[str, Any]:
return async_to_sync(self.avalidated_claims)(unvalidated, required_claims)
return async_to_sync(self.avalidated_claims)(unvalidated, required_claims, extra_leeway=extra_leeway)

async def avalidated_claims(
self, unvalidated: str, required_claims: dict[str, Any] | None = None
self, unvalidated: str, required_claims: dict[str, Any] | None = None, *, extra_leeway: float = 0
) -> dict[str, Any]:
"""Decode the JWT token, returning the validated claims or raising an exception."""
try:
Expand All @@ -338,7 +338,7 @@ async def avalidated_claims(
issuer=self.issuer,
options={"require": list(self.required_claims)},
algorithms=algorithms,
leeway=self.leeway,
leeway=self.leeway + extra_leeway,
)

# Validate additional claims if provided
Expand Down
11 changes: 7 additions & 4 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_sig_validation_args,
get_signing_args,
)
from airflow.api_fastapi.execution_api.security import _jwt_bearer

if TYPE_CHECKING:
import httpx
Expand Down Expand Up @@ -137,14 +138,16 @@ async def dispatch(self, request: Request, call_next):
refreshed_token: str | None = None
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
async with svcs.Container(request.app.state.svcs_registry) as services:
validator: JWTValidator = await services.aget(JWTValidator)
claims = await validator.avalidated_claims(token, {})
validated_token = await _jwt_bearer(request, services)
if validated_token is None:
return response
claims = validated_token.claims.model_dump()
claims["sub"] = str(validated_token.id)

# Workload tokens are long-lived and meant to survive queue
# wait times so avoid refreshing them. If avalidated_claims
# wait times so avoid refreshing them. If JWTBearer
# raises for a workload token, the outer except handles it.
if claims.get("scope") == "workload":
return response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from unittest.mock import AsyncMock

import jwt
import pytest
from fastapi import FastAPI

Expand Down Expand Up @@ -67,3 +68,34 @@ def test_expiring_token_is_reissued(
assert "Refreshed-API-Token" in response.headers
else:
assert "Refreshed-API-Token" not in response.headers


@pytest.mark.db_test
def test_reissue_reuses_cached_claims_from_jwt_bearer(client, exec_app: FastAPI, time_machine):
"""Ensure reissue uses claims already validated by JWTBearer.

If middleware validates the JWT again, this test would fail because the mocked
second validation call raises ExpiredSignatureError.
"""
moment = 1743451846 # A "random" unix epoch timestamp.
validity = 60
claims = {
"sub": "edb09971-4e0e-4221-ad3f-800852d38085",
"iat": moment,
"exp": moment + validity,
}

auth = AsyncMock(spec=JWTValidator)
auth.avalidated_claims.side_effect = [
claims,
jwt.ExpiredSignatureError("Signature has expired"),
]

time_machine.move_to(moment + 31, tick=False)

lifespan.registry.register_value(JWTValidator, auth)

response = client.get("/execution/variables/key1", headers={"Authorization": "Bearer dummy"})

assert "Refreshed-API-Token" in response.headers
assert auth.avalidated_claims.await_count == 1
Loading