Skip to content

Commit

Permalink
Merge pull request #27 from axnsan12/fix/callable-default
Browse files Browse the repository at this point in the history
Fix callable default handling
  • Loading branch information
axnsan12 committed Dec 23, 2017
2 parents 7683a28 + 43034dd commit 443d74b
Show file tree
Hide file tree
Showing 18 changed files with 300 additions and 122 deletions.
45 changes: 14 additions & 31 deletions src/drf_yasg/generators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import re
from collections import defaultdict, OrderedDict

import django.db.models
import uritemplate
from coreapi.compat import force_text
from rest_framework.schemas.generators import SchemaGenerator, EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.inspectors import get_pk_description

from . import openapi
from .inspectors import SwaggerAutoSchema
from .openapi import ReferenceResolver
from .utils import inspect_model_field, get_model_field

PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}')

Expand Down Expand Up @@ -82,9 +80,9 @@ def get_schema(self, request=None, public=False):
:return: the generated Swagger specification
:rtype: openapi.Swagger
"""
endpoints = self.get_endpoints(None if public else request)
endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
paths = self.get_paths(endpoints, components)
paths = self.get_paths(endpoints, components, public)

url = self._gen.url
if not url and request is not None:
Expand Down Expand Up @@ -114,9 +112,9 @@ def create_view(self, callback, method, request=None):
return view

def get_endpoints(self, request=None):
"""Iterate over all the registered endpoints in the API.
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
:param rest_framework.request.Request request: used for returning only endpoints available to the given request
:param rest_framework.request.Request request: request to bind to the endpoint views
:return: {path: (view_class, list[(http_method, view_instance)])
:rtype: dict
"""
Expand Down Expand Up @@ -151,11 +149,12 @@ def get_operation_keys(self, subpath, method, view):
"""
return self._gen.get_keys(subpath, method, view)

def get_paths(self, endpoints, components):
def get_paths(self, endpoints, components, public):
"""Generate the Swagger Paths for the API from the given endpoints.
:param dict endpoints: endpoints as returned by get_endpoints
:param ReferenceResolver components: resolver/container for Swagger References
:param bool public: if True, all endpoints are included regardless of access through `request`
:rtype: openapi.Paths
"""
if not endpoints:
Expand All @@ -169,7 +168,7 @@ def get_paths(self, endpoints, components):
path_parameters = self.get_path_parameters(path, view_cls)
operations = {}
for method, view in methods:
if not self._gen.has_view_permissions(path, method, view):
if not public and not self._gen.has_view_permissions(path, method, view):
continue

operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
Expand Down Expand Up @@ -209,36 +208,20 @@ def get_path_parameters(self, path, view_cls):
:rtype: list[openapi.Parameter]
"""
parameters = []
queryset = getattr(view_cls, 'queryset', None)
model = getattr(getattr(view_cls, 'queryset', None), 'model', None)

for variable in uritemplate.variables(path):
pattern = None
type = openapi.TYPE_STRING
description = None
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception: # pragma: no cover
model_field = None

if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)

if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable:
pattern = view_cls.lookup_value_regex
elif isinstance(model_field, django.db.models.AutoField):
type = openapi.TYPE_INTEGER
model, model_field = get_model_field(queryset, variable)
attrs = inspect_model_field(model, model_field)
if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable:
attrs['pattern'] = view_cls.lookup_value_regex

field = openapi.Parameter(
name=variable,
required=True,
in_=openapi.IN_PATH,
type=type,
pattern=pattern,
description=description,
**attrs
)
parameters.append(field)

Expand Down
2 changes: 2 additions & 0 deletions src/drf_yasg/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def __init__(self, type=None, format=None, enum=None, pattern=None, items=None,
:param .Items items: only valid if `type` is ``array``
"""
super(Items, self).__init__(**extra)
assert type is not None, "type is required!"
self.type = type
self.format = format
self.enum = enum
Expand Down Expand Up @@ -372,6 +373,7 @@ def __init__(self, description=None, required=None, type=None, properties=None,
# common error
raise AssertionError(
"the `requires` attribute of schema must be an array of required properties, not a boolean!")
assert type is not None, "type is required!"
self.description = description
self.required = required
self.type = type
Expand Down
9 changes: 8 additions & 1 deletion src/drf_yasg/templates/drf-yasg/swagger-ui.html
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,14 @@
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
layout: "StandaloneLayout"
layout: "StandaloneLayout",
filter: true,
requestInterceptor: function(request) {
console.log(request);
var headers = request.headers || {};
headers["X-CSRFToken"] = document.querySelector("[name=csrfmiddlewaretoken]").value;
return request;
}
};

var swaggerSettings = {};
Expand Down
155 changes: 150 additions & 5 deletions src/drf_yasg/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import logging
from collections import OrderedDict

from django.core.validators import RegexValidator
from django.db import models
from django.utils.encoding import force_text
from rest_framework import serializers
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin
from rest_framework.schemas.inspectors import get_pk_description
from rest_framework.settings import api_settings
from rest_framework.utils import json, encoders

from . import openapi
from .errors import SwaggerGenerationError

logger = logging.getLogger(__name__)

#: used to forcibly remove the body of a request via :func:`.swagger_auto_schema`
no_body = object()

Expand Down Expand Up @@ -157,6 +163,87 @@ def decorator(view_method):
return decorator


def get_model_field(queryset, field_name):
"""Try to get information about a model and model field from a queryset.
:param queryset: the queryset
:param field_name: the target field name
:returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
:rtype: tuple
"""
model = getattr(queryset, 'model', None)
try:
model_field = model._meta.get_field(field_name)
except Exception: # pragma: no cover
model_field = None

return model, model_field


model_field_to_swagger_type = [
(models.AutoField, (openapi.TYPE_INTEGER, None)),
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
(models.BooleanField, (openapi.TYPE_BOOLEAN, None)),
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
(models.DecimalField, (openapi.TYPE_NUMBER, None)),
(models.DurationField, (openapi.TYPE_INTEGER, None)),
(models.FloatField, (openapi.TYPE_NUMBER, None)),
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
(models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)),
(models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)),
(models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
(models.TextField, (openapi.TYPE_STRING, None)),
(models.TimeField, (openapi.TYPE_STRING, None)),
(models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
(models.CharField, (openapi.TYPE_STRING, None)),
]


def inspect_model_field(model, model_field):
"""Extract information from a django model field instance.
:param model: the django model
:param model_field: a field on the model
:return: description, type, format and pattern extracted from the model field
:rtype: OrderedDict
"""
if model is not None and model_field is not None:
for model_field_class, tf in model_field_to_swagger_type:
if isinstance(model_field, model_field_class):
swagger_type, format = tf
break
else:
swagger_type, format = None, None

if format is None or format == openapi.FORMAT_SLUG:
pattern = find_regex(model_field)
else:
pattern = None

if model_field.help_text:
description = force_text(model_field.help_text)
elif model_field.primary_key:
description = get_pk_description(model, model_field)
else:
description = None
else:
description = None
swagger_type = None
format = None
pattern = None

result = OrderedDict([
('description', description),
('type', swagger_type or openapi.TYPE_STRING),
('format', format),
('pattern', pattern)
])
# TODO: filter none
return result


def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object.
Expand All @@ -176,17 +263,50 @@ def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **
description = force_text(field.help_text) if field.help_text else None
description = description if swagger_object_type != openapi.Items else None # Items has no description either

def SwaggerType(**instance_kwargs):
def SwaggerType(existing_object=None, **instance_kwargs):
if swagger_object_type == openapi.Parameter and 'required' not in instance_kwargs:
instance_kwargs['required'] = field.required
if swagger_object_type != openapi.Items and 'default' not in instance_kwargs:
default = getattr(field, 'default', serializers.empty)
if default is not serializers.empty:
instance_kwargs['default'] = default
if callable(default):
try:
if hasattr(default, 'set_context'):
default.set_context(field)
default = default()
except Exception as e:
logger.warning("default for %s is callable but it raised an exception when "
"called; 'default' field will not be added to schema", field, exc_info=True)
default = None

if default is not None:
try:
default = field.to_representation(default)
# JSON roundtrip ensures that the value is valid JSON;
# for example, sets get transformed into lists
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
except Exception:
logger.warning("'default' on schema for %s will not be set because "
"to_representation raised an exception", field, exc_info=True)
default = None

if default is not None:
instance_kwargs['default'] = default

if swagger_object_type == openapi.Schema and 'read_only' not in instance_kwargs:
if field.read_only:
instance_kwargs['read_only'] = True
instance_kwargs.update(kwargs)
instance_kwargs.pop('title', None)
instance_kwargs.pop('description', None)

if existing_object is not None:
existing_object.title = title
existing_object.description = description
for attr, val in instance_kwargs.items():
setattr(existing_object, attr, val)
return existing_object

return swagger_object_type(title=title, description=description, **instance_kwargs)

# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
Expand Down Expand Up @@ -238,8 +358,29 @@ def make_schema_definition():
items=child_schema,
unique_items=True, # is this OK?
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
if field.pk_field:
result = serializer_field_to_swagger(field.pk_field, swagger_object_type, definitions, **kwargs)
return SwaggerType(existing_object=result)

attrs = {'type': openapi.TYPE_STRING}
try:
model = field.queryset.model
pk_field = model._meta.pk
except Exception:
logger.warning("an exception was raised when attempting to extract the primary key related to %s; "
"falling back to plain string" % field, exc_info=True)
else:
attrs.update(inspect_model_field(model, pk_field))

return SwaggerType(**attrs)
elif isinstance(field, serializers.HyperlinkedRelatedField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
elif isinstance(field, serializers.SlugRelatedField):
model, model_field = get_model_field(field.queryset, field.slug_field)
attrs = inspect_model_field(model, model_field)
return SwaggerType(**attrs)
elif isinstance(field, serializers.RelatedField):
# TODO: infer type for PrimaryKeyRelatedField?
return SwaggerType(type=openapi.TYPE_STRING)
# ------ CHOICES
elif isinstance(field, serializers.MultipleChoiceField):
Expand All @@ -253,7 +394,7 @@ def make_schema_definition():
elif isinstance(field, serializers.ChoiceField):
return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
# ------ BOOL
elif isinstance(field, serializers.BooleanField):
elif isinstance(field, (serializers.BooleanField, serializers.NullBooleanField)):
return SwaggerType(type=openapi.TYPE_BOOLEAN)
# ------ NUMERIC
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
Expand All @@ -262,6 +403,8 @@ def make_schema_definition():
elif isinstance(field, serializers.IntegerField):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_INTEGER)
elif isinstance(field, serializers.DurationField):
return SwaggerType(type=openapi.TYPE_INTEGER)
# ------ STRING
elif isinstance(field, serializers.EmailField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
Expand Down Expand Up @@ -308,8 +451,10 @@ def make_schema_definition():
type=openapi.TYPE_OBJECT,
additional_properties=child_schema
)
elif isinstance(field, serializers.ModelField):
return SwaggerType(type=openapi.TYPE_STRING)

# TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField? JSONField
# TODO unhandled fields: TimeField HiddenField JSONField

# everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING)
Expand Down
12 changes: 8 additions & 4 deletions testproj/articles/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
# Generated by Django 2.0 on 2017-12-05 04:05
# Generated by Django 2.0 on 2017-12-23 09:07

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

initial = True

dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]

operations = [
migrations.CreateModel(
name='Article',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(help_text='Main article headline', max_length=255, unique=True)),
('body', models.TextField(help_text='Article content', max_length=5000)),
('slug', models.SlugField(blank=True, help_text='Unique URL slug identifying the article', unique=True)),
('title', models.CharField(help_text='title model help_text', max_length=255, unique=True)),
('body', models.TextField(help_text='article model help_text', max_length=5000)),
('slug', models.SlugField(blank=True, help_text='slug model help_text', unique=True)),
('date_created', models.DateTimeField(auto_now_add=True)),
('date_modified', models.DateTimeField(auto_now=True)),
('cover', models.ImageField(blank=True, upload_to='article/original/')),
('author', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='articles', to=settings.AUTH_USER_MODEL)),
],
),
]

0 comments on commit 443d74b

Please sign in to comment.