Skip to content

Commit

Permalink
Merge 9bc1bdd into 72fb923
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Sep 18, 2019
2 parents 72fb923 + 9bc1bdd commit 8c65a30
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 3 deletions.
10 changes: 10 additions & 0 deletions flask_rest_api/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Custom fields"""

import marshmallow as ma


class Upload(ma.fields.Field):
"""File upload field"""
def __init__(self, format='binary', **kwargs):
self.format = format
super().__init__(**kwargs)
6 changes: 4 additions & 2 deletions flask_rest_api/spec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""API specification using Open API"""

"""API specification using OpenAPI"""
import json

import flask
Expand All @@ -9,6 +8,7 @@

from flask_rest_api.exceptions import OpenAPIVersionNotSpecified
from .plugins import FlaskPlugin
from .field_converters import uploadfield2properties


def _add_leading_slash(string):
Expand Down Expand Up @@ -197,6 +197,8 @@ def _init_spec(
# Register custom converters in spec
for args in self._converters:
self._register_converter(*args)
# Register Upload field properties function
self.ma_plugin.converter.add_attribute_function(uploadfield2properties)

def register_converter(self, converter, conv_type, conv_format=None):
"""Register custom path parameter converter
Expand Down
14 changes: 14 additions & 0 deletions flask_rest_api/spec/field_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Custom field properties functions"""
from flask_rest_api.fields import Upload


def uploadfield2properties(self, field, **kwargs):
"""Document Upload field"""
ret = {}
if isinstance(field, Upload):
if self.openapi_version.major < 3:
ret['type'] = 'file'
else:
ret['type'] = 'string'
ret['format'] = field.format
return ret
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@
'flask>=1.1.0',
'marshmallow>=2.15.2',
'webargs>=1.5.2',
'apispec>=2.0.0',
'apispec>=3.0.0',
],
)
58 changes: 58 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Test Blueprint extra features"""

import io
import json
import http

import pytest

import marshmallow as ma
Expand All @@ -10,6 +12,7 @@
from flask.views import MethodView

from flask_rest_api import Api, Blueprint, Page
from flask_rest_api.fields import Upload

from .utils import build_ref

Expand Down Expand Up @@ -281,6 +284,61 @@ def func(document, query_args):
'query_args': {'arg1': 'test'},
}

@pytest.mark.parametrize('openapi_version', ('2.0', '3.0.2'))
def test_blueprint_arguments_files_multipart(
self, app, schemas, openapi_version):
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

class MultipartSchema(ma.Schema):
file_1 = Upload()
file_2 = Upload()

@blp.route('/', methods=['POST'])
@blp.arguments(MultipartSchema, location='files')
def func(files):
return jsonify(
files['file_1'].read().decode(),
files['file_2'].read().decode(),
)

api.register_blueprint(blp)
spec = api.spec.to_dict()

files = {
'file_1': (io.BytesIO('Test 1'.encode()), 'file_1.txt'),
'file_2': (io.BytesIO('Test 2'.encode()), 'file_2.txt'),
}

response = client.post('/test/', data=files)
assert response.json == ['Test 1', 'Test 2']

if openapi_version == '2.0':
for param in spec['paths']['/test/']['post']['parameters']:
assert param['in'] == 'formData'
assert param['type'] == 'file'
else:
assert (
spec['paths']['/test/']['post']['requestBody']['content'] ==
{
'multipart/form-data': {
'schema': {'$ref': '#/components/schemas/Multipart'}
}
}
)
assert (
spec['components']['schemas']['Multipart'] ==
{
'type': 'object',
'properties': {
'file_1': {'type': 'string', 'format': 'binary'},
'file_2': {'type': 'string', 'format': 'binary'},
}
}
)

# This is only relevant to OAS3.
@pytest.mark.parametrize('openapi_version', ('3.0.2', ))
def test_blueprint_arguments_examples(self, app, schemas, openapi_version):
Expand Down

0 comments on commit 8c65a30

Please sign in to comment.