Skip to content

Commit

Permalink
fix(event_handler): security scheme unhashable list when working with…
Browse files Browse the repository at this point in the history
… router (#4421)
  • Loading branch information
leandrodamascena committed Jun 7, 2024
1 parent 88c8e91 commit e4c236b
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 38 deletions.
35 changes: 26 additions & 9 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_PREFIX,
METHODS_WITH_BODY,
Expand All @@ -43,7 +43,12 @@
validation_error_definition,
validation_error_response_definition,
)
from aws_lambda_powertools.event_handler.util import _FrozenDict, extract_origin_header
from aws_lambda_powertools.event_handler.util import (
_FrozenDict,
_FrozenListDict,
_validate_openapi_security_parameters,
extract_origin_header,
)
from aws_lambda_powertools.shared.cookies import Cookie
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand Down Expand Up @@ -703,6 +708,7 @@ def _openapi_operation_parameters(
from aws_lambda_powertools.event_handler.openapi.params import Param

parameters = []
parameter: Dict[str, Any]
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
Expand Down Expand Up @@ -1588,6 +1594,16 @@ def get_openapi_schema(

# Add routes to the OpenAPI schema
for route in all_routes:

if route.security and not _validate_openapi_security_parameters(
security=route.security,
security_schemes=security_schemes,
):
raise SchemaValidationError(
"Security configuration was not found in security_schemas or security_schema was not defined. "
"See: https://docs.powertools.aws.dev/lambda/python/latest/core/event_handler/api_gateway/#security-schemes",
)

if not route.include_in_schema:
continue

Expand Down Expand Up @@ -1630,15 +1646,15 @@ def _get_openapi_security(
security: Optional[List[Dict[str, List[str]]]],
security_schemes: Optional[Dict[str, "SecurityScheme"]],
) -> Optional[List[Dict[str, List[str]]]]:

if not security:
return None

if not security_schemes:
raise ValueError("security_schemes must be provided if security is provided")

# Check if all keys in security are present in the security_schemes
if any(key not in security_schemes for sec in security for key in sec):
raise ValueError("Some security schemes not found in security_schemes")
if not _validate_openapi_security_parameters(security=security, security_schemes=security_schemes):
raise SchemaValidationError(
"Security configuration was not found in security_schemas or security_schema was not defined. "
"See: https://docs.powertools.aws.dev/lambda/python/latest/core/event_handler/api_gateway/#security-schemes",
)

return security

Expand Down Expand Up @@ -2386,6 +2402,7 @@ def register_route(func: Callable):
methods = (method,) if isinstance(method, str) else tuple(method)
frozen_responses = _FrozenDict(responses) if responses else None
frozen_tags = frozenset(tags) if tags else None
frozen_security = _FrozenListDict(security) if security else None

route_key = (
rule,
Expand All @@ -2400,7 +2417,7 @@ def register_route(func: Callable):
frozen_tags,
operation_id,
include_in_schema,
security,
frozen_security,
)

# Collate Middleware for routes
Expand Down
6 changes: 6 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body


class SchemaValidationError(ValidationException):
"""
Raised when the OpenAPI schema validation fails
"""
68 changes: 63 additions & 5 deletions aws_lambda_powertools/event_handler/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.event_handler.openapi.models import SecurityScheme
from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value


Expand All @@ -18,17 +19,45 @@ def __hash__(self):
return hash(frozenset(self.keys()))


class _FrozenListDict(List[Dict[str, List[str]]]):
"""
Freezes a list of dictionaries containing lists of strings.
This function takes a list of dictionaries where the values are lists of strings and converts it into
a frozen set of frozen sets of frozen dictionaries. This is done by iterating over the input list,
converting each dictionary's values (lists of strings) into frozen sets of strings, and then
converting the resulting dictionary into a frozen dictionary. Finally, all these frozen dictionaries
are collected into a frozen set of frozen sets.
This operation is useful when you want to ensure the immutability of the data structure and make it
hashable, which is required for certain operations like using it as a key in a dictionary or as an
element in a set.
Example: [{"TestAuth": ["test", "test1"]}]
"""

def __hash__(self):
hashable_items = []
for item in self:
hashable_items.extend((key, frozenset(value)) for key, value in item.items())
return hash(frozenset(hashable_items))


def extract_origin_header(resolver_headers: Dict[str, Any]):
"""
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
The 'origin' or 'Origin' header can be either a single header or a multi-header.
Args:
resolver_headers (Dict): A dictionary containing the headers.
Parameters
----------
resolver_headers: Dict
A dictionary containing the headers.
Returns:
Optional[str]: The value(s) of the origin header or None.
Returns
-------
Optional[str]
The value(s) of the origin header or None.
"""
resolved_header = get_header_value(
headers=resolver_headers,
Expand All @@ -40,3 +69,32 @@ def extract_origin_header(resolver_headers: Dict[str, Any]):
return resolved_header[0]

return resolved_header


def _validate_openapi_security_parameters(
security: List[Dict[str, List[str]]],
security_schemes: Optional[Dict[str, "SecurityScheme"]] = None,
) -> bool:
"""
This function checks if all security requirements listed in the 'security'
parameter are defined in the 'security_schemes' dictionary, as specified
in the OpenAPI schema.
Parameters
----------
security: List[Dict[str, List[str]]]
A list of security requirements
security_schemes: Optional[Dict[str, "SecurityScheme"]]
A dictionary mapping security scheme names to their corresponding security scheme objects.
Returns
-------
bool
Whether list of security schemes match allowed security_schemes.
"""

security_schemes = security_schemes or {}

security_schema_match = all(key in security_schemes for sec in security for key in sec)

return bool(security_schema_match and security_schemes)
3 changes: 1 addition & 2 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -1032,8 +1032,7 @@ Below is an example configuration for serving Swagger UI from a custom path or C
???-info "Does Powertools implement any of the security schemes?"
No. Powertools adds support for generating OpenAPI documentation with [security schemes](https://swagger.io/docs/specification/authentication/), but it doesn't implement any of the security schemes itself, so you must implement the security mechanisms separately.

OpenAPI uses the term security scheme for [authentication and authorization schemes](https://swagger.io/docs/specification/authentication/){target="_blank"}.
When you're describing your API, declare security schemes at the top level, and reference them globally or per operation.
Security schemes are declared at the top-level first. You can reference them globally or on a per path _(operation)_ level. **However**, if you reference security schemes that are not defined at the top-level it will lead to a `SchemaValidationError` _(invalid OpenAPI spec)_.

=== "Global OpenAPI security schemes"

Expand Down
6 changes: 6 additions & 0 deletions tests/functional/event_handler/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import fastjsonschema
import pytest

from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn
from tests.functional.utils import load_event


Expand Down Expand Up @@ -114,3 +115,8 @@ def openapi31_schema():
data,
use_formats=False,
)


@pytest.fixture
def security_scheme():
return {"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header)}
111 changes: 89 additions & 22 deletions tests/functional/event_handler/test_openapi_security.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import pytest

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn
from aws_lambda_powertools.event_handler.api_gateway import Router
from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError


def test_openapi_top_level_security():
def test_openapi_top_level_security(security_scheme):
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(
security_schemes={
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
},
security=[{"apiKey": []}],
)
# WHEN the get_openapi_schema method is called with a security scheme
schema = app.get_openapi_schema(security_schemes=security_scheme, security=[{"apiKey": []}])

# THEN the resulting schema should have security defined at the top level
security = schema.security
assert security is not None

Expand All @@ -26,37 +25,105 @@ def handler():


def test_openapi_top_level_security_missing():
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/")
def handler():
raise NotImplementedError()

with pytest.raises(ValueError):
# WHEN the get_openapi_schema method is called with security defined without security schemes
# THEN a SchemaValidationError should be raised
with pytest.raises(SchemaValidationError):
app.get_openapi_schema(
security=[{"apiKey": []}],
)


def test_openapi_operation_security():
def test_openapi_top_level_security_mismatch(security_scheme):
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/")
def handler():
raise NotImplementedError()

# WHEN the get_openapi_schema method is called with security defined security schemes as APIKey
# AND top level security is defined as HTTPBearer
# THEN a SchemaValidationError should be raised
with pytest.raises(SchemaValidationError):
app.get_openapi_schema(
security_schemes=security_scheme,
security=[{"HTTPBearer": []}],
)


def test_openapi_operation_level_security(security_scheme):
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/", security=[{"apiKey": []}])
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(
security_schemes={
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
},
)
# WHEN the get_openapi_schema method is called with security defined at the operation level
schema = app.get_openapi_schema(security_schemes=security_scheme)

security = schema.security
assert security is None
# THEN the resulting schema should have security defined at the operation level, not the top level
top_level_security = schema.security
path_level_security = schema.paths["/"].get.security
assert top_level_security is None
assert path_level_security[0] == {"apiKey": []}

operation = schema.paths["/"].get
security = operation.security
assert security is not None

assert len(security) == 1
assert security[0] == {"apiKey": []}
def test_openapi_operation_level_security_missing():
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

# AND a route with a security scheme defined
@app.get("/", security=[{"apiKey": []}])
def handler():
raise NotImplementedError()

# WHEN the get_openapi_schema method is called without security schemes defined
# THEN a SchemaValidationError should be raised
with pytest.raises(SchemaValidationError):
app.get_openapi_schema()


def test_openapi_operation_level_security_mismatch(security_scheme):
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

# AND a route with a security scheme using HTTPBearer
@app.get("/", security=[{"HTTPBearer": []}])
def handler():
raise NotImplementedError()

# WHEN the get_openapi_schema method is called with security defined security schemes as APIKey
# THEN a SchemaValidationError should be raised
with pytest.raises(SchemaValidationError):
app.get_openapi_schema(
security_schemes=security_scheme,
)


def test_openapi_operation_level_security_with_router(security_scheme):
# GIVEN an APIGatewayRestResolver instance with a Router
app = APIGatewayRestResolver()
router = Router()

@router.get("/", security=[{"apiKey": []}])
def handler():
raise NotImplementedError()

app.include_router(router)

# WHEN the get_openapi_schema method is called with security defined at the operation level in the Router
schema = app.get_openapi_schema(security_schemes=security_scheme)

# THEN the resulting schema should have security defined at the operation level
top_level_security = schema.security
path_level_security = schema.paths["/"].get.security
assert top_level_security is None
assert path_level_security[0] == {"apiKey": []}

0 comments on commit e4c236b

Please sign in to comment.