diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39c..d39ebb6664 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,5 +1,6 @@ ### SDK Features ### SDK Enhancements +`aws/client`: Adding status code 429 to throttlable status codes in default retryer (#1621) ### SDK Bugs diff --git a/aws/client/default_retryer.go b/aws/client/default_retryer.go index e25a460fba..c31cb395b0 100644 --- a/aws/client/default_retryer.go +++ b/aws/client/default_retryer.go @@ -2,6 +2,7 @@ package client import ( "math/rand" + "strconv" "sync" "time" @@ -38,6 +39,10 @@ func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { minTime := 30 throttle := d.shouldThrottle(r) if throttle { + if delay, ok := getRetryDelay(r); ok { + return delay + } + minTime = 500 } @@ -68,12 +73,49 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool { // ShouldThrottle returns true if the request should be throttled. func (d DefaultRetryer) shouldThrottle(r *request.Request) bool { - if r.HTTPResponse.StatusCode == 502 || - r.HTTPResponse.StatusCode == 503 || - r.HTTPResponse.StatusCode == 504 { - return true + switch r.HTTPResponse.StatusCode { + case 429: + case 502: + case 503: + case 504: + default: + return r.IsErrorThrottle() + } + + return true +} + +// This will look in the Retry-After header, RFC 7231, for how long +// it will wait before attempting another request +func getRetryDelay(r *request.Request) (time.Duration, bool) { + if !canUseRetryAfterHeader(r) { + return 0, false + } + + delayStr := r.HTTPResponse.Header.Get("Retry-After") + if len(delayStr) == 0 { + return 0, false } - return r.IsErrorThrottle() + + delay, err := strconv.Atoi(delayStr) + if err != nil { + return 0, false + } + + return time.Duration(delay) * time.Second, true +} + +// Will look at the status code to see if the retry header pertains to +// the status code. +func canUseRetryAfterHeader(r *request.Request) bool { + switch r.HTTPResponse.StatusCode { + case 429: + case 503: + default: + return false + } + + return true } // lockedSource is a thread-safe implementation of rand.Source diff --git a/aws/client/default_retryer_test.go b/aws/client/default_retryer_test.go new file mode 100644 index 0000000000..4868088673 --- /dev/null +++ b/aws/client/default_retryer_test.go @@ -0,0 +1,166 @@ +package client + +import ( + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/request" +) + +func TestRetryThrottleStatusCodes(t *testing.T) { + cases := []struct { + expectThrottle bool + expectRetry bool + r request.Request + }{ + { + false, + false, + request.Request{ + HTTPResponse: &http.Response{StatusCode: 200}, + }, + }, + { + true, + true, + request.Request{ + HTTPResponse: &http.Response{StatusCode: 429}, + }, + }, + { + true, + true, + request.Request{ + HTTPResponse: &http.Response{StatusCode: 502}, + }, + }, + { + true, + true, + request.Request{ + HTTPResponse: &http.Response{StatusCode: 503}, + }, + }, + { + true, + true, + request.Request{ + HTTPResponse: &http.Response{StatusCode: 504}, + }, + }, + { + false, + true, + request.Request{ + HTTPResponse: &http.Response{StatusCode: 500}, + }, + }, + } + + d := DefaultRetryer{NumMaxRetries: 10} + for i, c := range cases { + throttle := d.shouldThrottle(&c.r) + retry := d.ShouldRetry(&c.r) + + if e, a := c.expectThrottle, throttle; e != a { + t.Errorf("%d: expected %v, but received %v", i, e, a) + } + + if e, a := c.expectRetry, retry; e != a { + t.Errorf("%d: expected %v, but received %v", i, e, a) + } + } +} + +func TestCanUseRetryAfter(t *testing.T) { + cases := []struct { + r request.Request + e bool + }{ + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 200}, + }, + false, + }, + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 500}, + }, + false, + }, + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 429}, + }, + true, + }, + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 503}, + }, + true, + }, + } + + for i, c := range cases { + a := canUseRetryAfterHeader(&c.r) + if c.e != a { + t.Errorf("%d: expected %v, but received %v", i, c.e, a) + } + } +} + +func TestGetRetryDelay(t *testing.T) { + cases := []struct { + r request.Request + e time.Duration + equal bool + ok bool + }{ + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 429, Header: http.Header{"Retry-After": []string{"3600"}}}, + }, + 3600 * time.Second, + true, + true, + }, + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}}, + }, + 120 * time.Second, + true, + true, + }, + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}}, + }, + 1 * time.Second, + false, + true, + }, + { + request.Request{ + HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}}, + }, + 0 * time.Second, + true, + false, + }, + } + + for i, c := range cases { + a, ok := getRetryDelay(&c.r) + if c.ok != ok { + t.Errorf("%d: expected %v, but received %v", i, c.ok, ok) + } + + if (c.e != a) == c.equal { + t.Errorf("%d: expected %v, but received %v", i, c.e, a) + } + } +} diff --git a/aws/request/request_test.go b/aws/request/request_test.go index 4d9258c0a6..e406255ee7 100644 --- a/aws/request/request_test.go +++ b/aws/request/request_test.go @@ -112,7 +112,8 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)}, - {StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)}, + {StatusCode: 400, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)}, + {StatusCode: 429, Body: body(`{"__type":"FooException","message":"Rate exceeded."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } @@ -131,7 +132,7 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) { if err != nil { t.Fatalf("expect no error, but got %v", err) } - if e, a := 2, int(r.RetryCount); e != a { + if e, a := 3, int(r.RetryCount); e != a { t.Errorf("expect %d retry count, got %d", e, a) } if e, a := "valid", out.Data; e != a {