Skip to content

Commit

Permalink
allow multiple cors origins
Browse files Browse the repository at this point in the history
  • Loading branch information
MarwanDebbiche committed Jun 13, 2023
1 parent 6db1084 commit 0ca4561
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
27 changes: 21 additions & 6 deletions pylambdarest/applications.py
Expand Up @@ -55,13 +55,16 @@ def _validate_request(
return None

def _format_response(
self, code: int, body: Any = None, headers: Optional[dict] = None
self,
request: Request,
code: int,
body: Any = None,
headers: Optional[dict] = None,
) -> Dict[str, Any]:
"""
Format the handler's response to the expected Lambda response format.
"""
if not isinstance(code, int):
print(code, body, headers)
raise TypeError(f"Invalid status code. {type(code)} is not int.")
if type(headers) not in [type(None), dict]:
raise TypeError(
Expand All @@ -73,8 +76,16 @@ def _format_response(
response["body"] = json.dumps(body)

if self.config.ALLOW_CORS:
if isinstance(self.config.CORS_ORIGIN, list):
origin = request.headers.get("Origin") or request.headers.get(
"origin"
)
if origin not in self.config.CORS_ORIGIN:
origin = self.config.CORS_ORIGIN[0]
else:
origin = self.config.CORS_ORIGIN
response["headers"] = {
"Access-Control-Allow-Origin": self.config.CORS_ORIGIN,
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Credentials": self.config.CORS_ALLOW_CREDENTIALS,
}

Expand Down Expand Up @@ -210,7 +221,9 @@ def wrapper(event, context):
try:
jwt_payload = self._check_jwt_bearer(request)
except AuthError:
return self._format_response(401, "Unauthorized")
return self._format_response(
request, 401, "Unauthorized"
)

validation_error = App._validate_request(
request,
Expand All @@ -219,7 +232,9 @@ def wrapper(event, context):
)
if validation_error is not None:
return self._format_response(
400, str(validation_error).split("\n", maxsplit=1)[0]
request,
400,
str(validation_error).split("\n", maxsplit=1)[0],
)

for arg in handler_args:
Expand Down Expand Up @@ -250,7 +265,7 @@ def wrapper(event, context):
if not isinstance(res, tuple):
res = (res,)

return self._format_response(*res)
return self._format_response(request, *res)

return wrapper

Expand Down
6 changes: 3 additions & 3 deletions pylambdarest/config.py
Expand Up @@ -8,7 +8,7 @@


from enum import Enum
from typing import Optional
from typing import List, Optional, Union

from pylambdarest.exceptions import ConfigError

Expand All @@ -33,15 +33,15 @@ def __init__( # pylint: disable=R0913
JWT_SECRET: Optional[str] = None,
JWT_ALGORITHM: Optional[str] = None,
ALLOW_CORS: bool = False,
CORS_ORIGIN: Optional[str] = None,
CORS_ORIGIN: Optional[Union[str, List[str]]] = None,
CORS_ALLOW_CREDENTIALS: Optional[bool] = None,
has_jwt: Optional[bool] = None,
) -> None:
self.AUTH_SCHEME: Optional[AuthSchemeEnum] = AUTH_SCHEME
self.JWT_SECRET: Optional[str] = JWT_SECRET
self.JWT_ALGORITHM: Optional[str] = JWT_ALGORITHM
self.ALLOW_CORS: bool = ALLOW_CORS
self.CORS_ORIGIN: Optional[str] = CORS_ORIGIN
self.CORS_ORIGIN: Optional[Union[str, List[str]]] = CORS_ORIGIN
self.CORS_ALLOW_CREDENTIALS: Optional[bool] = CORS_ALLOW_CREDENTIALS

if (self.AUTH_SCHEME == "JWT_BEARER") and (self.JWT_SECRET is None):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pylambdarest"
version = "0.3.0rc3"
version = "0.3.0rc4"
license = "MIT"
description = "Lightweight framework for building REST API using AWS Lambda + API Gateway"
authors = ["Marwan Debbiche (Macbook Pro) <marwan.debbiche@gmail.com>"]
Expand Down

0 comments on commit 0ca4561

Please sign in to comment.