From 2f79764f524bfecc8ef3ba9b0208eea0021e861c Mon Sep 17 00:00:00 2001 From: Henry Chen Date: Sun, 8 Mar 2026 22:37:28 +0800 Subject: [PATCH 1/2] Fix: Ensure JWTValidator handles GUESS algorithm with JWKS - Updated `avalidated_claims` to read the signing algorithm (`alg`) from the token header when `jwt_algorithm` is set to "GUESS". - Passed the raw key (`key.key`) instead of the `PyJWK` object to prevent pyjwt from overriding the algorithm with `PyJWK.algorithm_name`. --- .../src/airflow/api_fastapi/auth/tokens.py | 21 +++++++------ .../unit/api_fastapi/auth/test_tokens.py | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 4732164be7115..a57e85eb40c03 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -291,14 +291,8 @@ def __attrs_post_init__(self): raise ValueError("Exactly one of private_key and secret_key must be specified") if self.algorithm == ["GUESS"]: - if self.jwks: - # TODO: We could probably populate this from the jwks document, but we don't have that at - # construction time. - raise ValueError( - "Cannot guess the algorithm when using JWKS - please specify it in the config option " - "[api_auth] jwt_algorithm" - ) - self.algorithm = ["HS512"] + if not self.jwks: + self.algorithm = ["HS512"] def _get_kid_from_header(self, unvalidated: str) -> str: header = jwt.get_unverified_header(unvalidated) @@ -326,13 +320,20 @@ async def avalidated_claims( ) -> dict[str, Any]: """Decode the JWT token, returning the validated claims or raising an exception.""" key = await self._get_validation_key(unvalidated) + algorithms = self.algorithm + validation_key: str | jwt.PyJWK | Any = key + if algorithms == ["GUESS"] and isinstance(key, jwt.PyJWK): + header = jwt.get_unverified_header(unvalidated) + algorithms = [header.get("alg") or key.algorithm_name] + validation_key = key.key + claims = jwt.decode( unvalidated, - key, + validation_key, audience=self.audience, issuer=self.issuer, options={"require": list(self.required_claims)}, - algorithms=self.algorithm, + algorithms=algorithms, leeway=self.leeway, ) diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index e477c42af4f5a..6b848f723a004 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -264,6 +264,37 @@ async def test_jwt_generate_validate_roundtrip_with_jwks(private_key, algorithm, assert await validator.avalidated_claims(token) +@pytest.mark.parametrize("private_key", ["rsa_private_key", "ed25519_private_key"], indirect=True) +async def test_jwt_validate_roundtrip_with_jwks_and_guess_algorithm(private_key, tmp_path: pathlib.Path): + jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, "custom-kid")]}) + + jwks = tmp_path.joinpath("jwks.json") + await anyio.Path(jwks).write_text(jwk_content) + + priv_key = tmp_path.joinpath("key.pem") + await anyio.Path(priv_key).write_bytes(key_to_pem(private_key)) + + with conf_vars( + { + ("api_auth", "trusted_jwks_url"): str(jwks), + ("api_auth", "jwt_kid"): "custom-kid", + ("api_auth", "jwt_issuer"): "http://my-issuer.localdomain", + ("api_auth", "jwt_private_key_path"): str(priv_key), + ("api_auth", "jwt_algorithm"): "GUESS", + ("api_auth", "jwt_secret"): "", + } + ): + gen = JWTGenerator(audience="airflow1", valid_for=300) + token = gen.generate({"sub": "test"}) + + validator = JWTValidator( + audience="airflow1", + leeway=0, + **get_sig_validation_args(make_secret_key_if_needed=False), + ) + assert await validator.avalidated_claims(token) + + class TestRevokeToken: pytestmark = [pytest.mark.db_test] From 2d3672cc47e4fd3ffa8fe1cc5908c084c69d8d77 Mon Sep 17 00:00:00 2001 From: Henry Chen Date: Mon, 9 Mar 2026 18:46:55 +0800 Subject: [PATCH 2/2] Use algorithm_name instead of the token header --- airflow-core/src/airflow/api_fastapi/auth/tokens.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index a57e85eb40c03..3375853a29a5d 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -93,7 +93,7 @@ def _guess_best_algorithm(key: AllowedPrivateKeys): from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey if isinstance(key, RSAPrivateKey): - return "RS512" + return "RS256" if isinstance(key, Ed25519PrivateKey): return "EdDSA" raise ValueError(f"Unknown key object {type(key)}") @@ -323,8 +323,9 @@ async def avalidated_claims( algorithms = self.algorithm validation_key: str | jwt.PyJWK | Any = key if algorithms == ["GUESS"] and isinstance(key, jwt.PyJWK): - header = jwt.get_unverified_header(unvalidated) - algorithms = [header.get("alg") or key.algorithm_name] + if not key.algorithm_name: + raise jwt.InvalidTokenError("Missing algorithm in JWK") + algorithms = [key.algorithm_name] validation_key = key.key claims = jwt.decode(