diff --git a/pylambdarest/applications.py b/pylambdarest/applications.py index 8af9146..c39403c 100644 --- a/pylambdarest/applications.py +++ b/pylambdarest/applications.py @@ -55,13 +55,16 @@ def _validate_request( return None def _format_response( - self, code: int, body: Any = None, headers: Optional[dict] = None + self, + request: Request, + code: int, + body: Any = None, + headers: Optional[dict] = None, ) -> Dict[str, Any]: """ Format the handler's response to the expected Lambda response format. """ if not isinstance(code, int): - print(code, body, headers) raise TypeError(f"Invalid status code. {type(code)} is not int.") if type(headers) not in [type(None), dict]: raise TypeError( @@ -73,8 +76,16 @@ def _format_response( response["body"] = json.dumps(body) if self.config.ALLOW_CORS: + if isinstance(self.config.CORS_ORIGIN, list): + origin = request.headers.get("Origin") or request.headers.get( + "origin" + ) + if origin not in self.config.CORS_ORIGIN: + origin = self.config.CORS_ORIGIN[0] + else: + origin = self.config.CORS_ORIGIN response["headers"] = { - "Access-Control-Allow-Origin": self.config.CORS_ORIGIN, + "Access-Control-Allow-Origin": origin, "Access-Control-Allow-Credentials": self.config.CORS_ALLOW_CREDENTIALS, } @@ -210,7 +221,9 @@ def wrapper(event, context): try: jwt_payload = self._check_jwt_bearer(request) except AuthError: - return self._format_response(401, "Unauthorized") + return self._format_response( + request, 401, "Unauthorized" + ) validation_error = App._validate_request( request, @@ -219,7 +232,9 @@ def wrapper(event, context): ) if validation_error is not None: return self._format_response( - 400, str(validation_error).split("\n", maxsplit=1)[0] + request, + 400, + str(validation_error).split("\n", maxsplit=1)[0], ) for arg in handler_args: @@ -250,7 +265,7 @@ def wrapper(event, context): if not isinstance(res, tuple): res = (res,) - return self._format_response(*res) + return self._format_response(request, *res) return wrapper diff --git a/pylambdarest/config.py b/pylambdarest/config.py index fb744ee..fe26c4f 100644 --- a/pylambdarest/config.py +++ b/pylambdarest/config.py @@ -8,7 +8,7 @@ from enum import Enum -from typing import Optional +from typing import List, Optional, Union from pylambdarest.exceptions import ConfigError @@ -33,7 +33,7 @@ def __init__( # pylint: disable=R0913 JWT_SECRET: Optional[str] = None, JWT_ALGORITHM: Optional[str] = None, ALLOW_CORS: bool = False, - CORS_ORIGIN: Optional[str] = None, + CORS_ORIGIN: Optional[Union[str, List[str]]] = None, CORS_ALLOW_CREDENTIALS: Optional[bool] = None, has_jwt: Optional[bool] = None, ) -> None: @@ -41,7 +41,7 @@ def __init__( # pylint: disable=R0913 self.JWT_SECRET: Optional[str] = JWT_SECRET self.JWT_ALGORITHM: Optional[str] = JWT_ALGORITHM self.ALLOW_CORS: bool = ALLOW_CORS - self.CORS_ORIGIN: Optional[str] = CORS_ORIGIN + self.CORS_ORIGIN: Optional[Union[str, List[str]]] = CORS_ORIGIN self.CORS_ALLOW_CREDENTIALS: Optional[bool] = CORS_ALLOW_CREDENTIALS if (self.AUTH_SCHEME == "JWT_BEARER") and (self.JWT_SECRET is None): diff --git a/pyproject.toml b/pyproject.toml index 01b93f2..2d39690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pylambdarest" -version = "0.3.0rc3" +version = "0.3.0rc4" license = "MIT" description = "Lightweight framework for building REST API using AWS Lambda + API Gateway" authors = ["Marwan Debbiche (Macbook Pro) "]