Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _guess_best_algorithm(key: AllowedPrivateKeys):
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey

if isinstance(key, RSAPrivateKey):
return "RS512"
return "RS256"
if isinstance(key, Ed25519PrivateKey):
return "EdDSA"
raise ValueError(f"Unknown key object {type(key)}")
Expand Down Expand Up @@ -291,14 +291,8 @@ def __attrs_post_init__(self):
raise ValueError("Exactly one of private_key and secret_key must be specified")

if self.algorithm == ["GUESS"]:
if self.jwks:
# TODO: We could probably populate this from the jwks document, but we don't have that at
# construction time.
raise ValueError(
"Cannot guess the algorithm when using JWKS - please specify it in the config option "
"[api_auth] jwt_algorithm"
)
self.algorithm = ["HS512"]
if not self.jwks:
self.algorithm = ["HS512"]

def _get_kid_from_header(self, unvalidated: str) -> str:
header = jwt.get_unverified_header(unvalidated)
Expand Down Expand Up @@ -326,13 +320,21 @@ async def avalidated_claims(
) -> dict[str, Any]:
"""Decode the JWT token, returning the validated claims or raising an exception."""
key = await self._get_validation_key(unvalidated)
algorithms = self.algorithm
validation_key: str | jwt.PyJWK | Any = key
if algorithms == ["GUESS"] and isinstance(key, jwt.PyJWK):
if not key.algorithm_name:
raise jwt.InvalidTokenError("Missing algorithm in JWK")
algorithms = [key.algorithm_name]
validation_key = key.key

claims = jwt.decode(
unvalidated,
key,
validation_key,
audience=self.audience,
issuer=self.issuer,
options={"require": list(self.required_claims)},
algorithms=self.algorithm,
algorithms=algorithms,
leeway=self.leeway,
)

Expand Down
31 changes: 31 additions & 0 deletions airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,37 @@ async def test_jwt_generate_validate_roundtrip_with_jwks(private_key, algorithm,
assert await validator.avalidated_claims(token)


@pytest.mark.parametrize("private_key", ["rsa_private_key", "ed25519_private_key"], indirect=True)
async def test_jwt_validate_roundtrip_with_jwks_and_guess_algorithm(private_key, tmp_path: pathlib.Path):
jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, "custom-kid")]})

jwks = tmp_path.joinpath("jwks.json")
await anyio.Path(jwks).write_text(jwk_content)

priv_key = tmp_path.joinpath("key.pem")
await anyio.Path(priv_key).write_bytes(key_to_pem(private_key))

with conf_vars(
{
("api_auth", "trusted_jwks_url"): str(jwks),
("api_auth", "jwt_kid"): "custom-kid",
("api_auth", "jwt_issuer"): "http://my-issuer.localdomain",
("api_auth", "jwt_private_key_path"): str(priv_key),
("api_auth", "jwt_algorithm"): "GUESS",
("api_auth", "jwt_secret"): "",
}
):
gen = JWTGenerator(audience="airflow1", valid_for=300)
token = gen.generate({"sub": "test"})

validator = JWTValidator(
audience="airflow1",
leeway=0,
**get_sig_validation_args(make_secret_key_if_needed=False),
)
assert await validator.avalidated_claims(token)


class TestRevokeToken:
pytestmark = [pytest.mark.db_test]

Expand Down
Loading