Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Example (sanitized):
{
"access_config": "s3://<bucket>/access.json",
"token_provider_url": "https://<token-ui.example>",
"token_public_key_url": "https://<token-api.example>/public-key",
"token_public_keys_url": "https://<token-api.example>/token/public-keys",
"kafka_bootstrap_server": "broker1:9092,broker2:9092",
"event_bus_arn": "arn:aws:events:region:acct:event-bus/your-bus"
}
Expand Down Expand Up @@ -137,7 +137,7 @@ Use when Kafka access needs Kerberos / SASL_SSL or custom `librdkafka` build.
| Code coverage | [Code Coverage](./DEVELOPER.md#code-coverage) |

## Security & Authorization
- JWT tokens must be RS256 signed; the public key is fetched at cold start from `token_public_key_url` (DER base64 inside JSON `{ "key": "..." }`).
- JWT tokens must be RS256 signed; current and previous public keys are fetched at cold start from `token_public_keys_url` as DER base64 values (list `keys[*].key`, with single-key fallback `{ "key": "..." }`).
- Subject claim (`sub`) is matched against `ACCESS[topicName]`.
- Authorization header forms accepted:
- `Authorization: Bearer <token>` (preferred)
Expand Down
2 changes: 1 addition & 1 deletion conf/config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"access_config": "s3://<redacted>/access.json",
"token_provider_url": "https://<redacted>",
"token_public_key_url": "https://<redacted>",
"token_public_keys_url": "https://<redacted>",
"kafka_bootstrap_server": "localhost:9092",
"event_bus_arn": "arn:aws:events:<redacted>"
}
79 changes: 14 additions & 65 deletions src/event_gate_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#

"""Event Gate Lambda function implementation."""
import base64
import json
import logging
import os
Expand All @@ -24,13 +23,11 @@

import boto3
import jwt
import requests
import urllib3
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import serialization
from jsonschema import validate
from jsonschema.exceptions import ValidationError

from src.handlers.handler_token import HandlerToken
from src.writers import writer_eventbridge, writer_kafka, writer_postgres
from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV

Expand Down Expand Up @@ -64,35 +61,28 @@
logger.debug("Loaded TOPICS")

with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file:
CONFIG = json.load(file)
config = json.load(file)
logger.debug("Loaded main CONFIG")

aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally
logger.debug("Initialized AWS S3 Client")

if CONFIG["access_config"].startswith("s3://"):
name_parts = CONFIG["access_config"].split("/")
if config["access_config"].startswith("s3://"):
name_parts = config["access_config"].split("/")
BUCKET_NAME = name_parts[2]
BUCKET_OBJECT_KEY = "/".join(name_parts[3:])
ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8"))
else:
with open(CONFIG["access_config"], "r", encoding="utf-8") as file:
with open(config["access_config"], "r", encoding="utf-8") as file:
ACCESS = json.load(file)
logger.debug("Loaded ACCESS definitions")

TOKEN_PROVIDER_URL = CONFIG["token_provider_url"]
# Add timeout to avoid hanging requests; wrap in robust error handling so failures are explicit
try:
response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external
token_public_key_encoded = response_json["key"]
TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded))
logger.debug("Loaded TOKEN_PUBLIC_KEY")
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
logger.exception("Failed to fetch or deserialize token public key from %s", CONFIG.get("token_public_key_url"))
raise RuntimeError("Token public key initialization failed") from exc

writer_eventbridge.init(logger, CONFIG)
writer_kafka.init(logger, CONFIG)
# Initialize token handler and load token public keys
handler_token = HandlerToken(config).load_public_keys()

# Initialize EventGate writers
writer_eventbridge.init(logger, config)
writer_kafka.init(logger, config)
writer_postgres.init(logger)


Expand Down Expand Up @@ -124,12 +114,6 @@ def get_api() -> Dict[str, Any]:
return {"statusCode": 200, "body": API}


def get_token() -> Dict[str, Any]:
"""Return 303 redirect to token provider endpoint."""
logger.debug("Handling GET Token")
return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}}


def get_topics() -> Dict[str, Any]:
"""Return list of available topic names."""
logger.debug("Handling GET Topics")
Expand Down Expand Up @@ -163,7 +147,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
"""
logger.debug("Handling POST %s", topic_name)
try:
token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type]
token: Dict[str, Any] = handler_token.decode_jwt(token_encoded)
except jwt.PyJWTError: # type: ignore[attr-defined]
return _error_response(401, "auth", "Invalid or missing token")

Expand Down Expand Up @@ -205,41 +189,6 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
}


def extract_token(event_headers: Dict[str, str]) -> str:
"""Extract bearer token from headers (case-insensitive).

Supports:
- Custom 'bearer' header (any casing) whose value is the raw token
- Standard 'Authorization: Bearer <token>' header (case-insensitive scheme & key)
Returns empty string if token not found or malformed.
"""
if not event_headers:
return ""

# Normalize keys to lowercase for case-insensitive lookup
lowered = {str(k).lower(): v for k, v in event_headers.items()}

# Direct bearer header (raw token)
if "bearer" in lowered and isinstance(lowered["bearer"], str):
token_candidate = lowered["bearer"].strip()
if token_candidate:
return token_candidate

# Authorization header with Bearer scheme
auth_val = lowered.get("authorization", "")
if not isinstance(auth_val, str): # defensive
return ""
auth_val = auth_val.strip()
if not auth_val:
return ""

# Case-insensitive match for 'Bearer ' prefix
if not auth_val.lower().startswith("bearer "):
return ""
token_part = auth_val[7:].strip() # len('Bearer ')==7
return token_part


def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements
"""AWS Lambda entry point.

