diff --git a/tools/semgrep.yml b/tools/semgrep.yml index 4e3b778fea2123..35b398d5bb0d58 100644 --- a/tools/semgrep.yml +++ b/tools/semgrep.yml @@ -186,3 +186,26 @@ rules: or use @typed_endpoint_without_parameters instead. languages: [python] severity: ERROR + + - id: dont-nest-annotated-types-with-param-config + patterns: + - pattern-not: | + def $F(..., invalid_param: typing.Optional[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>], ...) -> ...: + ... + - pattern-not: | + def $F(..., $A: typing_extensions.Annotated[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>], ...) -> ...: + ... + - pattern-not: | + def $F(..., $A: typing_extensions.Annotated[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>] = ..., ...) -> ...: + ... + - pattern-either: + - pattern: | + def $F(..., $A: $B[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>], ...) -> ...: + ... + - pattern: | + def $F(..., $A: $B[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>] = ..., ...) -> ...: + ... + message: | + Annotated types containing zerver.lib.typed_endpoint.ApiParamConfig should not be nested inside Optional. Use Annotated[Optional[...], zerver.lib.typed_endpoint.ApiParamConfig(...)] instead. + languages: [python] + severity: ERROR diff --git a/zerver/lib/typed_endpoint.py b/zerver/lib/typed_endpoint.py index ce930d104de057..6307c0311b2932 100644 --- a/zerver/lib/typed_endpoint.py +++ b/zerver/lib/typed_endpoint.py @@ -1,5 +1,6 @@ import inspect import json +import sys from dataclasses import dataclass from enum import Enum, auto from functools import wraps @@ -167,6 +168,41 @@ def is_annotated(type_annotation: Type[object]) -> bool: return origin is Annotated +def is_optional(type_annotation: Type[object]) -> bool: + origin = get_origin(type_annotation) + type_args = get_args(type_annotation) + return origin is Union and type(None) in type_args and len(type_args) == 2 + + +API_PARAM_CONFIG_USAGE_HINT = f""" + Detected incorrect usage of Annotated types for parameter {{param_name}}! + Check the placement of the {ApiParamConfig.__name__} object in the type annotation: + + {{param_name}}: {{param_type}} + + The Annotated[T, ...] type annotation containing the + {ApiParamConfig.__name__} object should not be nested inside another type. + + Correct examples: + + # Using Optional inside Annotated + param: Annotated[Optional[int], ApiParamConfig(...)] + param: Annotated[Optional[int], ApiParamConfig(...)]] = None + + # Not using Optional when the default is not None + param: Annotated[int, ApiParamConfig(...)] + + Incorrect examples: + + # Nesting Annotated inside Optional + param: Optional[Annotated[int, ApiParamConfig(...)]] + param: Optional[Annotated[int, ApiParamConfig(...)]] = None + + # Nesting the Annotated type carrying ApiParamConfig inside other types like Union + param: Union[str, Annotated[int, ApiParamConfig(...)]] +""" + + def parse_single_parameter( param_name: str, param_type: Type[T], parameter: inspect.Parameter ) -> FuncParam[T]: @@ -181,13 +217,24 @@ def parse_single_parameter( # otherwise causes undesired behaviors that the annotated metadata gets # lost. This is fixed in Python 3.11: # https://github.com/python/cpython/issues/90353 - if param_default is None: - origin = get_origin(param_type) + if ( + sys.version_info < (3, 11) and param_default is None + ): # nocoverage # We lose coverage of this with Python 3.11+ only type_args = get_args(param_type) - if origin is Union and type(None) in type_args and len(type_args) == 2: - inner_type = type_args[0] if type_args[1] is type(None) else type_args[1] - if is_annotated(inner_type): - param_type = inner_type + assert is_optional(param_type) + inner_type = type_args[0] if type_args[1] is type(None) else type_args[1] + if is_annotated(inner_type): + annotated_type, *annotations = get_args(inner_type) + has_api_param_config = any( + isinstance(annotation, ApiParamConfig) for annotation in annotations + ) + # This prohibits the use of `Optional[Annotated[T, ApiParamConfig(...)]] = None` + # and encourage `Annotated[Optional[T], ApiParamConfig(...)] = None` + # to avoid confusion when the parameter metadata is unintentionally nested. + assert not has_api_param_config or is_optional( + annotated_type + ), API_PARAM_CONFIG_USAGE_HINT.format(param_name=param_name, param_type=param_type) + param_type = inner_type param_config: Optional[ApiParamConfig] = None json_wrapper = False @@ -196,7 +243,7 @@ def parse_single_parameter( # metadata attached to Annotated. Note that we do not transform # param_type to its underlying type because the Annotated metadata might # still be needed by other parties like Pydantic. - _, *annotations = get_args(param_type) + ignored_type, *annotations = get_args(param_type) for annotation in annotations: if type(annotation) is Json: json_wrapper = True @@ -204,6 +251,14 @@ def parse_single_parameter( continue assert param_config is None, "ApiParamConfig can only be defined once per parameter" param_config = annotation + else: + # When no parameter configuration is found, assert that there is none + # nested somewhere in a Union type to avoid silently ignoring it. If it + # does present in the stringified parameter type, it is very likely a + # programming error. + assert ApiParamConfig.__name__ not in str(param_type), API_PARAM_CONFIG_USAGE_HINT.format( + param_name=param_name, param_type=param_type + ) # Set param_config to a default early to avoid additional None-checks. if param_config is None: param_config = ApiParamConfig() diff --git a/zerver/tests/test_typed_endpoint.py b/zerver/tests/test_typed_endpoint.py index 9d7ad47e89c1a5..c1431793a0ae93 100644 --- a/zerver/tests/test_typed_endpoint.py +++ b/zerver/tests/test_typed_endpoint.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast import orjson from django.core.exceptions import ValidationError as DjangoValidationError from django.http import HttpRequest, HttpResponse -from pydantic import BaseModel, ConfigDict, Json +from pydantic import BaseModel, ConfigDict, Json, StringConstraints from pydantic.dataclasses import dataclass from typing_extensions import Annotated @@ -18,6 +18,7 @@ PathOnly, RequiredStringConstraint, WebhookPayload, + is_optional, typed_endpoint, typed_endpoint_without_parameters, ) @@ -36,6 +37,16 @@ def call_endpoint( class TestEndpoint(ZulipTestCase): + def test_is_optional(self) -> None: + """This test is only needed because we don't + have coverage of is_optional in Python 3.11. + """ + type = cast(Type[Optional[str]], Optional[str]) + self.assertTrue(is_optional(type)) + + type = str + self.assertFalse(is_optional(str)) + def test_coerce(self) -> None: @typed_endpoint def view(request: HttpRequest, *, strict_int: int) -> int: @@ -409,6 +420,48 @@ def view3( ) self.assertFalse(result) + # Not nesting the Annotated type with the ApiParamConfig inside Optional is fine + @typed_endpoint + def no_nesting( + request: HttpRequest, + *, + bar: Annotated[ + Optional[str], + StringConstraints(strip_whitespace=True, max_length=3), + ApiParamConfig("test"), + ] = None, + ) -> None: + raise AssertionError + + with self.assertRaisesMessage(ApiParamValidationError, "test is too long"): + call_endpoint(no_nesting, HostRequestMock({"test": "long"})) + + # Nesting Annotated with ApiParamConfig inside Optional is not fine + def nesting_with_config( + request: HttpRequest, + *, + invalid_param: Optional[Annotated[str, ApiParamConfig("test")]] = None, + ) -> None: + raise AssertionError + + with self.assertRaisesRegex( + AssertionError, + "Detected incorrect usage of Annotated types for parameter invalid_param!", + ): + typed_endpoint(nesting_with_config) + + # Nesting Annotated inside Optional, when ApiParamConfig is not also nested is fine + @typed_endpoint + def nesting_without_config( + request: HttpRequest, + *, + bar: Optional[Annotated[str, StringConstraints(max_length=3)]] = None, + ) -> None: + raise AssertionError + + with self.assertRaisesMessage(ApiParamValidationError, "bar is too long"): + call_endpoint(nesting_without_config, HostRequestMock({"bar": "long"})) + def test_aliases(self) -> None: @typed_endpoint def foo(