diff --git a/contrib/net/http/roundtripper.go b/contrib/net/http/roundtripper.go index fc9ba64330..6137c08474 100644 --- a/contrib/net/http/roundtripper.go +++ b/contrib/net/http/roundtripper.go @@ -49,13 +49,14 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (res *http.Response, err er if rt.cfg.before != nil { rt.cfg.before(req, span) } - // inject the span context into the http request - err = tracer.Inject(span.Context(), tracer.HTTPHeadersCarrier(req.Header)) + r2 := req.Clone(ctx) + // inject the span context into the http request copy + err = tracer.Inject(span.Context(), tracer.HTTPHeadersCarrier(r2.Header)) if err != nil { // this should never happen fmt.Fprintf(os.Stderr, "contrib/net/http.Roundtrip: failed to inject http headers: %v\n", err) } - res, err = rt.base.RoundTrip(req.WithContext(ctx)) + res, err = rt.base.RoundTrip(r2) if err != nil { span.SetTag("http.errors", err.Error()) span.SetTag(ext.Error, err) diff --git a/contrib/net/http/roundtripper_test.go b/contrib/net/http/roundtripper_test.go index f0d74fca3c..d71d8d30be 100644 --- a/contrib/net/http/roundtripper_test.go +++ b/contrib/net/http/roundtripper_test.go @@ -230,6 +230,31 @@ func TestRoundTripperAnalyticsSettings(t *testing.T) { }) } +// TestRoundTripperCopy is a regression test ensuring that RoundTrip +// does not modify the request per the RoundTripper contract. See: +// https://cs.opensource.google/go/go/+/refs/tags/go1.18.1:src/net/http/client.go;l=129-133 +func TestRoundTripperCopy(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)) + assert.NoError(t, err) + w.Write([]byte("Hello World")) + })) + defer s.Close() + + initialReq, err := http.NewRequest("GET", s.URL+"/hello/world", nil) + assert.NoError(t, err) + req, err := http.NewRequest("GET", s.URL+"/hello/world", nil) + assert.NoError(t, err) + rt := WrapRoundTripper(http.DefaultTransport).(*roundTripper) + _, err = rt.RoundTrip(req) + assert.NoError(t, err) + assert.Len(t, req.Header, 0) + assert.Equal(t, initialReq, req) +} + func TestServiceName(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello World"))