Expand All @@ -250,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
if resource == "/api":
return get_api()
if resource == "/token":
return get_token()
return handler_token.get_token()
if resource == "/topics":
return get_topics()
if resource == "/topics/{topic_name}":
Expand All @@ -261,7 +210,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
return post_topic_message(
event["pathParameters"]["topic_name"].lower(),
json.loads(event["body"]),
extract_token(event.get("headers", {})),
handler_token.extract_token(event.get("headers", {})),
)
if resource == "/terminate":
sys.exit("TERMINATING") # pragma: no cover - deliberate termination path
Expand Down
15 changes: 15 additions & 0 deletions src/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2025 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
142 changes: 142 additions & 0 deletions src/handlers/handler_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#
# Copyright 2025 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides the HandlerToken class for managing the token related operations.
"""

import base64
import logging
import os
from typing import Dict, Any, cast

import jwt
import requests
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey

from src.utils.constants import TOKEN_PROVIDER_URL, TOKEN_PUBLIC_KEYS_URL, TOKEN_PUBLIC_KEY_URL

logger = logging.getLogger(__name__)
log_level = os.environ.get("LOG_LEVEL", "INFO")
logger.setLevel(log_level)


class HandlerToken:
"""
HandlerToken manages token provider URL and public keys for JWT verification.
"""

def __init__(self, config):
self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "")
self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL)
self.public_keys: list[RSAPublicKey] = []

def load_public_keys(self) -> "HandlerToken":
"""
Load token public keys from the configured URL.
Returns:
HandlerToken: The current instance with loaded public keys.
Raises:
RuntimeError: If fetching or deserializing the public keys fails.
"""
logger.debug("Loading token public keys from %s", self.public_keys_url)

try:
response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
Comment on lines +59 to +60
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Security concern: SSL certificate verification is disabled.

Using verify=False disables SSL certificate verification, making this request vulnerable to man-in-the-middle attacks. This is particularly concerning when fetching cryptographic public keys, as an attacker could inject malicious keys.

Consider making SSL verification configurable or defaulting to verify=True:

-            response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
+            response_json = requests.get(self.public_keys_url, timeout=5).json()

If there's a legitimate need to disable verification in certain environments (e.g., development with self-signed certs), consider making it configurable via an environment variable or config parameter rather than unconditionally disabling it.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
try:
response_json = requests.get(self.public_keys_url, timeout=5).json()
🧰 Tools
🪛 Ruff (0.14.5)

60-60: Probable use of requests call with verify=False disabling SSL certificate checks

(S501)

🤖 Prompt for AI Agents
In src/handlers/handler_token.py around lines 59-60, the requests.get call
disables SSL certificate verification (verify=False) which is insecure; change
it to use a configurable verification flag (e.g., read from an environment
variable or config with default True) and pass that flag to requests.get
(verify=<config_flag>) instead of False, ensure the default is True, add a log
warning if verification is explicitly disabled, and do not leave verify=False
hard-coded in the repository.

raw_keys: list[str] = []

if isinstance(response_json, dict):
if "keys" in response_json and isinstance(response_json["keys"], list):
for item in response_json["keys"]:
if "key" in item:
raw_keys.append(item["key"].strip())
elif "key" in response_json:
raw_keys.append(response_json["key"].strip())

if not raw_keys:
raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response")

self.public_keys = [
cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys
]
logger.debug("Loaded %d token public keys", len(self.public_keys))

return self
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url)
raise RuntimeError("Token public key initialization failed") from exc

def decode_jwt(self, token_encoded: str) -> Dict[str, Any]:
"""
Decode and verify a JWT using the loaded public keys.
Args:
token_encoded (str): The encoded JWT token.
Returns:
Dict[str, Any]: The decoded JWT payload.
Raises:
jwt.PyJWTError: If verification fails for all public keys.
"""
logger.debug("Decoding JWT")
for public_key in self.public_keys:
try:
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
except jwt.PyJWTError:
continue
raise jwt.PyJWTError("Verification failed for all public keys")

def get_token(self) -> Dict[str, Any]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_token_provider_info maybe, or something similar? get_token suggests actual token of some kind is returned, at least to me.

"""
Returns: A 303 redirect response to the token provider URL.
"""
logger.debug("Handling GET Token")
return {"statusCode": 303, "headers": {"Location": self.provider_url}}

@staticmethod
def extract_token(event_headers: Dict[str, str]) -> str:
"""
Extracts the bearer (custom/standard) token from event headers.
Args:
event_headers (Dict[str, str]): The event headers.
Returns:
str: The extracted bearer token, or an empty string if not found.
"""
if not event_headers:
return ""

# Normalize keys to lowercase for case-insensitive lookup
lowered = {str(k).lower(): v for k, v in event_headers.items()}

# Direct bearer header (raw token)
if "bearer" in lowered and isinstance(lowered["bearer"], str):
token_candidate = lowered["bearer"].strip()
if token_candidate:
return token_candidate

# Authorization header with Bearer scheme
auth_val = lowered.get("authorization", "")
if not isinstance(auth_val, str): # defensive
return ""
auth_val = auth_val.strip()
if not auth_val:
return ""

# Case-insensitive match for 'Bearer ' prefix
if not auth_val.lower().startswith("bearer "):
return ""
token_part = auth_val[7:].strip() # len('Bearer ')==7
return token_part
24 changes: 24 additions & 0 deletions src/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright 2025 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module contains all constants and enums used across the project.
"""

# Token related constants
TOKEN_PROVIDER_URL = "token_provider_url"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it could be more explicitly said that these are constants for config keys? On the first glance, it looks like the constant holds actual URL.

TOKEN_PUBLIC_KEY_URL = "token_public_key_url"
TOKEN_PUBLIC_KEYS_URL = "token_public_keys_url"
Loading