Skip to content

Commit

Permalink
aws: Fix BuildableHTTPClient datarace bug (#504)
Browse files Browse the repository at this point in the history
Fixes a broken unit test missed when implementing #487.
  • Loading branch information
jasdel committed Mar 17, 2020
1 parent 0f1fe1b commit 7e732f1
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 57 deletions.
90 changes: 35 additions & 55 deletions aws/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"reflect"
"sync"
"time"

"golang.org/x/net/http2"
)

// Defaults for the HTTPTransportBuilder.
Expand Down Expand Up @@ -43,7 +41,7 @@ type BuildableHTTPClient struct {
transport *http.Transport
dialer *net.Dialer

initOnce *sync.Once
initOnce sync.Once

clientTimeout time.Duration
client *http.Client
Expand All @@ -52,9 +50,7 @@ type BuildableHTTPClient struct {
// NewBuildableHTTPClient returns an initialized client for invoking HTTP
// requests.
func NewBuildableHTTPClient() *BuildableHTTPClient {
return &BuildableHTTPClient{
initOnce: new(sync.Once),
}
return &BuildableHTTPClient{}
}

// Do implements the HTTPClient interface's Do method to invoke a HTTP request,
Expand All @@ -68,40 +64,25 @@ func NewBuildableHTTPClient() *BuildableHTTPClient {
// Redirect (3xx) responses will not be followed, the HTTP response received
// will returned instead.
func (b *BuildableHTTPClient) Do(req *http.Request) (*http.Response, error) {
b.initOnce.Do(b.initClient)
b.initOnce.Do(b.build)

return b.client.Do(req)
}

func (b *BuildableHTTPClient) initClient() {
b.client = b.build()
}

// BuildHTTPClient returns an initialized HTTPClient built from the options of
// the builder.
func (b BuildableHTTPClient) build() *http.Client {
var tr *http.Transport
if b.transport != nil {
tr = shallowCopyStruct(b.transport).(*http.Transport)
} else {
tr = defaultHTTPTransport()
}

// TODO Any way to ensure HTTP 2 is supported without depending on
// an unversioned experimental package?
// Maybe only clients that depend on HTTP/2 should call this?
http2.ConfigureTransport(tr)

return wrapWithoutRedirect(&http.Client{
func (b *BuildableHTTPClient) build() {
b.client = wrapWithoutRedirect(&http.Client{
Timeout: b.clientTimeout,
Transport: tr,
Transport: b.GetTransport(),
})
}

func (b BuildableHTTPClient) initReset() BuildableHTTPClient {
b.initOnce = new(sync.Once)
b.client = nil
return b
func (b *BuildableHTTPClient) clone() *BuildableHTTPClient {
cpy := NewBuildableHTTPClient()
cpy.transport = b.GetTransport()
cpy.dialer = b.GetDialer()
cpy.clientTimeout = b.clientTimeout

return cpy
}

// WithTransportOptions copies the BuildableHTTPClient and returns it with the
Expand All @@ -110,51 +91,49 @@ func (b BuildableHTTPClient) initReset() BuildableHTTPClient {
// If a non (*http.Transport) was set as the round tripper, the round tripper
// will be replaced with a default Transport value before invoking the option
// functions.
func (b BuildableHTTPClient) WithTransportOptions(opts ...func(*http.Transport)) HTTPClient {
b = b.initReset()
func (b *BuildableHTTPClient) WithTransportOptions(opts ...func(*http.Transport)) HTTPClient {
cpy := b.clone()

tr := b.GetTransport()
tr := cpy.GetTransport()
for _, opt := range opts {
opt(tr)
}
b.transport = tr
cpy.transport = tr

return &b
return cpy
}

// WithDialerOptions copies the BuildableHTTPClient and returns it with the
// net.Dialer options applied. Will set the client's http.Transport DialContext
// member.
func (b BuildableHTTPClient) WithDialerOptions(opts ...func(*net.Dialer)) HTTPClient {
b = b.initReset()
func (b *BuildableHTTPClient) WithDialerOptions(opts ...func(*net.Dialer)) HTTPClient {
cpy := b.clone()

dialer := b.GetDialer()
dialer := cpy.GetDialer()
for _, opt := range opts {
opt(dialer)
}
b.dialer = dialer
cpy.dialer = dialer

tr := b.GetTransport()
tr.DialContext = b.dialer.DialContext
b.transport = tr
tr := cpy.GetTransport()
tr.DialContext = cpy.dialer.DialContext
cpy.transport = tr

return &b
return cpy
}

// WithTimeout Sets the timeout used by the client for all requests.
func (b BuildableHTTPClient) WithTimeout(timeout time.Duration) HTTPClient {
b = b.initReset()

b.clientTimeout = timeout

return &b
func (b *BuildableHTTPClient) WithTimeout(timeout time.Duration) HTTPClient {
cpy := b.clone()
cpy.clientTimeout = timeout
return cpy
}

// GetTransport returns a copy of the client's HTTP Transport.
func (b BuildableHTTPClient) GetTransport() *http.Transport {
func (b *BuildableHTTPClient) GetTransport() *http.Transport {
var tr *http.Transport
if b.transport != nil {
tr = shallowCopyStruct(b.transport).(*http.Transport)
tr = b.transport.Clone()
} else {
tr = defaultHTTPTransport()
}
Expand All @@ -163,7 +142,7 @@ func (b BuildableHTTPClient) GetTransport() *http.Transport {
}

// GetDialer returns a copy of the client's network dialer.
func (b BuildableHTTPClient) GetDialer() *net.Dialer {
func (b *BuildableHTTPClient) GetDialer() *net.Dialer {
var dialer *net.Dialer
if b.dialer != nil {
dialer = shallowCopyStruct(b.dialer).(*net.Dialer)
Expand All @@ -175,7 +154,7 @@ func (b BuildableHTTPClient) GetDialer() *net.Dialer {
}

// GetTimeout returns a copy of the client's timeout to cancel requests with.
func (b BuildableHTTPClient) GetTimeout() time.Duration {
func (b *BuildableHTTPClient) GetTimeout() time.Duration {
return b.clientTimeout
}

Expand All @@ -198,6 +177,7 @@ func defaultHTTPTransport() *http.Transport {
MaxIdleConnsPerHost: DefaultHTTPTransportMaxIdleConnsPerHost,
IdleConnTimeout: DefaultHTTPTransportIdleConnTimeout,
ExpectContinueTimeout: DefaultHTTPTransportExpectContinueTimeout,
ForceAttemptHTTP2: true,
}

return tr
Expand Down
41 changes: 41 additions & 0 deletions aws/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aws_test
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -43,3 +44,43 @@ func TestBuildableHTTPClient_WithTimeout(t *testing.T) {
t.Errorf("expect %v timeout, got %v", e, a)
}
}

func TestBuildableHTTPClient_concurrent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer server.Close()

var client aws.HTTPClient = aws.NewBuildableHTTPClient()

atOnce := 100
var wg sync.WaitGroup
wg.Add(atOnce)
for i := 0; i < atOnce; i++ {
go func(i int, client aws.HTTPClient) {
defer wg.Done()

if v, ok := client.(interface{ GetTimeout() time.Duration }); ok {
v.GetTimeout()
}

if i%3 == 0 {
if v, ok := client.(interface {
WithTransportOptions(opts ...func(*http.Transport)) aws.HTTPClient
}); ok {
client = v.WithTransportOptions()
}
}

req, _ := http.NewRequest("GET", server.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
resp.Body.Close()
}(i, client)
}

wg.Wait()
}
3 changes: 1 addition & 2 deletions service/s3/s3manager/upload_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,9 @@ func TestUploadByteSlicePool_Failures(t *testing.T) {
}

if r.Operation.Name == operation {
r.Retryable = aws.Bool(false)
r.Error = fmt.Errorf("request error")
r.HTTPResponse = &http.Response{
StatusCode: 500,
StatusCode: 400,
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
return
Expand Down

0 comments on commit 7e732f1

Please sign in to comment.