Skip to content

Commit

Permalink
fix: Allow for case insensitive header names
Browse files Browse the repository at this point in the history
  • Loading branch information
jfuss committed May 21, 2018
2 parents c54ed8c + ea3716e commit 0d39e43
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 3 deletions.
19 changes: 18 additions & 1 deletion samcli/local/apigw/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@
LOG = logging.getLogger(__name__)


class CaseInsensitiveDict(dict):
"""
Implement a simple case insensitive dictionary for storing headers. To preserve the original
case of the given Header (e.g. X-FooBar-Fizz) this only touches the get and contains magic
methods rather than implementing a __setitem__ where we normalize the case of the headers.
"""

def __getitem__(self, key):
matches = [v for k, v in self.items() if k.lower() == key.lower()]
if not matches:
raise KeyError(key)
return matches[0]

def __contains__(self, key):
return key.lower() in [k.lower() for k in self.keys()]


class Route(object):

def __init__(self, methods, function_name, path, binary_types=None):
Expand Down Expand Up @@ -270,7 +287,7 @@ def _parse_lambda_output(lambda_output, binary_types, flask_request):
raise TypeError("Lambda returned %{s} instead of dict", type(json_output))

status_code = json_output.get("statusCode") or 200
headers = json_output.get("headers") or {}
headers = CaseInsensitiveDict(json_output.get("headers") or {})
body = json_output.get("body") or "no data"
is_base_64_encoded = json_output.get("isBase64Encoded") or False

Expand Down
9 changes: 9 additions & 0 deletions tests/functional/function_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
}
"""

API_GATEWAY_CONTENT_TYPE_LOWER = """
exports.handler = function(event, context, callback){
body = JSON.stringify("hello")
response = {"statusCode":200,"headers":{"content-type":"text/plain"},"body":body,"isBase64Encoded":false}
context.done(null, response);
}
"""

API_GATEWAY_ECHO_BASE64_EVENT = """
exports.base54request = function(event, context, callback){
Expand Down
54 changes: 53 additions & 1 deletion tests/functional/local/apigw/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mock import Mock

from samcli.local.apigw.service import Route, Service
from tests.functional.function_code import nodejs_lambda, API_GATEWAY_ECHO_EVENT, API_GATEWAY_BAD_PROXY_RESPONSE, API_GATEWAY_ECHO_BASE64_EVENT
from tests.functional.function_code import nodejs_lambda, API_GATEWAY_ECHO_EVENT, API_GATEWAY_BAD_PROXY_RESPONSE, API_GATEWAY_ECHO_BASE64_EVENT, API_GATEWAY_CONTENT_TYPE_LOWER
from samcli.commands.local.lib import provider
from samcli.local.lambdafn.runtime import LambdaRuntime
from samcli.commands.local.lib.local_lambda import LocalLambdaRunner
Expand Down Expand Up @@ -63,6 +63,58 @@ def test_non_proxy_response(self):
self.assertEquals(response.headers.get('Content-Type'), "application/json")


class TestService_ContentType(TestCase):
@classmethod
def setUpClass(cls):
cls.code_abs_path = nodejs_lambda(API_GATEWAY_CONTENT_TYPE_LOWER)

# Let's convert this absolute path to relative path. Let the parent be the CWD, and codeuri be the folder
cls.cwd = os.path.dirname(cls.code_abs_path)
cls.code_uri = os.path.relpath(cls.code_abs_path, cls.cwd) # Get relative path with respect to CWD

cls.function_name = "name"

cls.function = provider.Function(name=cls.function_name, runtime="nodejs4.3", memory=256, timeout=5,
handler="index.handler", codeuri=cls.code_uri, environment=None,
rolearn=None)

cls.base64_response_function = provider.Function(name=cls.function_name, runtime="nodejs4.3", memory=256, timeout=5,
handler="index.handler", codeuri=cls.code_uri, environment=None,
rolearn=None)

cls.mock_function_provider = Mock()
cls.mock_function_provider.get.return_value = cls.function

list_of_routes = [
Route(['GET'], cls.function_name, '/'),
]

cls.service, cls.port, cls.url, cls.scheme = make_service(list_of_routes, cls.mock_function_provider, cls.cwd)
cls.service.create()
t = threading.Thread(name='thread', target=cls.service.run, args=())
t.setDaemon(True)
t.start()

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.code_abs_path)

def setUp(self):
# Print full diff when comparing large dictionaries
self.maxDiff = None

def test_calling_service_root(self):
expected = "hello"

response = requests.get(self.url)

actual = response.json()

self.assertEquals(actual, expected)
self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get('content-type'), "text/plain")


class TestService_EventSerialization(TestCase):
@classmethod
def setUpClass(cls):
Expand Down
35 changes: 34 additions & 1 deletion tests/unit/local/apigw/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from parameterized import parameterized, param

from samcli.local.apigw.service import Service, Route
from samcli.local.apigw.service import Service, Route, CaseInsensitiveDict
from samcli.local.lambdafn.exceptions import FunctionNotFound


Expand Down Expand Up @@ -525,3 +525,36 @@ def test_should_base64_encode_returns_true(self, test_case_name, binary_types, m
])
def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype):
self.assertFalse(Service._should_base64_encode(binary_types, mimetype))


class TestService_CaseInsensiveDict(TestCase):

def setUp(self):
self.data = CaseInsensitiveDict({
'Content-Type': 'text/html',
'Browser': 'APIGW',
})

def test_contains_lower(self):
self.assertTrue('content-type' in self.data)

def test_contains_title(self):
self.assertTrue('Content-Type' in self.data)

def test_contains_upper(self):
self.assertTrue('CONTENT-TYPE' in self.data)

def test_contains_browser_key(self):
self.assertTrue('Browser' in self.data)

def test_contains_not_in(self):
self.assertTrue('Dog-Food' not in self.data)

def test_setitem_found(self):
self.data['Browser'] = 'APIGW'

self.assertTrue(self.data['browser'])

def test_keyerror(self):
with self.assertRaises(KeyError):
self.data['does-not-exist']

0 comments on commit 0d39e43

Please sign in to comment.