Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): support Header parameter validation in OpenAPI schema #3687

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,22 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
query_string,
)

# Normalize header values before validate this
headers = _normalize_multi_header_values_with_param(
app.current_event.resolved_headers_field,
route.dependant.header_params,
)

# Process header values
header_values, header_errors = _request_params_to_args(
route.dependant.header_params,
headers,
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors
values.update(header_values)
errors += path_errors + query_errors + header_errors

# Process the request body, if it exists
if route.dependant.body_params:
Expand Down Expand Up @@ -243,12 +256,14 @@ def _request_params_to_args(
errors = []

for field in required_params:
value = received_params.get(field.alias)

field_info = field.field_info

# To ensure early failure, we check if it's not an instance of Param.
if not isinstance(field_info, Param):
raise AssertionError(f"Expected Param field_info, got {field_info}")

value = received_params.get(field.alias)

loc = (field_info.in_.value, field.alias)

# If we don't have a value, see if it's required or has a default
Expand Down Expand Up @@ -377,3 +392,30 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
except KeyError:
pass
return query_string


def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
"""
Extract and normalize resolved_headers_field

Parameters
----------
headers: Dict
A dictionary containing the initial header parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.

Returns
-------
A dictionary containing the processed headers.
"""
if headers:
for param in filter(is_scalar_field, params):
try:
if len(headers[param.alias]) == 1:
# if the target parameter is a scalar and the list contains only 1 element
# we keep the first value of the headers regardless if there are more in the payload
headers[param.alias] = headers[param.alias][0]
except KeyError:
pass
return headers
27 changes: 16 additions & 11 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Dependant,
Header,
Param,
ParamTypes,
Query,
_File,
_Form,
_Header,
analyze_param,
create_response_field,
get_flat_dependant,
Expand Down Expand Up @@ -59,16 +59,21 @@ def add_param_to_fields(

"""
field_info = cast(Param, field.field_info)
if field_info.in_ == ParamTypes.path:
dependant.path_params.append(field)
elif field_info.in_ == ParamTypes.query:
dependant.query_params.append(field)
elif field_info.in_ == ParamTypes.header:
dependant.header_params.append(field)

# Dictionary to map ParamTypes to their corresponding lists in dependant
param_type_map = {
ParamTypes.path: dependant.path_params,
ParamTypes.query: dependant.query_params,
ParamTypes.header: dependant.header_params,
ParamTypes.cookie: dependant.cookie_params,
}

# Check if field_info.in_ is a valid key in param_type_map and append the field to the corresponding list
# or raise an exception if it's not a valid key.
if field_info.in_ in param_type_map:
param_type_map[field_info.in_].append(field)
else:
if field_info.in_ != ParamTypes.cookie:
raise AssertionError(f"Unsupported param type: {field_info.in_}")
dependant.cookie_params.append(field)
raise AssertionError(f"Unsupported param type: {field_info.in_}")


def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -265,7 +270,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return False
elif is_scalar_field(field=param_field):
return False
elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
return False
else:
if not isinstance(param_field.field_info, Body):
Expand Down
79 changes: 77 additions & 2 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
)


class _Header(Param):
class Header(Param):
"""
A class used internally to represent a header parameter in a path operation.
"""
Expand Down Expand Up @@ -527,12 +527,75 @@ def __init__(
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
"""
Constructs a new Query param.

Parameters
----------
default: Any
The default value of the parameter
default_factory: Callable[[], Any], optional
Callable that will be called when a default value is needed for this field
annotation: Any, optional
The type annotation of the parameter
alias: str, optional
The public name of the field
alias_priority: int, optional
Priority of the alias. This affects whether an alias generator is used
validation_alias: str | AliasPath | AliasChoices | None, optional
Alias to be used for validation only
serialization_alias: str | AliasPath | AliasChoices | None, optional
Alias to be used for serialization only
convert_underscores: bool
If true convert "_" to "-"
See RFC: https://www.rfc-editor.org/rfc/rfc9110.html#name-field-name-registry
title: str, optional
The title of the parameter
description: str, optional
The description of the parameter
gt: float, optional
Only applies to numbers, required the field to be "greater than"
ge: float, optional
Only applies to numbers, required the field to be "greater than or equal"
lt: float, optional
Only applies to numbers, required the field to be "less than"
le: float, optional
Only applies to numbers, required the field to be "less than or equal"
min_length: int, optional
Only applies to strings, required the field to have a minimum length
max_length: int, optional
Only applies to strings, required the field to have a maximum length
pattern: str, optional
Only applies to strings, requires the field match against a regular expression pattern string
discriminator: str, optional
Parameter field name for discriminating the type in a tagged union
strict: bool, optional
Enables Pydantic's strict mode for the field
multiple_of: float, optional
Only applies to numbers, requires the field to be a multiple of the given value
allow_inf_nan: bool, optional
Only applies to numbers, requires the field to allow infinity and NaN values
max_digits: int, optional
Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
decimal_places: int, optional
Only applies to Decimals, requires the field to have at most a number of decimal places
examples: List[Any], optional
A list of examples for the parameter
deprecated: bool, optional
If `True`, the parameter will be marked as deprecated
include_in_schema: bool, optional
If `False`, the parameter will be excluded from the generated OpenAPI schema
json_schema_extra: Dict[str, Any], optional
Extra values to include in the generated OpenAPI schema
"""
self.convert_underscores = convert_underscores
self._alias = alias

super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias=self._alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
Expand All @@ -558,6 +621,18 @@ def __init__(
**extra,
)

@property
def alias(self):
return self._alias

@alias.setter
def alias(self, value: Optional[str] = None):
if value is not None:
# Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name
# This ensures that customers can access headers with any casing, as per the RFC guidelines.
# Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
self._alias = value.lower()


class Body(FieldInfo):
"""
Expand Down
11 changes: 11 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:

return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}

@property
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueHeaders")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:

return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}

@property
def request_context(self) -> APIGatewayEventRequestContext:
return APIGatewayEventRequestContext(self._data)
Expand Down Expand Up @@ -316,3 +327,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
return query_string

return {}

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
return headers

return {}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper

Expand Down Expand Up @@ -112,3 +112,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
return {}
15 changes: 15 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
"""
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
"""
This property determines the appropriate header to be used
as a trusted source for validating OpenAPI.

This is necessary because different resolvers use different formats to encode
headers parameters.

Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the header name
This ensures that customers can access headers with any casing, as per the RFC guidelines.
Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
"""
return self.headers

@property
def is_base64_encoded(self) -> Optional[bool]:
return self.get("isBase64Encoded")
Expand Down
15 changes: 15 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def query_string_parameters(self) -> Dict[str, str]:
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
return headers

return {}


class vpcLatticeEventV2Identity(DictWrapper):
@property
Expand Down Expand Up @@ -259,3 +267,10 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, str]]:
if self.headers is not None:
return {key.lower(): value for key, value in self.headers.items()}

return {}
34 changes: 34 additions & 0 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,40 @@ For example, we could validate that `<todo_id>` dynamic path should be no greate

1. `Path` is a special OpenAPI type that allows us to constrain todo_id to be less than 999.

#### Validating headers

We use the `Annotated` type to tell Event Handler that a particular parameter is a header that needs to be validated.
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

!!! info "We adhere to HTTP RFC standards, which means we treat HTTP headers as case-insensitive."
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

In the following example, we use a new `Header` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as:

* `correlation_id` is a header that must be present in the request
* `correlation_id`, when set, should have 16 characters
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
* Doesn't match? Event Handler will return a validation error response
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

<!-- markdownlint-disable MD013 -->

=== "validating_headers.py"

```python hl_lines="8 10 29"
--8<-- "examples/event_handler_rest/src/validating_headers.py"
```

1. If you're not using Python 3.9 or higher, you can install and use [`typing_extensions`](https://pypi.org/project/typing-extensions/){target="_blank" rel="nofollow"} to the same effect
2. `Header` is a special OpenAPI type that can add constraints to a header well as document them
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
3. **First time seeing the `Annotated`?** <br><br> This special type uses the first argument as the actual type, and subsequent arguments are metadata. <br><br> At runtime, static checkers will also see the first argument, but anyone receiving them could inspect them to fetch their metadata.
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

=== "working_with_headers_multi_value.py"

If you need to handle multi-value for specific headers, you can create a list of the desired type.
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

```python hl_lines="23"
--8<-- "examples/event_handler_rest/src/working_with_headers_multi_value.py"
```

1. `cloudfront_viewer_country` is a list that must contain values from the `CountriesAllowed` enumeration.

### Accessing request details

Event Handler integrates with [Event Source Data Classes utilities](../../utilities/data_classes.md){target="_blank"}, and it exposes their respective resolver request details and convenient methods under `app.current_event`.
Expand Down
Loading
Loading