Skip to content

Commit

Permalink
improvement of validating arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Mar 7, 2021
1 parent 6c2af88 commit 96bd7dc
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
31 changes: 23 additions & 8 deletions lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package lambda
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"

Expand Down Expand Up @@ -41,19 +42,33 @@ func errorHandler(e error) lambdaHandler {
}

func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
} else if handler.NumIn() > 0 {
switch handler.NumIn() {
case 0:
return false, nil
case 1:
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
handlerTakesContext = argumentType.Implements(contextType)
if handler.NumIn() > 1 && !handlerTakesContext {
if argumentType.Kind() != reflect.Interface {
return false, nil
}
if !contextType.Implements(argumentType) {
return false, errors.New("the first argument is an interface, but it is not a Context")
}
if argumentType.NumMethod() == 0 {
// the first argument might be TIn or context.Context.
// we choose TIn for backward compatibility.
return false, nil
}
return true, nil
case 2:
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
if argumentType.Kind() != reflect.Interface || !contextType.Implements(argumentType) {
return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind())
}
return true, nil
}

return handlerTakesContext, nil
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
}

func validateReturns(handler reflect.Type) error {
Expand Down
45 changes: 44 additions & 1 deletion lambda/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ import (

func TestInvalidHandlers(t *testing.T) {

type valuer interface {
Value(key interface{}) interface{}
}

type customContext interface {
context.Context
MyCustomMethod()
}

testCases := []struct {
name string
handler interface{}
Expand Down Expand Up @@ -71,12 +80,46 @@ func TestInvalidHandlers(t *testing.T) {
handler: func() {
},
},
{
name: "the handler takes the empty interface",
expected: nil,
handler: func(v interface{}) error {
if _, ok := v.(context.Context); ok {
return errors.New("v should not be a Context")
}
return nil
},
},
{
name: "the handler takes a subset of context.Context",
expected: nil,
handler: func(ctx valuer) {
},
},
{
name: "the handler takes a superset of context.Context",
expected: errors.New("the first argument is an interface, but it is not a Context"),
handler: func(ctx customContext) {
},
},
{
name: "the handler takes two arguments and first argument is a subset of context.Context",
expected: nil,
handler: func(ctx valuer, v interface{}) {
},
},
{
name: "the handler takes two arguments and first argument is a superset of context.Context",
expected: errors.New("handler takes two arguments, but the first is not Context. got interface"),
handler: func(ctx customContext, v interface{}) {
},
},
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
lambdaHandler := NewHandler(testCase.handler)
_, err := lambdaHandler.Invoke(context.TODO(), make([]byte, 0))
_, err := lambdaHandler.Invoke(context.TODO(), []byte("{}"))
assert.Equal(t, testCase.expected, err)
})
}
Expand Down

0 comments on commit 96bd7dc

Please sign in to comment.