From fefc9319d891040858312ba08787c8edca647b91 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Wed, 12 Nov 2025 11:32:52 +0100 Subject: [PATCH 1/7] Token verification with previous public token --- README.md | 4 +- conf/config.json | 2 +- src/event_gate_lambda.py | 47 +++++++++++++++++---- tests/{utils => }/test_conf_validation.py | 4 +- tests/test_event_gate_lambda.py | 51 +++++++++++++++++++++++ 5 files changed, 95 insertions(+), 13 deletions(-) rename tests/{utils => }/test_conf_validation.py (98%) diff --git a/README.md b/README.md index e0cb1f4..418a5c7 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ Example (sanitized): { "access_config": "s3:///access.json", "token_provider_url": "https://", - "token_public_key_url": "https:///public-key", + "token_public_keys_url": "https:///token/public-keys", "kafka_bootstrap_server": "broker1:9092,broker2:9092", "event_bus_arn": "arn:aws:events:region:acct:event-bus/your-bus" } @@ -137,7 +137,7 @@ Use when Kafka access needs Kerberos / SASL_SSL or custom `librdkafka` build. | Code coverage | [Code Coverage](./DEVELOPER.md#code-coverage) | ## Security & Authorization -- JWT tokens must be RS256 signed; the public key is fetched at cold start from `token_public_key_url` (DER base64 inside JSON `{ "key": "..." }`). +- JWT tokens must be RS256 signed; current and previous public keys are fetched at cold start from `token_public_keys_url` as DER base64 values (list `keys[*].key`, with single-key fallback `{ "key": "..." }`). - Subject claim (`sub`) is matched against `ACCESS[topicName]`. - Authorization header forms accepted: - `Authorization: Bearer ` (preferred) diff --git a/conf/config.json b/conf/config.json index 84ab9f8..c35e61c 100644 --- a/conf/config.json +++ b/conf/config.json @@ -1,7 +1,7 @@ { "access_config": "s3:///access.json", "token_provider_url": "https://", - "token_public_key_url": "https://", + "token_public_keys_url": "https://", "kafka_bootstrap_server": "localhost:9092", "event_bus_arn": "arn:aws:events:" } \ No newline at end of file diff --git a/src/event_gate_lambda.py b/src/event_gate_lambda.py index 34dcb81..3480fff 100644 --- a/src/event_gate_lambda.py +++ b/src/event_gate_lambda.py @@ -28,6 +28,7 @@ import urllib3 from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from jsonschema import validate from jsonschema.exceptions import ValidationError @@ -80,17 +81,33 @@ ACCESS = json.load(file) logger.debug("Loaded ACCESS definitions") -TOKEN_PROVIDER_URL = CONFIG["token_provider_url"] -# Add timeout to avoid hanging requests; wrap in robust error handling so failures are explicit +# Initialize token public keys +TOKEN_PROVIDER_URL = CONFIG.get("token_provider_url") +TOKEN_PUBLIC_KEYS_URL = CONFIG.get("token_public_keys_url") or CONFIG.get("token_public_key_url") + try: - response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external - token_public_key_encoded = response_json["key"] - TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded)) - logger.debug("Loaded TOKEN_PUBLIC_KEY") + response_json = requests.get(TOKEN_PUBLIC_KEYS_URL, verify=False, timeout=5).json() + raw_keys: list[str] = [] + if isinstance(response_json, dict): + if "keys" in response_json and isinstance(response_json["keys"], list): + for item in response_json["keys"]: + if "key" in item: + raw_keys.append(item["key"].strip()) + elif "key" in response_json: + raw_keys.append(response_json["key"].strip()) + + if not raw_keys: + raise KeyError(f"No public keys found in {TOKEN_PUBLIC_KEYS_URL} endpoint response") + + TOKEN_PUBLIC_KEYS: list[RSAPublicKey] = [ + serialization.load_der_public_key(base64.b64decode(raw_key)) for raw_key in raw_keys + ] + logger.debug("Loaded %d TOKEN_PUBLIC_KEYS", len(TOKEN_PUBLIC_KEYS)) except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: - logger.exception("Failed to fetch or deserialize token public key from %s", CONFIG.get("token_public_key_url")) + logger.exception("Failed to fetch or deserialize token public key from %s", TOKEN_PUBLIC_KEYS_URL) raise RuntimeError("Token public key initialization failed") from exc +# Initialize EventGate writers writer_eventbridge.init(logger, CONFIG) writer_kafka.init(logger, CONFIG) writer_postgres.init(logger) @@ -163,7 +180,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc """ logger.debug("Handling POST %s", topic_name) try: - token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type] + token = decode_jwt_all(token_encoded) except jwt.PyJWTError: # type: ignore[attr-defined] return _error_response(401, "auth", "Invalid or missing token") @@ -205,6 +222,20 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc } +def decode_jwt_all(token_encoded: str) -> Dict[str, Any]: + """Decode JWT using any of the loaded public keys. + + Args: + token_encoded: Encoded bearer JWT token string. + """ + for public_key in TOKEN_PUBLIC_KEYS: + try: + return jwt.decode(token_encoded, public_key, algorithms=["RS256"]) + except jwt.PyJWTError: + continue + raise jwt.PyJWTError("Verification failed for all public keys") + + def extract_token(event_headers: Dict[str, str]) -> str: """Extract bearer token from headers (case-insensitive). diff --git a/tests/utils/test_conf_validation.py b/tests/test_conf_validation.py similarity index 98% rename from tests/utils/test_conf_validation.py rename to tests/test_conf_validation.py index 7ed62ec..ef4ec1c 100644 --- a/tests/utils/test_conf_validation.py +++ b/tests/test_conf_validation.py @@ -19,12 +19,12 @@ from glob import glob import pytest -CONF_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "conf") +CONF_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "conf") REQUIRED_CONFIG_KEYS = { "access_config", "token_provider_url", - "token_public_key_url", + "token_public_keys_url", "kafka_bootstrap_server", "event_bus_arn", } diff --git a/tests/test_event_gate_lambda.py b/tests/test_event_gate_lambda.py index 13ee3a1..28022ad 100644 --- a/tests/test_event_gate_lambda.py +++ b/tests/test_event_gate_lambda.py @@ -352,3 +352,54 @@ def test_post_invalid_json_body(event_gate_module, make_event): assert resp["statusCode"] == 500 body = json.loads(resp["body"]) assert any(e["type"] == "internal" for e in body["errors"]) # internal error path + + +def test_post_expired_token(event_gate_module, make_event, valid_payload): + """Expired JWT should yield 401 auth error.""" + with patch.object( + event_gate_module.jwt, + "decode", + side_effect=event_gate_module.jwt.ExpiredSignatureError("expired"), + create=True, + ): + event = make_event( + "/topics/{topic_name}", + method="POST", + topic="public.cps.za.test", + body=valid_payload, + headers={"Authorization": "Bearer expiredtoken"}, + ) + resp = event_gate_module.lambda_handler(event, None) + assert resp["statusCode"] == 401 + body = json.loads(resp["body"]) + assert any(e["type"] == "auth" for e in body["errors"]) + + +def test_decode_jwt_all_second_key_succeeds(event_gate_module): + """First key fails signature, second key succeeds; claims returned from second key.""" + first_key = object() + second_key = object() + event_gate_module.TOKEN_PUBLIC_KEYS = [first_key, second_key] + + def decode_side_effect(token, key, algorithms): + if key is first_key: + raise event_gate_module.jwt.PyJWTError("signature mismatch") + return {"sub": "TestUser"} + + with patch.object(event_gate_module.jwt, "decode", side_effect=decode_side_effect, create=True): + claims = event_gate_module.decode_jwt_all("dummy-token") + assert claims["sub"] == "TestUser" + + +def test_decode_jwt_all_all_keys_fail(event_gate_module): + """All keys fail; final PyJWTError with aggregate message is raised.""" + bad_keys = [object(), object()] + event_gate_module.TOKEN_PUBLIC_KEYS = bad_keys + + def always_fail(token, key, algorithms): + raise event_gate_module.jwt.PyJWTError("bad signature") + + with patch.object(event_gate_module.jwt, "decode", side_effect=always_fail, create=True): + with pytest.raises(event_gate_module.jwt.PyJWTError) as exc: + event_gate_module.decode_jwt_all("dummy-token") + assert "Verification failed for all public keys" in str(exc.value) From e00a6f70808ed90c014bc8042e815e94e7377d35 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Wed, 12 Nov 2025 11:55:35 +0100 Subject: [PATCH 2/7] Mypy check fix. --- src/event_gate_lambda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/event_gate_lambda.py b/src/event_gate_lambda.py index 3480fff..6fd8f55 100644 --- a/src/event_gate_lambda.py +++ b/src/event_gate_lambda.py @@ -20,7 +20,7 @@ import logging import os import sys -from typing import Any, Dict +from typing import Any, Dict, cast import boto3 import jwt @@ -100,7 +100,7 @@ raise KeyError(f"No public keys found in {TOKEN_PUBLIC_KEYS_URL} endpoint response") TOKEN_PUBLIC_KEYS: list[RSAPublicKey] = [ - serialization.load_der_public_key(base64.b64decode(raw_key)) for raw_key in raw_keys + cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys ] logger.debug("Loaded %d TOKEN_PUBLIC_KEYS", len(TOKEN_PUBLIC_KEYS)) except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: From d23819c05c3f66a699427cc7df69f2c7b510ed83 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 25 Nov 2025 10:09:19 +0100 Subject: [PATCH 3/7] HandlerToken class created to manage token related operations. --- src/event_gate_lambda.py | 108 ++--------- src/handlers/__init__.py | 15 ++ src/handlers/handler_token.py | 142 +++++++++++++++ src/utils/constants.py | 24 +++ tests/conftest.py | 128 ++++++++++++++ tests/handlers/__init__.py | 15 ++ tests/handlers/test_handler_token.py | 91 ++++++++++ tests/test_event_gate_lambda.py | 177 +------------------ tests/test_event_gate_lambda_local_access.py | 2 +- tests/utils/test_extract_token.py | 72 -------- 10 files changed, 431 insertions(+), 343 deletions(-) create mode 100644 src/handlers/__init__.py create mode 100644 src/handlers/handler_token.py create mode 100644 src/utils/constants.py create mode 100644 tests/handlers/__init__.py create mode 100644 tests/handlers/test_handler_token.py delete mode 100644 tests/utils/test_extract_token.py diff --git a/src/event_gate_lambda.py b/src/event_gate_lambda.py index 6fd8f55..6298d46 100644 --- a/src/event_gate_lambda.py +++ b/src/event_gate_lambda.py @@ -15,23 +15,19 @@ # """Event Gate Lambda function implementation.""" -import base64 import json import logging import os import sys -from typing import Any, Dict, cast +from typing import Any, Dict import boto3 import jwt -import requests import urllib3 -from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from jsonschema import validate from jsonschema.exceptions import ValidationError +from src.handlers.handler_token import HandlerToken from src.writers import writer_eventbridge, writer_kafka, writer_postgres from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV @@ -65,51 +61,28 @@ logger.debug("Loaded TOPICS") with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file: - CONFIG = json.load(file) + config = json.load(file) logger.debug("Loaded main CONFIG") aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally logger.debug("Initialized AWS S3 Client") -if CONFIG["access_config"].startswith("s3://"): - name_parts = CONFIG["access_config"].split("/") +if config["access_config"].startswith("s3://"): + name_parts = config["access_config"].split("/") BUCKET_NAME = name_parts[2] BUCKET_OBJECT_KEY = "/".join(name_parts[3:]) ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8")) else: - with open(CONFIG["access_config"], "r", encoding="utf-8") as file: + with open(config["access_config"], "r", encoding="utf-8") as file: ACCESS = json.load(file) logger.debug("Loaded ACCESS definitions") -# Initialize token public keys -TOKEN_PROVIDER_URL = CONFIG.get("token_provider_url") -TOKEN_PUBLIC_KEYS_URL = CONFIG.get("token_public_keys_url") or CONFIG.get("token_public_key_url") - -try: - response_json = requests.get(TOKEN_PUBLIC_KEYS_URL, verify=False, timeout=5).json() - raw_keys: list[str] = [] - if isinstance(response_json, dict): - if "keys" in response_json and isinstance(response_json["keys"], list): - for item in response_json["keys"]: - if "key" in item: - raw_keys.append(item["key"].strip()) - elif "key" in response_json: - raw_keys.append(response_json["key"].strip()) - - if not raw_keys: - raise KeyError(f"No public keys found in {TOKEN_PUBLIC_KEYS_URL} endpoint response") - - TOKEN_PUBLIC_KEYS: list[RSAPublicKey] = [ - cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys - ] - logger.debug("Loaded %d TOKEN_PUBLIC_KEYS", len(TOKEN_PUBLIC_KEYS)) -except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: - logger.exception("Failed to fetch or deserialize token public key from %s", TOKEN_PUBLIC_KEYS_URL) - raise RuntimeError("Token public key initialization failed") from exc +# Initialize token handler and load token public keys +handler_token = HandlerToken(config).load_public_keys() # Initialize EventGate writers -writer_eventbridge.init(logger, CONFIG) -writer_kafka.init(logger, CONFIG) +writer_eventbridge.init(logger, config) +writer_kafka.init(logger, config) writer_postgres.init(logger) @@ -141,12 +114,6 @@ def get_api() -> Dict[str, Any]: return {"statusCode": 200, "body": API} -def get_token() -> Dict[str, Any]: - """Return 303 redirect to token provider endpoint.""" - logger.debug("Handling GET Token") - return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}} - - def get_topics() -> Dict[str, Any]: """Return list of available topic names.""" logger.debug("Handling GET Topics") @@ -180,7 +147,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc """ logger.debug("Handling POST %s", topic_name) try: - token = decode_jwt_all(token_encoded) + token: Dict[str, Any] = handler_token.decode_jwt(token_encoded) except jwt.PyJWTError: # type: ignore[attr-defined] return _error_response(401, "auth", "Invalid or missing token") @@ -222,55 +189,6 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc } -def decode_jwt_all(token_encoded: str) -> Dict[str, Any]: - """Decode JWT using any of the loaded public keys. - - Args: - token_encoded: Encoded bearer JWT token string. - """ - for public_key in TOKEN_PUBLIC_KEYS: - try: - return jwt.decode(token_encoded, public_key, algorithms=["RS256"]) - except jwt.PyJWTError: - continue - raise jwt.PyJWTError("Verification failed for all public keys") - - -def extract_token(event_headers: Dict[str, str]) -> str: - """Extract bearer token from headers (case-insensitive). - - Supports: - - Custom 'bearer' header (any casing) whose value is the raw token - - Standard 'Authorization: Bearer ' header (case-insensitive scheme & key) - Returns empty string if token not found or malformed. - """ - if not event_headers: - return "" - - # Normalize keys to lowercase for case-insensitive lookup - lowered = {str(k).lower(): v for k, v in event_headers.items()} - - # Direct bearer header (raw token) - if "bearer" in lowered and isinstance(lowered["bearer"], str): - token_candidate = lowered["bearer"].strip() - if token_candidate: - return token_candidate - - # Authorization header with Bearer scheme - auth_val = lowered.get("authorization", "") - if not isinstance(auth_val, str): # defensive - return "" - auth_val = auth_val.strip() - if not auth_val: - return "" - - # Case-insensitive match for 'Bearer ' prefix - if not auth_val.lower().startswith("bearer "): - return "" - token_part = auth_val[7:].strip() # len('Bearer ')==7 - return token_part - - def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements """AWS Lambda entry point. @@ -281,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus if resource == "/api": return get_api() if resource == "/token": - return get_token() + return handler_token.get_token() if resource == "/topics": return get_topics() if resource == "/topics/{topic_name}": @@ -292,7 +210,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus return post_topic_message( event["pathParameters"]["topic_name"].lower(), json.loads(event["body"]), - extract_token(event.get("headers", {})), + handler_token.extract_token(event.get("headers", {})), ) if resource == "/terminate": sys.exit("TERMINATING") # pragma: no cover - deliberate termination path diff --git a/src/handlers/__init__.py b/src/handlers/__init__.py new file mode 100644 index 0000000..f7115cb --- /dev/null +++ b/src/handlers/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/handlers/handler_token.py b/src/handlers/handler_token.py new file mode 100644 index 0000000..0bdcdc1 --- /dev/null +++ b/src/handlers/handler_token.py @@ -0,0 +1,142 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides the HandlerToken class for managing the token related operations. +""" + +import base64 +import logging +import os +from typing import Dict, Any, cast + +import jwt +import requests +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + +from src.utils.constants import TOKEN_PROVIDER_URL, TOKEN_PUBLIC_KEYS_URL, TOKEN_PUBLIC_KEY_URL + +logger = logging.getLogger(__name__) +log_level = os.environ.get("LOG_LEVEL", "INFO") +logger.setLevel(log_level) + + +class HandlerToken: + """ + HandlerToken manages token provider URL and public keys for JWT verification. + """ + + def __init__(self, config): + self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "") + self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL) + self.public_keys: list[RSAPublicKey] = [] + + def load_public_keys(self) -> "HandlerToken": + """ + Load token public keys from the configured URL. + Returns: + HandlerToken: The current instance with loaded public keys. + Raises: + RuntimeError: If fetching or deserializing the public keys fails. + """ + logger.debug("Loading token public keys from %s", self.public_keys_url) + + try: + response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json() + raw_keys: list[str] = [] + + if isinstance(response_json, dict): + if "keys" in response_json and isinstance(response_json["keys"], list): + for item in response_json["keys"]: + if "key" in item: + raw_keys.append(item["key"].strip()) + elif "key" in response_json: + raw_keys.append(response_json["key"].strip()) + + if not raw_keys: + raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response") + + self.public_keys = [ + cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys + ] + logger.debug("Loaded %d token public keys", len(self.public_keys)) + + return self + except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: + logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url) + raise RuntimeError("Token public key initialization failed") from exc + + def decode_jwt(self, token_encoded: str) -> Dict[str, Any]: + """ + Decode and verify a JWT using the loaded public keys. + Args: + token_encoded (str): The encoded JWT token. + Returns: + Dict[str, Any]: The decoded JWT payload. + Raises: + jwt.PyJWTError: If verification fails for all public keys. + """ + logger.debug("Decoding JWT") + for public_key in self.public_keys: + try: + return jwt.decode(token_encoded, public_key, algorithms=["RS256"]) + except jwt.PyJWTError: + continue + raise jwt.PyJWTError("Verification failed for all public keys") + + def get_token(self) -> Dict[str, Any]: + """ + Returns: A 303 redirect response to the token provider URL. + """ + logger.debug("Handling GET Token") + return {"statusCode": 303, "headers": {"Location": self.provider_url}} + + @staticmethod + def extract_token(event_headers: Dict[str, str]) -> str: + """ + Extracts the bearer (custom/standard) token from event headers. + Args: + event_headers (Dict[str, str]): The event headers. + Returns: + str: The extracted bearer token, or an empty string if not found. + """ + if not event_headers: + return "" + + # Normalize keys to lowercase for case-insensitive lookup + lowered = {str(k).lower(): v for k, v in event_headers.items()} + + # Direct bearer header (raw token) + if "bearer" in lowered and isinstance(lowered["bearer"], str): + token_candidate = lowered["bearer"].strip() + if token_candidate: + return token_candidate + + # Authorization header with Bearer scheme + auth_val = lowered.get("authorization", "") + if not isinstance(auth_val, str): # defensive + return "" + auth_val = auth_val.strip() + if not auth_val: + return "" + + # Case-insensitive match for 'Bearer ' prefix + if not auth_val.lower().startswith("bearer "): + return "" + token_part = auth_val[7:].strip() # len('Bearer ')==7 + return token_part diff --git a/src/utils/constants.py b/src/utils/constants.py new file mode 100644 index 0000000..1affe7c --- /dev/null +++ b/src/utils/constants.py @@ -0,0 +1,24 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module contains all constants and enums used across the project. +""" + +# Token related constants +TOKEN_PROVIDER_URL = "token_provider_url" +TOKEN_PUBLIC_KEY_URL = "token_public_key_url" +TOKEN_PUBLIC_KEYS_URL = "token_public_keys_url" diff --git a/tests/conftest.py b/tests/conftest.py index 3ee2e24..52e5cac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,8 +14,136 @@ # limitations under the License. # import os, sys +import base64 +import importlib +import json +import types +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +import pytest + # Ensure project root is on sys.path so 'src' package is importable during tests PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) + + +@pytest.fixture(scope="module") +def event_gate_module(): + """Import `src.event_gate_lambda` with external deps patched/mocked. + + This fixture centralises the heavy environment setup shared by + multiple test modules. + """ + started_patches = [] + exit_stack = ExitStack() + + def start_patch(target: str): + p = patch(target) + started_patches.append(p) + return p.start() + + # Local, temporary dummy modules only if truly missing + if importlib.util.find_spec("confluent_kafka") is None: # pragma: no cover - environment dependent + dummy_ck = types.ModuleType("confluent_kafka") + + class DummyProducer: # minimal interface + def __init__(self, *_, **__): + pass + + def produce(self, *_, **kwargs): + cb = kwargs.get("callback") + if cb: + cb(None, None) + + def flush(self): # noqa: D401 - simple stub + return None + + dummy_ck.Producer = DummyProducer # type: ignore[attr-defined] + + class DummyKafkaException(Exception): + pass + + dummy_ck.KafkaException = DummyKafkaException # type: ignore[attr-defined] + exit_stack.enter_context(patch.dict(sys.modules, {"confluent_kafka": dummy_ck})) + + if importlib.util.find_spec("psycopg2") is None: # pragma: no cover - environment dependent + dummy_pg = types.ModuleType("psycopg2") + exit_stack.enter_context(patch.dict(sys.modules, {"psycopg2": dummy_pg})) + + mock_requests_get = start_patch("requests.get") + mock_requests_get.return_value.json.return_value = {"key": base64.b64encode(b"dummy_der").decode("utf-8")} + + mock_load_key = start_patch("cryptography.hazmat.primitives.serialization.load_der_public_key") + mock_load_key.return_value = object() + + class MockS3ObjectBody: + def read(self): + return json.dumps( + { + "public.cps.za.runs": ["FooBarUser"], + "public.cps.za.dlchange": ["FooUser", "BarUser"], + "public.cps.za.test": ["TestUser"], + } + ).encode("utf-8") + + class MockS3Object: + def get(self): + return {"Body": MockS3ObjectBody()} + + class MockS3Bucket: + def Object(self, _key): # noqa: D401 - simple proxy + return MockS3Object() + + class MockS3Resource: + def Bucket(self, _name): # noqa: D401 - simple proxy + return MockS3Bucket() + + mock_session = start_patch("boto3.Session") + mock_session.return_value.resource.return_value = MockS3Resource() + + mock_boto_client = start_patch("boto3.client") + mock_events_client = MagicMock() + mock_events_client.put_events.return_value = {"FailedEntryCount": 0} + mock_boto_client.return_value = mock_events_client + + # Allow kafka producer patching (already stubbed) but still patch to inspect if needed + start_patch("confluent_kafka.Producer") + + module = importlib.import_module("src.event_gate_lambda") + + yield module + + for p in started_patches: + p.stop() + exit_stack.close() + + +@pytest.fixture +def make_event(): + """Build a minimal API Gateway-style event dict for tests.""" + + def _make(resource, method="GET", body=None, topic=None, headers=None): + return { + "resource": resource, + "httpMethod": method, + "headers": headers or {}, + "pathParameters": {"topic_name": topic} if topic else {}, + "body": json.dumps(body) if isinstance(body, dict) else body, + } + + return _make + + +@pytest.fixture +def valid_payload(): + """A canonical valid payload used across tests.""" + return { + "event_id": "e1", + "tenant_id": "t1", + "source_app": "app", + "environment": "dev", + "timestamp": 123, + } diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py new file mode 100644 index 0000000..f7115cb --- /dev/null +++ b/tests/handlers/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/handlers/test_handler_token.py b/tests/handlers/test_handler_token.py new file mode 100644 index 0000000..801a20b --- /dev/null +++ b/tests/handlers/test_handler_token.py @@ -0,0 +1,91 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from unittest.mock import patch + +import pytest + +from src.handlers.handler_token import HandlerToken + + +def test_get_token_endpoint(event_gate_module, make_event): + event = make_event("/token") + resp = event_gate_module.lambda_handler(event, None) + assert resp["statusCode"] == 303 + assert "Location" in resp["headers"] + + +def test_post_expired_token(event_gate_module, make_event, valid_payload): + """Expired JWT should yield 401 auth error.""" + + with patch.object( + event_gate_module.jwt, + "decode", + side_effect=event_gate_module.jwt.ExpiredSignatureError("expired"), + create=True, + ): + event = make_event( + "/topics/{topic_name}", + method="POST", + topic="public.cps.za.test", + body=valid_payload, + headers={"Authorization": "Bearer expiredtoken"}, + ) + resp = event_gate_module.lambda_handler(event, None) + assert resp["statusCode"] == 401 + body = json.loads(resp["body"]) + assert any(e["type"] == "auth" for e in body["errors"]) + + +def test_decode_jwt_all_second_key_succeeds(event_gate_module): + """First key fails signature, second key succeeds; claims returned from second key.""" + # Arrange: two dummy public keys + first_key = object() + second_key = object() + event_gate_module.handler_token.public_keys = [first_key, second_key] + + def decode_side_effect(token, key, algorithms): # noqa: D401 - test stub + if key is first_key: + raise event_gate_module.jwt.PyJWTError("signature mismatch") + return {"sub": "TestUser"} + + with patch.object(event_gate_module.jwt, "decode", side_effect=decode_side_effect, create=True): + claims = event_gate_module.handler_token.decode_jwt("dummy-token") + assert claims["sub"] == "TestUser" + + +def test_decode_jwt_all_all_keys_fail(event_gate_module): + """All keys fail; final PyJWTError with aggregate message is raised.""" + bad_keys = [object(), object()] + event_gate_module.handler_token.public_keys = bad_keys + + def always_fail(token, key, algorithms): # noqa: D401 - test stub + raise event_gate_module.jwt.PyJWTError("bad signature") + + with patch.object(event_gate_module.jwt, "decode", side_effect=always_fail, create=True): + with pytest.raises(event_gate_module.jwt.PyJWTError) as exc: + event_gate_module.handler_token.decode_jwt("dummy-token") + assert "Verification failed for all public keys" in str(exc.value) + + +def test_extract_token_empty(): + assert HandlerToken.extract_token({}) == "" + + +def test_extract_token_direct_bearer_header(): + token = HandlerToken.extract_token({"Bearer": " tok123 "}) + assert token == "tok123" diff --git a/tests/test_event_gate_lambda.py b/tests/test_event_gate_lambda.py index 28022ad..e6683b6 100644 --- a/tests/test_event_gate_lambda.py +++ b/tests/test_event_gate_lambda.py @@ -13,124 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import base64 -import json -import importlib -import sys -import types -import pytest -from unittest.mock import patch, MagicMock -import importlib.util -from contextlib import ExitStack - - -@pytest.fixture(scope="module") -def event_gate_module(): - started_patches = [] - exit_stack = ExitStack() - - def start_patch(target): - p = patch(target) - started_patches.append(p) - return p.start() - - # Local, temporary dummy modules only if truly missing - # confluent_kafka - if importlib.util.find_spec("confluent_kafka") is None: # pragma: no cover - environment dependent - dummy_ck = types.ModuleType("confluent_kafka") - - class DummyProducer: # minimal interface - def __init__(self, *a, **kw): - pass - - def produce(self, *a, **kw): - cb = kw.get("callback") - if cb: - cb(None, None) - - def flush(self): # noqa: D401 - simple stub - return None - - dummy_ck.Producer = DummyProducer # type: ignore[attr-defined] - - class DummyKafkaException(Exception): - pass - - dummy_ck.KafkaException = DummyKafkaException # type: ignore[attr-defined] - exit_stack.enter_context(patch.dict(sys.modules, {"confluent_kafka": dummy_ck})) - - # psycopg2 optional dependency - if importlib.util.find_spec("psycopg2") is None: # pragma: no cover - environment dependent - dummy_pg = types.ModuleType("psycopg2") - exit_stack.enter_context(patch.dict(sys.modules, {"psycopg2": dummy_pg})) - - mock_requests_get = start_patch("requests.get") - mock_requests_get.return_value.json.return_value = {"key": base64.b64encode(b"dummy_der").decode("utf-8")} - - mock_load_key = start_patch("cryptography.hazmat.primitives.serialization.load_der_public_key") - mock_load_key.return_value = object() - - # Mock S3 access_config retrieval - class MockS3ObjectBody: - def read(self): - return json.dumps( - { - "public.cps.za.runs": ["FooBarUser"], - "public.cps.za.dlchange": ["FooUser", "BarUser"], - "public.cps.za.test": ["TestUser"], - } - ).encode("utf-8") - - class MockS3Object: - def get(self): - return {"Body": MockS3ObjectBody()} - - class MockS3Bucket: - def Object(self, key): # noqa: D401 - simple proxy - return MockS3Object() - - class MockS3Resource: - def Bucket(self, name): # noqa: D401 - simple proxy - return MockS3Bucket() - - mock_session = start_patch("boto3.Session") - mock_session.return_value.resource.return_value = MockS3Resource() - - # Mock EventBridge client - mock_boto_client = start_patch("boto3.client") - mock_events_client = MagicMock() - mock_events_client.put_events.return_value = {"FailedEntryCount": 0} - mock_boto_client.return_value = mock_events_client - - # Allow kafka producer patching (already stubbed) but still patch to inspect if needed - start_patch("confluent_kafka.Producer") - - module = importlib.import_module("src.event_gate_lambda") - - yield module - - for p in started_patches: - p.stop() - exit_stack.close() - - -@pytest.fixture -def make_event(): - def _make(resource, method="GET", body=None, topic=None, headers=None): - return { - "resource": resource, - "httpMethod": method, - "headers": headers or {}, - "pathParameters": {"topic_name": topic} if topic else {}, - "body": json.dumps(body) if isinstance(body, dict) else body, - } - - return _make - -@pytest.fixture -def valid_payload(): - return {"event_id": "e1", "tenant_id": "t1", "source_app": "app", "environment": "dev", "timestamp": 123} +import json +from unittest.mock import patch # --- GET flows --- @@ -323,13 +208,6 @@ def test_get_api_endpoint(event_gate_module, make_event): assert "openapi" in resp["body"].lower() -def test_get_token_endpoint(event_gate_module, make_event): - event = make_event("/token") - resp = event_gate_module.lambda_handler(event, None) - assert resp["statusCode"] == 303 - assert "Location" in resp["headers"] - - def test_internal_error_path(event_gate_module, make_event): with patch("src.event_gate_lambda.get_topics", side_effect=RuntimeError("boom")): event = make_event("/topics") @@ -352,54 +230,3 @@ def test_post_invalid_json_body(event_gate_module, make_event): assert resp["statusCode"] == 500 body = json.loads(resp["body"]) assert any(e["type"] == "internal" for e in body["errors"]) # internal error path - - -def test_post_expired_token(event_gate_module, make_event, valid_payload): - """Expired JWT should yield 401 auth error.""" - with patch.object( - event_gate_module.jwt, - "decode", - side_effect=event_gate_module.jwt.ExpiredSignatureError("expired"), - create=True, - ): - event = make_event( - "/topics/{topic_name}", - method="POST", - topic="public.cps.za.test", - body=valid_payload, - headers={"Authorization": "Bearer expiredtoken"}, - ) - resp = event_gate_module.lambda_handler(event, None) - assert resp["statusCode"] == 401 - body = json.loads(resp["body"]) - assert any(e["type"] == "auth" for e in body["errors"]) - - -def test_decode_jwt_all_second_key_succeeds(event_gate_module): - """First key fails signature, second key succeeds; claims returned from second key.""" - first_key = object() - second_key = object() - event_gate_module.TOKEN_PUBLIC_KEYS = [first_key, second_key] - - def decode_side_effect(token, key, algorithms): - if key is first_key: - raise event_gate_module.jwt.PyJWTError("signature mismatch") - return {"sub": "TestUser"} - - with patch.object(event_gate_module.jwt, "decode", side_effect=decode_side_effect, create=True): - claims = event_gate_module.decode_jwt_all("dummy-token") - assert claims["sub"] == "TestUser" - - -def test_decode_jwt_all_all_keys_fail(event_gate_module): - """All keys fail; final PyJWTError with aggregate message is raised.""" - bad_keys = [object(), object()] - event_gate_module.TOKEN_PUBLIC_KEYS = bad_keys - - def always_fail(token, key, algorithms): - raise event_gate_module.jwt.PyJWTError("bad signature") - - with patch.object(event_gate_module.jwt, "decode", side_effect=always_fail, create=True): - with pytest.raises(event_gate_module.jwt.PyJWTError) as exc: - event_gate_module.decode_jwt_all("dummy-token") - assert "Verification failed for all public keys" in str(exc.value) diff --git a/tests/test_event_gate_lambda_local_access.py b/tests/test_event_gate_lambda_local_access.py index 525cdcb..714e4a8 100644 --- a/tests/test_event_gate_lambda_local_access.py +++ b/tests/test_event_gate_lambda_local_access.py @@ -54,5 +54,5 @@ def Bucket(self, name): # noqa: D401 # Force reload so import-level logic re-executes with patched open egl_reloaded = importlib.reload(egl) - assert not egl_reloaded.CONFIG["access_config"].startswith("s3://") # type: ignore[attr-defined] + assert not egl_reloaded.config["access_config"].startswith("s3://") # type: ignore[attr-defined] assert egl_reloaded.ACCESS["public.cps.za.test"] == ["User"] # type: ignore[attr-defined] diff --git a/tests/utils/test_extract_token.py b/tests/utils/test_extract_token.py deleted file mode 100644 index 5a896a7..0000000 --- a/tests/utils/test_extract_token.py +++ /dev/null @@ -1,72 +0,0 @@ -import base64 -import json -import importlib -from unittest.mock import patch, MagicMock -import pytest - - -@pytest.fixture(scope="module") -def egl_mod(): - patches = [] - - def start(target): - p = patch(target) - patches.append(p) - return p.start() - - # Patch requests.get for public key fetch - mock_get = start("requests.get") - mock_get.return_value.json.return_value = {"key": base64.b64encode(b"dummy_der").decode("utf-8")} - - # Patch crypto key loader - start_loader = start("cryptography.hazmat.primitives.serialization.load_der_public_key") - start_loader.return_value = object() - - # Patch boto3.Session resource for S3 to avoid bucket validation/network - class MockBody: - def read(self): - return json.dumps( - { - "public.cps.za.runs": ["User"], - "public.cps.za.dlchange": ["User"], - "public.cps.za.test": ["User"], - } - ).encode("utf-8") - - class MockObject: - def get(self): - return {"Body": MockBody()} - - class MockBucket: - def Object(self, key): # noqa: D401 - return MockObject() - - class MockS3: - def Bucket(self, name): # noqa: D401 - return MockBucket() - - mock_session = start("boto3.Session") - mock_session.return_value.resource.return_value = MockS3() - - # Patch boto3.client for EventBridge - mock_client = start("boto3.client") - mock_events = MagicMock() - mock_events.put_events.return_value = {"FailedEntryCount": 0} - mock_client.return_value = mock_events - - # Patch Kafka Producer - start("confluent_kafka.Producer") - - module = importlib.import_module("src.event_gate_lambda") - yield module - - for p in patches: - p.stop() - - -def test_extract_token_empty(egl_mod): - assert egl_mod.extract_token({}) == "" - - -def test_extract_token_direct_bearer_header(egl_mod): - token = egl_mod.extract_token({"Bearer": " tok123 "}) From a0295d45d0f299106ca2105c62ef31a420549793 Mon Sep 17 00:00:00 2001 From: Tobias Mikula <72911271+tmikula-dev@users.noreply.github.com> Date: Mon, 1 Dec 2025 14:22:25 +0100 Subject: [PATCH 4/7] Periodically refresh JWT public key set (#88) * Periodic JWT Public Key Refresh Implementation --- src/handlers/handler_token.py | 25 +++++++++++++ tests/handlers/test_handler_token.py | 55 +++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/handlers/handler_token.py b/src/handlers/handler_token.py index 0bdcdc1..f0a9fe0 100644 --- a/src/handlers/handler_token.py +++ b/src/handlers/handler_token.py @@ -21,6 +21,7 @@ import base64 import logging import os +from datetime import datetime, timedelta, timezone from typing import Dict, Any, cast import jwt @@ -41,10 +42,31 @@ class HandlerToken: HandlerToken manages token provider URL and public keys for JWT verification. """ + _REFRESH_INTERVAL = timedelta(minutes=28) + def __init__(self, config): self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "") self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL) self.public_keys: list[RSAPublicKey] = [] + self._last_loaded_at: datetime | None = None + + def _refresh_keys_if_needed(self) -> None: + """ + Refresh the public keys if the refresh interval has passed. + """ + logger.debug("Checking if the token public keys need refresh") + + if self._last_loaded_at is None: + return + now = datetime.now(timezone.utc) + if now - self._last_loaded_at < self._REFRESH_INTERVAL: + logger.debug("Token public keys are up to date, no refresh needed") + return + try: + logger.debug("Token public keys are stale, refreshing now") + self.load_public_keys() + except RuntimeError: + logger.warning("Token public key refresh failed, using existing keys") def load_public_keys(self) -> "HandlerToken": """ @@ -75,6 +97,7 @@ def load_public_keys(self) -> "HandlerToken": cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys ] logger.debug("Loaded %d token public keys", len(self.public_keys)) + self._last_loaded_at = datetime.now(timezone.utc) return self except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: @@ -91,6 +114,8 @@ def decode_jwt(self, token_encoded: str) -> Dict[str, Any]: Raises: jwt.PyJWTError: If verification fails for all public keys. """ + self._refresh_keys_if_needed() + logger.debug("Decoding JWT") for public_key in self.public_keys: try: diff --git a/tests/handlers/test_handler_token.py b/tests/handlers/test_handler_token.py index 801a20b..d477636 100644 --- a/tests/handlers/test_handler_token.py +++ b/tests/handlers/test_handler_token.py @@ -15,13 +15,22 @@ # import json -from unittest.mock import patch +from datetime import datetime, timedelta, timezone +from unittest.mock import patch, Mock import pytest +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from src.handlers.handler_token import HandlerToken +@pytest.fixture +def token_handler(): + """Create a HandlerToken instance for testing.""" + config = {"token_public_keys_url": "https://example.com/keys"} + return HandlerToken(config) + + def test_get_token_endpoint(event_gate_module, make_event): event = make_event("/token") resp = event_gate_module.lambda_handler(event, None) @@ -89,3 +98,47 @@ def test_extract_token_empty(): def test_extract_token_direct_bearer_header(): token = HandlerToken.extract_token({"Bearer": " tok123 "}) assert token == "tok123" + + +## Checking the freshness of public keys +def test_refresh_keys_not_needed_when_keys_fresh(token_handler): + """Keys loaded less than 30 minutes ago should not trigger refresh.""" + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10) + token_handler.public_keys = [Mock(spec=RSAPublicKey)] + + with patch.object(token_handler, "load_public_keys") as mock_load: + token_handler._refresh_keys_if_needed() + mock_load.assert_not_called() + + +def test_refresh_keys_triggered_when_keys_stale(token_handler): + """Keys loaded more than 30 minutes ago should trigger refresh.""" + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) + token_handler.public_keys = [Mock(spec=RSAPublicKey)] + + with patch.object(token_handler, "load_public_keys") as mock_load: + token_handler._refresh_keys_if_needed() + mock_load.assert_called_once() + + +def test_refresh_keys_handles_load_failure_gracefully(token_handler): + """If key refresh fails, should log warning and continue with existing keys.""" + old_key = Mock(spec=RSAPublicKey) + token_handler.public_keys = [old_key] + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) + + with patch.object(token_handler, "load_public_keys", side_effect=RuntimeError("Network error")): + token_handler._refresh_keys_if_needed() + assert token_handler.public_keys == [old_key] + + +def test_decode_jwt_triggers_refresh_check(token_handler): + """Decoding JWT should check if keys need refresh before decoding.""" + dummy_key = Mock(spec=RSAPublicKey) + token_handler.public_keys = [dummy_key] + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10) + + with patch.object(token_handler, "_refresh_keys_if_needed") as mock_refresh: + with patch("jwt.decode", return_value={"sub": "TestUser"}): + token_handler.decode_jwt("dummy-token") + mock_refresh.assert_called_once() From 788a90aa28f53cf9abe5003deeaee82cb5bf8e4a Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Mon, 1 Dec 2025 14:57:45 +0100 Subject: [PATCH 5/7] Recommended changes implementation --- src/event_gate_lambda.py | 2 +- src/handlers/handler_token.py | 8 ++++---- src/utils/constants.py | 8 ++++---- tests/conftest.py | 4 ++-- tests/handlers/test_handler_token.py | 4 ++-- tests/test_event_gate_lambda_local_access.py | 3 ++- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/event_gate_lambda.py b/src/event_gate_lambda.py index 6298d46..33d4203 100644 --- a/src/event_gate_lambda.py +++ b/src/event_gate_lambda.py @@ -199,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus if resource == "/api": return get_api() if resource == "/token": - return handler_token.get_token() + return handler_token.get_token_provider_info() if resource == "/topics": return get_topics() if resource == "/topics/{topic_name}": diff --git a/src/handlers/handler_token.py b/src/handlers/handler_token.py index f0a9fe0..44579f1 100644 --- a/src/handlers/handler_token.py +++ b/src/handlers/handler_token.py @@ -30,7 +30,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey -from src.utils.constants import TOKEN_PROVIDER_URL, TOKEN_PUBLIC_KEYS_URL, TOKEN_PUBLIC_KEY_URL +from src.utils.constants import TOKEN_PROVIDER_URL_KEY, TOKEN_PUBLIC_KEYS_URL_KEY, TOKEN_PUBLIC_KEY_URL_KEY logger = logging.getLogger(__name__) log_level = os.environ.get("LOG_LEVEL", "INFO") @@ -45,8 +45,8 @@ class HandlerToken: _REFRESH_INTERVAL = timedelta(minutes=28) def __init__(self, config): - self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "") - self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL) + self.provider_url: str = config.get(TOKEN_PROVIDER_URL_KEY, "") + self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL_KEY) or config.get(TOKEN_PUBLIC_KEY_URL_KEY) self.public_keys: list[RSAPublicKey] = [] self._last_loaded_at: datetime | None = None @@ -124,7 +124,7 @@ def decode_jwt(self, token_encoded: str) -> Dict[str, Any]: continue raise jwt.PyJWTError("Verification failed for all public keys") - def get_token(self) -> Dict[str, Any]: + def get_token_provider_info(self) -> Dict[str, Any]: """ Returns: A 303 redirect response to the token provider URL. """ diff --git a/src/utils/constants.py b/src/utils/constants.py index 1affe7c..0d1eddb 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -18,7 +18,7 @@ This module contains all constants and enums used across the project. """ -# Token related constants -TOKEN_PROVIDER_URL = "token_provider_url" -TOKEN_PUBLIC_KEY_URL = "token_public_key_url" -TOKEN_PUBLIC_KEYS_URL = "token_public_keys_url" +# Token related configuration keys +TOKEN_PROVIDER_URL_KEY = "token_provider_url" +TOKEN_PUBLIC_KEY_URL_KEY = "token_public_key_url" +TOKEN_PUBLIC_KEYS_URL_KEY = "token_public_keys_url" diff --git a/tests/conftest.py b/tests/conftest.py index 52e5cac..e72da94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,8 +109,8 @@ def Bucket(self, _name): # noqa: D401 - simple proxy mock_events_client.put_events.return_value = {"FailedEntryCount": 0} mock_boto_client.return_value = mock_events_client - # Allow kafka producer patching (already stubbed) but still patch to inspect if needed - start_patch("confluent_kafka.Producer") + mock_kafka_producer = start_patch("confluent_kafka.Producer") + mock_kafka_producer.return_value = MagicMock() module = importlib.import_module("src.event_gate_lambda") diff --git a/tests/handlers/test_handler_token.py b/tests/handlers/test_handler_token.py index d477636..b519be4 100644 --- a/tests/handlers/test_handler_token.py +++ b/tests/handlers/test_handler_token.py @@ -102,7 +102,7 @@ def test_extract_token_direct_bearer_header(): ## Checking the freshness of public keys def test_refresh_keys_not_needed_when_keys_fresh(token_handler): - """Keys loaded less than 30 minutes ago should not trigger refresh.""" + """Keys loaded less than 28 minutes ago should not trigger refresh.""" token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10) token_handler.public_keys = [Mock(spec=RSAPublicKey)] @@ -112,7 +112,7 @@ def test_refresh_keys_not_needed_when_keys_fresh(token_handler): def test_refresh_keys_triggered_when_keys_stale(token_handler): - """Keys loaded more than 30 minutes ago should trigger refresh.""" + """Keys loaded more than 28 minutes ago should trigger refresh.""" token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) token_handler.public_keys = [Mock(spec=RSAPublicKey)] diff --git a/tests/test_event_gate_lambda_local_access.py b/tests/test_event_gate_lambda_local_access.py index 714e4a8..1f0023b 100644 --- a/tests/test_event_gate_lambda_local_access.py +++ b/tests/test_event_gate_lambda_local_access.py @@ -36,11 +36,12 @@ def open_side_effect(path, *args, **kwargs): # noqa: D401 patch("cryptography.hazmat.primitives.serialization.load_der_public_key") as mock_load_key, patch("boto3.Session") as mock_session, patch("boto3.client") as mock_boto_client, - patch("confluent_kafka.Producer"), + patch("confluent_kafka.Producer") as mock_kafka_producer, patch("builtins.open", side_effect=open_side_effect), ): mock_get.return_value.json.return_value = {"key": "ZHVtbXk="} # base64 for 'dummy' mock_load_key.return_value = object() + mock_kafka_producer.return_value = MagicMock() class MockS3: def Bucket(self, name): # noqa: D401 From 692b30907a1f9f664c8da3e4a683f709137186cd Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Mon, 1 Dec 2025 15:09:11 +0100 Subject: [PATCH 6/7] Deleting not used tags. --- tests/conftest.py | 16 ++++++++-------- tests/handlers/test_handler_token.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e72da94..7672f1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,10 +46,10 @@ def start_patch(target: str): return p.start() # Local, temporary dummy modules only if truly missing - if importlib.util.find_spec("confluent_kafka") is None: # pragma: no cover - environment dependent + if importlib.util.find_spec("confluent_kafka") is None: dummy_ck = types.ModuleType("confluent_kafka") - class DummyProducer: # minimal interface + class DummyProducer: def __init__(self, *_, **__): pass @@ -58,18 +58,18 @@ def produce(self, *_, **kwargs): if cb: cb(None, None) - def flush(self): # noqa: D401 - simple stub + def flush(self): return None - dummy_ck.Producer = DummyProducer # type: ignore[attr-defined] + dummy_ck.Producer = DummyProducer class DummyKafkaException(Exception): pass - dummy_ck.KafkaException = DummyKafkaException # type: ignore[attr-defined] + dummy_ck.KafkaException = DummyKafkaException exit_stack.enter_context(patch.dict(sys.modules, {"confluent_kafka": dummy_ck})) - if importlib.util.find_spec("psycopg2") is None: # pragma: no cover - environment dependent + if importlib.util.find_spec("psycopg2") is None: dummy_pg = types.ModuleType("psycopg2") exit_stack.enter_context(patch.dict(sys.modules, {"psycopg2": dummy_pg})) @@ -94,11 +94,11 @@ def get(self): return {"Body": MockS3ObjectBody()} class MockS3Bucket: - def Object(self, _key): # noqa: D401 - simple proxy + def Object(self, _key): return MockS3Object() class MockS3Resource: - def Bucket(self, _name): # noqa: D401 - simple proxy + def Bucket(self, _name): return MockS3Bucket() mock_session = start_patch("boto3.Session") diff --git a/tests/handlers/test_handler_token.py b/tests/handlers/test_handler_token.py index b519be4..c191c4e 100644 --- a/tests/handlers/test_handler_token.py +++ b/tests/handlers/test_handler_token.py @@ -67,7 +67,7 @@ def test_decode_jwt_all_second_key_succeeds(event_gate_module): second_key = object() event_gate_module.handler_token.public_keys = [first_key, second_key] - def decode_side_effect(token, key, algorithms): # noqa: D401 - test stub + def decode_side_effect(token, key, algorithms): if key is first_key: raise event_gate_module.jwt.PyJWTError("signature mismatch") return {"sub": "TestUser"} @@ -82,7 +82,7 @@ def test_decode_jwt_all_all_keys_fail(event_gate_module): bad_keys = [object(), object()] event_gate_module.handler_token.public_keys = bad_keys - def always_fail(token, key, algorithms): # noqa: D401 - test stub + def always_fail(token, key, algorithms): raise event_gate_module.jwt.PyJWTError("bad signature") with patch.object(event_gate_module.jwt, "decode", side_effect=always_fail, create=True): From ae60d1eceb3db662d8e37e95e0f114396f1a6378 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Mon, 1 Dec 2025 15:22:21 +0100 Subject: [PATCH 7/7] Import critical update made --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7672f1e..ef4e834 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ # import os, sys import base64 -import importlib +import importlib.util import json import types from contextlib import ExitStack