Skip to content

Commit

Permalink
fix datarace for http2 (#72)
Browse files Browse the repository at this point in the history
* fix datarace for http2

* fix test name
  • Loading branch information
t-yuki authored and luluzhao committed Nov 19, 2018
1 parent 22b714b commit 996d3b9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
50 changes: 46 additions & 4 deletions xray/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ package xray

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io/ioutil"
Expand All @@ -20,6 +22,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/net/http2"
)

var rt *roundtripper
Expand Down Expand Up @@ -58,7 +61,7 @@ func TestRoundTrip(t *testing.T) {

reader := strings.NewReader("")
ctx, root := BeginSegment(context.Background(), "Test")
req := httptest.NewRequest("GET", ts.URL, reader)
req, _ := http.NewRequest("GET", ts.URL, reader)
req = req.WithContext(ctx)
_, err := rt.RoundTrip(req)
root.Close(nil)
Expand Down Expand Up @@ -91,7 +94,7 @@ func TestRoundTripWithError(t *testing.T) {

reader := strings.NewReader("")
ctx, root := BeginSegment(context.Background(), "Test")
req := httptest.NewRequest("GET", ts.URL, reader)
req, _ := http.NewRequest("GET", ts.URL, reader)
req = req.WithContext(ctx)
_, err := rt.RoundTrip(req)
root.Close(nil)
Expand Down Expand Up @@ -193,7 +196,7 @@ func TestBadRoundTrip(t *testing.T) {
assert.Equal(t, fmt.Sprintf("%v", err), subseg.Cause.Exceptions[0].Message)
}

func TestRoundTrip_reuse_datarace(t *testing.T) {
func TestRoundTripReuseDatarace(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := []byte(`200 - Nothing to see`)
w.WriteHeader(http.StatusOK)
Expand All @@ -210,7 +213,7 @@ func TestRoundTrip_reuse_datarace(t *testing.T) {
defer wg.Done()
reader := strings.NewReader("")
ctx, root := BeginSegment(context.Background(), "Test")
req := httptest.NewRequest("GET", strings.Replace(ts.URL, "127.0.0.1", "localhost", -1), reader)
req, _ := http.NewRequest("GET", strings.Replace(ts.URL, "127.0.0.1", "localhost", -1), reader)
req = req.WithContext(ctx)
res, err := rt.RoundTrip(req)
ioutil.ReadAll(res.Body)
Expand All @@ -225,3 +228,42 @@ func TestRoundTrip_reuse_datarace(t *testing.T) {
}
wg.Wait()
}

func TestRoundTripHttp2Datarace(t *testing.T) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b := []byte(`200 - Nothing to see`)
w.WriteHeader(http.StatusOK)
w.Write(b)
}))
err := http2.ConfigureServer(ts.Config, nil)
assert.NoError(t, err)
ts.TLS = ts.Config.TLSConfig
ts.StartTLS()

defer ts.Close()

certpool := x509.NewCertPool()
certpool.AddCert(ts.Certificate())
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
}
http2.ConfigureTransport(tr)
rt := &roundtripper{
Base: tr,
}

reader := strings.NewReader("")
ctx, root := BeginSegment(context.Background(), "Test")
req, _ := http.NewRequest("GET", ts.URL, reader)
req = req.WithContext(ctx)
res, err := rt.RoundTrip(req)
assert.NoError(t, err)
ioutil.ReadAll(res.Body)
res.Body.Close()
root.Close(nil)

_, e := TestDaemon.Recv()
assert.NoError(t, e)
}
14 changes: 11 additions & 3 deletions xray/httptrace.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"crypto/tls"
"errors"
"net/http/httptrace"
"sync"
)

// HTTPSubsegments is a set of context in different HTTP operation.
Expand All @@ -26,6 +27,7 @@ type HTTPSubsegments struct {
tlsCtx context.Context
reqCtx context.Context
responseCtx context.Context
mu sync.Mutex
}

// NewHTTPSubsegments creates a new HTTPSubsegments to use in
Expand Down Expand Up @@ -149,15 +151,21 @@ func (xt *HTTPSubsegments) GotConn(info *httptrace.GotConnInfo, err error) {
func (xt *HTTPSubsegments) WroteRequest(info httptrace.WroteRequestInfo) {
if xt.reqCtx != nil && GetSegment(xt.opCtx).InProgress {
GetSegment(xt.reqCtx).Close(info.Err)
xt.responseCtx, _ = BeginSubsegment(xt.opCtx, "response")
resCtx, _ := BeginSubsegment(xt.opCtx, "response")
xt.mu.Lock()
xt.responseCtx = resCtx
xt.mu.Unlock()
}
}

// GotFirstResponseByte closes the response subsegment if the HTTP
// operation subsegment is still in progress.
func (xt *HTTPSubsegments) GotFirstResponseByte() {
if xt.responseCtx != nil && GetSegment(xt.opCtx).InProgress {
GetSegment(xt.responseCtx).Close(nil)
xt.mu.Lock()
resCtx := xt.responseCtx
xt.mu.Unlock()
if resCtx != nil && GetSegment(xt.opCtx).InProgress {
GetSegment(resCtx).Close(nil)
}
}

Expand Down

0 comments on commit 996d3b9

Please sign in to comment.