Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 160 additions & 72 deletions awslambda/serverless_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
"""
This module converts an AWS API Gateway proxied request to a WSGI request.

Inspired by: https://github.com/miserlou/zappa

Author: Logan Raarup <logan@logan.dk>
"""
import base64
import io
import json
import os
import sys
from io import BytesIO
from urllib.parse import urlencode, unquote, unquote_plus

from werkzeug.datastructures import Headers, MultiDict, iter_multi_items
from werkzeug.datastructures import Headers, iter_multi_items
from werkzeug.http import HTTP_STATUS_CODES
from werkzeug.urls import url_encode, url_unquote, url_unquote_plus
from werkzeug.wrappers import Response

# List of MIME types that should not be base64 encoded. MIME types within `text/*`
Expand Down Expand Up @@ -86,21 +94,32 @@ def is_alb_event(event):


def encode_query_string(event):
params = event.get(u"multiValueQueryStringParameters")
params = event.get("multiValueQueryStringParameters")
if not params:
params = event.get("queryStringParameters")
if not params:
params = event.get(u"queryStringParameters")
params = event.get("query")
if not params:
params = ""
if is_alb_event(event):
params = MultiDict((url_unquote_plus(k), url_unquote_plus(v)) for k, v in iter_multi_items(params))
return url_encode(params)
params = [
(unquote_plus(k), unquote_plus(v))
for k, v in iter_multi_items(params)
]
return urlencode(params, doseq=True)


def get_script_name(headers, request_context):
strip_stage_path = os.environ.get("STRIP_STAGE_PATH", "").lower().strip() in ["yes", "y", "true", "t", "1"]

if u"amazonaws.com" in headers.get(u"Host", u"") and not strip_stage_path:
script_name = "/{}".format(request_context.get(u"stage", ""))
strip_stage_path = os.environ.get("STRIP_STAGE_PATH", "").lower().strip() in [
"yes",
"y",
"true",
"t",
"1",
]

if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path:
script_name = "/{}".format(request_context.get("stage", ""))
else:
script_name = ""
return script_name
Expand All @@ -110,7 +129,7 @@ def get_body_bytes(event, body):
if event.get("isBase64Encoded", False):
body = base64.b64decode(body)
if isinstance(body, str):
body = body.encode(encoding="utf-8", errors="strict")
body = body.encode("utf-8")
return body


Expand All @@ -127,96 +146,112 @@ def setup_environ_items(environ, headers):


def generate_response(response, event):
returndict = {u"statusCode": response.status_code}
returndict = {"statusCode": response.status_code}

if u"multiValueHeaders" in event:
returndict[u"multiValueHeaders"] = group_headers(response.headers)
if "multiValueHeaders" in event and event["multiValueHeaders"]:
returndict["multiValueHeaders"] = group_headers(response.headers)
else:
returndict[u"headers"] = split_headers(response.headers)
returndict["headers"] = split_headers(response.headers)

if is_alb_event(event):
# If the request comes from ALB we need to add a status description
returndict["statusDescription"] = u"%d %s" % (response.status_code, HTTP_STATUS_CODES[response.status_code])
returndict["statusDescription"] = "%d %s" % (
response.status_code,
HTTP_STATUS_CODES[response.status_code],
)

if response.data:
mimetype = response.mimetype or "text/plain"
if (mimetype.startswith("text/") or mimetype in TEXT_MIME_TYPES) and not response.headers.get(
"Content-Encoding", ""
):
if (
mimetype.startswith("text/") or mimetype in TEXT_MIME_TYPES
) and not response.headers.get("Content-Encoding", ""):
returndict["body"] = response.get_data(as_text=True)
returndict["isBase64Encoded"] = False
else:
returndict["body"] = base64.b64encode(response.data).decode("utf-8")
returndict["body"] = base64.b64encode(
response.data).decode("utf-8")
returndict["isBase64Encoded"] = True

return returndict


def strip_express_gateway_query_params(path):
"""Contrary to regular AWS lambda HTTP events, Express Gateway
(https://github.com/ExpressGateway/express-gateway-plugin-lambda)
adds query parameters to the path, which we need to strip.
"""
if "?" in path:
path = path.split("?")[0]
return path


def handle_request(app, event, context):
if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
# Lambda warming event received, skipping handler
print("Lambda warming event received, skipping handler")
return {}

if (
event.get("version") is None
and event.get("isBase64Encoded") is None
and event.get("requestPath") is not None
and not is_alb_event(event)
):
return handle_lambda_integration(app, event, context)

if event.get("version") == "2.0":
return handle_payload_v2(app, event, context)

return handle_payload_v1(app, event, context)


