From 7a7120ff8a2ed08856fd7d256e7b09a72c1cf2e0 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 16 Jan 2020 14:33:01 -0800 Subject: [PATCH] Some refactoring in azcore (#6982) Added DefaultRetryOptions() to create initialized default options. Removed Response.CheckStatusCode() as it can't create custom errors. --- sdk/azcore/policy_retry.go | 44 ++++++++++++--------------------- sdk/azcore/policy_retry_test.go | 22 +++++++++-------- sdk/azcore/response.go | 12 --------- sdk/azcore/response_test.go | 27 ++++++++------------ 4 files changed, 38 insertions(+), 67 deletions(-) diff --git a/sdk/azcore/policy_retry.go b/sdk/azcore/policy_retry.go index 2a9531aef5ba..7d48f98a42b4 100644 --- a/sdk/azcore/policy_retry.go +++ b/sdk/azcore/policy_retry.go @@ -59,33 +59,15 @@ var ( } ) -func (o RetryOptions) defaults() RetryOptions { - // We assume the following: - // 1. o.MaxTries >= 0 - // 2. o.TryTimeout, o.RetryDelay, and o.MaxRetryDelay >=0 - // 3. o.RetryDelay <= o.MaxRetryDelay - // 4. Both o.RetryDelay and o.MaxRetryDelay must be 0 or neither can be 0 - - if len(o.StatusCodes) == 0 { - o.StatusCodes = StatusCodesForRetry[:] +// DefaultRetryOptions returns an instance of RetryOptions initialized with default values. +func DefaultRetryOptions() RetryOptions { + return RetryOptions{ + StatusCodes: StatusCodesForRetry[:], + MaxTries: defaultMaxTries, + TryTimeout: 1 * time.Minute, + RetryDelay: 4 * time.Second, + MaxRetryDelay: 120 * time.Second, } - - IfDefault := func(current *time.Duration, desired time.Duration) { - if *current == time.Duration(0) { - *current = desired - } - } - - // Set defaults if unspecified - if o.MaxTries == 0 { - o.MaxTries = defaultMaxTries - } - - IfDefault(&o.TryTimeout, 1*time.Minute) - IfDefault(&o.RetryDelay, 4*time.Second) - IfDefault(&o.MaxRetryDelay, 120*time.Second) - - return o } func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0 @@ -108,8 +90,14 @@ func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never } // NewRetryPolicy creates a policy object configured using the specified options. -func NewRetryPolicy(o RetryOptions) Policy { - return &retryPolicy{options: o.defaults()} // Force defaults to be calculated +// Pass nil to accept the default values; this is the same as passing the result +// from a call to DefaultRetryOptions(). +func NewRetryPolicy(o *RetryOptions) Policy { + if o == nil { + def := DefaultRetryOptions() + o = &def + } + return &retryPolicy{options: *o} } type retryPolicy struct { diff --git a/sdk/azcore/policy_retry_test.go b/sdk/azcore/policy_retry_test.go index d096b12bd31b..6a3dc79197cf 100644 --- a/sdk/azcore/policy_retry_test.go +++ b/sdk/azcore/policy_retry_test.go @@ -17,13 +17,17 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) -const retryDelay = 20 * time.Millisecond +func testRetryOptions() *RetryOptions { + def := DefaultRetryOptions() + def.RetryDelay = 20 * time.Millisecond + return &def +} func TestRetryPolicySuccess(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusOK)) - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{})) + pl := NewPipeline(srv, NewRetryPolicy(nil)) req := NewRequest(http.MethodGet, srv.URL()) body := newRewindTrackingBody("stuff") req.SetBody(body) @@ -46,7 +50,7 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError)) - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay})) + pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req := NewRequest(http.MethodGet, srv.URL()) body := newRewindTrackingBody("stuff") req.SetBody(body) @@ -74,7 +78,7 @@ func TestRetryPolicySuccessWithRetry(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse() - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay})) + pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req := NewRequest(http.MethodGet, srv.URL()) body := newRewindTrackingBody("stuff") req.SetBody(body) @@ -101,7 +105,7 @@ func TestRetryPolicyFailOnError(t *testing.T) { defer close() fakeErr := errors.New("bogus error") srv.SetError(fakeErr) - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay})) + pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req := NewRequest(http.MethodPost, srv.URL()) body := newRewindTrackingBody("stuff") req.SetBody(body) @@ -130,7 +134,7 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) { srv.AppendError(errors.New("bogus error")) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay})) + pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req := NewRequest(http.MethodGet, srv.URL()) body := newRewindTrackingBody("stuff") req.SetBody(body) @@ -156,7 +160,7 @@ func TestRetryPolicyRequestTimedOut(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetError(errors.New("bogus error")) - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{})) + pl := NewPipeline(srv, NewRetryPolicy(nil)) req := NewRequest(http.MethodPost, srv.URL()) body := newRewindTrackingBody("stuff") req.SetBody(body) @@ -195,9 +199,7 @@ func TestRetryPolicyIsNotRetriable(t *testing.T) { defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendError(theErr) - pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{ - RetryDelay: retryDelay, - })) + pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions())) _, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL())) if err == nil { t.Fatal("unexpected nil error") diff --git a/sdk/azcore/response.go b/sdk/azcore/response.go index 4776d29eeb0b..edb19ad244e3 100644 --- a/sdk/azcore/response.go +++ b/sdk/azcore/response.go @@ -36,18 +36,6 @@ func (r *Response) payload() []byte { return nil } -// CheckStatusCode returns a RequestError if the Response's status code isn't one of the specified values. -func (r *Response) CheckStatusCode(statusCodes ...int) error { - if !r.HasStatusCode(statusCodes...) { - msg := r.Status - if len(r.payload()) > 0 { - msg = string(r.payload()) - } - return newRequestError(msg, r) - } - return nil -} - // HasStatusCode returns true if the Response's status code is one of the specified values. func (r *Response) HasStatusCode(statusCodes ...int) bool { if r == nil { diff --git a/sdk/azcore/response_test.go b/sdk/azcore/response_test.go index e8fca4b4ce2f..263f0c3db674 100644 --- a/sdk/azcore/response_test.go +++ b/sdk/azcore/response_test.go @@ -23,8 +23,8 @@ func TestResponseUnmarshalXML(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if err := resp.CheckStatusCode(http.StatusOK); err != nil { - t.Fatalf("unexpected status code error: %v", err) + if !resp.HasStatusCode(http.StatusOK) { + t.Fatalf("unexpected status code: %d", resp.StatusCode) } var tx testXML if err := resp.UnmarshalAsXML(&tx); err != nil { @@ -44,15 +44,8 @@ func TestResponseFailureStatusCode(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if err = resp.CheckStatusCode(http.StatusOK); err == nil { - t.Fatal("unexpected nil status code error") - } - re, ok := err.(RequestError) - if !ok { - t.Fatal("expected RequestError type") - } - if re.Response().StatusCode != http.StatusForbidden { - t.Fatal("unexpected response") + if resp.HasStatusCode(http.StatusOK) { + t.Fatalf("unexpected status code: %d", resp.StatusCode) } } @@ -65,8 +58,8 @@ func TestResponseUnmarshalJSON(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if err := resp.CheckStatusCode(http.StatusOK); err != nil { - t.Fatalf("unexpected status code error: %v", err) + if !resp.HasStatusCode(http.StatusOK) { + t.Fatalf("unexpected status code: %d", resp.StatusCode) } var tx testJSON if err := resp.UnmarshalAsJSON(&tx); err != nil { @@ -86,8 +79,8 @@ func TestResponseUnmarshalJSONNoBody(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if err := resp.CheckStatusCode(http.StatusOK); err != nil { - t.Fatalf("unexpected status code error: %v", err) + if !resp.HasStatusCode(http.StatusOK) { + t.Fatalf("unexpected status code: %d", resp.StatusCode) } if err := resp.UnmarshalAsJSON(nil); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) @@ -103,8 +96,8 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if err := resp.CheckStatusCode(http.StatusOK); err != nil { - t.Fatalf("unexpected status code error: %v", err) + if !resp.HasStatusCode(http.StatusOK) { + t.Fatalf("unexpected status code: %d", resp.StatusCode) } if err := resp.UnmarshalAsXML(nil); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err)