diff --git a/aws/http_client.go b/aws/http_client.go index ebb92a03dbc..a40d9f8d767 100644 --- a/aws/http_client.go +++ b/aws/http_client.go @@ -6,8 +6,6 @@ import ( "reflect" "sync" "time" - - "golang.org/x/net/http2" ) // Defaults for the HTTPTransportBuilder. @@ -43,7 +41,7 @@ type BuildableHTTPClient struct { transport *http.Transport dialer *net.Dialer - initOnce *sync.Once + initOnce sync.Once clientTimeout time.Duration client *http.Client @@ -52,9 +50,7 @@ type BuildableHTTPClient struct { // NewBuildableHTTPClient returns an initialized client for invoking HTTP // requests. func NewBuildableHTTPClient() *BuildableHTTPClient { - return &BuildableHTTPClient{ - initOnce: new(sync.Once), - } + return &BuildableHTTPClient{} } // Do implements the HTTPClient interface's Do method to invoke a HTTP request, @@ -68,40 +64,25 @@ func NewBuildableHTTPClient() *BuildableHTTPClient { // Redirect (3xx) responses will not be followed, the HTTP response received // will returned instead. func (b *BuildableHTTPClient) Do(req *http.Request) (*http.Response, error) { - b.initOnce.Do(b.initClient) + b.initOnce.Do(b.build) return b.client.Do(req) } -func (b *BuildableHTTPClient) initClient() { - b.client = b.build() -} - -// BuildHTTPClient returns an initialized HTTPClient built from the options of -// the builder. -func (b BuildableHTTPClient) build() *http.Client { - var tr *http.Transport - if b.transport != nil { - tr = shallowCopyStruct(b.transport).(*http.Transport) - } else { - tr = defaultHTTPTransport() - } - - // TODO Any way to ensure HTTP 2 is supported without depending on - // an unversioned experimental package? - // Maybe only clients that depend on HTTP/2 should call this? - http2.ConfigureTransport(tr) - - return wrapWithoutRedirect(&http.Client{ +func (b *BuildableHTTPClient) build() { + b.client = wrapWithoutRedirect(&http.Client{ Timeout: b.clientTimeout, - Transport: tr, + Transport: b.GetTransport(), }) } -func (b BuildableHTTPClient) initReset() BuildableHTTPClient { - b.initOnce = new(sync.Once) - b.client = nil - return b +func (b *BuildableHTTPClient) clone() *BuildableHTTPClient { + cpy := NewBuildableHTTPClient() + cpy.transport = b.GetTransport() + cpy.dialer = b.GetDialer() + cpy.clientTimeout = b.clientTimeout + + return cpy } // WithTransportOptions copies the BuildableHTTPClient and returns it with the @@ -110,51 +91,49 @@ func (b BuildableHTTPClient) initReset() BuildableHTTPClient { // If a non (*http.Transport) was set as the round tripper, the round tripper // will be replaced with a default Transport value before invoking the option // functions. -func (b BuildableHTTPClient) WithTransportOptions(opts ...func(*http.Transport)) HTTPClient { - b = b.initReset() +func (b *BuildableHTTPClient) WithTransportOptions(opts ...func(*http.Transport)) HTTPClient { + cpy := b.clone() - tr := b.GetTransport() + tr := cpy.GetTransport() for _, opt := range opts { opt(tr) } - b.transport = tr + cpy.transport = tr - return &b + return cpy } // WithDialerOptions copies the BuildableHTTPClient and returns it with the // net.Dialer options applied. Will set the client's http.Transport DialContext // member. -func (b BuildableHTTPClient) WithDialerOptions(opts ...func(*net.Dialer)) HTTPClient { - b = b.initReset() +func (b *BuildableHTTPClient) WithDialerOptions(opts ...func(*net.Dialer)) HTTPClient { + cpy := b.clone() - dialer := b.GetDialer() + dialer := cpy.GetDialer() for _, opt := range opts { opt(dialer) } - b.dialer = dialer + cpy.dialer = dialer - tr := b.GetTransport() - tr.DialContext = b.dialer.DialContext - b.transport = tr + tr := cpy.GetTransport() + tr.DialContext = cpy.dialer.DialContext + cpy.transport = tr - return &b + return cpy } // WithTimeout Sets the timeout used by the client for all requests. -func (b BuildableHTTPClient) WithTimeout(timeout time.Duration) HTTPClient { - b = b.initReset() - - b.clientTimeout = timeout - - return &b +func (b *BuildableHTTPClient) WithTimeout(timeout time.Duration) HTTPClient { + cpy := b.clone() + cpy.clientTimeout = timeout + return cpy } // GetTransport returns a copy of the client's HTTP Transport. -func (b BuildableHTTPClient) GetTransport() *http.Transport { +func (b *BuildableHTTPClient) GetTransport() *http.Transport { var tr *http.Transport if b.transport != nil { - tr = shallowCopyStruct(b.transport).(*http.Transport) + tr = b.transport.Clone() } else { tr = defaultHTTPTransport() } @@ -163,7 +142,7 @@ func (b BuildableHTTPClient) GetTransport() *http.Transport { } // GetDialer returns a copy of the client's network dialer. -func (b BuildableHTTPClient) GetDialer() *net.Dialer { +func (b *BuildableHTTPClient) GetDialer() *net.Dialer { var dialer *net.Dialer if b.dialer != nil { dialer = shallowCopyStruct(b.dialer).(*net.Dialer) @@ -175,7 +154,7 @@ func (b BuildableHTTPClient) GetDialer() *net.Dialer { } // GetTimeout returns a copy of the client's timeout to cancel requests with. -func (b BuildableHTTPClient) GetTimeout() time.Duration { +func (b *BuildableHTTPClient) GetTimeout() time.Duration { return b.clientTimeout } @@ -198,6 +177,7 @@ func defaultHTTPTransport() *http.Transport { MaxIdleConnsPerHost: DefaultHTTPTransportMaxIdleConnsPerHost, IdleConnTimeout: DefaultHTTPTransportIdleConnTimeout, ExpectContinueTimeout: DefaultHTTPTransportExpectContinueTimeout, + ForceAttemptHTTP2: true, } return tr diff --git a/aws/http_client_test.go b/aws/http_client_test.go index 872493315dd..5846e5df8d7 100644 --- a/aws/http_client_test.go +++ b/aws/http_client_test.go @@ -3,6 +3,7 @@ package aws_test import ( "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -43,3 +44,43 @@ func TestBuildableHTTPClient_WithTimeout(t *testing.T) { t.Errorf("expect %v timeout, got %v", e, a) } } + +func TestBuildableHTTPClient_concurrent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer server.Close() + + var client aws.HTTPClient = aws.NewBuildableHTTPClient() + + atOnce := 100 + var wg sync.WaitGroup + wg.Add(atOnce) + for i := 0; i < atOnce; i++ { + go func(i int, client aws.HTTPClient) { + defer wg.Done() + + if v, ok := client.(interface{ GetTimeout() time.Duration }); ok { + v.GetTimeout() + } + + if i%3 == 0 { + if v, ok := client.(interface { + WithTransportOptions(opts ...func(*http.Transport)) aws.HTTPClient + }); ok { + client = v.WithTransportOptions() + } + } + + req, _ := http.NewRequest("GET", server.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + resp.Body.Close() + }(i, client) + } + + wg.Wait() +} diff --git a/service/s3/s3manager/upload_internal_test.go b/service/s3/s3manager/upload_internal_test.go index 630e24aa134..fbbbc910816 100644 --- a/service/s3/s3manager/upload_internal_test.go +++ b/service/s3/s3manager/upload_internal_test.go @@ -174,10 +174,9 @@ func TestUploadByteSlicePool_Failures(t *testing.T) { } if r.Operation.Name == operation { - r.Retryable = aws.Bool(false) r.Error = fmt.Errorf("request error") r.HTTPResponse = &http.Response{ - StatusCode: 500, + StatusCode: 400, Body: ioutil.NopCloser(bytes.NewReader([]byte{})), } return