Skip to content

Commit

Permalink
Support handlers that return io.Reader (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmoffatt committed Dec 23, 2022
1 parent ad74310 commit c80f8ac
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 58 deletions.
3 changes: 3 additions & 0 deletions lambda/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import (
//
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
// See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves
//
// "TOut" may also implement the io.Reader interface.
// If "TOut" is both json serializable and implements io.Reader, then the json serialization is used.
func Start(handler interface{}) {
StartWithOptions(handler)
}
Expand Down
94 changes: 73 additions & 21 deletions lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil" // nolint:staticcheck
"reflect"
"strings"

"github.com/aws/aws-lambda-go/lambda/handlertrace"
)
Expand All @@ -18,7 +21,7 @@ type Handler interface {
}

type handlerOptions struct {
Handler
handlerFunc
baseContext context.Context
jsonResponseEscapeHTML bool
jsonResponseIndentPrefix string
Expand Down Expand Up @@ -184,32 +187,68 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
h.Handler = reflectHandler(handlerFunc, h)
h.handlerFunc = reflectHandler(handlerFunc, h)
return h
}

type bytesHandlerFunc func(context.Context, []byte) ([]byte, error)
type handlerFunc func(context.Context, []byte) (io.Reader, error)

func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
return h(ctx, payload)
// back-compat for the rpc mode
func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
response, err := h(ctx, payload)
if err != nil {
return nil, err
}
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
if response, ok := response.(io.Closer); ok {
defer response.Close()
}
// optimization: if the response is a *bytes.Buffer, a copy can be eliminated
switch response := response.(type) {
case *jsonOutBuffer:
return response.Bytes(), nil
case *bytes.Buffer:
return response.Bytes(), nil
}
b, err := ioutil.ReadAll(response)
if err != nil {
return nil, err
}
return b, nil
}
func errorHandler(err error) Handler {
return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {

func errorHandler(err error) handlerFunc {
return func(_ context.Context, _ []byte) (io.Reader, error) {
return nil, err
})
}
}

type jsonOutBuffer struct {
*bytes.Buffer
}

func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
if handlerFunc == nil {
func (j *jsonOutBuffer) ContentType() string {
return contentTypeJSON
}

func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
if f == nil {
return errorHandler(errors.New("handler is nil"))
}

if handler, ok := handlerFunc.(Handler); ok {
return handler
// back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped
if handler, ok := f.(Handler); ok {
return func(ctx context.Context, payload []byte) (io.Reader, error) {
b, err := handler.Invoke(ctx, payload)
if err != nil {
return nil, err
}
return bytes.NewBuffer(b), nil
}
}

handler := reflect.ValueOf(handlerFunc)
handlerType := reflect.TypeOf(handlerFunc)
handler := reflect.ValueOf(f)
handlerType := reflect.TypeOf(f)
if handlerType.Kind() != reflect.Func {
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
}
Expand All @@ -223,9 +262,10 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
return errorHandler(err)
}

return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) {
out := &jsonOutBuffer{bytes.NewBuffer(nil)}
return func(ctx context.Context, payload []byte) (io.Reader, error) {
out.Reset()
in := bytes.NewBuffer(payload)
out := bytes.NewBuffer(nil)
decoder := json.NewDecoder(in)
encoder := json.NewEncoder(out)
encoder.SetEscapeHTML(h.jsonResponseEscapeHTML)
Expand Down Expand Up @@ -266,16 +306,28 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
trace.ResponseEvent(ctx, val)
}
}

// encode to JSON
if err := encoder.Encode(val); err != nil {
// if response is not JSON serializable, but the response type is a reader, return it as-is
if reader, ok := val.(io.Reader); ok {
return reader, nil
}
return nil, err
}

responseBytes := out.Bytes()
// if response value is an io.Reader, return it as-is
if reader, ok := val.(io.Reader); ok {
// back-compat, don't return the reader if the value serialized to a non-empty json
if strings.HasPrefix(out.String(), "{}") {
return reader, nil
}
}

// back-compat, strip the encoder's trailing newline unless WithSetIndent was used
if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" {
return responseBytes[:len(responseBytes)-1], nil
out.Truncate(out.Len() - 1)
}

return responseBytes, nil
})
return out, nil
}
}
118 changes: 104 additions & 14 deletions lambda/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
package lambda

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil" //nolint: staticcheck
"strings"
"testing"
"time"

"github.com/aws/aws-lambda-go/lambda/handlertrace"
"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestInvalidHandlers(t *testing.T) {
Expand Down Expand Up @@ -145,6 +150,23 @@ func TestInvalidHandlers(t *testing.T) {
}
}

type arbitraryJSON struct {
json []byte
err error
}

func (a arbitraryJSON) MarshalJSON() ([]byte, error) {
return a.json, a.err
}

