Skip to content

Commit

Permalink
Merge eb045e9 into d7d9793
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Feb 24, 2019
2 parents d7d9793 + eb045e9 commit 23ee9ab
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 17 deletions.
9 changes: 6 additions & 3 deletions flask_rest_api/etag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ def wrapper(*args, **kwargs):
# Pass data to use as ETag data if set_etag was not called
# If etag_schema is provided, pass raw result rather than
# dump, as the dump needs to be done using etag_schema
etag_data = get_appcontext()[
'result_dump' if etag_schema is None else 'result_raw'
]
# If 'result_dump'/'result_raw' is not in appcontext,
# the Etag must have been set manually. Just pass None.
etag_data = get_appcontext().get(
'result_dump' if etag_schema is None else 'result_raw',
None
)
self._set_etag_in_response(resp, etag_data, etag_schema)

return resp
Expand Down
7 changes: 4 additions & 3 deletions flask_rest_api/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import marshmallow as ma
from webargs.flaskparser import FlaskParser

from .utils import get_appcontext
from .utils import get_appcontext, unpack_tuple_response
from .compat import MARSHMALLOW_VERSION_MAJOR


Expand Down Expand Up @@ -166,7 +166,8 @@ def wrapper(*args, **kwargs):
kwargs['pagination_parameters'] = page_params

# Execute decorated function
result = func(*args, **kwargs)
result, status, headers = unpack_tuple_response(
func(*args, **kwargs))

# Post pagination: use pager class to paginate the result
if pager is not None:
Expand All @@ -185,7 +186,7 @@ def wrapper(*args, **kwargs):
get_appcontext()['headers'][
self.PAGINATION_HEADER_FIELD_NAME] = page_header

return result
return result, status, headers

return wrapper

Expand Down
31 changes: 23 additions & 8 deletions flask_rest_api/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from functools import wraps

from flask import jsonify
from flask import jsonify, Response

from .utils import deepupdate, get_appcontext
from .utils import (
deepupdate, get_appcontext,
unpack_tuple_response, set_status_and_headers_in_response
)
from .compat import MARSHMALLOW_VERSION_MAJOR


Expand All @@ -16,7 +19,8 @@ def response(self, schema=None, *, code=200, description=''):
:param schema: :class:`Schema <marshmallow.Schema>` class or instance.
If not None, will be used to serialize response data.
:param int code: HTTP status code (default: 200).
:param int code: HTTP status code (default: 200). Used if none is
returned from the view function.
:param str descripton: Description of the response.
See :doc:`Response <response>`.
Expand All @@ -36,8 +40,17 @@ def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):

appcontext = get_appcontext()

# Execute decorated function
result_raw = func(*args, **kwargs)
result_raw, status, headers = unpack_tuple_response(
func(*args, **kwargs))

# If return value is a flask Response, return it
if isinstance(result_raw, Response):
set_status_and_headers_in_response(
result_raw, status, headers)
return result_raw

# Dump result with schema if specified
if schema is None:
Expand All @@ -48,13 +61,15 @@ def wrapper(*args, **kwargs):
result_dump = result_dump[0]

# Store result in appcontext (may be used for ETag computation)
get_appcontext()['result_raw'] = result_raw
get_appcontext()['result_dump'] = result_dump
appcontext['result_raw'] = result_raw
appcontext['result_dump'] = result_dump

# Build response
resp = jsonify(self._prepare_response_content(result_dump))
resp.headers.extend(get_appcontext()['headers'])
resp.status_code = code
resp.headers.extend(appcontext['headers'])
set_status_and_headers_in_response(resp, status, headers)
if status is None:
resp.status_code = code

return resp

Expand Down
42 changes: 42 additions & 0 deletions flask_rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from collections.abc import Mapping

from werkzeug.datastructures import Headers
from flask import _app_ctx_stack
from apispec.utils import trim_docstring, dedent

Expand Down Expand Up @@ -62,3 +63,44 @@ def load_info_from_docstring(docstring):
if description_lines:
info['description'] = dedent('\n'.join(description_lines))
return info


# Copied from flask
def unpack_tuple_response(rv):
"""Unpack a flask Response tuple"""

status = headers = None

# unpack tuple returns
if isinstance(rv, tuple):
len_rv = len(rv)

# a 3-tuple is unpacked directly
if len_rv == 3:
rv, status, headers = rv
# decide if a 2-tuple has status or headers
elif len_rv == 2:
if isinstance(rv[1], (Headers, dict, tuple, list)):
rv, headers = rv
else:
rv, status = rv
# other sized tuples are not allowed
else:
raise TypeError(
'The view function did not return a valid response tuple.'
' The tuple must have the form (body, status, headers),'
' (body, status), or (body, headers).'
)

