diff --git a/xray/client_test.go b/xray/client_test.go index 74180f6b..1c2878fd 100644 --- a/xray/client_test.go +++ b/xray/client_test.go @@ -10,6 +10,8 @@ package xray import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io/ioutil" @@ -20,6 +22,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "golang.org/x/net/http2" ) var rt *roundtripper @@ -58,7 +61,7 @@ func TestRoundTrip(t *testing.T) { reader := strings.NewReader("") ctx, root := BeginSegment(context.Background(), "Test") - req := httptest.NewRequest("GET", ts.URL, reader) + req, _ := http.NewRequest("GET", ts.URL, reader) req = req.WithContext(ctx) _, err := rt.RoundTrip(req) root.Close(nil) @@ -91,7 +94,7 @@ func TestRoundTripWithError(t *testing.T) { reader := strings.NewReader("") ctx, root := BeginSegment(context.Background(), "Test") - req := httptest.NewRequest("GET", ts.URL, reader) + req, _ := http.NewRequest("GET", ts.URL, reader) req = req.WithContext(ctx) _, err := rt.RoundTrip(req) root.Close(nil) @@ -193,7 +196,7 @@ func TestBadRoundTrip(t *testing.T) { assert.Equal(t, fmt.Sprintf("%v", err), subseg.Cause.Exceptions[0].Message) } -func TestRoundTrip_reuse_datarace(t *testing.T) { +func TestRoundTripReuseDatarace(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b := []byte(`200 - Nothing to see`) w.WriteHeader(http.StatusOK) @@ -210,7 +213,7 @@ func TestRoundTrip_reuse_datarace(t *testing.T) { defer wg.Done() reader := strings.NewReader("") ctx, root := BeginSegment(context.Background(), "Test") - req := httptest.NewRequest("GET", strings.Replace(ts.URL, "127.0.0.1", "localhost", -1), reader) + req, _ := http.NewRequest("GET", strings.Replace(ts.URL, "127.0.0.1", "localhost", -1), reader) req = req.WithContext(ctx) res, err := rt.RoundTrip(req) ioutil.ReadAll(res.Body) @@ -225,3 +228,42 @@ func TestRoundTrip_reuse_datarace(t *testing.T) { } wg.Wait() } + +func TestRoundTripHttp2Datarace(t *testing.T) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := []byte(`200 - Nothing to see`) + w.WriteHeader(http.StatusOK) + w.Write(b) + })) + err := http2.ConfigureServer(ts.Config, nil) + assert.NoError(t, err) + ts.TLS = ts.Config.TLSConfig + ts.StartTLS() + + defer ts.Close() + + certpool := x509.NewCertPool() + certpool.AddCert(ts.Certificate()) + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + } + http2.ConfigureTransport(tr) + rt := &roundtripper{ + Base: tr, + } + + reader := strings.NewReader("") + ctx, root := BeginSegment(context.Background(), "Test") + req, _ := http.NewRequest("GET", ts.URL, reader) + req = req.WithContext(ctx) + res, err := rt.RoundTrip(req) + assert.NoError(t, err) + ioutil.ReadAll(res.Body) + res.Body.Close() + root.Close(nil) + + _, e := TestDaemon.Recv() + assert.NoError(t, e) +} diff --git a/xray/httptrace.go b/xray/httptrace.go index ceb3e951..7609afba 100644 --- a/xray/httptrace.go +++ b/xray/httptrace.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "errors" "net/http/httptrace" + "sync" ) // HTTPSubsegments is a set of context in different HTTP operation. @@ -26,6 +27,7 @@ type HTTPSubsegments struct { tlsCtx context.Context reqCtx context.Context responseCtx context.Context + mu sync.Mutex } // NewHTTPSubsegments creates a new HTTPSubsegments to use in @@ -149,15 +151,21 @@ func (xt *HTTPSubsegments) GotConn(info *httptrace.GotConnInfo, err error) { func (xt *HTTPSubsegments) WroteRequest(info httptrace.WroteRequestInfo) { if xt.reqCtx != nil && GetSegment(xt.opCtx).InProgress { GetSegment(xt.reqCtx).Close(info.Err) - xt.responseCtx, _ = BeginSubsegment(xt.opCtx, "response") + resCtx, _ := BeginSubsegment(xt.opCtx, "response") + xt.mu.Lock() + xt.responseCtx = resCtx + xt.mu.Unlock() } } // GotFirstResponseByte closes the response subsegment if the HTTP // operation subsegment is still in progress. func (xt *HTTPSubsegments) GotFirstResponseByte() { - if xt.responseCtx != nil && GetSegment(xt.opCtx).InProgress { - GetSegment(xt.responseCtx).Close(nil) + xt.mu.Lock() + resCtx := xt.responseCtx + xt.mu.Unlock() + if resCtx != nil && GetSegment(xt.opCtx).InProgress { + GetSegment(resCtx).Close(nil) } }