From c38791170331006961ddb8e1b506211d441297db Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Wed, 16 Aug 2023 18:35:10 -0400 Subject: [PATCH] api: Avoid programming errors due to nested Annotated types. We want to reject ambiguous type annotations that set ApiParamConfig inside a Union. If a parameter is Optional and has a default of None, we prefer Annotated[Optional[T], ...] over Optional[Annotated[T, ...]]. This implements a check that detects Optional[Annotated[T, ...]] and raise an assertion error if ApiParamConfig is in the annotation. It also checks if the type annotation contains any ApiParamConfig objects that are ignored, which can happen if the Annotated type is nested inside another type like List, Union, etc. Note that because param: Annotated[Optional[T], ...] = None and param: Optional[Annotated[Optional[T], ...]] = None are equivalent in runtime prior to Python 3.11, there is no way for us to distinguish the two. So we cannot detect that in runtime. See also: https://github.com/python/cpython/issues/90353 --- tools/semgrep.yml | 23 ++++++++++ zerver/lib/typed_endpoint.py | 69 ++++++++++++++++++++++++++--- zerver/tests/test_typed_endpoint.py | 57 +++++++++++++++++++++++- 3 files changed, 140 insertions(+), 9 deletions(-) diff --git a/tools/semgrep.yml b/tools/semgrep.yml index 8b4b32a545d7c8..6a5e12983eec38 100644 --- a/tools/semgrep.yml +++ b/tools/semgrep.yml @@ -186,3 +186,26 @@ rules: or use @typed_endpoint_without_parameters instead. languages: [python] severity: ERROR + + - id: dont-nest-annotated-types-with-param-config + patterns: + - pattern-not: | + def $F(..., invalid_param: typing.Optional[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>], ...) -> ...: + ... + - pattern-not: | + def $F(..., $A: typing_extensions.Annotated[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>], ...) -> ...: + ... + - pattern-not: | + def $F(..., $A: typing_extensions.Annotated[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>] = ..., ...) -> ...: + ... + - pattern-either: + - pattern: | + def $F(..., $A: $B[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>], ...) -> ...: + ... + - pattern: | + def $F(..., $A: $B[<... zerver.lib.typed_endpoint.ApiParamConfig(...) ...>] = ..., ...) -> ...: + ... + message: | + Annotated types containing zerver.lib.typed_endpoint.ApiParamConfig should not be nested inside Optional. Use Annotated[Optional[...], zerver.lib.typed_endpoint.ApiParamConfig(...)] instead. + languages: [python] + severity: ERROR diff --git a/zerver/lib/typed_endpoint.py b/zerver/lib/typed_endpoint.py index f36cb66d6457ae..414a2b6b6e0077 100644 --- a/zerver/lib/typed_endpoint.py +++ b/zerver/lib/typed_endpoint.py @@ -1,5 +1,6 @@ import inspect import json +import sys from dataclasses import dataclass from enum import Enum, auto from functools import wraps @@ -157,6 +158,41 @@ def is_annotated(type_annotation: Type[object]) -> bool: return origin is Annotated +def is_optional(type_annotation: Type[object]) -> bool: + origin = get_origin(type_annotation) + type_args = get_args(type_annotation) + return origin is Union and type(None) in type_args and len(type_args) == 2 + + +API_PARAM_CONFIG_USAGE_HINT = f""" + Detected incorrect usage of Annotated types for parameter {{param_name}}! + Check the placement of the {ApiParamConfig.__name__} object in the type annotation: + + {{param_name}}: {{param_type}} + + The Annotated[T, ...] type annotation containing the + {ApiParamConfig.__name__} object should not be nested inside another type. + + Correct examples: + + # Using Optional inside Annotated + param: Annotated[Optional[int], ApiParamConfig(...)] + param: Annotated[Optional[int], ApiParamConfig(...)]] = None + + # Not using Optional when the default is not None + param: Annotated[int, ApiParamConfig(...)] + + Incorrect examples: + + # Nesting Annotated inside Optional + param: Optional[Annotated[int, ApiParamConfig(...)]] + param: Optional[Annotated[int, ApiParamConfig(...)]] = None + + # Nesting the Annotated type carrying ApiParamConfig inside other types like Union + param: Union[str, Annotated[int, ApiParamConfig(...)]] +""" + + def parse_single_parameter( param_name: str, param_type: Type[T], parameter: inspect.Parameter ) -> FuncParam[T]: @@ -171,13 +207,24 @@ def parse_single_parameter( # otherwise causes undesired behaviors that the annotated metadata gets # lost. This is fixed in Python 3.11: # https://github.com/python/cpython/issues/90353 - if param_default is None: - origin = get_origin(param_type) + if ( + sys.version_info < (3, 11) and param_default is None + ): # nocoverage # We lose coverage of this with Python 3.11+ only type_args = get_args(param_type) - if origin is Union and type(None) in type_args and len(type_args) == 2: - inner_type = type_args[0] if type_args[1] is type(None) else type_args[1] - if is_annotated(inner_type): - param_type = inner_type + assert is_optional(param_type) + inner_type = type_args[0] if type_args[1] is type(None) else type_args[1] + if is_annotated(inner_type): + annotated_type, *annotations = get_args(inner_type) + has_api_param_config = any( + isinstance(annotation, ApiParamConfig) for annotation in annotations + ) + # This prohibits the use of `Optional[Annotated[T, ApiParamConfig(...)]] = None` + # and encourage `Annotated[Optional[T], ApiParamConfig(...)] = None` + # to avoid confusion when the parameter metadata is unintentionally nested. + assert not has_api_param_config or is_optional( + annotated_type + ), API_PARAM_CONFIG_USAGE_HINT.format(param_name=param_name, param_type=param_type) + param_type = inner_type param_config: Optional[ApiParamConfig] = None if is_annotated(param_type): @@ -185,12 +232,20 @@ def parse_single_parameter( # metadata attached to Annotated. Note that we do not transform # param_type to its underlying type because the Annotated metadata might # still be needed by other parties like Pydantic. - _, *annotations = get_args(param_type) + ignored_type, *annotations = get_args(param_type) for annotation in annotations: if not isinstance(annotation, ApiParamConfig): continue assert param_config is None, "ApiParamConfig can only be defined once per parameter" param_config = annotation + else: + # When no parameter configuration is found, assert that there is none + # nested somewhere in a Union type to avoid silently ignoring it. If it + # does present in the stringified parameter type, it is very likely a + # programming error. + assert ApiParamConfig.__name__ not in str(param_type), API_PARAM_CONFIG_USAGE_HINT.format( + param_name=param_name, param_type=param_type + ) # Set param_config to a default early to avoid additional None-checks. if param_config is None: param_config = ApiParamConfig() diff --git a/zerver/tests/test_typed_endpoint.py b/zerver/tests/test_typed_endpoint.py index 14e11d9037a8d3..c0ce06181a01c2 100644 --- a/zerver/tests/test_typed_endpoint.py +++ b/zerver/tests/test_typed_endpoint.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast import orjson from django.core.exceptions import ValidationError as DjangoValidationError from django.http import HttpRequest, HttpResponse -from pydantic import BaseModel, ConfigDict, Json, ValidationInfo, WrapValidator +from pydantic import BaseModel, ConfigDict, Json, StringConstraints, ValidationInfo, WrapValidator from pydantic.dataclasses import dataclass from pydantic.functional_validators import ModelWrapValidatorHandler from typing_extensions import Annotated @@ -19,6 +19,7 @@ PathOnly, RequiredStringConstraint, WebhookPayload, + is_optional, typed_endpoint, typed_endpoint_without_parameters, ) @@ -37,6 +38,16 @@ def call_endpoint( class TestEndpoint(ZulipTestCase): + def test_is_optional(self) -> None: + """This test is only needed because we don't + have coverage of is_optional in Python 3.11. + """ + type = cast(Type[Optional[str]], Optional[str]) + self.assertTrue(is_optional(type)) + + type = str + self.assertFalse(is_optional(str)) + def test_coerce(self) -> None: @typed_endpoint def view(request: HttpRequest, *, strict_int: int) -> None: @@ -415,6 +426,48 @@ def view3( ) self.assertFalse(result) + # Not nesting the Annotated type with the ApiParamConfig inside Optional is fine + @typed_endpoint + def no_nesting( + request: HttpRequest, + *, + bar: Annotated[ + Optional[str], + StringConstraints(strip_whitespace=True, max_length=3), + ApiParamConfig("test"), + ] = None, + ) -> None: + raise AssertionError + + with self.assertRaisesMessage(ApiParamValidationError, "test is too long"): + call_endpoint(no_nesting, HostRequestMock({"test": "long"})) + + # Nesting Annotated with ApiParamConfig inside Optional is not fine + def nesting_with_config( + request: HttpRequest, + *, + invalid_param: Optional[Annotated[str, ApiParamConfig("test")]] = None, + ) -> None: + raise AssertionError + + with self.assertRaisesRegex( + AssertionError, + "Detected incorrect usage of Annotated types for parameter invalid_param!", + ): + typed_endpoint(nesting_with_config) + + # Nesting Annotated inside Optional, when ApiParamConfig is not also nested is fine + @typed_endpoint + def nesting_without_config( + request: HttpRequest, + *, + bar: Optional[Annotated[str, StringConstraints(max_length=3)]] = None, + ) -> None: + raise AssertionError + + with self.assertRaisesMessage(ApiParamValidationError, "bar is too long"): + call_endpoint(nesting_without_config, HostRequestMock({"bar": "long"})) + def test_aliases(self) -> None: @typed_endpoint def foo(