From 3e1bd6b470eb59b68d0a6c0f88f4b38c9bd66660 Mon Sep 17 00:00:00 2001 From: Erik Tuerke Date: Wed, 7 Jul 2021 20:07:48 +0200 Subject: [PATCH] do not use eval to retrieve path param types --- restit/_path_parameter.py | 9 +++++-- restit/open_api/open_api_documentation.py | 30 ++++++++++++++++++++--- restit/resource.py | 22 +++++++++-------- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/restit/_path_parameter.py b/restit/_path_parameter.py index d752e9b..55699e2 100644 --- a/restit/_path_parameter.py +++ b/restit/_path_parameter.py @@ -1,3 +1,5 @@ +from typing import Union + from marshmallow import fields from marshmallow.fields import Field @@ -16,12 +18,15 @@ class PathParameter: _PYTHON_TYPE_FIELD_MAPPING = { int: fields.Integer(), - str: fields.String() + str: fields.String(), + float: fields.Float(), + bool: fields.Boolean() } - def __init__(self, name: str, description: str, field_type: Field): + def __init__(self, name: str, description: str, field_type: Union[Field, type]): self.name = name self.description = description + # noinspection PyTypeChecker self.field_type = \ field_type if isinstance(field_type, Field) else PathParameter._PYTHON_TYPE_FIELD_MAPPING[field_type] diff --git a/restit/open_api/open_api_documentation.py b/restit/open_api/open_api_documentation.py index 2a9c5a0..c492d0e 100644 --- a/restit/open_api/open_api_documentation.py +++ b/restit/open_api/open_api_documentation.py @@ -88,8 +88,7 @@ def generate_spec(self) -> dict: def _generate_paths(self, root_spec: dict): paths = root_spec["paths"] for resource in self._resources: - if resource.__request_mapping__: - self._add_resource(paths, resource, root_spec) + if resource.__request_mapping__: self._add_resource(paths, resource, root_spec) def _add_resource(self, paths: dict, resource: Resource, root_spec: dict): path, inferred_path_parameters = \ @@ -186,7 +185,11 @@ def _infer_path_params_and_open_api_path_syntax(path: str) -> Tuple[str, List[Pa def _handle_path_parameter(match: Match) -> str: path_parameter_list.append( - PathParameter(match.group(1), "", eval(match.group(2)) if match.group(2) else str) + PathParameter( + name=match.group(1), + description="", + field_type=OpenApiDocumentation._get_path_param_type(match.group(2)) + ) ) return "{%s}" % match.group(1) @@ -225,3 +228,24 @@ def _generate_root_spec(self) -> dict: "schemas": {} } } + + @staticmethod + @lru_cache() + def _get_path_param_type(type_name: str) -> type: + path_param_types = { + "str": str, + "string": str, + "int": int, + "integer": int, + "float": float, + "bool": bool, + "boolean": bool, + None: str + } + try: + return path_param_types[type_name] + except KeyError: + raise OpenApiDocumentation.UnknownPathParamType(type_name) + + class UnknownPathParamType(Exception): + pass diff --git a/restit/resource.py b/restit/resource.py index 6aa13f1..97e9fd5 100644 --- a/restit/resource.py +++ b/restit/resource.py @@ -2,7 +2,7 @@ import inspect import logging import re -from typing import Tuple, AnyStr, Dict, Union, List, Callable +from typing import Tuple, AnyStr, Dict, Union, List, Callable, Optional from marshmallow import ValidationError @@ -119,16 +119,16 @@ def _execute_request_with_exception_mapping(method_object: Callable, request: Re type(exception), target_exception_tuple_or_class[0], target_exception_tuple_or_class[1] ) raise target_exception_tuple_or_class[0](target_exception_tuple_or_class[1]) - else: - LOGGER.debug( - "Mapping exception class %s to %s", type(exception), target_exception_tuple_or_class - ) - raise target_exception_tuple_or_class(str(exception)) + + LOGGER.debug( + "Mapping exception class %s to %s", type(exception), target_exception_tuple_or_class + ) + raise target_exception_tuple_or_class(str(exception)) raise exception @staticmethod - def _find_response_schema_by_status(status: int, method_object: object) -> Union[None, ResponseStatusParameter]: + def _find_response_schema_by_status(status: int, method_object: object) -> Optional[ResponseStatusParameter]: response_status_parameters = get_response_status_parameters_for_method(method_object) if response_status_parameters: for response_status_parameter in response_status_parameters: # type: ResponseStatusParameter @@ -137,6 +137,8 @@ def _find_response_schema_by_status(status: int, method_object: object) -> Union LOGGER.warning("Response status code %d is not expected for %s", status, method_object) + return None + @staticmethod def _validate_request_body(method_object: object, request: Request) -> Request: request_body_properties: RequestBodyProperties = \ @@ -165,10 +167,10 @@ def _collect_and_convert_path_parameters(self, path_params: dict) -> dict: for path_parameter in getattr(self, "__path_parameters__", []): # type: PathParameter try: path_parameter_value = path_params[path_parameter.name] - except KeyError: + except KeyError as error: raise Resource.PathParameterNotFoundException( f"Unable to find {path_parameter} in incoming path parameters {path_params}" - ) + ) from error try: path_params[path_parameter.name] = \ SchemaOrFieldDeserializer.deserialize(path_parameter_value, path_parameter.field_type) @@ -176,7 +178,7 @@ def _collect_and_convert_path_parameters(self, path_params: dict) -> dict: raise BadRequest( f"Path parameter value '{path_parameter_value}' is not matching '{path_parameter}' " f"({str(error)})" - ) + ) from error return path_params