Skip to content

Commit

Permalink
Allow multiple content types in arguments decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Sep 2, 2019
1 parent b6bb2b8 commit 9345783
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 46 deletions.
11 changes: 6 additions & 5 deletions flask_rest_api/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ class ArgumentsMixin:
ARGUMENTS_PARSER = FlaskParser()

def arguments(
self, schema, *, location='json', content_type=None, required=True,
example=None, examples=None, **kwargs
self, schema, *, location='json', content_types=None,
required=True, example=None, examples=None, **kwargs
):
"""Decorator specifying the schema used to deserialize parameters
:param type|Schema schema: Marshmallow ``Schema`` class or instance
used to deserialize and validate the argument.
:param str location: Location of the argument.
:param str content_type: Content type of the argument.
:param str content_types: Allowed content types for the argument.
Should only be used in conjunction with ``json``, ``form`` or
``files`` location.
The default value depends on the location and is set in
``Blueprint.DEFAULT_LOCATION_CONTENT_TYPE_MAPPING``.
This is only used for documentation purpose.
:param bool required: Whether argument is required (default: True).
This only affects `body` arguments as, in this case, the docs
expose the whole schema as a `required` parameter.
Expand All @@ -49,8 +50,8 @@ def arguments(
'required': required,
'schema': schema,
}
if content_type is not None:
parameters['content_type'] = content_type
if content_types is not None:
parameters['content_types'] = content_types
if example is not None:
parameters['example'] = example
if examples is not None:
Expand Down
54 changes: 29 additions & 25 deletions flask_rest_api/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ class Blueprint(
# Order in which the methods are presented in the spec
HTTP_METHODS = ['OPTIONS', 'HEAD', 'GET', 'POST', 'PUT', 'PATCH', 'DELETE']

DEFAULT_LOCATION_CONTENT_TYPE_MAPPING = {
"json": "application/json",
"form": "application/x-www-form-urlencoded",
"files": "multipart/form-data",
DEFAULT_LOCATION_CONTENT_TYPES_MAPPING = {
"json": ["application/json"],
"form": ["application/x-www-form-urlencoded"],
"files": ["multipart/form-data"],
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -214,15 +214,18 @@ def _prepare_doc(self, operation, openapi_version):
if 'parameters' in operation:
for param in operation['parameters']:
if param['in'] in (
self.DEFAULT_LOCATION_CONTENT_TYPE_MAPPING
self.DEFAULT_LOCATION_CONTENT_TYPES_MAPPING
):
content_type = (
param.pop('content_type', None) or
self.DEFAULT_LOCATION_CONTENT_TYPE_MAPPING[
content_types = (
param.pop('content_types', None) or
self.DEFAULT_LOCATION_CONTENT_TYPES_MAPPING[
param['in']]
)
if content_type != DEFAULT_REQUEST_BODY_CONTENT_TYPE:
operation['consumes'] = [content_type, ]
if (
set(content_types) !=
{DEFAULT_REQUEST_BODY_CONTENT_TYPE}
):
operation['consumes'] = content_types
# body and formData are mutually exclusive
break
# OAS 3
Expand All @@ -240,25 +243,26 @@ def _prepare_doc(self, operation, openapi_version):
if 'parameters' in operation:
for param in operation['parameters']:
if param['in'] in (
self.DEFAULT_LOCATION_CONTENT_TYPE_MAPPING
self.DEFAULT_LOCATION_CONTENT_TYPES_MAPPING
):
content_type = (
param.pop('content_type', None) or
self.DEFAULT_LOCATION_CONTENT_TYPE_MAPPING[
param['in']]
)
request_body = {
x: param[x] for x in ('description', 'required')
x: param[x]
for x in ('description', 'required')
if x in param
}
fields = {
x: param.pop(x)
for x in ('schema', 'example', 'examples')
if x in param
}
for field in ('schema', 'example', 'examples'):
if field in param:
(
request_body
.setdefault('content', {})
.setdefault(content_type, {})
[field]
) = param.pop(field)
content_types = (
param.pop('content_types', None) or
self.DEFAULT_LOCATION_CONTENT_TYPES_MAPPING[
param['in']]
)
for content_type in content_types:
request_body.setdefault('content', {}).setdefault(
content_type, fields)
operation['requestBody'] = request_body
# There can be only one requestBody
operation['parameters'].remove(param)
Expand Down
41 changes: 25 additions & 16 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
)

REQUEST_BODY_CONTENT_TYPE = {
"json": "application/json",
"form": "application/x-www-form-urlencoded",
"files": "multipart/form-data",
"json": ["application/json"],
"form": ["application/x-www-form-urlencoded"],
"files": ["multipart/form-data"],
}


Expand Down Expand Up @@ -65,43 +65,52 @@ def func():
):
assert 'parameters' not in get
assert 'requestBody' in get
assert len(get['requestBody']['content']) == 1
assert REQUEST_BODY_CONTENT_TYPE[location] in get[
'requestBody']['content']
assert (
set(get['requestBody']['content'].keys()) ==
set(REQUEST_BODY_CONTENT_TYPE[location])
)
else:
loc = get['parameters'][0]['in']
assert loc == openapi_location
assert 'requestBody' not in get
if location in REQUEST_BODY_CONTENT_TYPE and location != 'json':
assert get['consumes'] == [REQUEST_BODY_CONTENT_TYPE[location]]
assert get['consumes'] == REQUEST_BODY_CONTENT_TYPE[location]
else:
assert 'consumes' not in get

@pytest.mark.parametrize('openapi_version', ('2.0', '3.0.2'))
@pytest.mark.parametrize('location', REQUEST_BODY_CONTENT_TYPE.keys())
@pytest.mark.parametrize('content_type', ('application/x-custom', None))
def test_blueprint_arguments_content_type(
self, app, schemas, location, content_type, openapi_version):
@pytest.mark.parametrize(
'content_types',
(['application/x-custom-1', 'application/x-custom-2'], None)
)
def test_blueprint_arguments_content_types(
self, app, schemas, location, content_types, openapi_version):
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
content_type = content_type or REQUEST_BODY_CONTENT_TYPE[location]
content_types = content_types or REQUEST_BODY_CONTENT_TYPE[location]

@blp.route('/')
@blp.arguments(
schemas.DocSchema, location=location, content_type=content_type)
schemas.DocSchema, location=location, content_types=content_types)
def func():
"""Dummy view func"""

api.register_blueprint(blp)
spec = api.spec.to_dict()
get = spec['paths']['/test/']['get']
if openapi_version == '3.0.2':
assert len(get['requestBody']['content']) == 1
assert content_type in get['requestBody']['content']
assert (
set(get['requestBody']['content'].keys()) ==
set(content_types)
)
for content in get['requestBody']['content'].values():
assert content == {
'schema': {'$ref': '#/components/schemas/Doc'}}
else:
if content_type != 'application/json':
assert get['consumes'] == [content_type]
if content_types != ['application/json']:
assert get['consumes'] == content_types
else:
assert 'consumes' not in get

Expand Down

0 comments on commit 9345783

Please sign in to comment.