Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local support for response multiValueHeaders #1166

Merged
merged 5 commits into from
May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import base64

from flask import Flask, request
from werkzeug.datastructures import Headers

from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser, CaseInsensitiveDict
from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser
from samcli.lib.utils.stream_writer import StreamWriter
from samcli.local.lambdafn.exceptions import FunctionNotFound
from samcli.local.events.api_event import ContextIdentity, RequestContext, ApiGatewayLambdaEvent
Expand Down Expand Up @@ -165,7 +166,7 @@ def _request_handler(self, **kwargs):
route.binary_types,
request)
except (KeyError, TypeError, ValueError):
LOG.error("Function returned an invalid response (must include one of: body, headers or "
LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or "
"statusCode in the response object). Response received: %s", lambda_response)
return ServiceErrorResponses.lambda_failure_response()

Expand Down Expand Up @@ -207,7 +208,8 @@ def _parse_lambda_output(lambda_output, binary_types, flask_request):
raise TypeError("Lambda returned %{s} instead of dict", type(json_output))

status_code = json_output.get("statusCode") or 200
headers = CaseInsensitiveDict(json_output.get("headers") or {})
headers = LocalApigwService._merge_response_headers(json_output.get("headers") or {},
json_output.get("multiValueHeaders") or {})
body = json_output.get("body") or "no data"
is_base_64_encoded = json_output.get("isBase64Encoded") or False