type staticHandler struct {
body []byte
}

func (h *staticHandler) Invoke(_ context.Context, _ []byte) ([]byte, error) {
return h.body, nil
}

type expected struct {
val string
err error
Expand All @@ -168,10 +190,8 @@ func TestInvokes(t *testing.T) {
}{
{
input: `"Lambda"`,
expected: expected{`"Hello Lambda!"`, nil},
handler: func(name string) (string, error) {
return hello(name), nil
},
expected: expected{`null`, nil},
handler: func(_ string) {},
},
{
input: `"Lambda"`,
Expand All @@ -180,6 +200,12 @@ func TestInvokes(t *testing.T) {
return hello(name), nil
},
},
{
expected: expected{`"Hello No Value!"`, nil},
handler: func(ctx context.Context) (string, error) {
return hello("No Value"), nil
},
},
{
input: `"Lambda"`,
expected: expected{`"Hello Lambda!"`, nil},
Expand Down Expand Up @@ -294,22 +320,86 @@ func TestInvokes(t *testing.T) {
{
name: "Handler interface implementations are passthrough",
expected: expected{`<xml>hello</xml>`, nil},
handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {
return []byte(`<xml>hello</xml>`), nil
}),
handler: &staticHandler{body: []byte(`<xml>hello</xml>`)},
},
{
name: "io.Reader responses are passthrough",
expected: expected{`<yolo>yolo</yolo>`, nil},
handler: func() (io.Reader, error) {
return strings.NewReader(`<yolo>yolo</yolo>`), nil
},
},
{
name: "io.Reader responses that are byte buffers are passthrough",
expected: expected{`<yolo>yolo</yolo>`, nil},
handler: func() (*bytes.Buffer, error) {
return bytes.NewBuffer([]byte(`<yolo>yolo</yolo>`)), nil
},
},
{
name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader",
expected: expected{`{"Yolo":"yolo"}`, nil},
handler: func() (io.Reader, error) {
return struct {
io.Reader `json:"-"`
Yolo string
}{
Reader: strings.NewReader(`<yolo>yolo</yolo>`),
Yolo: "yolo",
}, nil
},
},
{
name: "types that are not json serializable result in an error",
expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")},
handler: func() (interface{}, error) {
return struct {
arbitraryJSON
}{
arbitraryJSON{nil, errors.New("barf")},
}, nil
},
},
{
name: "io.Reader responses that not json serializable remain passthrough",
expected: expected{`wat`, nil},
handler: func() (io.Reader, error) {
return struct {
arbitraryJSON
io.Reader
}{
arbitraryJSON{nil, errors.New("barf")},
strings.NewReader("wat"),
}, nil
},
},
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
lambdaHandler := newHandler(testCase.handler, testCase.options...)
response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
if testCase.expected.err != nil {
assert.Equal(t, testCase.expected.err, err)
} else {
assert.NoError(t, err)
assert.Equal(t, testCase.expected.val, string(response))
}
t.Run("via Handler.Invoke", func(t *testing.T) {
response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
if testCase.expected.err != nil {
assert.EqualError(t, err, testCase.expected.err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, testCase.expected.val, string(response))
}
})
t.Run("via handlerOptions.handlerFunc", func(t *testing.T) {
response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input))
if testCase.expected.err != nil {
assert.EqualError(t, err, testCase.expected.err.Error())
} else {
assert.NoError(t, err)
require.NotNil(t, response)
responseBytes, err := ioutil.ReadAll(response)
assert.NoError(t, err)
assert.Equal(t, testCase.expected.val, string(responseBytes))
}
})

})
}
}
Expand Down
22 changes: 18 additions & 4 deletions lambda/invoke_loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package lambda

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"strconv"
Expand Down Expand Up @@ -70,7 +72,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)

// call the handler, marshal any returned error
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc)
if invokeErr != nil {
if err := reportFailure(invoke, invokeErr); err != nil {
return err
Expand All @@ -80,7 +82,19 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
}
return nil
}
if err := invoke.success(response, contentTypeJSON); err != nil {
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
if response, ok := response.(io.Closer); ok {
defer response.Close()
}

// if the response defines a content-type, plumb it through
contentType := contentTypeBytes
type ContentType interface{ ContentType() string }
if response, ok := response.(ContentType); ok {
contentType = response.ContentType()
}

if err := invoke.success(response, contentType); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
}

Expand All @@ -90,13 +104,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
errorPayload := safeMarshal(invokeErr)
log.Printf("%s", errorPayload)
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
if err := invoke.failure(bytes.NewReader(errorPayload), contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
}
return nil
}

func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) {
defer func() {
if err := recover(); err != nil {
invokeErr = lambdaPanicResponse(err)
Expand Down
Loading

0 comments on commit c80f8ac

Please sign in to comment.