diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 6b2c691442f..df8adfe303a 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -14,6 +14,7 @@ # We use this for forward reference, as it allows us to handle forward references in type annotations. from pydantic._internal._typing_extra import eval_type_lenient from pydantic._internal._utils import lenient_issubclass +from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic_core import PydanticUndefined, PydanticUndefinedType from typing_extensions import Annotated, Literal, get_args, get_origin @@ -186,8 +187,36 @@ def model_rebuild(model: type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: # Create a shallow copy of the field_info to preserve its type and all attributes new_field = copy(field_info) - # Update only the annotation to the new one - new_field.annotation = annotation + + # Recursively extract all metadata from nested Annotated types + def extract_metadata(ann: Any) -> tuple[Any, list[Any]]: + """Extract base type and all non-FieldInfo metadata from potentially nested Annotated types.""" + if get_origin(ann) is not Annotated: + return ann, [] + + args = get_args(ann) + base_type = args[0] + metadata = list(args[1:]) + + # If base type is also Annotated, recursively extract its metadata + if get_origin(base_type) is Annotated: + inner_base, inner_metadata = extract_metadata(base_type) + all_metadata = [m for m in inner_metadata + metadata if not isinstance(m, PydanticFieldInfo)] + return inner_base, all_metadata + else: + constraint_metadata = [m for m in metadata if not isinstance(m, PydanticFieldInfo)] + return base_type, constraint_metadata + + # Extract base type and constraints + base_type, constraints = extract_metadata(annotation) + + # Set the annotation with base type and all constraint metadata + # Use tuple unpacking for Python 3.9+ compatibility + if constraints: + new_field.annotation = Annotated[(base_type, *constraints)] + else: + new_field.annotation = base_type + return new_field diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index b97bf690109..0439b1c2fc1 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1110,6 +1110,10 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup type_annotation = annotated_args[0] powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] + # Preserve non-FieldInfo metadata (like annotated_types constraints) + # This is important for constraints like Interval, Gt, Lt, etc. + other_metadata = [arg for arg in annotated_args[1:] if not isinstance(arg, FieldInfo)] + # Determine which annotation to use powertools_annotation: FieldInfo | None = None has_discriminator_with_param = False @@ -1124,6 +1128,11 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup else: powertools_annotation = next(iter(powertools_annotations), None) + # Reconstruct type_annotation with non-FieldInfo metadata if present + # This ensures constraints like Interval are preserved + if other_metadata and not has_discriminator_with_param: + type_annotation = Annotated[(type_annotation, *other_metadata)] + # Process the annotation if it exists field_info: FieldInfo | None = None if isinstance(powertools_annotation, FieldInfo): # pragma: no cover diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index f797cd541a5..4c9087fff13 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -1,7 +1,9 @@ +import json from dataclasses import dataclass from datetime import datetime from typing import List, Optional, Tuple +import pytest from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -1044,3 +1046,185 @@ def complex_handler(params: Annotated[QueryParams, Query()]): assert type_mapping["int_field"] == "integer" assert type_mapping["float_field"] == "number" assert type_mapping["bool_field"] == "boolean" + + +@pytest.mark.parametrize( + "body_value,expected_value", + [ + ("50", 50), # Valid: within range + ("0", 0), # Valid: at lower bound + ("100", 100), # Valid: at upper bound + ], +) +def test_annotated_types_interval_constraints_in_body_params(body_value, expected_value): + """ + Test for issue #7600: Validate that annotated_types.Interval constraints + are properly enforced in Body parameters with valid values. + """ + from annotated_types import Interval + + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # AND a constrained type using annotated_types.Interval + ConstrainedInt = Annotated[int, Interval(ge=0, le=100)] + + @app.post("/items") + def create_item(value: Annotated[ConstrainedInt, Body()]): + return {"value": value} + + # WHEN sending a request with a valid value + event = { + "resource": "/items", + "path": "/items", + "httpMethod": "POST", + "body": body_value, + "isBase64Encoded": False, + } + + # THEN the request should succeed + result = app(event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["value"] == expected_value + + +@pytest.mark.parametrize( + "body_value", + [ + "-1", # Invalid: below range + "101", # Invalid: above range + ], +) +def test_annotated_types_interval_constraints_in_body_params_invalid(body_value): + """ + Test for issue #7600: Validate that annotated_types.Interval constraints + reject invalid values in Body parameters. + """ + from annotated_types import Interval + + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # AND a constrained type using annotated_types.Interval + constrained_int = Annotated[int, Interval(ge=0, le=100)] + + @app.post("/items") + def create_item(value: Annotated[constrained_int, Body()]): + return {"value": value} + + # WHEN sending a request with an invalid value + event = { + "resource": "/items", + "path": "/items", + "httpMethod": "POST", + "body": body_value, + "isBase64Encoded": False, + } + + # THEN validation should fail + result = app(event, {}) + assert result["statusCode"] == 422 + + +@pytest.mark.parametrize( + "query_value,expected_value", + [ + ("50", 50), # Valid: within range + ("0", 0), # Valid: at lower bound + ("100", 100), # Valid: at upper bound + ], +) +def test_annotated_types_interval_constraints_in_query_params(query_value, expected_value): + """ + Test for issue #7600: Validate that annotated_types.Interval constraints + are properly enforced in Query parameters with valid values. + """ + from annotated_types import Interval + + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # AND a constrained type using annotated_types.Interval + constrained_int = Annotated[int, Interval(ge=0, le=100)] + + @app.get("/items") + def list_items(limit: Annotated[constrained_int, Query()]): + return {"limit": limit} + + # WHEN sending a request with a valid value + event = { + "resource": "/items", + "path": "/items", + "httpMethod": "GET", + "queryStringParameters": {"limit": query_value}, + "isBase64Encoded": False, + } + + # THEN the request should succeed + result = app(event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["limit"] == expected_value + + +@pytest.mark.parametrize( + "query_value", + [ + "-1", # Invalid: below range + "101", # Invalid: above range + ], +) +def test_annotated_types_interval_constraints_in_query_params_invalid(query_value): + """ + Test for issue #7600: Validate that annotated_types.Interval constraints + reject invalid values in Query parameters. + """ + from annotated_types import Interval + + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # AND a constrained type using annotated_types.Interval + constrained_int = Annotated[int, Interval(ge=0, le=100)] + + @app.get("/items") + def list_items(limit: Annotated[constrained_int, Query()]): + return {"limit": limit} + + # WHEN sending a request with an invalid value + event = { + "resource": "/items", + "path": "/items", + "httpMethod": "GET", + "queryStringParameters": {"limit": query_value}, + "isBase64Encoded": False, + } + + # THEN validation should fail + result = app(event, {}) + assert result["statusCode"] == 422 + + +def test_annotated_types_interval_in_openapi_schema(): + """ + Test that annotated_types.Interval constraints are reflected in the OpenAPI schema. + """ + from annotated_types import Interval + + app = APIGatewayRestResolver() + constrained_int = Annotated[int, Interval(ge=0, le=100)] + + @app.get("/items") + def list_items(limit: Annotated[constrained_int, Query()] = 10): + return {"limit": limit} + + schema = app.get_openapi_schema() + + # Verify the Query parameter schema includes constraints + get_operation = schema.paths["/items"].get + limit_param = next(p for p in get_operation.parameters if p.name == "limit") + + assert limit_param.schema_.type == "integer" + assert limit_param.schema_.default == 10 + assert limit_param.required is False