diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 4732164be7115..3375853a29a5d 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -93,7 +93,7 @@ def _guess_best_algorithm(key: AllowedPrivateKeys): from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey if isinstance(key, RSAPrivateKey): - return "RS512" + return "RS256" if isinstance(key, Ed25519PrivateKey): return "EdDSA" raise ValueError(f"Unknown key object {type(key)}") @@ -291,14 +291,8 @@ def __attrs_post_init__(self): raise ValueError("Exactly one of private_key and secret_key must be specified") if self.algorithm == ["GUESS"]: - if self.jwks: - # TODO: We could probably populate this from the jwks document, but we don't have that at - # construction time. - raise ValueError( - "Cannot guess the algorithm when using JWKS - please specify it in the config option " - "[api_auth] jwt_algorithm" - ) - self.algorithm = ["HS512"] + if not self.jwks: + self.algorithm = ["HS512"] def _get_kid_from_header(self, unvalidated: str) -> str: header = jwt.get_unverified_header(unvalidated) @@ -326,13 +320,21 @@ async def avalidated_claims( ) -> dict[str, Any]: """Decode the JWT token, returning the validated claims or raising an exception.""" key = await self._get_validation_key(unvalidated) + algorithms = self.algorithm + validation_key: str | jwt.PyJWK | Any = key + if algorithms == ["GUESS"] and isinstance(key, jwt.PyJWK): + if not key.algorithm_name: + raise jwt.InvalidTokenError("Missing algorithm in JWK") + algorithms = [key.algorithm_name] + validation_key = key.key + claims = jwt.decode( unvalidated, - key, + validation_key, audience=self.audience, issuer=self.issuer, options={"require": list(self.required_claims)}, - algorithms=self.algorithm, + algorithms=algorithms, leeway=self.leeway, ) diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index e477c42af4f5a..6b848f723a004 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -264,6 +264,37 @@ async def test_jwt_generate_validate_roundtrip_with_jwks(private_key, algorithm, assert await validator.avalidated_claims(token) +@pytest.mark.parametrize("private_key", ["rsa_private_key", "ed25519_private_key"], indirect=True) +async def test_jwt_validate_roundtrip_with_jwks_and_guess_algorithm(private_key, tmp_path: pathlib.Path): + jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, "custom-kid")]}) + + jwks = tmp_path.joinpath("jwks.json") + await anyio.Path(jwks).write_text(jwk_content) + + priv_key = tmp_path.joinpath("key.pem") + await anyio.Path(priv_key).write_bytes(key_to_pem(private_key)) + + with conf_vars( + { + ("api_auth", "trusted_jwks_url"): str(jwks), + ("api_auth", "jwt_kid"): "custom-kid", + ("api_auth", "jwt_issuer"): "http://my-issuer.localdomain", + ("api_auth", "jwt_private_key_path"): str(priv_key), + ("api_auth", "jwt_algorithm"): "GUESS", + ("api_auth", "jwt_secret"): "", + } + ): + gen = JWTGenerator(audience="airflow1", valid_for=300) + token = gen.generate({"sub": "test"}) + + validator = JWTValidator( + audience="airflow1", + leeway=0, + **get_sig_validation_args(make_secret_key_if_needed=False), + ) + assert await validator.avalidated_claims(token) + + class TestRevokeToken: pytestmark = [pytest.mark.db_test]