Skip to content

Commit

Permalink
Add support for custom and collection type hint classes (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
mofr authored and axnsan12 committed Jan 29, 2019
1 parent 58e6dae commit 3806d6e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .idea/drf-yasg.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 39 additions & 20 deletions src/drf_yasg/inspectors/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,39 @@ def decimal_return_type():
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER


raw_type_info = [
def get_origin_type(hint_class):
return getattr(hint_class, '__origin__', None) or hint_class


def is_origin_type_subclasses(hint_class, check_class):
origin_type = get_origin_type(hint_class)
return inspect.isclass(origin_type) and issubclass(origin_type, check_class)


hinting_type_info = [
(bool, (openapi.TYPE_BOOLEAN, None)),
(int, (openapi.TYPE_INTEGER, None)),
(str, (openapi.TYPE_STRING, None)),
(float, (openapi.TYPE_NUMBER, None)),
(dict, (openapi.TYPE_OBJECT, None)),
(Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
(uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
(datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
(datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
# TODO - support typing.List etc
]

hinting_type_info = raw_type_info
if typing:
def inspect_collection_hint_class(hint_class):
args = hint_class.__args__
child_class = args[0] if args else str
child_type_info = get_basic_type_info_from_hint(child_class)
if not child_type_info:
child_type_info = {'type': openapi.TYPE_STRING}
return OrderedDict([
('type', openapi.TYPE_ARRAY),
('items', openapi.Items(**child_type_info)),
])
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))


def get_basic_type_info_from_hint(hint_class):
Expand All @@ -478,27 +499,25 @@ def get_basic_type_info_from_hint(hint_class):
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
:rtype: OrderedDict
"""
if typing and get_origin_type(hint_class) == typing.Union:
if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None):
child_type = hint_class.__args__[0]
return get_basic_type_info_from_hint(child_type)
return None

for check_class, type_format in hinting_type_info:
if issubclass(hint_class, check_class):
swagger_type, format = type_format
for check_class, info in hinting_type_info:
if is_origin_type_subclasses(hint_class, check_class):
if callable(info):
return info(hint_class)
swagger_type, format = info
if callable(swagger_type):
swagger_type = swagger_type()
# if callable(format):
# format = format(klass)
break
else: # pragma: no cover
return None
return OrderedDict([
('type', swagger_type),
('format', format),
])

pattern = None

result = OrderedDict([
('type', swagger_type),
('format', format),
('pattern', pattern)
])

return result
return None


class SerializerMethodFieldInspector(FieldInspector):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_get_basic_type_info_from_hint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import uuid

import pytest

from drf_yasg import openapi
from drf_yasg.inspectors.field import get_basic_type_info_from_hint

try:
import typing
from typing import Dict, List, Union, Optional, Set
except ImportError:
typing = None


if typing:
@pytest.mark.parametrize('hint_class, expected_swagger_type_info', [
(int, {'type': openapi.TYPE_INTEGER, 'format': None}),
(str, {'type': openapi.TYPE_STRING, 'format': None}),
(bool, {'type': openapi.TYPE_BOOLEAN, 'format': None}),
(dict, {'type': openapi.TYPE_OBJECT, 'format': None}),
(Dict[int, int], {'type': openapi.TYPE_OBJECT, 'format': None}),
(uuid.UUID, {'type': openapi.TYPE_STRING, 'format': openapi.FORMAT_UUID}),
(List[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
(List[str], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
(List[bool], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_BOOLEAN)}),
(Set[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
(Optional[bool], {'type': openapi.TYPE_BOOLEAN, 'format': None}),
(Optional[List[int]], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
(Union[List[int], type(None)], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
# Following cases are not 100% correct, but it should work somehow and not crash.
(Union[int, float], None),
(List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
])
def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info):
type_info = get_basic_type_info_from_hint(hint_class)
assert type_info == expected_swagger_type_info

0 comments on commit 3806d6e

Please sign in to comment.