diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index b2b2eab1b3..685485263b 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -220,13 +220,14 @@ def binary_media_types(self): return list(self.binary_media_types_set) -_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"]) +_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "allow_credentials", "max_age"]) _CorsTuple.__new__.__defaults__ = ( None, # Allow Origin defaults to None None, # Allow Methods is optional and defaults to empty None, # Allow Headers is optional and defaults to empty + None, # Allow Credentials is optional and defaults to empty None, # MaxAge is optional and defaults to empty ) @@ -250,6 +251,7 @@ def cors_to_headers(cors): "Access-Control-Allow-Origin": cors.allow_origin, "Access-Control-Allow-Methods": cors.allow_methods, "Access-Control-Allow-Headers": cors.allow_headers, + "Access-Control-Allow-Credentials": cors.allow_credentials, "Access-Control-Max-Age": cors.max_age, } # Filters out items in the headers dictionary that isn't empty. diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 01554e8828..c57ce1c6df 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -111,10 +111,15 @@ def extract_cors(self, cors_prop): allow_origin = self._get_cors_prop(cors_prop, "AllowOrigin") allow_headers = self._get_cors_prop(cors_prop, "AllowHeaders") + allow_credentials = self._get_cors_prop(cors_prop, "AllowCredentials", is_string=False) max_age = self._get_cors_prop(cors_prop, "MaxAge") cors = Cors( - allow_origin=allow_origin, allow_methods=allow_methods, allow_headers=allow_headers, max_age=max_age + allow_origin=allow_origin, + allow_methods=allow_methods, + allow_headers=allow_headers, + allow_credentials=allow_credentials, + max_age=max_age, ) elif cors_prop and isinstance(cors_prop, string_types): allow_origin = cors_prop @@ -128,12 +133,13 @@ def extract_cors(self, cors_prop): allow_origin=allow_origin, allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)), allow_headers=None, + allow_credentials=None, max_age=None, ) return cors @staticmethod - def _get_cors_prop(cors_dict, prop_name): + def _get_cors_prop(cors_dict, prop_name, is_string=True): """ Extract cors properties from dictionary and remove extra quotes. @@ -147,7 +153,7 @@ def _get_cors_prop(cors_dict, prop_name): A string with the extra quotes removed """ prop = cors_dict.get(prop_name) - if prop: + if prop and is_string: if (not isinstance(prop, string_types)) or (not (prop.startswith("'") and prop.endswith("'"))): raise InvalidSamDocumentException( "{} must be a quoted string " '(i.e. "\'value\'" is correct, but "value" is not).'.format(prop_name) diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 612724d637..30efb40b28 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -875,6 +875,7 @@ def test_provider_parse_cors_dict(self): "AllowMethods": "'POST, GET'", "AllowOrigin": "'*'", "AllowHeaders": "'Upgrade-Insecure-Requests'", + "AllowCredentials": True, "MaxAge": "'600'", }, "DefinitionBody": { @@ -917,6 +918,7 @@ def test_provider_parse_cors_dict(self): allow_origin="*", allow_methods=",".join(sorted(["POST", "GET", "OPTIONS"])), allow_headers="Upgrade-Insecure-Requests", + allow_credentials=True, max_age="600", ) route1 = Route(path="/path2", methods=["POST", "OPTIONS"], function_name="NoApiEventFunction")