Skip to content

Commit

Permalink
Add xray.AWSSession to install handlers on session
Browse files Browse the repository at this point in the history
An application has to call xray.AWS for each AWS client it constructs.
This creates opportunities for blind spots if someone forgets to
configure a new client.

The xray.AWSSession installs the same handlers at the Session level.
Clients inherit handlers from the session they're created with. As long
as the application systematically reuses the same session to create
clients, it only needs to install X-Ray handlers once.
  • Loading branch information
logan committed Jul 2, 2019
1 parent bf07cf7 commit ebf643a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 48 deletions.
40 changes: 27 additions & 13 deletions xray/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-xray-sdk-go/internal/logger"
"github.com/aws/aws-xray-sdk-go/resources"
)
Expand Down Expand Up @@ -137,33 +138,46 @@ var xRayAfterRetryHandler = request.NamedHandler{
},
}

func pushHandlers(c *client.Client) {
c.Handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler)
c.Handlers.Build.PushBackNamed(xRayAfterBuildHandler)
c.Handlers.Sign.PushFrontNamed(xRayBeforeSignHandler)
c.Handlers.Send.PushBackNamed(xRayAfterSendHandler)
c.Handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler)
c.Handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler)
c.Handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler)
c.Handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler)
func pushHandlers(handlers *request.Handlers, completionWhitelistFilename string) {
handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler)
handlers.Build.PushBackNamed(xRayAfterBuildHandler)
handlers.Sign.PushFrontNamed(xRayBeforeSignHandler)
handlers.Send.PushBackNamed(xRayAfterSendHandler)
handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler)
handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler)
handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler)
handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler)
handlers.Complete.PushFrontNamed(xrayCompleteHandler(completionWhitelistFilename))
}

// AWS adds X-Ray tracing to an AWS client.
func AWS(c *client.Client) {
if c == nil {
panic("Please initialize the provided AWS client before passing to the AWS() method.")
}
pushHandlers(c)
c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler(""))
pushHandlers(&c.Handlers, "")
}

// AWSWithWhitelist allows a custom parameter whitelist JSON file to be defined.
func AWSWithWhitelist(c *client.Client, filename string) {
if c == nil {
panic("Please initialize the provided AWS client before passing to the AWSWithWhitelist() method.")
}
pushHandlers(c)
c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler(filename))
pushHandlers(&c.Handlers, filename)
}

// AWSSession adds X-Ray tracing to an AWS session. Clients created under this
// session will inherit X-Ray tracing.
func AWSSession(s *session.Session) *session.Session {
pushHandlers(&s.Handlers, "")
return s
}

// AWSSessionWithWhitelist allows a custom parameter whitelist JSON file to be
// defined.
func AWSSessionWithWhitelist(s *session.Session, filename string) *session.Session {
pushHandlers(&s.Handlers, filename)
return s
}

func xrayCompleteHandler(filename string) request.NamedHandler {
Expand Down
114 changes: 79 additions & 35 deletions xray/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,86 @@ import (
"github.com/stretchr/testify/assert"
)

func TestClientSuccessfulConnection(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := []byte(`{}`)
w.WriteHeader(http.StatusOK)
w.Write(b)
}))

svc := lambda.New(session.Must(session.NewSession(&aws.Config{
Endpoint: aws.String(ts.URL),
Region: aws.String("fake-moon-1"),
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")})))
func TestAWS(t *testing.T) {
// Runs a suite of tests against two different methods of registering
// handlers on an AWS client.

type test func(*testing.T, *lambda.Lambda)
tests := []struct {
name string
test test
failConn bool
}{
{"failed connection", testClientFailedConnection, true},
{"successful connection", testClientSuccessfulConnection, false},
{"without segment", testClientWithoutSegment, false},
}

ctx, root := BeginSegment(context.Background(), "Test")
onClient := func(s *session.Session) *lambda.Lambda {
svc := lambda.New(s)
AWS(svc.Client)
return svc
}

AWS(svc.Client)
onSession := func(s *session.Session) *lambda.Lambda {
return lambda.New(AWSSession(s))
}

const whitelist = "../resources/AWSWhitelist.json"

onClientWithWhitelist := func(s *session.Session) *lambda.Lambda {
svc := lambda.New(s)
AWSWithWhitelist(svc.Client, whitelist)
return svc
}

onSessionWithWhitelist := func(s *session.Session) *lambda.Lambda {
return lambda.New(AWSSessionWithWhitelist(s, whitelist))
}

type constructor func(*session.Session) *lambda.Lambda
constructors := []struct {
name string
constructor constructor
}{
{"AWS()", onClient},
{"AWSSession()", onSession},
{"AWSWithWhitelist()", onClientWithWhitelist},
{"AWSSessionWithWhitelist()", onSessionWithWhitelist},
}

// Run all combinations of constructors + tests.
for _, cons := range constructors {
t.Run(cons.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.test(t, cons.constructor(fakeSession(t, test.failConn)))
})
}
})
}
}

func fakeSession(t *testing.T, failConn bool) *session.Session {
cfg := &aws.Config{
Region: aws.String("fake-moon-1"),
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop"),
}
if !failConn {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := []byte(`{}`)
w.WriteHeader(http.StatusOK)
w.Write(b)
}))
cfg.Endpoint = aws.String(ts.URL)
}
s, err := session.NewSession(cfg)
assert.NoError(t, err)
return s
}

func testClientSuccessfulConnection(t *testing.T, svc *lambda.Lambda) {
ctx, root := BeginSegment(context.Background(), "Test")
_, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{})
root.Close(nil)
assert.NoError(t, err)
Expand Down Expand Up @@ -76,15 +140,8 @@ func TestClientSuccessfulConnection(t *testing.T) {
}
}

func TestClientFailedConnection(t *testing.T) {
svc := lambda.New(session.Must(session.NewSession(&aws.Config{
Region: aws.String("fake-moon-1"),
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")})))

func testClientFailedConnection(t *testing.T, svc *lambda.Lambda) {
ctx, root := BeginSegment(context.Background(), "Test")

AWS(svc.Client)

_, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{})
root.Close(nil)
assert.Error(t, err)
Expand Down Expand Up @@ -116,24 +173,11 @@ func TestClientFailedConnection(t *testing.T) {
assert.NotEmpty(t, connectSubseg.Subsegments)
}

func TestClientWithoutSegment(t *testing.T) {
func testClientWithoutSegment(t *testing.T, svc *lambda.Lambda) {
Configure(Config{ContextMissingStrategy: &TestContextMissingStrategy{}})
defer ResetConfig()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := []byte(`{}`)
w.WriteHeader(http.StatusOK)
w.Write(b)
}))

svc := lambda.New(session.Must(session.NewSession(&aws.Config{
Endpoint: aws.String(ts.URL),
Region: aws.String("fake-moon-1"),
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")})))

ctx := context.Background()

AWS(svc.Client)

_, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{})
assert.NoError(t, err)
}

0 comments on commit ebf643a

Please sign in to comment.