Expand Down Expand Up @@ -244,6 +246,7 @@ def _invalid_apig_response_keys(output):
"statusCode",
"body",
"headers",
"multiValueHeaders",
"isBase64Encoded"
}
# In Python 2.7, need to explicitly make the Dictionary keys into a set
Expand All @@ -261,7 +264,7 @@ def _should_base64_decode_body(binary_types, flask_request, lamba_response_heade
Corresponds to self.binary_types (aka. what is parsed from SAM Template
flask_request flask.request
Flask request
lamba_response_headers dict
lamba_response_headers werkzeug.datastructures.Headers
Headers Lambda returns
is_base_64_encoded bool
True if the body is Base64 encoded
Expand All @@ -271,11 +274,44 @@ def _should_base64_decode_body(binary_types, flask_request, lamba_response_heade
True if the body from the request should be converted to binary, otherwise false

"""
best_match_mimetype = flask_request.accept_mimetypes.best_match([lamba_response_headers["Content-Type"]])
best_match_mimetype = flask_request.accept_mimetypes.best_match(lamba_response_headers.get_all("Content-Type"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this is the right behavior that matches API Gateway? Seems reasonable just want to make sure we emulate the service as closely as possible.

is_best_match_in_binary_types = best_match_mimetype in binary_types or '*/*' in binary_types

return best_match_mimetype and is_best_match_in_binary_types and is_base_64_encoded

@staticmethod
def _merge_response_headers(headers, multi_headers):
"""
Merge multiValueHeaders headers with headers

* If you specify values for both headers and multiValueHeaders, API Gateway merges them into a single list.
* If the same key-value pair is specified in both, the value will only appear once.

Parameters
----------
headers dict
Headers map from the lambda_response_headers
multi_headers dict
multiValueHeaders map from the lambda_response_headers

Returns
-------
Merged list in accordance to the AWS documentation within a Flask Headers object

"""

processed_headers = Headers(multi_headers)

for header in headers:
# Prevent duplication of values when the key-value pair exists in both
# headers and multi_headers, but preserve order from multi_headers
if header in multi_headers and headers[header] in multi_headers[header]:
continue

processed_headers.add(header, headers[header])

return processed_headers

@staticmethod
def _construct_event(flask_request, port, binary_types):
"""
Expand Down
4 changes: 2 additions & 2 deletions samcli/local/lambda_service/local_lambda_invoke_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flask import Flask, request

from samcli.lib.utils.stream_writer import StreamWriter
from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser, CaseInsensitiveDict
from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser
from samcli.local.lambdafn.exceptions import FunctionNotFound
from .lambda_error_responses import LambdaErrorResponses

Expand Down Expand Up @@ -92,7 +92,7 @@ def validate_request():
LOG.debug("Query parameters are in the request but not supported")
return LambdaErrorResponses.invalid_request_content("Query Parameters are not supported")

request_headers = CaseInsensitiveDict(flask_request.headers)
request_headers = flask_request.headers

log_type = request_headers.get('X-Amz-Log-Type', 'None')
if log_type != 'None':
Expand Down
19 changes: 1 addition & 18 deletions samcli/local/services/base_local_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,6 @@
LOG = logging.getLogger(__name__)


class CaseInsensitiveDict(dict):
"""
Implement a simple case insensitive dictionary for storing headers. To preserve the original
case of the given Header (e.g. X-FooBar-Fizz) this only touches the get and contains magic
methods rather than implementing a __setitem__ where we normalize the case of the headers.
"""

def __getitem__(self, key):
matches = [v for k, v in self.items() if k.lower() == key.lower()]
if not matches:
raise KeyError(key)
return matches[0]

def __contains__(self, key):
return key.lower() in [k.lower() for k in self.keys()]


class BaseLocalService(object):

def __init__(self, is_debugging, port, host):
Expand Down Expand Up @@ -86,7 +69,7 @@ def service_response(body, headers, status_code):
Constructs a Flask Response from the body, headers, and status_code.

:param str body: Response body as a string
:param dict headers: headers for the response
:param werkzeug.datastructures.Headers headers: headers for the response
:param int status_code: status_code for response
:return: Flask Response
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/local/start_api/test_start_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ class TestServiceResponses(StartApiIntegBaseClass):
def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_multiple_headers_response(self):
response = requests.get(self.url + "/multipleheaders")

self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get("Content-Type"), "text/plain")
self.assertEquals(response.headers.get("MyCustomHeader"), 'Value1, Value2')

def test_multiple_headers_overrides_headers_response(self):
response = requests.get(self.url + "/multipleheadersoverridesheaders")

self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get("Content-Type"), "text/plain")
self.assertEquals(response.headers.get("MyCustomHeader"), 'Value1, Value2, Custom')

def test_binary_response(self):
"""
Binary data is returned correctly
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/testdata/start_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,21 @@ def echo_base64_event_body(event, context):
},
"isBase64Encoded": event["isBase64Encoded"]
}


def multiple_headers(event, context):
return {
"statusCode": 200,
"body": "hello",
"headers": {"Content-Type": "text/plain"},
"multiValueHeaders": {"MyCustomHeader": ['Value1', 'Value2']}
}


def multiple_headers_overrides_headers(event, context):
return {
"statusCode": 200,
"body": "hello",
"headers": {"Content-Type": "text/plain", "MyCustomHeader": 'Custom'},
"multiValueHeaders": {"MyCustomHeader": ['Value1', 'Value2']}
}
26 changes: 26 additions & 0 deletions tests/integration/testdata/start_api/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,29 @@ Resources:
Properties:
Method: POST
Path: /echobase64eventbody

MultipleHeadersResponseFunction:
Type: AWS::Serverless::Function
Properties:
Handler: main.multiple_headers
Runtime: python3.6
CodeUri: .
Events:
IdBasePath:
Type: Api
Properties:
Method: GET
Path: /multipleheaders

MultipleHeadersOverridesHeadersResponseFunction:
Type: AWS::Serverless::Function
Properties:
Handler: main.multiple_headers_overrides_headers
Runtime: python3.6
CodeUri: .
Events:
IdBasePath:
Type: Api
Properties:
Method: GET
Path: /multipleheadersoverridesheaders
67 changes: 64 additions & 3 deletions tests/unit/local/apigw/test_local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64

from parameterized import parameterized, param
from werkzeug.datastructures import Headers

from samcli.local.apigw.local_apigw_service import LocalApigwService, Route
from samcli.local.lambdafn.exceptions import FunctionNotFound
Expand Down Expand Up @@ -260,6 +261,39 @@ def test_class_initialization(self):
self.assertEquals(self.api_gateway.path, '/')


class TestLambdaHeaderDictionaryMerge(TestCase):

def test_empty_dictionaries_produce_empty_result(self):
headers = {}
multi_value_headers = {}

result = LocalApigwService._merge_response_headers(headers, multi_value_headers)

self.assertEquals(result, Headers({}))

def test_headers_are_merged(self):
headers = {"h1": "value1", "h2": "value2", "h3": "value3"}
multi_value_headers = {"h3": ["value4"]}

result = LocalApigwService._merge_response_headers(headers, multi_value_headers)

self.assertIn("h1", result)
self.assertIn("h2", result)
self.assertIn("h3", result)
self.assertEquals(result["h1"], "value1")
self.assertEquals(result["h2"], "value2")
self.assertEquals(result.get_all("h3"), ["value4", "value3"])

def test_merge_does_not_duplicate_values(self):
headers = {"h1": "ValueB"}
multi_value_headers = {"h1": ["ValueA", "ValueB", "ValueC"]}

result = LocalApigwService._merge_response_headers(headers, multi_value_headers)

self.assertIn("h1", result)
self.assertEquals(result.get_all("h1"), ["ValueA", "ValueB", "ValueC"])


class TestServiceParsingLambdaOutput(TestCase):

def test_default_content_type_header_added_with_no_headers(self):
Expand Down Expand Up @@ -289,6 +323,33 @@ def test_custom_content_type_header_is_not_modified(self):
self.assertIn("Content-Type", headers)
self.assertEquals(headers["Content-Type"], "text/xml")

def test_custom_content_type_multivalue_header_is_not_modified(self):
lambda_output = '{"statusCode": 200, "multiValueHeaders":{"Content-Type": ["text/xml"]}, "body": "{}", ' \
'"isBase64Encoded": false}'

(_, headers, _) = LocalApigwService._parse_lambda_output(lambda_output, binary_types=[], flask_request=Mock())

self.assertIn("Content-Type", headers)
self.assertEquals(headers["Content-Type"], "text/xml")

def test_multivalue_headers(self):
lambda_output = '{"statusCode": 200, "multiValueHeaders":{"X-Foo": ["bar", "42"]}, ' \
'"body": "{\\"message\\":\\"Hello from Lambda\\"}", "isBase64Encoded": false}'

(_, headers, _) = LocalApigwService._parse_lambda_output(lambda_output, binary_types=[], flask_request=Mock())

self.assertEquals(headers, Headers({"Content-Type": "application/json", "X-Foo": ["bar", "42"]}))

def test_single_and_multivalue_headers(self):
lambda_output = '{"statusCode": 200, "headers":{"X-Foo": "foo", "X-Bar": "bar"}, ' \
'"multiValueHeaders":{"X-Foo": ["bar", "42"]}, ' \
'"body": "{\\"message\\":\\"Hello from Lambda\\"}", "isBase64Encoded": false}'

(_, headers, _) = LocalApigwService._parse_lambda_output(lambda_output, binary_types=[], flask_request=Mock())

self.assertEquals(
headers, Headers({"Content-Type": "application/json", "X-Bar": "bar", "X-Foo": ["bar", "42", "foo"]}))

def test_extra_values_raise(self):
lambda_output = '{"statusCode": 200, "headers": {}, "body": "{\\"message\\":\\"Hello from Lambda\\"}", ' \
'"isBase64Encoded": false, "another_key": "some value"}'
Expand All @@ -307,7 +368,7 @@ def test_parse_returns_correct_tuple(self):
flask_request=Mock())

self.assertEquals(status_code, 200)
self.assertEquals(headers, {"Content-Type": "application/json"})
self.assertEquals(headers, Headers({"Content-Type": "application/json"}))
self.assertEquals(body, '{"message":"Hello from Lambda"}')

@patch('samcli.local.apigw.local_apigw_service.LocalApigwService._should_base64_decode_body')
Expand All @@ -326,7 +387,7 @@ def test_parse_returns_decodes_base64_to_binary(self, should_decode_body_patch):
flask_request=Mock())

self.assertEquals(status_code, 200)
self.assertEquals(headers, {"Content-Type": "application/octet-stream"})
self.assertEquals(headers, Headers({"Content-Type": "application/octet-stream"}))
self.assertEquals(body, binary_body)

def test_status_code_not_int(self):
Expand Down Expand Up @@ -388,7 +449,7 @@ def test_properties_are_null(self):
flask_request=Mock())

self.assertEquals(status_code, 200)
self.assertEquals(headers, {"Content-Type": "application/json"})
self.assertEquals(headers, Headers({"Content-Type": "application/json"}))
self.assertEquals(body, "no data")


Expand Down
35 changes: 1 addition & 34 deletions tests/unit/local/services/test_base_local_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from parameterized import parameterized, param

from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser, CaseInsensitiveDict
from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser


class TestLocalHostRunner(TestCase):
Expand Down Expand Up @@ -128,36 +128,3 @@ def test_get_lambda_output_extracts_response(self, test_case_name, stdout_data,
])
def test_is_lambda_error_response(self, input, exected_result):
self.assertEquals(LambdaOutputParser.is_lambda_error_response(input), exected_result)


class CaseInsensiveDict(TestCase):

def setUp(self):
self.data = CaseInsensitiveDict({
'Content-Type': 'text/html',
'Browser': 'APIGW',
})

def test_contains_lower(self):
self.assertTrue('content-type' in self.data)

def test_contains_title(self):
self.assertTrue('Content-Type' in self.data)

def test_contains_upper(self):
self.assertTrue('CONTENT-TYPE' in self.data)

def test_contains_browser_key(self):
self.assertTrue('Browser' in self.data)

def test_contains_not_in(self):
self.assertTrue('Dog-Food' not in self.data)

def test_setitem_found(self):
self.data['Browser'] = 'APIGW'

self.assertTrue(self.data['browser'])

def test_keyerror(self):
with self.assertRaises(KeyError):
self.data['does-not-exist']