return rv, status, headers


def set_status_and_headers_in_response(response, status, headers):
"""Set status and headers in flask Reponse object"""
if headers:
response.headers.extend(headers)
if status is not None:
if isinstance(status, int):
response.status_code = status
else:
response.status = status
146 changes: 144 additions & 2 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from flask import jsonify
from flask.views import MethodView

from flask_rest_api import Api
from flask_rest_api.blueprint import Blueprint
from flask_rest_api import Api, Blueprint, Page
from flask_rest_api.exceptions import InvalidLocationError


Expand Down Expand Up @@ -481,3 +480,146 @@ def func():

assert 'get' in paths['/test/route_1']
assert 'get' in paths['/test/route_2']

def test_blueprint_response_tuple(self, app):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/response')
@blp.response()
def func_response():
return {}

@blp.route('/response_code_int')
@blp.response()
def func_response_code_int():
return {}, 201

@blp.route('/response_code_str')
@blp.response()
def func_response_code_str():
return {}, '201 CREATED'

@blp.route('/response_headers')
@blp.response()
def func_response_headers():
return {}, {'X-header': 'test'}

@blp.route('/response_code_int_headers')
@blp.response()
def func_response_code_int_headers():
return {}, 201, {'X-header': 'test'}

@blp.route('/response_code_str_headers')
@blp.response()
def func_response_code_str_headers():
return {}, '201 CREATED', {'X-header': 'test'}

@blp.route('/response_wrong_tuple')
@blp.response()
def func_response_wrong_tuple():
return {}, 201, {'X-header': 'test'}, 'extra'

api.register_blueprint(blp)

response = client.get('/test/response')
assert response.status_code == 200
assert response.json == {}
response = client.get('/test/response_code_int')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
response = client.get('/test/response_code_str')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
response = client.get('/test/response_headers')
assert response.status_code == 200
assert response.json == {}
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_code_int_headers')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_code_str_headers')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_wrong_tuple')
assert response.status_code == 500

def test_blueprint_pagination_response_tuple(self, app):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/response')
@blp.response()
@blp.paginate(Page)
def func_response():
return [1, 2]

@blp.route('/response_code')
@blp.response()
@blp.paginate(Page)
def func_response_code():
return [1, 2], 201

@blp.route('/response_headers')
@blp.response()
@blp.paginate(Page)
def func_response_headers():
return [1, 2], {'X-header': 'test'}

@blp.route('/response_code_headers')
@blp.response()
@blp.paginate(Page)
def func_response_code_headers():
return [1, 2], 201, {'X-header': 'test'}

@blp.route('/response_wrong_tuple')
@blp.response()
@blp.paginate(Page)
def func_response_wrong_tuple():
return [1, 2], 201, {'X-header': 'test'}, 'extra'

api.register_blueprint(blp)

response = client.get('/test/response')
assert response.status_code == 200
assert response.json == [1, 2]
response = client.get('/test/response_code')
assert response.status_code == 201
assert response.json == [1, 2]
response = client.get('/test/response_headers')
assert response.status_code == 200
assert response.json == [1, 2]
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_code_headers')
assert response.status_code == 201
assert response.json == [1, 2]
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_wrong_tuple')
assert response.status_code == 500

def test_blueprint_response_response_object(self, app, schemas):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/response')
# Schema is ignored when response object is returned
@blp.response(schemas.DocSchema, code=200)
def func_response():
return jsonify({}), 201, {'X-header': 'test'}

api.register_blueprint(blp)

response = client.get('/test/response')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
assert response.headers['X-header'] == 'test'
22 changes: 21 additions & 1 deletion tests/test_etag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from flask import Response
from flask import jsonify, Response
from flask.views import MethodView

from flask_rest_api import Api, Blueprint, abort
Expand Down Expand Up @@ -358,6 +358,26 @@ def test_etag_set_etag_in_response(self, app, schemas, paginate):
blp._set_etag_in_response(resp, item, etag_schema)
assert resp.get_etag() == (etag_with_schema, False)

def test_etag_response_object(self, app):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/')
@blp.etag
@blp.response()
def func_response_etag():
# When the view function returns a Response object,
# the ETag must be specified manually
blp.set_etag('test')
return jsonify({})

api.register_blueprint(blp)

response = client.get('/test/')
assert response.json == {}
assert response.get_etag() == (blp._generate_etag('test'), False)

def test_etag_operations_etag_enabled(self, app_with_etag):

client = app_with_etag.test_client()
Expand Down

0 comments on commit 23ee9ab

Please sign in to comment.