Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(event_handler): security scheme unhashable list when working with router #4421

Merged
6 changes: 4 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
validation_error_definition,
validation_error_response_definition,
)
from aws_lambda_powertools.event_handler.util import _FrozenDict, extract_origin_header
from aws_lambda_powertools.event_handler.util import _FrozenDict, _FrozenListDict, extract_origin_header
from aws_lambda_powertools.shared.cookies import Cookie
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand Down Expand Up @@ -703,6 +703,7 @@ def _openapi_operation_parameters(
from aws_lambda_powertools.event_handler.openapi.params import Param

parameters = []
parameter: Dict[str, Any]
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
Expand Down Expand Up @@ -2386,6 +2387,7 @@ def register_route(func: Callable):
methods = (method,) if isinstance(method, str) else tuple(method)
frozen_responses = _FrozenDict(responses) if responses else None
frozen_tags = frozenset(tags) if tags else None
frozen_security = _FrozenListDict(security) if security else None
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved

route_key = (
rule,
Expand All @@ -2400,7 +2402,7 @@ def register_route(func: Callable):
frozen_tags,
operation_id,
include_in_schema,
security,
frozen_security,
)

# Collate Middleware for routes
Expand Down
23 changes: 22 additions & 1 deletion aws_lambda_powertools/event_handler/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List

from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value

Expand All @@ -18,6 +18,27 @@ def __hash__(self):
return hash(frozenset(self.keys()))


class _FrozenListDict(List[Dict[str, List[str]]]):
"""
Freezes a list of dictionaries containing lists of strings.

This function takes a list of dictionaries where the values are lists of strings and converts it into
a frozen set of frozen sets of frozen dictionaries. This is done by iterating over the input list,
converting each dictionary's values (lists of strings) into frozen sets of strings, and then
converting the resulting dictionary into a frozen dictionary. Finally, all these frozen dictionaries
are collected into a frozen set of frozen sets.

This operation is useful when you want to ensure the immutability of the data structure and make it
hashable, which is required for certain operations like using it as a key in a dictionary or as an
element in a set.

Example: [{"TestAuth": ["test", "test1"]}]
"""

def __hash__(self):
return hash(frozenset({_FrozenDict({key: frozenset(self) for key, self in item.items()}) for item in self}))
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved


def extract_origin_header(resolver_headers: Dict[str, Any]):
"""
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
Expand Down
44 changes: 37 additions & 7 deletions tests/functional/event_handler/test_openapi_security.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import pytest

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.api_gateway import Router
from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn


def test_openapi_top_level_security():
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/")
def handler():
raise NotImplementedError()

# WHEN the get_openapi_schema method is called with a security scheme
schema = app.get_openapi_schema(
security_schemes={
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
},
security=[{"apiKey": []}],
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
)

# THEN the resulting schema should have security defined at the top level
security = schema.security
assert security is not None

Expand All @@ -26,37 +30,63 @@ def handler():


def test_openapi_top_level_security_missing():
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/")
def handler():
raise NotImplementedError()

# WHEN the get_openapi_schema method is called with security defined without security schemes
# THEN a ValueError should be raised
with pytest.raises(ValueError):
app.get_openapi_schema(
security=[{"apiKey": []}],
)


def test_openapi_operation_security():
# GIVEN an APIGatewayRestResolver instance
app = APIGatewayRestResolver()

@app.get("/", security=[{"apiKey": []}])
def handler():
raise NotImplementedError()

# WHEN the get_openapi_schema method is called with security defined at the operation level
schema = app.get_openapi_schema(
security_schemes={
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
},
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
)

security = schema.security
assert security is None
# THEN the resulting schema should have security defined at the operation level, not the top level
top_level_security = schema.security
path_level_security = schema.paths["/"].get.security
assert top_level_security is None
assert path_level_security[0] == {"apiKey": []}

operation = schema.paths["/"].get
security = operation.security
assert security is not None

assert len(security) == 1
assert security[0] == {"apiKey": []}
def test_openapi_operation_security_with_router():
# GIVEN an APIGatewayRestResolver instance with a Router
app = APIGatewayRestResolver()
router = Router()

@router.get("/", security=[{"apiKey": []}])
def handler():
raise NotImplementedError()

app.include_router(router)

# WHEN the get_openapi_schema method is called with security defined at the operation level in the Router
schema = app.get_openapi_schema(
security_schemes={
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
},
)

# THEN the resulting schema should have security defined at the operation level
top_level_security = schema.security
path_level_security = schema.paths["/"].get.security
assert top_level_security is None
assert path_level_security[0] == {"apiKey": []}