Skip to content

Commit

Permalink
Add API GW Proxy context and Lambda Context to http.Request context (#33
Browse files Browse the repository at this point in the history
)

This looks good. Thanks for the contribution. I'll update the readme to document the new `ProxyWithContext` method.

* Pass context.Context from Lambda runtime to http.Request.

Fixes #27

* Move API GW context to context.Context

* Remove api GW context header

* Pass APIGatewayContext in Request.Context

* Refactor to way better

* Cleanup

* Add comments

* Add tests

* PR comment fixes

* Inbtroduce two paths

* Refactor adapters

* Rename for consistency
  • Loading branch information
nsarychev authored and sapessi committed May 6, 2019
1 parent eb5a49d commit 4e8766f
Show file tree
Hide file tree
Showing 15 changed files with 272 additions and 42 deletions.
15 changes: 14 additions & 1 deletion chi/adapter.go
Expand Up @@ -4,6 +4,7 @@
package chiadapter

import (
"context"
"net/http"

"github.com/aws/aws-lambda-go/events"
Expand All @@ -29,9 +30,21 @@ func New(chi *chi.Mux) *ChiLambda {

// Proxy receives an API Gateway proxy event, transforms it into an http.Request
// object, and sends it to the chi.Mux for routing.
// It returns a proxy response object gneerated from the http.ResponseWriter.
// It returns a proxy response object generated from the http.ResponseWriter.
func (g *ChiLambda) Proxy(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
chiRequest, err := g.ProxyEventToHTTPRequest(req)
return g.proxyInternal(chiRequest, err)
}

// ProxyWithContext receives context and an API Gateway proxy event,
// transforms them into an http.Request object, and sends it to the chi.Mux for routing.
// It returns a proxy response object generated from the http.ResponseWriter.
func (g *ChiLambda) ProxyWithContext(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
chiRequest, err := g.EventToRequestWithContext(ctx, req)
return g.proxyInternal(chiRequest, err)
}

func (g *ChiLambda) proxyInternal(chiRequest *http.Request, err error) (events.APIGatewayProxyResponse, error) {

if err != nil {
return core.GatewayTimeout(), core.NewLoggedError("Could not convert proxy event to request: %v", err)
Expand Down
9 changes: 7 additions & 2 deletions chi/chilambda_test.go
@@ -1,11 +1,12 @@
package chiadapter_test

import (
"context"
"log"
"net/http"

"github.com/aws/aws-lambda-go/events"
"github.com/awslabs/aws-lambda-go-api-proxy/chi"
chiadapter "github.com/awslabs/aws-lambda-go-api-proxy/chi"
"github.com/go-chi/chi"

. "github.com/onsi/ginkgo"
Expand All @@ -29,10 +30,14 @@ var _ = Describe("ChiLambda tests", func() {
HTTPMethod: "GET",
}

resp, err := adapter.Proxy(req)
resp, err := adapter.ProxyWithContext(context.Background(), req)

Expect(err).To(BeNil())
Expect(resp.StatusCode).To(Equal(200))

resp, err = adapter.Proxy(req)
Expect(err).To(BeNil())
Expect(resp.StatusCode).To(Equal(200))
})
})
})
86 changes: 71 additions & 15 deletions core/request.go
Expand Up @@ -4,6 +4,7 @@ package core

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
Expand All @@ -15,6 +16,7 @@ import (
"strings"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambdacontext"
)

// CustomHostVariable is the name of the environment variable that contains
Expand Down Expand Up @@ -102,13 +104,33 @@ func (r *RequestAccessor) StripBasePath(basePath string) string {
return newBasePath
}

// ProxyEventToHTTPRequest converts an API Gateway proxy event into an
// http.Request object.
// Returns the populated request with an additional two custom headers for the
// stage variables and API Gateway context. To access these properties use
// the GetAPIGatewayStageVars and GetAPIGatewayContext method of the RequestAccessor
// object.
// ProxyEventToHTTPRequest converts an API Gateway proxy event into a http.Request object.
// Returns the populated http request with additional two custom headers for the stage variables and API Gateway context.
// To access these properties use the GetAPIGatewayStageVars and GetAPIGatewayContext method of the RequestAccessor object.
func (r *RequestAccessor) ProxyEventToHTTPRequest(req events.APIGatewayProxyRequest) (*http.Request, error) {
httpRequest, err := r.EventToRequest(req)
if err != nil {
log.Println(err)
return nil, err
}
return addToHeader(httpRequest, req)
}

