diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 91347743da..755502670f 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -5,8 +5,9 @@ import base64 from flask import Flask, request +from werkzeug.datastructures import Headers -from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser, CaseInsensitiveDict +from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser from samcli.lib.utils.stream_writer import StreamWriter from samcli.local.lambdafn.exceptions import FunctionNotFound from samcli.local.events.api_event import ContextIdentity, RequestContext, ApiGatewayLambdaEvent @@ -165,7 +166,7 @@ def _request_handler(self, **kwargs): route.binary_types, request) except (KeyError, TypeError, ValueError): - LOG.error("Function returned an invalid response (must include one of: body, headers or " + LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or " "statusCode in the response object). Response received: %s", lambda_response) return ServiceErrorResponses.lambda_failure_response() @@ -207,7 +208,8 @@ def _parse_lambda_output(lambda_output, binary_types, flask_request): raise TypeError("Lambda returned %{s} instead of dict", type(json_output)) status_code = json_output.get("statusCode") or 200 - headers = CaseInsensitiveDict(json_output.get("headers") or {}) + headers = LocalApigwService._merge_response_headers(json_output.get("headers") or {}, + json_output.get("multiValueHeaders") or {}) body = json_output.get("body") or "no data" is_base_64_encoded = json_output.get("isBase64Encoded") or False @@ -244,6 +246,7 @@ def _invalid_apig_response_keys(output): "statusCode", "body", "headers", + "multiValueHeaders", "isBase64Encoded" } # In Python 2.7, need to explicitly make the Dictionary keys into a set @@ -261,7 +264,7 @@ def _should_base64_decode_body(binary_types, flask_request, lamba_response_heade Corresponds to self.binary_types (aka. what is parsed from SAM Template flask_request flask.request Flask request - lamba_response_headers dict + lamba_response_headers werkzeug.datastructures.Headers Headers Lambda returns is_base_64_encoded bool True if the body is Base64 encoded @@ -271,11 +274,44 @@ def _should_base64_decode_body(binary_types, flask_request, lamba_response_heade True if the body from the request should be converted to binary, otherwise false """ - best_match_mimetype = flask_request.accept_mimetypes.best_match([lamba_response_headers["Content-Type"]]) + best_match_mimetype = flask_request.accept_mimetypes.best_match(lamba_response_headers.get_all("Content-Type")) is_best_match_in_binary_types = best_match_mimetype in binary_types or '*/*' in binary_types return best_match_mimetype and is_best_match_in_binary_types and is_base_64_encoded + @staticmethod + def _merge_response_headers(headers, multi_headers): + """ + Merge multiValueHeaders headers with headers + + * If you specify values for both headers and multiValueHeaders, API Gateway merges them into a single list. + * If the same key-value pair is specified in both, the value will only appear once. + + Parameters + ---------- + headers dict + Headers map from the lambda_response_headers + multi_headers dict + multiValueHeaders map from the lambda_response_headers + + Returns + ------- + Merged list in accordance to the AWS documentation within a Flask Headers object + + """ + + processed_headers = Headers(multi_headers) + + for header in headers: + # Prevent duplication of values when the key-value pair exists in both + # headers and multi_headers, but preserve order from multi_headers + if header in multi_headers and headers[header] in multi_headers[header]: + continue + + processed_headers.add(header, headers[header]) + + return processed_headers + @staticmethod def _construct_event(flask_request, port, binary_types): """ diff --git a/samcli/local/lambda_service/local_lambda_invoke_service.py b/samcli/local/lambda_service/local_lambda_invoke_service.py index 7c684f9159..976a468da8 100644 --- a/samcli/local/lambda_service/local_lambda_invoke_service.py +++ b/samcli/local/lambda_service/local_lambda_invoke_service.py @@ -7,7 +7,7 @@ from flask import Flask, request from samcli.lib.utils.stream_writer import StreamWriter -from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser, CaseInsensitiveDict +from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser from samcli.local.lambdafn.exceptions import FunctionNotFound from .lambda_error_responses import LambdaErrorResponses @@ -92,7 +92,7 @@ def validate_request(): LOG.debug("Query parameters are in the request but not supported") return LambdaErrorResponses.invalid_request_content("Query Parameters are not supported") - request_headers = CaseInsensitiveDict(flask_request.headers) + request_headers = flask_request.headers log_type = request_headers.get('X-Amz-Log-Type', 'None') if log_type != 'None': diff --git a/samcli/local/services/base_local_service.py b/samcli/local/services/base_local_service.py index e94ae1ffac..0d54357e86 100644 --- a/samcli/local/services/base_local_service.py +++ b/samcli/local/services/base_local_service.py @@ -9,23 +9,6 @@ LOG = logging.getLogger(__name__) -class CaseInsensitiveDict(dict): - """ - Implement a simple case insensitive dictionary for storing headers. To preserve the original - case of the given Header (e.g. X-FooBar-Fizz) this only touches the get and contains magic - methods rather than implementing a __setitem__ where we normalize the case of the headers. - """ - - def __getitem__(self, key): - matches = [v for k, v in self.items() if k.lower() == key.lower()] - if not matches: - raise KeyError(key) - return matches[0] - - def __contains__(self, key): - return key.lower() in [k.lower() for k in self.keys()] - - class BaseLocalService(object): def __init__(self, is_debugging, port, host): @@ -86,7 +69,7 @@ def service_response(body, headers, status_code): Constructs a Flask Response from the body, headers, and status_code. :param str body: Response body as a string - :param dict headers: headers for the response + :param werkzeug.datastructures.Headers headers: headers for the response :param int status_code: status_code for response :return: Flask Response """ diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index dc8971cddb..168eadfd5c 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -298,6 +298,20 @@ class TestServiceResponses(StartApiIntegBaseClass): def setUp(self): self.url = "http://127.0.0.1:{}".format(self.port) + def test_multiple_headers_response(self): + response = requests.get(self.url + "/multipleheaders") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "text/plain") + self.assertEquals(response.headers.get("MyCustomHeader"), 'Value1, Value2') + + def test_multiple_headers_overrides_headers_response(self): + response = requests.get(self.url + "/multipleheadersoverridesheaders") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "text/plain") + self.assertEquals(response.headers.get("MyCustomHeader"), 'Value1, Value2, Custom') + def test_binary_response(self): """ Binary data is returned correctly diff --git a/tests/integration/testdata/start_api/main.py b/tests/integration/testdata/start_api/main.py index 127e758981..7e3875cdfc 100644 --- a/tests/integration/testdata/start_api/main.py +++ b/tests/integration/testdata/start_api/main.py @@ -89,3 +89,21 @@ def echo_base64_event_body(event, context): }, "isBase64Encoded": event["isBase64Encoded"] } + + +def multiple_headers(event, context): + return { + "statusCode": 200, + "body": "hello", + "headers": {"Content-Type": "text/plain"}, + "multiValueHeaders": {"MyCustomHeader": ['Value1', 'Value2']} + } + + +def multiple_headers_overrides_headers(event, context): + return { + "statusCode": 200, + "body": "hello", + "headers": {"Content-Type": "text/plain", "MyCustomHeader": 'Custom'}, + "multiValueHeaders": {"MyCustomHeader": ['Value1', 'Value2']} + } diff --git a/tests/integration/testdata/start_api/template.yaml b/tests/integration/testdata/start_api/template.yaml index 8820753c47..c18cee2079 100644 --- a/tests/integration/testdata/start_api/template.yaml +++ b/tests/integration/testdata/start_api/template.yaml @@ -229,3 +229,29 @@ Resources: Properties: Method: POST Path: /echobase64eventbody + + MultipleHeadersResponseFunction: + Type: AWS::Serverless::Function + Properties: + Handler: main.multiple_headers + Runtime: python3.6 + CodeUri: . + Events: + IdBasePath: + Type: Api + Properties: + Method: GET + Path: /multipleheaders + + MultipleHeadersOverridesHeadersResponseFunction: + Type: AWS::Serverless::Function + Properties: + Handler: main.multiple_headers_overrides_headers + Runtime: python3.6 + CodeUri: . + Events: + IdBasePath: + Type: Api + Properties: + Method: GET + Path: /multipleheadersoverridesheaders diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 68ee809982..50702c8bf2 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -4,6 +4,7 @@ import base64 from parameterized import parameterized, param +from werkzeug.datastructures import Headers from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -260,6 +261,39 @@ def test_class_initialization(self): self.assertEquals(self.api_gateway.path, '/') +class TestLambdaHeaderDictionaryMerge(TestCase): + + def test_empty_dictionaries_produce_empty_result(self): + headers = {} + multi_value_headers = {} + + result = LocalApigwService._merge_response_headers(headers, multi_value_headers) + + self.assertEquals(result, Headers({})) + + def test_headers_are_merged(self): + headers = {"h1": "value1", "h2": "value2", "h3": "value3"} + multi_value_headers = {"h3": ["value4"]} + + result = LocalApigwService._merge_response_headers(headers, multi_value_headers) + + self.assertIn("h1", result) + self.assertIn("h2", result) + self.assertIn("h3", result) + self.assertEquals(result["h1"], "value1") + self.assertEquals(result["h2"], "value2") + self.assertEquals(result.get_all("h3"), ["value4", "value3"]) + + def test_merge_does_not_duplicate_values(self): + headers = {"h1": "ValueB"} + multi_value_headers = {"h1": ["ValueA", "ValueB", "ValueC"]} + + result = LocalApigwService._merge_response_headers(headers, multi_value_headers) + + self.assertIn("h1", result) + self.assertEquals(result.get_all("h1"), ["ValueA", "ValueB", "ValueC"]) + + class TestServiceParsingLambdaOutput(TestCase): def test_default_content_type_header_added_with_no_headers(self): @@ -289,6 +323,33 @@ def test_custom_content_type_header_is_not_modified(self): self.assertIn("Content-Type", headers) self.assertEquals(headers["Content-Type"], "text/xml") + def test_custom_content_type_multivalue_header_is_not_modified(self): + lambda_output = '{"statusCode": 200, "multiValueHeaders":{"Content-Type": ["text/xml"]}, "body": "{}", ' \ + '"isBase64Encoded": false}' + + (_, headers, _) = LocalApigwService._parse_lambda_output(lambda_output, binary_types=[], flask_request=Mock()) + + self.assertIn("Content-Type", headers) + self.assertEquals(headers["Content-Type"], "text/xml") + + def test_multivalue_headers(self): + lambda_output = '{"statusCode": 200, "multiValueHeaders":{"X-Foo": ["bar", "42"]}, ' \ + '"body": "{\\"message\\":\\"Hello from Lambda\\"}", "isBase64Encoded": false}' + + (_, headers, _) = LocalApigwService._parse_lambda_output(lambda_output, binary_types=[], flask_request=Mock()) + + self.assertEquals(headers, Headers({"Content-Type": "application/json", "X-Foo": ["bar", "42"]})) + + def test_single_and_multivalue_headers(self): + lambda_output = '{"statusCode": 200, "headers":{"X-Foo": "foo", "X-Bar": "bar"}, ' \ + '"multiValueHeaders":{"X-Foo": ["bar", "42"]}, ' \ + '"body": "{\\"message\\":\\"Hello from Lambda\\"}", "isBase64Encoded": false}' + + (_, headers, _) = LocalApigwService._parse_lambda_output(lambda_output, binary_types=[], flask_request=Mock()) + + self.assertEquals( + headers, Headers({"Content-Type": "application/json", "X-Bar": "bar", "X-Foo": ["bar", "42", "foo"]})) + def test_extra_values_raise(self): lambda_output = '{"statusCode": 200, "headers": {}, "body": "{\\"message\\":\\"Hello from Lambda\\"}", ' \ '"isBase64Encoded": false, "another_key": "some value"}' @@ -307,7 +368,7 @@ def test_parse_returns_correct_tuple(self): flask_request=Mock()) self.assertEquals(status_code, 200) - self.assertEquals(headers, {"Content-Type": "application/json"}) + self.assertEquals(headers, Headers({"Content-Type": "application/json"})) self.assertEquals(body, '{"message":"Hello from Lambda"}') @patch('samcli.local.apigw.local_apigw_service.LocalApigwService._should_base64_decode_body') @@ -326,7 +387,7 @@ def test_parse_returns_decodes_base64_to_binary(self, should_decode_body_patch): flask_request=Mock()) self.assertEquals(status_code, 200) - self.assertEquals(headers, {"Content-Type": "application/octet-stream"}) + self.assertEquals(headers, Headers({"Content-Type": "application/octet-stream"})) self.assertEquals(body, binary_body) def test_status_code_not_int(self): @@ -388,7 +449,7 @@ def test_properties_are_null(self): flask_request=Mock()) self.assertEquals(status_code, 200) - self.assertEquals(headers, {"Content-Type": "application/json"}) + self.assertEquals(headers, Headers({"Content-Type": "application/json"})) self.assertEquals(body, "no data") diff --git a/tests/unit/local/services/test_base_local_service.py b/tests/unit/local/services/test_base_local_service.py index 00e8e3efe3..a37c5e21fb 100644 --- a/tests/unit/local/services/test_base_local_service.py +++ b/tests/unit/local/services/test_base_local_service.py @@ -3,7 +3,7 @@ from parameterized import parameterized, param -from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser, CaseInsensitiveDict +from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser class TestLocalHostRunner(TestCase): @@ -128,36 +128,3 @@ def test_get_lambda_output_extracts_response(self, test_case_name, stdout_data, ]) def test_is_lambda_error_response(self, input, exected_result): self.assertEquals(LambdaOutputParser.is_lambda_error_response(input), exected_result) - - -class CaseInsensiveDict(TestCase): - - def setUp(self): - self.data = CaseInsensitiveDict({ - 'Content-Type': 'text/html', - 'Browser': 'APIGW', - }) - - def test_contains_lower(self): - self.assertTrue('content-type' in self.data) - - def test_contains_title(self): - self.assertTrue('Content-Type' in self.data) - - def test_contains_upper(self): - self.assertTrue('CONTENT-TYPE' in self.data) - - def test_contains_browser_key(self): - self.assertTrue('Browser' in self.data) - - def test_contains_not_in(self): - self.assertTrue('Dog-Food' not in self.data) - - def test_setitem_found(self): - self.data['Browser'] = 'APIGW' - - self.assertTrue(self.data['browser']) - - def test_keyerror(self): - with self.assertRaises(KeyError): - self.data['does-not-exist']