Skip to content

Commit

Permalink
allow setting response decorator also for class
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuerke Erik committed Mar 24, 2020
1 parent 3d274c5 commit af9498f
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 12 deletions.
9 changes: 9 additions & 0 deletions restit/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import sys
from html import escape
from typing import List

from restit.internal.response_status_parameter import ResponseStatusParameter

_DEFAULT_ENCODING = sys.getdefaultencoding()

Expand Down Expand Up @@ -45,3 +48,9 @@ def guess_text_content_subtype_string(content: str) -> str:
return "text/html"

return "text/plain"


def get_response_status_parameters_for_method(method_object: object) -> List[ResponseStatusParameter]:
response_status_parameters = getattr(method_object, "__response_status_parameters__", [])
response_status_parameters.extend(getattr(method_object.__self__, "__response_status_parameters__", []))
return response_status_parameters
1 change: 1 addition & 0 deletions restit/decorator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .exception_mapping_decorator import exception_mapping
from .path_decorator import path
from .path_parameter_decorator import path_parameter
from .query_parameter_decorator import query_parameter
Expand Down
33 changes: 33 additions & 0 deletions restit/decorator/exception_mapping_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
from typing import Type, Dict, Tuple, Union

from restit.exception import HttpError

LOGGER = logging.getLogger(__name__)


def exception_mapping(mapping: Dict[Type[Exception], Union[Tuple[Type[HttpError], str], Type[HttpError]]]):
def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as exception:
for source_exception_class, target_exception_tuple_or_class in mapping.items():
if isinstance(exception, source_exception_class):
if isinstance(target_exception_tuple_or_class, tuple):
LOGGER.debug(
"Mapping exception class %s to %s with description: %s",
type(exception), target_exception_tuple_or_class[0], target_exception_tuple_or_class[1]
)
raise target_exception_tuple_or_class[0](target_exception_tuple_or_class[1])
else:
LOGGER.debug(
"Mapping exception class %s to %s", type(exception), target_exception_tuple_or_class
)
raise target_exception_tuple_or_class(str(exception))

raise exception

return wrapper

return decorator
10 changes: 5 additions & 5 deletions restit/decorator/response_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@

# noinspection PyShadowingBuiltins
def response(status: Union[int, HTTPStatus], content_types: Dict[str, Union[Schema, Field]], description: str):
def decorator(func):
def decorator(func_or_class):
http_status_code = status if isinstance(status, int) or status is None else status.value
response_status_parameter = ResponseStatusParameter(http_status_code, description, content_types)

registered_response_status_parameters: List[ResponseStatusParameter] = \
getattr(func, "__response_status_parameters__", [])
getattr(func_or_class, "__response_status_parameters__", [])
LOGGER.debug(
"Registering response status parameter %s for %s", response_status_parameter, func.__name__
"Registering response status parameter %s for %s", response_status_parameter, func_or_class.__name__
)
registered_response_status_parameters.append(response_status_parameter)
setattr(func, "__response_status_parameters__", registered_response_status_parameters)
return func
setattr(func_or_class, "__response_status_parameters__", registered_response_status_parameters)
return func_or_class

return decorator
3 changes: 2 additions & 1 deletion restit/open_api/open_api_documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import lru_cache
from typing import List, Tuple, Union, Match

from restit.common import get_response_status_parameters_for_method
from restit.internal.request_body_properties import RequestBodyProperties
from restit.internal.response_status_parameter import ResponseStatusParameter
from restit.open_api.info_object import InfoObject
Expand Down Expand Up @@ -58,7 +59,7 @@ def _add_resource(self, paths: dict, resource: Resource, root_spec: dict):

@staticmethod
def _add_responses(method_spec: dict, method_object: object, root_spec: dict):
response_status_parameters = getattr(method_object, "__response_status_parameters__", None)
response_status_parameters = get_response_status_parameters_for_method(method_object)
if response_status_parameters:
for response_status_parameter in response_status_parameters: # type: ResponseStatusParameter
method_spec["responses"][response_status_parameter.status or "default"] = {
Expand Down
12 changes: 6 additions & 6 deletions restit/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from restit._path_parameter import PathParameter
from restit._response import Response
from restit.common import get_response_status_parameters_for_method
from restit.exception import MethodNotAllowed
from restit.exception.client_errors_4xx import BadRequest
from restit.internal.query_parameter import QueryParameter
Expand Down Expand Up @@ -106,13 +107,12 @@ def handle_request(self, request_method: str, request: Request, path_params: Dic

@staticmethod
def _find_response_schema_by_status(status: int, method_object: object) -> Union[None, ResponseStatusParameter]:
response_status_parameters = getattr(method_object, "__response_status_parameters__", None)
if response_status_parameters:
for response_status_parameter in response_status_parameters: # type: ResponseStatusParameter
if response_status_parameter.status == status:
return response_status_parameter
response_status_parameters = get_response_status_parameters_for_method(method_object)
for response_status_parameter in response_status_parameters: # type: ResponseStatusParameter
if response_status_parameter.status == status:
return response_status_parameter

LOGGER.warning("Response status code %d is not expected for %s", status, method_object)
LOGGER.warning("Response status code %d is not expected for %s", status, method_object)

@staticmethod
def _validate_request_body(method_object: object, request: Request) -> Request:
Expand Down

0 comments on commit af9498f

Please sign in to comment.