// EventToRequestWithContext converts an API Gateway proxy event and context into an http.Request object.
// Returns the populated http request with lambda context, stage variables and APIGatewayProxyRequestContext as part of its context.
// Access those using GetAPIGatewayContextFromContext, GetStageVarsFromContext and GetRuntimeContextFromContext functions in this package.
func (r *RequestAccessor) EventToRequestWithContext(ctx context.Context, req events.APIGatewayProxyRequest) (*http.Request, error) {
httpRequest, err := r.EventToRequest(req)
if err != nil {
log.Println(err)
return nil, err
}
return addToContext(ctx, httpRequest, req), nil
}

// EventToRequest converts an API Gateway proxy event into an http.Request object.
// Returns the populated request maintaining headers
func (r *RequestAccessor) EventToRequest(req events.APIGatewayProxyRequest) (*http.Request, error) {
decodedBody := []byte(req.Body)
if req.IsBase64Encoded {
base64Body, err := base64.StdEncoding.DecodeString(req.Body)
Expand Down Expand Up @@ -157,23 +179,57 @@ func (r *RequestAccessor) ProxyEventToHTTPRequest(req events.APIGatewayProxyRequ
log.Println(err)
return nil, err
}

for h := range req.Headers {
httpRequest.Header.Add(h, req.Headers[h])
}
return httpRequest, nil
}

apiGwContext, err := json.Marshal(req.RequestContext)
func addToHeader(req *http.Request, apiGwRequest events.APIGatewayProxyRequest) (*http.Request, error) {
stageVars, err := json.Marshal(apiGwRequest.StageVariables)
if err != nil {
log.Println("Could not Marshal API GW context for custom header")
log.Println("Could not marshal stage variables for custom header")
return nil, err
}
stageVars, err := json.Marshal(req.StageVariables)
req.Header.Add(APIGwStageVarsHeader, string(stageVars))
apiGwContext, err := json.Marshal(apiGwRequest.RequestContext)
if err != nil {
log.Println("Could not marshal stage variables for custom header")
return nil, err
log.Println("Could not Marshal API GW context for custom header")
return req, err
}
httpRequest.Header.Add(APIGwContextHeader, string(apiGwContext))
httpRequest.Header.Add(APIGwStageVarsHeader, string(stageVars))
req.Header.Add(APIGwContextHeader, string(apiGwContext))
return req, nil
}

return httpRequest, nil
func addToContext(ctx context.Context, req *http.Request, apiGwRequest events.APIGatewayProxyRequest) *http.Request {
lc, _ := lambdacontext.FromContext(ctx)
rc := requestContext{lambdaContext: lc, gatewayProxyContext: apiGwRequest.RequestContext, stageVars: apiGwRequest.StageVariables}
ctx = context.WithValue(req.Context(), ctxKey{}, rc)
return req.WithContext(ctx)
}

// GetAPIGatewayContextFromContext retrieve APIGatewayProxyRequestContext from context.Context
func GetAPIGatewayContextFromContext(ctx context.Context) (events.APIGatewayProxyRequestContext, bool) {
v, ok := ctx.Value(ctxKey{}).(requestContext)
return v.gatewayProxyContext, ok
}

// GetRuntimeContextFromContext retrieve Lambda Runtime Context from context.Context
func GetRuntimeContextFromContext(ctx context.Context) (*lambdacontext.LambdaContext, bool) {
v, ok := ctx.Value(ctxKey{}).(requestContext)
return v.lambdaContext, ok
}

// GetStageVarsFromContext retrieve stage variables from context
func GetStageVarsFromContext(ctx context.Context) (map[string]string, bool) {
v, ok := ctx.Value(ctxKey{}).(requestContext)
return v.stageVars, ok
}

type ctxKey struct{}

type requestContext struct {
lambdaContext *lambdacontext.LambdaContext
gatewayProxyContext events.APIGatewayProxyRequestContext
stageVars map[string]string
}
82 changes: 70 additions & 12 deletions core/request_test.go
@@ -1,12 +1,14 @@
package core_test

import (
"context"
"encoding/base64"
"io/ioutil"
"math/rand"
"os"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambdacontext"
"github.com/awslabs/aws-lambda-go-api-proxy/core"

. "github.com/onsi/ginkgo"
Expand All @@ -18,14 +20,15 @@ var _ = Describe("RequestAccessor tests", func() {
accessor := core.RequestAccessor{}
basicRequest := getProxyRequest("/hello", "GET")
It("Correctly converts a basic event", func() {
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("GET").To(Equal(httpReq.Method))
})

basicRequest = getProxyRequest("/hello", "get")
It("Converts method to uppercase", func() {
// calling old method to verify reverse compatibility
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expand All @@ -45,7 +48,7 @@ var _ = Describe("RequestAccessor tests", func() {
binaryRequest.IsBase64Encoded = true

It("Decodes a base64 encoded body", func() {
httpReq, err := accessor.ProxyEventToHTTPRequest(binaryRequest)
httpReq, err := accessor.EventToRequestWithContext(context.Background(), binaryRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("POST").To(Equal(httpReq.Method))
Expand All @@ -63,7 +66,7 @@ var _ = Describe("RequestAccessor tests", func() {
"world": {"2", "3"},
}
It("Populates query string correctly", func() {
httpReq, err := accessor.ProxyEventToHTTPRequest(qsRequest)
httpReq, err := accessor.EventToRequestWithContext(context.Background(), qsRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("GET").To(Equal(httpReq.Method))
Expand All @@ -83,7 +86,8 @@ var _ = Describe("RequestAccessor tests", func() {

It("Stips the base path correct", func() {
accessor.StripBasePath("app1")
httpReq, err := accessor.ProxyEventToHTTPRequest(basePathRequest)
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basePathRequest)

Expect(err).To(BeNil())
Expect("/orders").To(Equal(httpReq.URL.Path))
})
Expand All @@ -92,6 +96,7 @@ var _ = Describe("RequestAccessor tests", func() {
contextRequest.RequestContext = getRequestContext()

It("Populates context header correctly", func() {
// calling old method to verify reverse compatibility
httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
Expect(err).To(BeNil())
Expect(2).To(Equal(len(httpReq.Header)))
Expand Down Expand Up @@ -123,16 +128,49 @@ var _ = Describe("RequestAccessor tests", func() {
contextRequest.RequestContext = getRequestContext()

accessor := core.RequestAccessor{}
// calling old method to verify reverse compatibility
httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
Expect(err).To(BeNil())

context, err := accessor.GetAPIGatewayContext(httpReq)
headerContext, err := accessor.GetAPIGatewayContext(httpReq)
Expect(err).To(BeNil())
Expect(headerContext).ToNot(BeNil())
Expect("x").To(Equal(headerContext.AccountID))
Expect("x").To(Equal(headerContext.RequestID))
Expect("x").To(Equal(headerContext.APIID))
proxyContext, ok := core.GetAPIGatewayContextFromContext(httpReq.Context())
// should fail because using header proxy method
Expect(ok).To(BeFalse())

httpReq, err = accessor.EventToRequestWithContext(context.Background(), contextRequest)
Expect(err).To(BeNil())
Expect(context).ToNot(BeNil())
Expect("x").To(Equal(context.AccountID))
Expect("x").To(Equal(context.RequestID))
Expect("x").To(Equal(context.APIID))
Expect("prod").To(Equal(context.Stage))
proxyContext, ok = core.GetAPIGatewayContextFromContext(httpReq.Context())
Expect(ok).To(BeTrue())
Expect("x").To(Equal(proxyContext.APIID))
Expect("x").To(Equal(proxyContext.RequestID))
Expect("x").To(Equal(proxyContext.APIID))
Expect("prod").To(Equal(proxyContext.Stage))
runtimeContext, ok := core.GetRuntimeContextFromContext(httpReq.Context())
Expect(ok).To(BeTrue())
Expect(runtimeContext).To(BeNil())

lambdaContext := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{AwsRequestID: "abc123"})
httpReq, err = accessor.EventToRequestWithContext(lambdaContext, contextRequest)
Expect(err).To(BeNil())

headerContext, err = accessor.GetAPIGatewayContext(httpReq)
// should fail as new context method doesn't populate headers
Expect(err).ToNot(BeNil())
proxyContext, ok = core.GetAPIGatewayContextFromContext(httpReq.Context())
Expect(ok).To(BeTrue())
Expect("x").To(Equal(proxyContext.APIID))
Expect("x").To(Equal(proxyContext.RequestID))
Expect("x").To(Equal(proxyContext.APIID))
Expect("prod").To(Equal(proxyContext.Stage))
runtimeContext, ok = core.GetRuntimeContextFromContext(httpReq.Context())
Expect(ok).To(BeTrue())
Expect(runtimeContext).ToNot(BeNil())
Expect("abc123").To(Equal(runtimeContext.AwsRequestID))
})

It("Populates stage variables correctly", func() {
Expand All @@ -150,9 +188,29 @@ var _ = Describe("RequestAccessor tests", func() {
Expect(stageVars["var2"]).ToNot(BeNil())
Expect("value1").To(Equal(stageVars["var1"]))
Expect("value2").To(Equal(stageVars["var2"]))

stageVars, ok := core.GetStageVarsFromContext(httpReq.Context())
// not present in context
Expect(ok).To(BeFalse())

httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest)
Expect(err).To(BeNil())

stageVars, err = accessor.GetAPIGatewayStageVars(httpReq)
// should not be in headers
Expect(err).ToNot(BeNil())

stageVars, ok = core.GetStageVarsFromContext(httpReq.Context())
Expect(ok).To(BeTrue())
Expect(2).To(Equal(len(stageVars)))
Expect(stageVars["var1"]).ToNot(BeNil())
Expect(stageVars["var2"]).ToNot(BeNil())
Expect("value1").To(Equal(stageVars["var1"]))
Expect("value2").To(Equal(stageVars["var2"]))
})

It("Populates the default hostname correctly", func() {

basicRequest := getProxyRequest("orders", "GET")
accessor := core.RequestAccessor{}
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
Expand All @@ -167,7 +225,7 @@ var _ = Describe("RequestAccessor tests", func() {
os.Setenv(core.CustomHostVariable, myCustomHost)
basicRequest := getProxyRequest("orders", "GET")
accessor := core.RequestAccessor{}
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
Expect(err).To(BeNil())

Expect(myCustomHost).To(Equal("http://" + httpReq.Host))
Expand All @@ -180,7 +238,7 @@ var _ = Describe("RequestAccessor tests", func() {
os.Setenv(core.CustomHostVariable, myCustomHost+"/")
basicRequest := getProxyRequest("orders", "GET")
accessor := core.RequestAccessor{}
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
Expect(err).To(BeNil())

Expect(myCustomHost).To(Equal("http://" + httpReq.Host))
Expand Down
17 changes: 15 additions & 2 deletions gin/adapter.go
Expand Up @@ -4,6 +4,7 @@
package ginadapter

import (
"context"
"net/http"

"github.com/aws/aws-lambda-go/events"
Expand All @@ -29,16 +30,28 @@ func New(gin *gin.Engine) *GinLambda {

// Proxy receives an API Gateway proxy event, transforms it into an http.Request
// object, and sends it to the gin.Engine for routing.
// It returns a proxy response object gneerated from the http.ResponseWriter.
// It returns a proxy response object generated from the http.ResponseWriter.
func (g *GinLambda) Proxy(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
ginRequest, err := g.ProxyEventToHTTPRequest(req)
return g.proxyInternal(ginRequest, err)
}

// ProxyWithContext receives context and an API Gateway proxy event,
// transforms them into an http.Request object, and sends it to the gin.Engine for routing.
// It returns a proxy response object generated from the http.ResponseWriter.
func (g *GinLambda) ProxyWithContext(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
ginRequest, err := g.EventToRequestWithContext(ctx, req)
return g.proxyInternal(ginRequest, err)
}

func (g *GinLambda) proxyInternal(req *http.Request, err error) (events.APIGatewayProxyResponse, error) {

if err != nil {
return core.GatewayTimeout(), core.NewLoggedError("Could not convert proxy event to request: %v", err)
}

respWriter := core.NewProxyResponseWriter()
g.ginEngine.ServeHTTP(http.ResponseWriter(respWriter), ginRequest)
g.ginEngine.ServeHTTP(http.ResponseWriter(respWriter), req)

proxyResponse, err := respWriter.GetProxyResponse()
if err != nil {
Expand Down

0 comments on commit 4e8766f

Please sign in to comment.