def handle_payload_v1(app, event, context):
if u"multiValueHeaders" in event:
headers = Headers(event[u"multiValueHeaders"])
if "multiValueHeaders" in event and event["multiValueHeaders"]:
headers = Headers(event["multiValueHeaders"])
else:
headers = Headers(event[u"headers"])
headers = Headers(event["headers"])

script_name = get_script_name(headers, event.get("requestContext", {}))

# If a user is using a custom domain on API Gateway, they may have a base
# path in their URL. This allows us to strip it out via an optional
# environment variable.
path_info = event[u"path"]
path_info = strip_express_gateway_query_params(event["path"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name) :] # noqa: E203
path_info = path_info[len(script_name):]

body = event[u"body"] or ""
body = event.get("body") or ""
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get(u"Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": encode_query_string(event),
"REMOTE_ADDR": event.get(u"requestContext", {}).get(u"identity", {}).get(u"sourceIp", ""),
"REMOTE_USER": event.get(u"requestContext", {}).get(u"authorizer", {}).get(u"principalId", ""),
"REQUEST_METHOD": event.get(u"httpMethod", {}),
"REMOTE_ADDR": event.get("requestContext", {})
.get("identity", {})
.get("sourceIp", ""),
"REMOTE_USER": (event.get("requestContext", {})
.get("authorizer") or {})
.get("principalId", ""),
"REQUEST_METHOD": event.get("httpMethod", {}),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get(u"Host", "lambda"),
"SERVER_PORT": headers.get(u"X-Forwarded-Port", "80"),
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": BytesIO(body),
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get(u"X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get(u"requestContext", {}).get(u"authorizer"),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
"serverless.context": context,
# TODO: Deprecate the following entries, as they do not comply with the WSGI
# spec. For custom variables, the spec says:
#
# Finally, the environ dictionary may also contain server-defined variables.
# These variables should be named using only lower-case letters, numbers, dots,
# and underscores, and should be prefixed with a name that is unique to the
# defining server or gateway.
"API_GATEWAY_AUTHORIZER": event.get(u"requestContext", {}).get(u"authorizer"),
"event": event,
"context": context,
}

environ = setup_environ_items(environ, headers)
Expand All @@ -228,49 +263,51 @@ def handle_payload_v1(app, event, context):


def handle_payload_v2(app, event, context):
headers = Headers(event[u"headers"])
headers = Headers(event["headers"])

script_name = get_script_name(headers, event.get("requestContext", {}))

path_info = event[u"rawPath"]
path_info = strip_express_gateway_query_params(event["rawPath"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name):]

body = event.get("body", "")
body = get_body_bytes(event, body)

headers["Cookie"] = "; ".join(event.get("cookies", []))

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get(u"Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": url_encode(event.get(u"queryStringParameters", {})),
"REMOTE_ADDR": event.get("requestContext", {}).get(u"http", {}).get(u"sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get(u"authorizer", {}).get(u"principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {}).get("http", {}).get("method", ""),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": event.get("rawQueryString", ""),
"REMOTE_ADDR": event.get("requestContext", {})
.get("http", {})
.get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {})
.get("authorizer", {})
.get("principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {})
.get("http", {})
.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get(u"Host", "lambda"),
"SERVER_PORT": headers.get(u"X-Forwarded-Port", "80"),
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": BytesIO(body),
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get(u"X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get(u"authorizer"),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
"serverless.context": context,
# TODO: Deprecate the following entries, as they do not comply with the WSGI
# spec. For custom variables, the spec says:
#
# Finally, the environ dictionary may also contain server-defined variables.
# These variables should be named using only lower-case letters, numbers, dots,
# and underscores, and should be prefixed with a name that is unique to the
# defining server or gateway.
"API_GATEWAY_AUTHORIZER": event.get("requestContext", {}).get(u"authorizer"),
"event": event,
"context": context,
}

environ = setup_environ_items(environ, headers)
Expand All @@ -280,3 +317,54 @@ def handle_payload_v2(app, event, context):
returndict = generate_response(response, event)

return returndict


def handle_lambda_integration(app, event, context):
headers = Headers(event["headers"])

script_name = get_script_name(headers, event)

path_info = strip_express_gateway_query_params(event["requestPath"])

for key, value in event.get("path", {}).items():
path_info = path_info.replace("{%s}" % key, value)
path_info = path_info.replace("{%s+}" % key, value)

body = event.get("body", {})
body = json.dumps(body) if body else ""
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": urlencode(event.get("query", {}), doseq=True),
"REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("principalId", ""),
"REQUEST_METHOD": event.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("enhancedAuthContext"),
"serverless.event": event,
"serverless.context": context,
}

environ = setup_environ_items(environ, headers)

response = Response.from_app(app, environ)

returndict = generate_response(response, event)

if response.status_code >= 300:
raise RuntimeError(json.dumps(returndict))

return returndict