diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 22d7ba91bcc..84cfa4ea503 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -92,7 +92,7 @@ Server, Tag, ) - from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( OAuth2Config, ) @@ -818,46 +818,123 @@ def _openapi_operation_parameters( """ Returns the OpenAPI operation parameters. """ - from aws_lambda_powertools.event_handler.openapi.compat import ( - get_schema_from_model_field, - ) from aws_lambda_powertools.event_handler.openapi.params import Param - parameters = [] - parameter: dict[str, Any] = {} + parameters: list[dict[str, Any]] = [] for param in all_route_params: - field_info = param.field_info - field_info = cast(Param, field_info) + field_info = cast(Param, param.field_info) if not field_info.include_in_schema: continue - param_schema = get_schema_from_model_field( - field=param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) + # Check if this is a Pydantic model that should be expanded + if Route._is_pydantic_model_param(field_info): + parameters.extend(Route._expand_pydantic_model_parameters(field_info)) + else: + parameters.append(Route._create_regular_parameter(param, model_name_map, field_mapping)) - parameter = { - "name": param.alias, - "in": field_info.in_.value, - "required": param.required, - "schema": param_schema, - } + return parameters - if field_info.description: - parameter["description"] = field_info.description + @staticmethod + def _is_pydantic_model_param(field_info: Param) -> bool: + """Check if the field info represents a Pydantic model parameter.""" + from pydantic import BaseModel - if field_info.openapi_examples: - parameter["examples"] = field_info.openapi_examples + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass - if field_info.deprecated: - parameter["deprecated"] = field_info.deprecated + return lenient_issubclass(field_info.annotation, BaseModel) - parameters.append(parameter) + @staticmethod + def _expand_pydantic_model_parameters(field_info: Param) -> list[dict[str, Any]]: + """Expand a Pydantic model into individual OpenAPI parameters.""" + from pydantic import BaseModel + + model_class = cast(type[BaseModel], field_info.annotation) + parameters: list[dict[str, Any]] = [] + + for field_name, field_def in model_class.model_fields.items(): + param_name = field_def.alias or field_name + individual_param = Route._create_pydantic_field_parameter( + param_name=param_name, + field_def=field_def, + param_location=field_info.in_.value, + ) + parameters.append(individual_param) return parameters + @staticmethod + def _create_pydantic_field_parameter( + param_name: str, + field_def: Any, + param_location: str, + ) -> dict[str, Any]: + """Create an OpenAPI parameter from a Pydantic field definition.""" + individual_param: dict[str, Any] = { + "name": param_name, + "in": param_location, + "required": field_def.is_required() if hasattr(field_def, "is_required") else field_def.default is ..., + "schema": Route._get_basic_type_schema(field_def.annotation or type(None)), + } + + if field_def.description: + individual_param["description"] = field_def.description + + return individual_param + + @staticmethod + def _create_regular_parameter( + param: ModelField, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + ) -> dict[str, Any]: + """Create an OpenAPI parameter from a regular ModelField.""" + from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field + from aws_lambda_powertools.event_handler.openapi.params import Param + + field_info = cast(Param, param.field_info) + param_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + parameter: dict[str, Any] = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + # Add optional attributes if present + if field_info.description: + parameter["description"] = field_info.description + if field_info.openapi_examples: + parameter["examples"] = field_info.openapi_examples + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + return parameter + + @staticmethod + def _get_basic_type_schema(param_type: type) -> dict[str, str]: + """ + Get basic OpenAPI schema for simple types + """ + try: + # Check bool before int, since bool is a subclass of int in Python + if issubclass(param_type, bool): + return {"type": "boolean"} + elif issubclass(param_type, int): + return {"type": "integer"} + elif issubclass(param_type, float): + return {"type": "number"} + else: + return {"type": "string"} + except TypeError: + # param_type may not be a type (e.g., typing.Optional[int]), fallback to string + return {"type": "string"} + @staticmethod def _openapi_operation_return( *, diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 6b1f37ae8a4..db9c73d7b39 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -3,7 +3,7 @@ import dataclasses import json import logging -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast from urllib.parse import parse_qs from pydantic import BaseModel @@ -13,8 +13,9 @@ _model_dump, _normalize_errors, _regenerate_error_with_loc, + field_annotation_is_sequence, get_missing_field_error, - is_sequence_field, + lenient_issubclass, ) from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder @@ -22,6 +23,8 @@ from aws_lambda_powertools.event_handler.openapi.params import Param if TYPE_CHECKING: + from pydantic.fields import FieldInfo + from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.api_gateway import Route from aws_lambda_powertools.event_handler.middlewares import NextMiddleware @@ -64,7 +67,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> ) # Normalize query values before validate this - query_string = _normalize_multi_query_string_with_param( + query_string = _normalize_multi_params( app.current_event.resolved_query_string_parameters, route.dependant.query_params, ) @@ -76,7 +79,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> ) # Normalize header values before validate this - headers = _normalize_multi_header_values_with_param( + headers = _normalize_multi_params( app.current_event.resolved_headers_field, route.dependant.header_params, ) @@ -366,7 +369,7 @@ def _request_body_to_args( _handle_missing_field_value(field, values, errors, loc) continue - value = _normalize_field_value(field, value) + value = _normalize_field_value(value=value, field_info=field.field_info) values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors @@ -409,10 +412,13 @@ def _handle_missing_field_value( values[field.name] = field.get_default() -def _normalize_field_value(field: ModelField, value: Any) -> Any: +def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any: """Normalize field value, converting lists to single values for non-sequence fields.""" - if isinstance(value, list) and not is_sequence_field(field): + if field_annotation_is_sequence(field_info.annotation): + return value + elif isinstance(value, list) and value: return value[0] + return value @@ -454,57 +460,70 @@ def _get_embed_body( return received_body, field_alias_omitted -def _normalize_multi_query_string_with_param( - query_string: dict[str, list[str]], +def _normalize_multi_params( + input_dict: MutableMapping[str, Any], params: Sequence[ModelField], -) -> dict[str, Any]: +) -> MutableMapping[str, Any]: """ - Extract and normalize resolved_query_string_parameters + Extract and normalize query string or header parameters with Pydantic model support. Parameters ---------- - query_string: dict - A dictionary containing the initial query string parameters. + input_dict: MutableMapping[str, Any] + A dictionary containing the initial query string or header parameters. params: Sequence[ModelField] A sequence of ModelField objects representing parameters. Returns ------- - A dictionary containing the processed multi_query_string_parameters. + MutableMapping[str, Any] + A dictionary containing the processed parameters with normalized values. """ - resolved_query_string: dict[str, Any] = query_string - for param in filter(is_scalar_field, params): - try: - # if the target parameter is a scalar, we keep the first value of the query string - # regardless if there are more in the payload - resolved_query_string[param.alias] = query_string[param.alias][0] - except KeyError: - pass - return resolved_query_string + for param in params: + if is_scalar_field(param): + _process_scalar_param(input_dict, param) + elif lenient_issubclass(param.field_info.annotation, BaseModel): + _process_model_param(input_dict, param) + return input_dict -def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]): - """ - Extract and normalize resolved_headers_field +def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: + """Process a scalar parameter by normalizing single-item lists.""" + try: + value = input_dict[param.alias] + if isinstance(value, list) and len(value) == 1: + input_dict[param.alias] = value[0] + except KeyError: + pass - Parameters - ---------- - headers: MutableMapping[str, Any] - A dictionary containing the initial header parameters. - params: Sequence[ModelField] - A sequence of ModelField objects representing parameters. - Returns - ------- - A dictionary containing the processed headers. - """ - if headers: - for param in filter(is_scalar_field, params): - try: - if len(headers[param.alias]) == 1: - # if the target parameter is a scalar and the list contains only 1 element - # we keep the first value of the headers regardless if there are more in the payload - headers[param.alias] = headers[param.alias][0] - except KeyError: - pass - return headers +def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: + """Process a Pydantic model parameter by extracting model fields.""" + model_class = cast(type[BaseModel], param.field_info.annotation) + + model_data = {} + for field_name, field_info in model_class.model_fields.items(): + field_alias = field_info.alias or field_name + value = _get_param_value(input_dict, field_alias, field_name, model_class) + + if value is not None: + model_data[field_alias] = _normalize_field_value(value=value, field_info=field_info) + + input_dict[param.alias] = model_data + + +def _get_param_value( + input_dict: MutableMapping[str, Any], + field_alias: str, + field_name: str, + model_class: type[BaseModel], +) -> Any: + """Get parameter value, checking both alias and field name if needed.""" + value = input_dict.get(field_alias) + if value is not None: + return value + + if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"): + value = input_dict.get(field_name) + + return value diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 98a8740a74f..310cab68e66 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -9,16 +9,13 @@ create_body_model, evaluate_forwardref, is_scalar_field, - is_scalar_sequence_field, ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, Dependant, Form, - Header, Param, ParamTypes, - Query, _File, analyze_param, create_response_field, @@ -275,7 +272,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): + elif isinstance(param_field.field_info, Param): return False else: if not isinstance(param_field.field_info, Body): diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 7b700e0b948..b97bf690109 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -4,7 +4,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Literal -from pydantic import BaseConfig +from pydantic import BaseConfig, BaseModel, create_model from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin @@ -17,6 +17,7 @@ copy_field_info, field_annotation_is_scalar, get_annotation_from_field_info, + lenient_issubclass, ) if TYPE_CHECKING: @@ -1160,6 +1161,42 @@ def create_response_field( return ModelField(**kwargs) # type: ignore[arg-type] +def _apply_header_underscore_conversion( + field_info: FieldInfo, + type_annotation: Any, + param_name: str, +) -> tuple[FieldInfo, Any]: + """ + Apply underscore-to-dash conversion for Header parameters. + + For BaseModel: Creates new model with underscore-to-dash alias generator. + Note: If the BaseModel already has an alias generator, it will be replaced + with dash-case conversion since HTTP headers should use dash-case. + For all Header fields: Sets the parameter alias if convert_underscores is True + """ + if not isinstance(field_info, Header) or not field_info.convert_underscores: + return field_info, type_annotation + + # Always set the parameter alias for Header fields (if not already set) + if not field_info.alias: + field_info.alias = param_name.replace("_", "-") + + # Handle BaseModel case - create new model with dash-case alias generator + if lenient_issubclass(type_annotation, BaseModel): + # For HTTP headers, we should use dash-case regardless of existing alias generator + # This ensures consistent header naming conventions + header_aliased_model = create_model( + f"{type_annotation.__name__}WithHeaderAliases", + __base__=type_annotation, + __config__={"alias_generator": lambda name: name.replace("_", "-")}, + ) + + type_annotation = header_aliased_model + field_info.annotation = type_annotation + + return field_info, type_annotation + + def _create_model_field( field_info: FieldInfo | None, type_annotation: Any, @@ -1178,21 +1215,17 @@ def _create_model_field( elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: field_info.in_ = ParamTypes.query + # Apply header underscore conversion + field_info, type_annotation = _apply_header_underscore_conversion(field_info, type_annotation, param_name) + # If the field_info is a Param, we use the `in_` attribute to determine the type annotation use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) - # If the field doesn't have a defined alias, we use the param name - if not field_info.alias and getattr(field_info, "convert_underscores", None): - alias = param_name.replace("_", "-") - else: - alias = field_info.alias or param_name - field_info.alias = alias - return create_response_field( name=param_name, type_=use_annotation, default=field_info.default, - alias=alias, + alias=field_info.alias, required=field_info.default in (Required, Undefined), field_info=field_info, ) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 5b0e1d7b4a4..f0a09994b35 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -11,6 +11,7 @@ Event handler for Amazon API Gateway REST and HTTP APIs, Application Load Balanc * Support for CORS, binary and Gzip compression, Decimals JSON encoding and bring your own JSON serializer * Built-in integration with [Event Source Data Classes utilities](../../utilities/data_classes.md){target="_blank"} for self-documented event schema * Works with micro function (one or a few routes) and monolithic functions (all routes) +* Support for Middleware * Support for OpenAPI schema generation * Support data validation for requests/responses @@ -303,7 +304,7 @@ If you need to accept multiple HTTP methods in a single function, or support a H For brevity, we'll focus on Pydantic only. -All resolvers can optionally coerce and validate incoming requests by setting `enable_validation=True`. +All resolvers can optionally coerce and validate incoming requests by setting `enable_validation=True`. You can use Pydantic models to validate request bodies, query strings, headers, and form parameters. With this feature, we can now express how we expect our incoming data and response to look like. This moves data validation responsibilities to Event Handler resolvers, reducing a ton of boilerplate code. @@ -519,6 +520,16 @@ In the following example, we use a new `Query` OpenAPI type to add [one out of m 1. `example_multi_value_param` is a list containing values from the `ExampleEnum` enumeration. +=== "validating_query_string_with_pydantic.py" + + You can use Pydantic models to define your query string parameters. + + ```python hl_lines="18-22 27" + --8<-- "examples/event_handler_rest/src/validating_query_string_with_pydantic.py" + ``` + + 1. `todo` is a Pydantic model. + #### Validating path parameters @@ -567,6 +578,16 @@ In the following example, we use a new `Header` OpenAPI type to add [one out of 1. `cloudfront_viewer_country` is a list that must contain values from the `CountriesAllowed` enumeration. +=== "validating_headers_with_pydantic.py" + + You can use Pydantic models to define your headers parameters. + + ```python hl_lines="18-22 27" + --8<-- "examples/event_handler_rest/src/validating_headers_with_pydantic.py" + ``` + + 1. `todo` is a Pydantic model. + #### Handling form data !!! info "You must set `enable_validation=True` to handle file uploads and form data via type annotation." diff --git a/examples/event_handler_rest/src/validating_headers_with_pydantic.py b/examples/event_handler_rest/src/validating_headers_with_pydantic.py new file mode 100644 index 00000000000..03f6c56e9fd --- /dev/null +++ b/examples/event_handler_rest/src/validating_headers_with_pydantic.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Header +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = APIGatewayRestResolver(enable_validation=True) + + +class Todo(BaseModel): + userId: int + id_: Optional[int] = Field(alias="id", default=None) + title: str + completed: bool + + +@app.get("/todos") +@tracer.capture_method +def get_todos(todo: Annotated[Todo, Header()]) -> Dict[str, Any]: # (1)! + return todo.model_dump() + + +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/validating_query_string_with_pydantic.py b/examples/event_handler_rest/src/validating_query_string_with_pydantic.py new file mode 100644 index 00000000000..75f6212e1a6 --- /dev/null +++ b/examples/event_handler_rest/src/validating_query_string_with_pydantic.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Query +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = APIGatewayRestResolver(enable_validation=True) + + +class Todo(BaseModel): + userId: int + id_: Optional[int] = Field(alias="id", default=None) + title: str + completed: bool + + +@app.get("/todos") +@tracer.capture_method +def get_todos(todo: Annotated[Todo, Query()]) -> Dict[str, Any]: # (1)! + return todo.model_dump() + + +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 19b5287d66a..f797cd541a5 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -25,6 +25,145 @@ JSON_CONTENT_TYPE = "application/json" +def test_openapi_pydantic_query_params(): + """Test that Pydantic models in Query parameters are expanded into individual fields in OpenAPI schema""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Number of items to return") + offset: int = Field(default=0, ge=0, description="Number of items to skip") + search: Optional[str] = Field(default=None, description="Search term") + + @app.get("/search") + def search_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/search" in schema.paths + path = schema.paths["/search"] + assert path.get is not None + + # Check that parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 + + # Check individual parameters + param_names = [param.name for param in get_operation.parameters] + assert "limit" in param_names + assert "offset" in param_names + assert "search" in param_names + + # Check parameter details + for param in get_operation.parameters: + assert param.in_ == ParameterInType.query + if param.name == "limit": + assert param.required is False # Has default value + assert param.description == "Number of items to return" + assert param.schema_.type == "integer" + elif param.name == "offset": + assert param.required is False # Has default value + assert param.description == "Number of items to skip" + assert param.schema_.type == "integer" + elif param.name == "search": + assert param.required is False # Optional field + assert param.description == "Search term" + assert param.schema_.type == "string" + + +def test_openapi_pydantic_header_params(): + """Test that Pydantic models in Header parameters are expanded into individual fields in OpenAPI schema""" + app = APIGatewayRestResolver() + + class HeaderParams(BaseModel): + authorization: str = Field(description="Authorization token") + user_agent: str = Field(default="PowerTools/1.0", description="User agent") + language: Optional[str] = Field(default=None, alias="accept-language", description="Language preference") + + @app.get("/protected") + def protected_handler(headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/protected" in schema.paths + path = schema.paths["/protected"] + assert path.get is not None + + # Check that parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 + + # Check individual parameters + param_names = [param.name for param in get_operation.parameters] + assert "authorization" in param_names + assert "user-agent" in param_names # headers are always spinal-case + assert "accept-language" in param_names # Should use alias + + # Check parameter details + for param in get_operation.parameters: + assert param.in_ == ParameterInType.header + if param.name == "authorization": + assert param.required is True # No default value + assert param.description == "Authorization token" + assert param.schema_.type == "string" + elif param.name == "user_agent": + assert param.required is False # Has default value + assert param.description == "User agent" + assert param.schema_.type == "string" + elif param.name == "accept-language": + assert param.required is False # Optional field + assert param.description == "Language preference" + assert param.schema_.type == "string" + + +def test_openapi_pydantic_mixed_params(): + """Test that mixed Pydantic models (Query + Header) work together""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + q: str = Field(description="Search query") + limit: int = Field(default=10, description="Number of results") + + class HeaderParams(BaseModel): + authorization: str = Field(description="Bearer token") + + @app.get("/mixed") + def mixed_handler(query: Annotated[QueryParams, Query()], headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/mixed" in schema.paths + path = schema.paths["/mixed"] + assert path.get is not None + + # Check that all parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 # 2 query + 1 header + + # Check parameter types + query_params = [p for p in get_operation.parameters if p.in_ == ParameterInType.query] + header_params = [p for p in get_operation.parameters if p.in_ == ParameterInType.header] + + assert len(query_params) == 2 + assert len(header_params) == 1 + + # Check specific parameters + query_names = [p.name for p in query_params] + assert "q" in query_names + assert "limit" in query_names + + header_names = [p.name for p in header_params] + assert "authorization" in header_names + + def test_openapi_no_params(): app = APIGatewayRestResolver() @@ -776,3 +915,132 @@ def form_edge_cases( assert "required_field" in component_schema.required assert "optional_field" not in component_schema.required # Optional assert "field_with_default" not in component_schema.required # Has default + + +def test_openapi_pydantic_query_with_constraints(): + """Test that Pydantic field constraints are preserved in OpenAPI schema""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + limit: int = Field(ge=1, le=100, description="Number of items") + name: str = Field(min_length=1, max_length=50, description="Name filter") + + @app.get("/items") + def get_items(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/items"] + get_operation = path.get + + # Find the limit parameter + limit_param = next(p for p in get_operation.parameters if p.name == "limit") + assert limit_param.schema_.type == "integer" + assert limit_param.description == "Number of items" + + # Find the name parameter + name_param = next(p for p in get_operation.parameters if p.name == "name") + assert name_param.schema_.type == "string" + assert name_param.description == "Name filter" + + +def test_openapi_pydantic_header_with_alias(): + """Test that Pydantic field aliases work correctly in Header parameters""" + app = APIGatewayRestResolver() + + class HeaderParams(BaseModel): + content_type: str = Field(alias="content-type", description="Content type") + user_agent: str = Field(alias="user-agent", description="User agent") + + @app.get("/test") + def test_handler(headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/test"] + get_operation = path.get + + # Check that aliases are used as parameter names + param_names = [param.name for param in get_operation.parameters] + assert "content-type" in param_names + assert "user-agent" in param_names + assert "content_type" not in param_names # Original field name should not be used + assert "user_agent" not in param_names + + +def test_openapi_pydantic_required_vs_optional(): + """Test that required vs optional fields are correctly identified""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + required_field: str = Field(description="Required field") + optional_with_default: str = Field(default="default", description="Optional with default") + optional_nullable: Optional[str] = Field(default=None, description="Optional nullable") + + @app.get("/test") + def test_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/test"] + get_operation = path.get + + for param in get_operation.parameters: + if param.name == "required_field": + assert param.required is True + elif param.name == "optional_with_default": + assert param.required is False + elif param.name == "optional_nullable": + assert param.required is False + + +def test_openapi_pydantic_backward_compatibility(): + """Test that existing Body parameter behavior is unchanged""" + app = APIGatewayRestResolver() + + class BodyModel(BaseModel): + name: str = Field(description="Name") + age: int = Field(description="Age") + + @app.post("/users") + def create_user(user: BodyModel): # No annotation - should work as Body + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/users"] + post_operation = path.post + + # Should have no parameters (body is handled separately) + assert post_operation.parameters is None or len(post_operation.parameters) == 0 + + # Should have request body + assert post_operation.requestBody is not None + assert "application/json" in post_operation.requestBody.content + + +def test_openapi_pydantic_complex_types(): + """Test that complex types are handled correctly""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + string_field: str = Field(description="String field") + int_field: int = Field(description="Integer field") + float_field: float = Field(description="Float field") + bool_field: bool = Field(description="Boolean field") + + @app.get("/complex") + def complex_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/complex"] + get_operation = path.get + + type_mapping = {} + for param in get_operation.parameters: + type_mapping[param.name] = param.schema_.type + + assert type_mapping["string_field"] == "string" + assert type_mapping["int_field"] == "integer" + assert type_mapping["float_field"] == "number" + assert type_mapping["bool_field"] == "boolean" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 58a3f19e504..24e6de0db43 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import pytest -from pydantic import BaseModel, Field +from pydantic import AfterValidator, Base64UrlStr, BaseModel, ConfigDict, Field, StringConstraints, alias_generators from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -47,6 +47,407 @@ def handler(user_id: int): assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) +def test_validate_pydantic_query_params(gw_event): + """Test that Pydantic models in Query parameters are validated correctly""" + + app = APIGatewayRestResolver(enable_validation=True) + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Number of items") + search: Optional[str] = Field(default=None, description="Search term") + + @app.get("/search") + def search_handler(params: Annotated[QueryParams, Query()]): + return { + "limit": params.limit, + "search": params.search, + } + + # Test valid request + gw_event["path"] = "/search" + gw_event["queryStringParameters"] = {"limit": "25", "search": "python"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["limit"] == 25 + assert body["search"] == "python" + + # Test with default values + gw_event["queryStringParameters"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["limit"] == 10 # Default value + assert body["search"] is None # Default value + + # Test validation error (limit too high) + gw_event["queryStringParameters"] = {"limit": "150"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("limit" in str(error) for error in body["detail"]) + + +def test_validate_multi_value_query_params(gw_event): + """Test that multi-value query parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/users") + def users_handler(ids: Annotated[List[int], Query()]): + return {"ids": ids} + + # Test valid request + gw_event["path"] = "/users" + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "2", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["ids"] == [1, 2, 3] + + # Test with invalid value + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "abc", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("ids" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_multi_value_query_params(gw_event): + """Test that Pydantic models in Multi-Value Query parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + class QueryParams(BaseModel): + ids: List[int] = Field(..., description="List of user IDs") + + @app.get("/users") + def users_handler(params: Annotated[QueryParams, Query()]): + return {"ids": params.ids} + + # Test valid request + gw_event["path"] = "/users" + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "2", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["ids"] == [1, 2, 3] + + # Test with invalid value + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "abc", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("ids" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_query_params_detailed_errors(gw_event): + """Test that Pydantic validation errors include detailed field-level information""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + full_name: str = Field(..., min_length=5, description="Full name with minimum 5 characters") + age: int = Field(..., ge=18, le=100, description="Age between 18 and 100") + + @app.get("/query-model") + def query_model(params: Annotated[QueryParams, Query()]): + return {"full_name": params.full_name, "age": params.age} + + # Test validation error with detailed field information + gw_event["path"] = "/query-model" + gw_event["queryStringParameters"] = {"full_name": "Jo", "age": "15"} # Both invalid + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + + # Check that we get detailed field-level errors + errors = body["detail"] + + # Should have errors for both fields + full_name_error = next((e for e in errors if "full_name" in e["loc"]), None) + age_error = next((e for e in errors if "age" in e["loc"]), None) + + assert full_name_error is not None, "Should have error for full_name field" + assert age_error is not None, "Should have error for age field" + + # Check error details for full_name + assert full_name_error["loc"] == ["query", "params", "full_name"] + assert full_name_error["type"] == "string_too_short" + + # Check error details for age + assert age_error["loc"] == ["query", "params", "age"] + assert age_error["type"] == "greater_than_equal" + + +def test_validate_pydantic_header_params(gw_event): + """Test that Pydantic models in Header parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + class HeaderParams(BaseModel): + authorization: str = Field(description="Authorization token") + user_agent: str = Field(default="PowerTools/1.0", description="User agent") + + @app.get("/protected") + def protected_handler(my_headers: Annotated[HeaderParams, Header()]): + return { + "authorization": my_headers.authorization, + "user_agent": my_headers.user_agent, + } + + # Test valid request + gw_event["path"] = "/protected" + gw_event["headers"] = {"authorization": "Bearer token123", "user-agent": "TestClient/1.0"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["authorization"] == "Bearer token123" + assert body["user_agent"] == "TestClient/1.0" + + # Test with default value + gw_event["headers"] = {"authorization": "Bearer token123"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["authorization"] == "Bearer token123" + assert body["user_agent"] == "PowerTools/1.0" # Default value + + # Test missing required header + gw_event["headers"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_multi_value_header_params(gw_event): + """Test that multi-value headers are validated correctly without Pydantic""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + @app.get("/multi-value-headers") + def multi_value_handler(my_headers: Annotated[List[str], Header()]): + return {"items": my_headers} + + # Test valid request + gw_event["path"] = "/multi-value-headers" + gw_event["multiValueHeaders"] = {"my-headers": ["item1", "item2"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["items"] == ["item1", "item2"] + + # Test invalid request + gw_event["multiValueHeaders"] = {"items": "invalid"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_multi_value_header_params(gw_event): + """Test that multi-value headers are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + class MultiValueHeaderParams(BaseModel): + list_items: List[str] = Field(description="List of items") + + @app.get("/multi-value-headers") + def multi_value_handler(my_headers: Annotated[MultiValueHeaderParams, Header()]): + return {"items": my_headers.list_items} + + # Test valid request + gw_event["path"] = "/multi-value-headers" + gw_event["multiValueHeaders"] = {"list-items": ["item1", "item2"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["items"] == ["item1", "item2"] + + # Test invalid request + gw_event["multiValueHeaders"] = {"items": "invalid"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_header_snake_case_to_kebab_case_schema(gw_event): + """Test that snake_case header fields are converted to kebab-case in OpenAPI schema and validation""" + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger() + + class HeaderParams(BaseModel): + correlation_id: str = Field(description="Correlation ID header") + user_agent: str = Field(default="PowerTools/1.0", description="User agent header") + accept: str = Field(default="application/json") # omit description to test optional description + + @app.get("/kebab-headers") + def kebab_handler(my_headers: Annotated[HeaderParams, Header()]): + return { + "correlation_id": my_headers.correlation_id, + "user_agent": my_headers.user_agent, + } + + # Test that OpenAPI schema uses kebab-case for headers + openapi_schema = app.get_openapi_schema() + operation = openapi_schema.paths["/kebab-headers"].get + parameters = operation.parameters + + # Find the correlation_id parameter + correlation_param = next((p for p in parameters if p.name == "correlation-id"), None) + assert correlation_param is not None, "Should have correlation-id parameter in kebab-case" + + # Find the user_agent parameter + user_agent_param = next((p for p in parameters if p.name == "user-agent"), None) + assert user_agent_param is not None, "Should have user-agent parameter in kebab-case" + + # Test validation with kebab-case headers + gw_event["path"] = "/kebab-headers" + gw_event["multiValueHeaders"] = {"correlation-id": "test-123", "user-agent": "TestClient/1.0"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["correlation_id"] == "test-123" + assert body["user_agent"] == "TestClient/1.0" + + +def test_validate_pydantic_mixed_params(gw_event): + """Test that mixed Pydantic models (Query + Header) are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + q: str = Field(description="Search query") + limit: int = Field(default=10, description="Number of results") + + class HeaderParams(BaseModel): + authorization: str = Field(description="Bearer token") + + @app.get("/mixed") + def mixed_handler(query: Annotated[QueryParams, Query()], headers: Annotated[HeaderParams, Header()]): + return { + "query": {"q": query.q, "limit": query.limit}, + "headers": {"authorization": headers.authorization}, + } + + # Test valid request + gw_event["path"] = "/mixed" + gw_event["queryStringParameters"] = {"q": "python", "limit": "25"} + gw_event["headers"] = {"authorization": "Bearer token123"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["query"]["q"] == "python" + assert body["query"]["limit"] == 25 + assert body["headers"]["authorization"] == "Bearer token123" + + # Test missing required query parameter + gw_event["queryStringParameters"] = {"limit": "25"} # Missing 'q' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("q" in str(error) for error in body["detail"]) + + # Test missing required header + gw_event["queryStringParameters"] = {"q": "python"} + gw_event["headers"] = {} # Missing 'authorization' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_with_alias(gw_event): + """Test that Pydantic models with field aliases work correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class HeaderParams(BaseModel): + accept_language: str = Field(alias="accept-language", description="Language preference") + + @app.get("/alias") + def alias_handler(headers: Annotated[HeaderParams, Header()]): + return {"accept_language": headers.accept_language} + + # Test with alias in request + gw_event["path"] = "/alias" + gw_event["headers"] = {"accept-language": "en-US"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["accept_language"] == "en-US" + + # Test missing aliased field + gw_event["headers"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("headers" in str(error) for error in body["detail"]) + + def test_validate_scalars_with_default(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -818,8 +1219,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -884,8 +1285,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -947,8 +1348,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -1012,8 +1413,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -1076,8 +1477,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -1140,8 +1541,8 @@ def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -2137,3 +2538,137 @@ def create_action(action: Annotated[action_type, Body()]): result = app(gw_event, {}) assert result["statusCode"] == 422 + + +def test_validate_pydantic_query_params_with_config_dict_and_validators(gw_event): + """Test that Pydantic models with ConfigDict, aliases, and validators work correctly""" + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + app = APIGatewayRestResolver(enable_validation=True) + + def _validate_powertools(value: str) -> str: + if not value.startswith("Powertools"): + raise ValueError("Full name must start with 'Powertools'") + return value + + class QuerySimple(BaseModel): + full_name: Annotated[str, StringConstraints(min_length=5), AfterValidator(_validate_powertools)] + next_token: Base64UrlStr + search_id: str + + @app.get("/query-model-simple") + def query_model(params: Annotated[QuerySimple, Query()]) -> Dict[str, Any]: + return { + "fullName": params.full_name, + "nextToken": params.next_token, + "searchId": params.search_id, + } + + class QueryAdvanced(BaseModel): + full_name: Annotated[str, StringConstraints(min_length=5)] + next_token: str + search_id: Annotated[str, Field(alias="id")] # Using str instead of UUID4 for simpler testing + + model_config = ConfigDict( + alias_generator=alias_generators.to_camel, + validate_by_alias=True, + validate_by_name=True, + serialize_by_alias=True, + ) + + @app.get("/query-model-advanced") + def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> Dict[str, Any]: + return params.model_dump() + + # Test QuerySimple with validators + gw_event["path"] = "/query-model-simple" + gw_event["queryStringParameters"] = { + "full_name": "Powertools Lambda", + "next_token": "dGVzdA==", # base64url encoded "test" + "search_id": "search-123", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["fullName"] == "Powertools Lambda" + assert body["nextToken"] == "test" + assert body["searchId"] == "search-123" + + # Test QuerySimple validation error (name doesn't start with "Powertools") + gw_event["queryStringParameters"] = { + "full_name": "Lambda Powertools", + "next_token": "dGVzdA==", + "search_id": "search-123", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + errors = body["detail"] + + # Should have validation error for full_name with proper location + full_name_error = next((e for e in errors if "full_name" in e["loc"]), None) + + assert full_name_error is not None, "Should have error for full_name field" + + # Check error details for full_name + assert full_name_error["loc"] == ["query", "params", "full_name"] + assert full_name_error["type"] == "value_error" + + # Test QueryAdvanced with ConfigDict and alias_generator + gw_event["path"] = "/query-model-advanced" + gw_event["queryStringParameters"] = { + "fullName": "Advanced Test", # camelCase from alias_generator + "nextToken": "dGVzdA==", # camelCase from alias_generator + "id": "search-456", # explicit alias + } + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + # Should return with camelCase keys due to serialize_by_alias=True + assert body["fullName"] == "Advanced Test" + assert body["nextToken"] == "dGVzdA==" + assert body["id"] == "search-456" + + # Test QueryAdvanced with snake_case field names due to validate_by_name=True + gw_event["queryStringParameters"] = { + "full_name": "Snake Case Test", # snake_case field name + "next_token": "dGVzdA==", # snake_case field name + "search_id": "search-789", # snake_case field name + } + + gw_event["path"] = "/query-model-advanced" + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["fullName"] == "Snake Case Test" + assert body["nextToken"] == "dGVzdA==" + assert body["id"] == "search-789" + + # Test QueryAdvanced validation error (full_name too short) + gw_event["queryStringParameters"] = { + "fullName": "Bad", # Too short (min_length=5) + "nextToken": "dGVzdA==", + "id": "search-456", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + errors = body["detail"] + + # Should have validation error for full_name with proper location + full_name_error = next((e for e in errors if "full_name" in e["loc"] or "fullName" in e["loc"]), None) + assert full_name_error is not None + assert full_name_error["type"] == "string_too_short"