From 17591b3baef5093924edfddb164665a83c62d5e8 Mon Sep 17 00:00:00 2001 From: Christopher Radek Date: Wed, 3 Apr 2019 10:43:11 -0700 Subject: [PATCH] bugfix: close inprogress connect subsegments (#102) * fixes race condition in centralized sampling test * connect subsegment closing more robust * add non-empty subsegment assertion to failing test * update aws_test for travis * adds additional tests --- strategy/sampling/centralized_test.go | 4 +- xray/aws.go | 14 ++- xray/aws_test.go | 117 ++++++++++++++++++++++++++ xray/client_test.go | 50 +++++++++++ xray/httptrace.go | 22 ++++- xray/segment.go | 2 +- 6 files changed, 202 insertions(+), 7 deletions(-) create mode 100644 xray/aws_test.go diff --git a/strategy/sampling/centralized_test.go b/strategy/sampling/centralized_test.go index f295d678..6639b534 100644 --- a/strategy/sampling/centralized_test.go +++ b/strategy/sampling/centralized_test.go @@ -2513,12 +2513,14 @@ A: break A default: // Assert that rule was added to manifest and the timestamp refreshed + ss.manifest.Lock() if len(ss.manifest.Rules) == 1 && len(ss.manifest.Index) == 1 && ss.manifest.refreshedAt == 1500000000 { - + ss.manifest.Unlock() break A } + ss.manifest.Unlock() } } } diff --git a/xray/aws.go b/xray/aws.go index 948923dd..939cc958 100644 --- a/xray/aws.go +++ b/xray/aws.go @@ -103,7 +103,18 @@ var xRayBeforeSendHandler = request.NamedHandler{ var xRayAfterSendHandler = request.NamedHandler{ Name: "XRayAfterSendHandler", Fn: func(r *request.Request) { - endSubsegment(r) + curseg := GetSegment(r.HTTPRequest.Context()) + + if curseg.Name == "attempt" { + // An error could have prevented the connect subsegment from closing, + // so clean it up here. + for _, subsegment := range curseg.rawSubsegments { + if subsegment.Name == "connect" && subsegment.safeInProgress() { + subsegment.Close(nil) + return + } + } + } }, } @@ -143,6 +154,7 @@ func pushHandlers(c *client.Client) { c.Handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler) c.Handlers.Build.PushBackNamed(xRayAfterBuildHandler) c.Handlers.Sign.PushFrontNamed(xRayBeforeSignHandler) + c.Handlers.Send.PushBackNamed(xRayAfterSendHandler) c.Handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler) c.Handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler) c.Handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler) diff --git a/xray/aws_test.go b/xray/aws_test.go new file mode 100644 index 00000000..fec19ee9 --- /dev/null +++ b/xray/aws_test.go @@ -0,0 +1,117 @@ +package xray + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/lambda" + "github.com/stretchr/testify/assert" +) + +func TestClientSuccessfulConnection(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := []byte(`{}`) + w.WriteHeader(http.StatusOK) + w.Write(b) + })) + + svc := lambda.New(session.Must(session.NewSession(&aws.Config{ + Endpoint: aws.String(ts.URL), + Region: aws.String("fake-moon-1"), + Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) + + ctx, root := BeginSegment(context.Background(), "Test") + + AWS(svc.Client) + + _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) + root.Close(nil) + assert.NoError(t, err) + + s, e := TestDaemon.Recv() + assert.NoError(t, e) + + subseg := &Segment{} + assert.NotEmpty(t, s.Subsegments) + assert.NoError(t, json.Unmarshal(s.Subsegments[0], &subseg)) + assert.False(t, subseg.Fault) + assert.NotEmpty(t, subseg.Subsegments) + + attemptSubseg := &Segment{} + for _, sub := range subseg.Subsegments { + tempSeg := &Segment{} + assert.NoError(t, json.Unmarshal(sub, &tempSeg)) + if tempSeg.Name == "attempt" { + attemptSubseg = tempSeg + break + } + } + + assert.Equal(t, "attempt", attemptSubseg.Name) + assert.Zero(t, attemptSubseg.openSegments) + + // Connect subsegment will contain multiple child subsegments. + // The subsegment should fail since the endpoint is not valid, + // and should not be InProgress. + connectSubseg := &Segment{} + assert.NotEmpty(t, attemptSubseg.Subsegments) + assert.NoError(t, json.Unmarshal(attemptSubseg.Subsegments[0], &connectSubseg)) + assert.Equal(t, "connect", connectSubseg.Name) + assert.False(t, connectSubseg.InProgress) + assert.NotZero(t, connectSubseg.EndTime) + assert.NotEmpty(t, connectSubseg.Subsegments) + + // Ensure that the 'connect' subsegments are completed. + for _, sub := range connectSubseg.Subsegments { + tempSeg := &Segment{} + assert.NoError(t, json.Unmarshal(sub, &tempSeg)) + assert.False(t, tempSeg.InProgress) + assert.NotZero(t, tempSeg.EndTime) + } +} + +func TestClientFailedConnection(t *testing.T) { + svc := lambda.New(session.Must(session.NewSession(&aws.Config{ + Region: aws.String("fake-moon-1"), + Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) + + ctx, root := BeginSegment(context.Background(), "Test") + + AWS(svc.Client) + + _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) + root.Close(nil) + assert.Error(t, err) + + s, e := TestDaemon.Recv() + assert.NoError(t, e) + + subseg := &Segment{} + assert.NotEmpty(t, s.Subsegments) + assert.NoError(t, json.Unmarshal(s.Subsegments[0], &subseg)) + assert.True(t, subseg.Fault) + // Should contain 'marshal' and 'attempt' subsegments only. + assert.Len(t, subseg.Subsegments, 2) + + attemptSubseg := &Segment{} + assert.NoError(t, json.Unmarshal(subseg.Subsegments[1], &attemptSubseg)) + assert.Equal(t, "attempt", attemptSubseg.Name) + assert.Zero(t, attemptSubseg.openSegments) + + // Connect subsegment will contain multiple child subsegments. + // The subsegment should fail since the endpoint is not valid, + // and should not be InProgress. + connectSubseg := &Segment{} + assert.NotEmpty(t, attemptSubseg.Subsegments) + assert.NoError(t, json.Unmarshal(attemptSubseg.Subsegments[0], &connectSubseg)) + assert.Equal(t, "connect", connectSubseg.Name) + assert.False(t, connectSubseg.InProgress) + assert.NotZero(t, connectSubseg.EndTime) + assert.NotEmpty(t, connectSubseg.Subsegments) +} diff --git a/xray/client_test.go b/xray/client_test.go index 1c2878fd..58be9d28 100644 --- a/xray/client_test.go +++ b/xray/client_test.go @@ -77,6 +77,30 @@ func TestRoundTrip(t *testing.T) { assert.Equal(t, 200, subseg.HTTP.Response.Status) assert.Equal(t, responseContentLength, subseg.HTTP.Response.ContentLength) assert.Equal(t, headers.RootTraceID, s.TraceID) + + connectSeg := &Segment{} + for _, sub := range subseg.Subsegments { + tempSeg := &Segment{} + assert.NoError(t, json.Unmarshal(sub, &tempSeg)) + if tempSeg.Name == "connect" { + connectSeg = tempSeg + break + } + } + + // Ensure that a 'connect' subsegment was created and closed + assert.Equal(t, "connect", connectSeg.Name) + assert.False(t, connectSeg.InProgress) + assert.NotZero(t, connectSeg.EndTime) + assert.NotEmpty(t, connectSeg.Subsegments) + + // Ensure that the 'connect' subsegments are completed. + for _, sub := range connectSeg.Subsegments { + tempSeg := &Segment{} + assert.NoError(t, json.Unmarshal(sub, &tempSeg)) + assert.False(t, tempSeg.InProgress) + assert.NotZero(t, tempSeg.EndTime) + } } func TestRoundTripWithError(t *testing.T) { @@ -196,6 +220,32 @@ func TestBadRoundTrip(t *testing.T) { assert.Equal(t, fmt.Sprintf("%v", err), subseg.Cause.Exceptions[0].Message) } +func TestBadRoundTripDial(t *testing.T) { + ctx, root := BeginSegment(context.Background(), "Test") + reader := strings.NewReader("") + // Make a request against an unreachable endpoint. + req := httptest.NewRequest("GET", "https://0.0.0.0:0", reader) + req = req.WithContext(ctx) + _, err := rt.RoundTrip(req) + root.Close(nil) + assert.Error(t, err) + + s, e := TestDaemon.Recv() + assert.NoError(t, e) + subseg := &Segment{} + assert.NoError(t, json.Unmarshal(s.Subsegments[0], &subseg)) + assert.Equal(t, fmt.Sprintf("%v", err), subseg.Cause.Exceptions[0].Message) + + // Also ensure that the 'connect' subsegment is closed and showing fault + connectSeg := &Segment{} + assert.NoError(t, json.Unmarshal(subseg.Subsegments[0], &connectSeg)) + assert.Equal(t, "connect", connectSeg.Name) + assert.NotZero(t, connectSeg.EndTime) + assert.False(t, connectSeg.InProgress) + assert.True(t, connectSeg.Fault) + assert.NotEmpty(t, connectSeg.Subsegments) +} + func TestRoundTripReuseDatarace(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b := []byte(`200 - Nothing to see`) diff --git a/xray/httptrace.go b/xray/httptrace.go index 7609afba..6126f093 100644 --- a/xray/httptrace.go +++ b/xray/httptrace.go @@ -39,7 +39,7 @@ func NewHTTPSubsegments(opCtx context.Context) *HTTPSubsegments { // GetConn begins a connect subsegment if the HTTP operation // subsegment is still in progress. func (xt *HTTPSubsegments) GetConn(hostPort string) { - if GetSegment(xt.opCtx).InProgress { + if GetSegment(xt.opCtx).safeInProgress() { xt.connCtx, _ = BeginSubsegment(xt.opCtx, "connect") } } @@ -47,7 +47,7 @@ func (xt *HTTPSubsegments) GetConn(hostPort string) { // DNSStart begins a dns subsegment if the HTTP operation // subsegment is still in progress. func (xt *HTTPSubsegments) DNSStart(info httptrace.DNSStartInfo) { - if GetSegment(xt.opCtx).safeInProgress() { + if GetSegment(xt.opCtx).safeInProgress() && xt.connCtx != nil { xt.dnsCtx, _ = BeginSubsegment(xt.connCtx, "dns") } } @@ -71,7 +71,7 @@ func (xt *HTTPSubsegments) DNSDone(info httptrace.DNSDoneInfo) { // ConnectStart begins a dial subsegment if the HTTP operation // subsegment is still in progress. func (xt *HTTPSubsegments) ConnectStart(network, addr string) { - if GetSegment(xt.opCtx).safeInProgress() { + if GetSegment(xt.opCtx).safeInProgress() && xt.connCtx != nil { xt.connectCtx, _ = BeginSubsegment(xt.connCtx, "dial") } } @@ -121,10 +121,14 @@ func (xt *HTTPSubsegments) TLSHandshakeDone(connState tls.ConnectionState, err e // metadata to the subsegment. If the connection is marked as reused, // the connect subsegment is deleted. func (xt *HTTPSubsegments) GotConn(info *httptrace.GotConnInfo, err error) { - if xt.connCtx != nil && GetSegment(xt.opCtx).InProgress { // GetConn may not have been called (client_test.TestBadRoundTrip) + if xt.connCtx != nil && GetSegment(xt.opCtx).safeInProgress() { // GetConn may not have been called (client_test.TestBadRoundTrip) if info != nil { if info.Reused { GetSegment(xt.opCtx).RemoveSubsegment(GetSegment(xt.connCtx)) + xt.mu.Lock() + // Remove the connCtx context since it is no longer needed. + xt.connCtx = nil + xt.mu.Unlock() } else { metadata := make(map[string]interface{}) metadata["reused"] = info.Reused @@ -136,6 +140,8 @@ func (xt *HTTPSubsegments) GotConn(info *httptrace.GotConnInfo, err error) { AddMetadataToNamespace(xt.connCtx, "http", "connection", metadata) GetSegment(xt.connCtx).Close(err) } + } else if xt.connCtx != nil && GetSegment(xt.connCtx).safeInProgress() { + GetSegment(xt.connCtx).Close(err) } if err == nil { @@ -156,6 +162,14 @@ func (xt *HTTPSubsegments) WroteRequest(info httptrace.WroteRequestInfo) { xt.responseCtx = resCtx xt.mu.Unlock() } + + // In case the GotConn http trace handler wasn't called, + // we close the connection subsegment since a connection + // had to have been acquired before attempting to write + // the request. + if xt.connCtx != nil && GetSegment(xt.connCtx).safeInProgress() { + GetSegment(xt.connCtx).Close(nil) + } } // GotFirstResponseByte closes the response subsegment if the HTTP diff --git a/xray/segment.go b/xray/segment.go index e2a30f6e..ac7e43af 100644 --- a/xray/segment.go +++ b/xray/segment.go @@ -285,7 +285,7 @@ func (seg *Segment) RemoveSubsegment(remove *Segment) bool { seg.rawSubsegments[len(seg.rawSubsegments)-1] = nil seg.rawSubsegments = seg.rawSubsegments[:len(seg.rawSubsegments)-1] - seg.totalSubSegments-- + seg.ParentSegment.totalSubSegments-- seg.openSegments-- return true }