diff --git a/samcli/local/apigw/service.py b/samcli/local/apigw/service.py index e21312187c..a1ec1b67d7 100644 --- a/samcli/local/apigw/service.py +++ b/samcli/local/apigw/service.py @@ -14,6 +14,23 @@ LOG = logging.getLogger(__name__) +class CaseInsensitiveDict(dict): + """ + Implement a simple case insensitive dictionary for storing headers. To preserve the original + case of the given Header (e.g. X-FooBar-Fizz) this only touches the get and contains magic + methods rather than implementing a __setitem__ where we normalize the case of the headers. + """ + + def __getitem__(self, key): + matches = [v for k, v in self.items() if k.lower() == key.lower()] + if not matches: + raise KeyError(key) + return matches[0] + + def __contains__(self, key): + return key.lower() in [k.lower() for k in self.keys()] + + class Route(object): def __init__(self, methods, function_name, path, binary_types=None): @@ -270,7 +287,7 @@ def _parse_lambda_output(lambda_output, binary_types, flask_request): raise TypeError("Lambda returned %{s} instead of dict", type(json_output)) status_code = json_output.get("statusCode") or 200 - headers = json_output.get("headers") or {} + headers = CaseInsensitiveDict(json_output.get("headers") or {}) body = json_output.get("body") or "no data" is_base_64_encoded = json_output.get("isBase64Encoded") or False diff --git a/tests/functional/function_code.py b/tests/functional/function_code.py index 1a88c6b269..80d8873826 100644 --- a/tests/functional/function_code.py +++ b/tests/functional/function_code.py @@ -63,6 +63,15 @@ } """ +API_GATEWAY_CONTENT_TYPE_LOWER = """ +exports.handler = function(event, context, callback){ + body = JSON.stringify("hello") + + response = {"statusCode":200,"headers":{"content-type":"text/plain"},"body":body,"isBase64Encoded":false} + context.done(null, response); +} +""" + API_GATEWAY_ECHO_BASE64_EVENT = """ exports.base54request = function(event, context, callback){ diff --git a/tests/functional/local/apigw/test_service.py b/tests/functional/local/apigw/test_service.py index ae6348b1f9..c5962eb859 100644 --- a/tests/functional/local/apigw/test_service.py +++ b/tests/functional/local/apigw/test_service.py @@ -10,7 +10,7 @@ from mock import Mock from samcli.local.apigw.service import Route, Service -from tests.functional.function_code import nodejs_lambda, API_GATEWAY_ECHO_EVENT, API_GATEWAY_BAD_PROXY_RESPONSE, API_GATEWAY_ECHO_BASE64_EVENT +from tests.functional.function_code import nodejs_lambda, API_GATEWAY_ECHO_EVENT, API_GATEWAY_BAD_PROXY_RESPONSE, API_GATEWAY_ECHO_BASE64_EVENT, API_GATEWAY_CONTENT_TYPE_LOWER from samcli.commands.local.lib import provider from samcli.local.lambdafn.runtime import LambdaRuntime from samcli.commands.local.lib.local_lambda import LocalLambdaRunner @@ -63,6 +63,58 @@ def test_non_proxy_response(self): self.assertEquals(response.headers.get('Content-Type'), "application/json") +class TestService_ContentType(TestCase): + @classmethod + def setUpClass(cls): + cls.code_abs_path = nodejs_lambda(API_GATEWAY_CONTENT_TYPE_LOWER) + + # Let's convert this absolute path to relative path. Let the parent be the CWD, and codeuri be the folder + cls.cwd = os.path.dirname(cls.code_abs_path) + cls.code_uri = os.path.relpath(cls.code_abs_path, cls.cwd) # Get relative path with respect to CWD + + cls.function_name = "name" + + cls.function = provider.Function(name=cls.function_name, runtime="nodejs4.3", memory=256, timeout=5, + handler="index.handler", codeuri=cls.code_uri, environment=None, + rolearn=None) + + cls.base64_response_function = provider.Function(name=cls.function_name, runtime="nodejs4.3", memory=256, timeout=5, + handler="index.handler", codeuri=cls.code_uri, environment=None, + rolearn=None) + + cls.mock_function_provider = Mock() + cls.mock_function_provider.get.return_value = cls.function + + list_of_routes = [ + Route(['GET'], cls.function_name, '/'), + ] + + cls.service, cls.port, cls.url, cls.scheme = make_service(list_of_routes, cls.mock_function_provider, cls.cwd) + cls.service.create() + t = threading.Thread(name='thread', target=cls.service.run, args=()) + t.setDaemon(True) + t.start() + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.code_abs_path) + + def setUp(self): + # Print full diff when comparing large dictionaries + self.maxDiff = None + + def test_calling_service_root(self): + expected = "hello" + + response = requests.get(self.url) + + actual = response.json() + + self.assertEquals(actual, expected) + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get('content-type'), "text/plain") + + class TestService_EventSerialization(TestCase): @classmethod def setUpClass(cls): diff --git a/tests/unit/local/apigw/test_service.py b/tests/unit/local/apigw/test_service.py index b898cec383..7a7328b983 100644 --- a/tests/unit/local/apigw/test_service.py +++ b/tests/unit/local/apigw/test_service.py @@ -5,7 +5,7 @@ from parameterized import parameterized, param -from samcli.local.apigw.service import Service, Route +from samcli.local.apigw.service import Service, Route, CaseInsensitiveDict from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -525,3 +525,36 @@ def test_should_base64_encode_returns_true(self, test_case_name, binary_types, m ]) def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): self.assertFalse(Service._should_base64_encode(binary_types, mimetype)) + + +class TestService_CaseInsensiveDict(TestCase): + + def setUp(self): + self.data = CaseInsensitiveDict({ + 'Content-Type': 'text/html', + 'Browser': 'APIGW', + }) + + def test_contains_lower(self): + self.assertTrue('content-type' in self.data) + + def test_contains_title(self): + self.assertTrue('Content-Type' in self.data) + + def test_contains_upper(self): + self.assertTrue('CONTENT-TYPE' in self.data) + + def test_contains_browser_key(self): + self.assertTrue('Browser' in self.data) + + def test_contains_not_in(self): + self.assertTrue('Dog-Food' not in self.data) + + def test_setitem_found(self): + self.data['Browser'] = 'APIGW' + + self.assertTrue(self.data['browser']) + + def test_keyerror(self): + with self.assertRaises(KeyError): + self.data['does-not-exist']