diff --git a/README.md b/README.md index e0cb1f4..418a5c7 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ Example (sanitized): { "access_config": "s3:///access.json", "token_provider_url": "https://", - "token_public_key_url": "https:///public-key", + "token_public_keys_url": "https:///token/public-keys", "kafka_bootstrap_server": "broker1:9092,broker2:9092", "event_bus_arn": "arn:aws:events:region:acct:event-bus/your-bus" } @@ -137,7 +137,7 @@ Use when Kafka access needs Kerberos / SASL_SSL or custom `librdkafka` build. | Code coverage | [Code Coverage](./DEVELOPER.md#code-coverage) | ## Security & Authorization -- JWT tokens must be RS256 signed; the public key is fetched at cold start from `token_public_key_url` (DER base64 inside JSON `{ "key": "..." }`). +- JWT tokens must be RS256 signed; current and previous public keys are fetched at cold start from `token_public_keys_url` as DER base64 values (list `keys[*].key`, with single-key fallback `{ "key": "..." }`). - Subject claim (`sub`) is matched against `ACCESS[topicName]`. - Authorization header forms accepted: - `Authorization: Bearer ` (preferred) diff --git a/conf/config.json b/conf/config.json index 84ab9f8..c35e61c 100644 --- a/conf/config.json +++ b/conf/config.json @@ -1,7 +1,7 @@ { "access_config": "s3:///access.json", "token_provider_url": "https://", - "token_public_key_url": "https://", + "token_public_keys_url": "https://", "kafka_bootstrap_server": "localhost:9092", "event_bus_arn": "arn:aws:events:" } \ No newline at end of file diff --git a/src/event_gate_lambda.py b/src/event_gate_lambda.py index 34dcb81..6298d46 100644 --- a/src/event_gate_lambda.py +++ b/src/event_gate_lambda.py @@ -15,7 +15,6 @@ # """Event Gate Lambda function implementation.""" -import base64 import json import logging import os @@ -24,13 +23,11 @@ import boto3 import jwt -import requests import urllib3 -from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.primitives import serialization from jsonschema import validate from jsonschema.exceptions import ValidationError +from src.handlers.handler_token import HandlerToken from src.writers import writer_eventbridge, writer_kafka, writer_postgres from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV @@ -64,35 +61,28 @@ logger.debug("Loaded TOPICS") with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file: - CONFIG = json.load(file) + config = json.load(file) logger.debug("Loaded main CONFIG") aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally logger.debug("Initialized AWS S3 Client") -if CONFIG["access_config"].startswith("s3://"): - name_parts = CONFIG["access_config"].split("/") +if config["access_config"].startswith("s3://"): + name_parts = config["access_config"].split("/") BUCKET_NAME = name_parts[2] BUCKET_OBJECT_KEY = "/".join(name_parts[3:]) ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8")) else: - with open(CONFIG["access_config"], "r", encoding="utf-8") as file: + with open(config["access_config"], "r", encoding="utf-8") as file: ACCESS = json.load(file) logger.debug("Loaded ACCESS definitions") -TOKEN_PROVIDER_URL = CONFIG["token_provider_url"] -# Add timeout to avoid hanging requests; wrap in robust error handling so failures are explicit -try: - response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external - token_public_key_encoded = response_json["key"] - TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded)) - logger.debug("Loaded TOKEN_PUBLIC_KEY") -except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: - logger.exception("Failed to fetch or deserialize token public key from %s", CONFIG.get("token_public_key_url")) - raise RuntimeError("Token public key initialization failed") from exc - -writer_eventbridge.init(logger, CONFIG) -writer_kafka.init(logger, CONFIG) +# Initialize token handler and load token public keys +handler_token = HandlerToken(config).load_public_keys() + +# Initialize EventGate writers +writer_eventbridge.init(logger, config) +writer_kafka.init(logger, config) writer_postgres.init(logger) @@ -124,12 +114,6 @@ def get_api() -> Dict[str, Any]: return {"statusCode": 200, "body": API} -def get_token() -> Dict[str, Any]: - """Return 303 redirect to token provider endpoint.""" - logger.debug("Handling GET Token") - return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}} - - def get_topics() -> Dict[str, Any]: """Return list of available topic names.""" logger.debug("Handling GET Topics") @@ -163,7 +147,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc """ logger.debug("Handling POST %s", topic_name) try: - token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type] + token: Dict[str, Any] = handler_token.decode_jwt(token_encoded) except jwt.PyJWTError: # type: ignore[attr-defined] return _error_response(401, "auth", "Invalid or missing token") @@ -205,41 +189,6 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc } -def extract_token(event_headers: Dict[str, str]) -> str: - """Extract bearer token from headers (case-insensitive). - - Supports: - - Custom 'bearer' header (any casing) whose value is the raw token - - Standard 'Authorization: Bearer ' header (case-insensitive scheme & key) - Returns empty string if token not found or malformed. - """ - if not event_headers: - return "" - - # Normalize keys to lowercase for case-insensitive lookup - lowered = {str(k).lower(): v for k, v in event_headers.items()} - - # Direct bearer header (raw token) - if "bearer" in lowered and isinstance(lowered["bearer"], str): - token_candidate = lowered["bearer"].strip() - if token_candidate: - return token_candidate - - # Authorization header with Bearer scheme - auth_val = lowered.get("authorization", "") - if not isinstance(auth_val, str): # defensive - return "" - auth_val = auth_val.strip() - if not auth_val: - return "" - - # Case-insensitive match for 'Bearer ' prefix - if not auth_val.lower().startswith("bearer "): - return "" - token_part = auth_val[7:].strip() # len('Bearer ')==7 - return token_part - - def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements """AWS Lambda entry point. @@ -250,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus if resource == "/api": return get_api() if resource == "/token": - return get_token() + return handler_token.get_token() if resource == "/topics": return get_topics() if resource == "/topics/{topic_name}": @@ -261,7 +210,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus return post_topic_message( event["pathParameters"]["topic_name"].lower(), json.loads(event["body"]), - extract_token(event.get("headers", {})), + handler_token.extract_token(event.get("headers", {})), ) if resource == "/terminate": sys.exit("TERMINATING") # pragma: no cover - deliberate termination path diff --git a/src/handlers/__init__.py b/src/handlers/__init__.py new file mode 100644 index 0000000..f7115cb --- /dev/null +++ b/src/handlers/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/handlers/handler_token.py b/src/handlers/handler_token.py new file mode 100644 index 0000000..0bdcdc1 --- /dev/null +++ b/src/handlers/handler_token.py @@ -0,0 +1,142 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides the HandlerToken class for managing the token related operations. +""" + +import base64 +import logging +import os +from typing import Dict, Any, cast + +import jwt +import requests +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + +from src.utils.constants import TOKEN_PROVIDER_URL, TOKEN_PUBLIC_KEYS_URL, TOKEN_PUBLIC_KEY_URL + +logger = logging.getLogger(__name__) +log_level = os.environ.get("LOG_LEVEL", "INFO") +logger.setLevel(log_level) + + +class HandlerToken: + """ + HandlerToken manages token provider URL and public keys for JWT verification. + """ + + def __init__(self, config): + self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "") + self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL) + self.public_keys: list[RSAPublicKey] = [] + + def load_public_keys(self) -> "HandlerToken": + """ + Load token public keys from the configured URL. + Returns: + HandlerToken: The current instance with loaded public keys. + Raises: + RuntimeError: If fetching or deserializing the public keys fails. + """ + logger.debug("Loading token public keys from %s", self.public_keys_url) + + try: + response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json() + raw_keys: list[str] = [] + + if isinstance(response_json, dict): + if "keys" in response_json and isinstance(response_json["keys"], list): + for item in response_json["keys"]: + if "key" in item: + raw_keys.append(item["key"].strip()) + elif "key" in response_json: + raw_keys.append(response_json["key"].strip()) + + if not raw_keys: + raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response") + + self.public_keys = [ + cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys + ] + logger.debug("Loaded %d token public keys", len(self.public_keys)) + + return self + except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: + logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url) + raise RuntimeError("Token public key initialization failed") from exc + + def decode_jwt(self, token_encoded: str) -> Dict[str, Any]: + """ + Decode and verify a JWT using the loaded public keys. + Args: + token_encoded (str): The encoded JWT token. + Returns: + Dict[str, Any]: The decoded JWT payload. + Raises: + jwt.PyJWTError: If verification fails for all public keys. + """ + logger.debug("Decoding JWT") + for public_key in self.public_keys: + try: + return jwt.decode(token_encoded, public_key, algorithms=["RS256"]) + except jwt.PyJWTError: + continue + raise jwt.PyJWTError("Verification failed for all public keys") + + def get_token(self) -> Dict[str, Any]: + """ + Returns: A 303 redirect response to the token provider URL. + """ + logger.debug("Handling GET Token") + return {"statusCode": 303, "headers": {"Location": self.provider_url}} + + @staticmethod + def extract_token(event_headers: Dict[str, str]) -> str: + """ + Extracts the bearer (custom/standard) token from event headers. + Args: + event_headers (Dict[str, str]): The event headers. + Returns: + str: The extracted bearer token, or an empty string if not found. + """ + if not event_headers: + return "" + + # Normalize keys to lowercase for case-insensitive lookup + lowered = {str(k).lower(): v for k, v in event_headers.items()} + + # Direct bearer header (raw token) + if "bearer" in lowered and isinstance(lowered["bearer"], str): + token_candidate = lowered["bearer"].strip() + if token_candidate: + return token_candidate + + # Authorization header with Bearer scheme + auth_val = lowered.get("authorization", "") + if not isinstance(auth_val, str): # defensive + return "" + auth_val = auth_val.strip() + if not auth_val: + return "" + + # Case-insensitive match for 'Bearer ' prefix + if not auth_val.lower().startswith("bearer "): + return "" + token_part = auth_val[7:].strip() # len('Bearer ')==7 + return token_part diff --git a/src/utils/constants.py b/src/utils/constants.py new file mode 100644 index 0000000..1affe7c --- /dev/null +++ b/src/utils/constants.py @@ -0,0 +1,24 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module contains all constants and enums used across the project. +""" + +# Token related constants +TOKEN_PROVIDER_URL = "token_provider_url" +TOKEN_PUBLIC_KEY_URL = "token_public_key_url" +TOKEN_PUBLIC_KEYS_URL = "token_public_keys_url" diff --git a/tests/conftest.py b/tests/conftest.py index 3ee2e24..52e5cac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,8 +14,136 @@ # limitations under the License. # import os, sys +import base64 +import importlib +import json +import types +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +import pytest + # Ensure project root is on sys.path so 'src' package is importable during tests PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) + + +@pytest.fixture(scope="module") +def event_gate_module(): + """Import `src.event_gate_lambda` with external deps patched/mocked. + + This fixture centralises the heavy environment setup shared by + multiple test modules. + """ + started_patches = [] + exit_stack = ExitStack() + + def start_patch(target: str): + p = patch(target) + started_patches.append(p) + return p.start() + + # Local, temporary dummy modules only if truly missing + if importlib.util.find_spec("confluent_kafka") is None: # pragma: no cover - environment dependent + dummy_ck = types.ModuleType("confluent_kafka") + + class DummyProducer: # minimal interface + def __init__(self, *_, **__): + pass + + def produce(self, *_, **kwargs): + cb = kwargs.get("callback") + if cb: + cb(None, None) + + def flush(self): # noqa: D401 - simple stub + return None + + dummy_ck.Producer = DummyProducer # type: ignore[attr-defined] + + class DummyKafkaException(Exception): + pass + + dummy_ck.KafkaException = DummyKafkaException # type: ignore[attr-defined] + exit_stack.enter_context(patch.dict(sys.modules, {"confluent_kafka": dummy_ck})) + + if importlib.util.find_spec("psycopg2") is None: # pragma: no cover - environment dependent + dummy_pg = types.ModuleType("psycopg2") + exit_stack.enter_context(patch.dict(sys.modules, {"psycopg2": dummy_pg})) + + mock_requests_get = start_patch("requests.get") + mock_requests_get.return_value.json.return_value = {"key": base64.b64encode(b"dummy_der").decode("utf-8")} + + mock_load_key = start_patch("cryptography.hazmat.primitives.serialization.load_der_public_key") + mock_load_key.return_value = object() + + class MockS3ObjectBody: + def read(self): + return json.dumps( + { + "public.cps.za.runs": ["FooBarUser"], + "public.cps.za.dlchange": ["FooUser", "BarUser"], + "public.cps.za.test": ["TestUser"], + } + ).encode("utf-8") + + class MockS3Object: + def get(self): + return {"Body": MockS3ObjectBody()} + + class MockS3Bucket: + def Object(self, _key): # noqa: D401 - simple proxy + return MockS3Object() + + class MockS3Resource: + def Bucket(self, _name): # noqa: D401 - simple proxy + return MockS3Bucket() + + mock_session = start_patch("boto3.Session") + mock_session.return_value.resource.return_value = MockS3Resource() + + mock_boto_client = start_patch("boto3.client") + mock_events_client = MagicMock() + mock_events_client.put_events.return_value = {"FailedEntryCount": 0} + mock_boto_client.return_value = mock_events_client + + # Allow kafka producer patching (already stubbed) but still patch to inspect if needed + start_patch("confluent_kafka.Producer") + + module = importlib.import_module("src.event_gate_lambda") + + yield module + + for p in started_patches: + p.stop() + exit_stack.close() + + +@pytest.fixture +def make_event(): + """Build a minimal API Gateway-style event dict for tests.""" + + def _make(resource, method="GET", body=None, topic=None, headers=None): + return { + "resource": resource, + "httpMethod": method, + "headers": headers or {}, + "pathParameters": {"topic_name": topic} if topic else {}, + "body": json.dumps(body) if isinstance(body, dict) else body, + } + + return _make + + +@pytest.fixture +def valid_payload(): + """A canonical valid payload used across tests.""" + return { + "event_id": "e1", + "tenant_id": "t1", + "source_app": "app", + "environment": "dev", + "timestamp": 123, + } diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py new file mode 100644 index 0000000..f7115cb --- /dev/null +++ b/tests/handlers/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/handlers/test_handler_token.py b/tests/handlers/test_handler_token.py new file mode 100644 index 0000000..801a20b --- /dev/null +++ b/tests/handlers/test_handler_token.py @@ -0,0 +1,91 @@ +# +# Copyright 2025 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from unittest.mock import patch + +import pytest + +from src.handlers.handler_token import HandlerToken + + +def test_get_token_endpoint(event_gate_module, make_event): + event = make_event("/token") + resp = event_gate_module.lambda_handler(event, None) + assert resp["statusCode"] == 303 + assert "Location" in resp["headers"] + + +def test_post_expired_token(event_gate_module, make_event, valid_payload): + """Expired JWT should yield 401 auth error.""" + + with patch.object( + event_gate_module.jwt, + "decode", + side_effect=event_gate_module.jwt.ExpiredSignatureError("expired"), + create=True, + ): + event = make_event( + "/topics/{topic_name}", + method="POST", + topic="public.cps.za.test", + body=valid_payload, + headers={"Authorization": "Bearer expiredtoken"}, + ) + resp = event_gate_module.lambda_handler(event, None) + assert resp["statusCode"] == 401 + body = json.loads(resp["body"]) + assert any(e["type"] == "auth" for e in body["errors"]) + + +def test_decode_jwt_all_second_key_succeeds(event_gate_module): + """First key fails signature, second key succeeds; claims returned from second key.""" + # Arrange: two dummy public keys + first_key = object() + second_key = object() + event_gate_module.handler_token.public_keys = [first_key, second_key] + + def decode_side_effect(token, key, algorithms): # noqa: D401 - test stub + if key is first_key: + raise event_gate_module.jwt.PyJWTError("signature mismatch") + return {"sub": "TestUser"} + + with patch.object(event_gate_module.jwt, "decode", side_effect=decode_side_effect, create=True): + claims = event_gate_module.handler_token.decode_jwt("dummy-token") + assert claims["sub"] == "TestUser" + + +def test_decode_jwt_all_all_keys_fail(event_gate_module): + """All keys fail; final PyJWTError with aggregate message is raised.""" + bad_keys = [object(), object()] + event_gate_module.handler_token.public_keys = bad_keys + + def always_fail(token, key, algorithms): # noqa: D401 - test stub + raise event_gate_module.jwt.PyJWTError("bad signature") + + with patch.object(event_gate_module.jwt, "decode", side_effect=always_fail, create=True): + with pytest.raises(event_gate_module.jwt.PyJWTError) as exc: + event_gate_module.handler_token.decode_jwt("dummy-token") + assert "Verification failed for all public keys" in str(exc.value) + + +def test_extract_token_empty(): + assert HandlerToken.extract_token({}) == "" + + +def test_extract_token_direct_bearer_header(): + token = HandlerToken.extract_token({"Bearer": " tok123 "}) + assert token == "tok123" diff --git a/tests/utils/test_conf_validation.py b/tests/test_conf_validation.py similarity index 98% rename from tests/utils/test_conf_validation.py rename to tests/test_conf_validation.py index 7ed62ec..ef4ec1c 100644 --- a/tests/utils/test_conf_validation.py +++ b/tests/test_conf_validation.py @@ -19,12 +19,12 @@ from glob import glob import pytest -CONF_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "conf") +CONF_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "conf") REQUIRED_CONFIG_KEYS = { "access_config", "token_provider_url", - "token_public_key_url", + "token_public_keys_url", "kafka_bootstrap_server", "event_bus_arn", } diff --git a/tests/test_event_gate_lambda.py b/tests/test_event_gate_lambda.py index 13ee3a1..e6683b6 100644 --- a/tests/test_event_gate_lambda.py +++ b/tests/test_event_gate_lambda.py @@ -13,124 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import base64 -import json -import importlib -import sys -import types -import pytest -from unittest.mock import patch, MagicMock -import importlib.util -from contextlib import ExitStack - - -@pytest.fixture(scope="module") -def event_gate_module(): - started_patches = [] - exit_stack = ExitStack() - - def start_patch(target): - p = patch(target) - started_patches.append(p) - return p.start() - - # Local, temporary dummy modules only if truly missing - # confluent_kafka - if importlib.util.find_spec("confluent_kafka") is None: # pragma: no cover - environment dependent - dummy_ck = types.ModuleType("confluent_kafka") - - class DummyProducer: # minimal interface - def __init__(self, *a, **kw): - pass - - def produce(self, *a, **kw): - cb = kw.get("callback") - if cb: - cb(None, None) - - def flush(self): # noqa: D401 - simple stub - return None - - dummy_ck.Producer = DummyProducer # type: ignore[attr-defined] - - class DummyKafkaException(Exception): - pass - - dummy_ck.KafkaException = DummyKafkaException # type: ignore[attr-defined] - exit_stack.enter_context(patch.dict(sys.modules, {"confluent_kafka": dummy_ck})) - - # psycopg2 optional dependency - if importlib.util.find_spec("psycopg2") is None: # pragma: no cover - environment dependent - dummy_pg = types.ModuleType("psycopg2") - exit_stack.enter_context(patch.dict(sys.modules, {"psycopg2": dummy_pg})) - - mock_requests_get = start_patch("requests.get") - mock_requests_get.return_value.json.return_value = {"key": base64.b64encode(b"dummy_der").decode("utf-8")} - - mock_load_key = start_patch("cryptography.hazmat.primitives.serialization.load_der_public_key") - mock_load_key.return_value = object() - - # Mock S3 access_config retrieval - class MockS3ObjectBody: - def read(self): - return json.dumps( - { - "public.cps.za.runs": ["FooBarUser"], - "public.cps.za.dlchange": ["FooUser", "BarUser"], - "public.cps.za.test": ["TestUser"], - } - ).encode("utf-8") - - class MockS3Object: - def get(self): - return {"Body": MockS3ObjectBody()} - - class MockS3Bucket: - def Object(self, key): # noqa: D401 - simple proxy - return MockS3Object() - class MockS3Resource: - def Bucket(self, name): # noqa: D401 - simple proxy - return MockS3Bucket() - - mock_session = start_patch("boto3.Session") - mock_session.return_value.resource.return_value = MockS3Resource() - - # Mock EventBridge client - mock_boto_client = start_patch("boto3.client") - mock_events_client = MagicMock() - mock_events_client.put_events.return_value = {"FailedEntryCount": 0} - mock_boto_client.return_value = mock_events_client - - # Allow kafka producer patching (already stubbed) but still patch to inspect if needed - start_patch("confluent_kafka.Producer") - - module = importlib.import_module("src.event_gate_lambda") - - yield module - - for p in started_patches: - p.stop() - exit_stack.close() - - -@pytest.fixture -def make_event(): - def _make(resource, method="GET", body=None, topic=None, headers=None): - return { - "resource": resource, - "httpMethod": method, - "headers": headers or {}, - "pathParameters": {"topic_name": topic} if topic else {}, - "body": json.dumps(body) if isinstance(body, dict) else body, - } - - return _make - - -@pytest.fixture -def valid_payload(): - return {"event_id": "e1", "tenant_id": "t1", "source_app": "app", "environment": "dev", "timestamp": 123} +import json +from unittest.mock import patch # --- GET flows --- @@ -323,13 +208,6 @@ def test_get_api_endpoint(event_gate_module, make_event): assert "openapi" in resp["body"].lower() -def test_get_token_endpoint(event_gate_module, make_event): - event = make_event("/token") - resp = event_gate_module.lambda_handler(event, None) - assert resp["statusCode"] == 303 - assert "Location" in resp["headers"] - - def test_internal_error_path(event_gate_module, make_event): with patch("src.event_gate_lambda.get_topics", side_effect=RuntimeError("boom")): event = make_event("/topics") diff --git a/tests/test_event_gate_lambda_local_access.py b/tests/test_event_gate_lambda_local_access.py index 525cdcb..714e4a8 100644 --- a/tests/test_event_gate_lambda_local_access.py +++ b/tests/test_event_gate_lambda_local_access.py @@ -54,5 +54,5 @@ def Bucket(self, name): # noqa: D401 # Force reload so import-level logic re-executes with patched open egl_reloaded = importlib.reload(egl) - assert not egl_reloaded.CONFIG["access_config"].startswith("s3://") # type: ignore[attr-defined] + assert not egl_reloaded.config["access_config"].startswith("s3://") # type: ignore[attr-defined] assert egl_reloaded.ACCESS["public.cps.za.test"] == ["User"] # type: ignore[attr-defined] diff --git a/tests/utils/test_extract_token.py b/tests/utils/test_extract_token.py deleted file mode 100644 index 5a896a7..0000000 --- a/tests/utils/test_extract_token.py +++ /dev/null @@ -1,72 +0,0 @@ -import base64 -import json -import importlib -from unittest.mock import patch, MagicMock -import pytest - - -@pytest.fixture(scope="module") -def egl_mod(): - patches = [] - - def start(target): - p = patch(target) - patches.append(p) - return p.start() - - # Patch requests.get for public key fetch - mock_get = start("requests.get") - mock_get.return_value.json.return_value = {"key": base64.b64encode(b"dummy_der").decode("utf-8")} - - # Patch crypto key loader - start_loader = start("cryptography.hazmat.primitives.serialization.load_der_public_key") - start_loader.return_value = object() - - # Patch boto3.Session resource for S3 to avoid bucket validation/network - class MockBody: - def read(self): - return json.dumps( - { - "public.cps.za.runs": ["User"], - "public.cps.za.dlchange": ["User"], - "public.cps.za.test": ["User"], - } - ).encode("utf-8") - - class MockObject: - def get(self): - return {"Body": MockBody()} - - class MockBucket: - def Object(self, key): # noqa: D401 - return MockObject() - - class MockS3: - def Bucket(self, name): # noqa: D401 - return MockBucket() - - mock_session = start("boto3.Session") - mock_session.return_value.resource.return_value = MockS3() - - # Patch boto3.client for EventBridge - mock_client = start("boto3.client") - mock_events = MagicMock() - mock_events.put_events.return_value = {"FailedEntryCount": 0} - mock_client.return_value = mock_events - - # Patch Kafka Producer - start("confluent_kafka.Producer") - - module = importlib.import_module("src.event_gate_lambda") - yield module - - for p in patches: - p.stop() - - -def test_extract_token_empty(egl_mod): - assert egl_mod.extract_token({}) == "" - - -def test_extract_token_direct_bearer_header(egl_mod): - token = egl_mod.extract_token({"Bearer": " tok123 "})