Skip to content

Commit

Permalink
azcore cleanup (#12644)
Browse files Browse the repository at this point in the history
Refactored various policy options to conform with our options pattern.
Removed some content that was specific to storage.
Made some content internal as it's not needed for consumers.
  • Loading branch information
jhendrixMSFT committed Oct 8, 2020
1 parent 81b71aa commit 7f2ad15
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 165 deletions.
8 changes: 5 additions & 3 deletions sdk/azcore/policy_http_header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ func TestAddCustomHTTPHeaderSuccess(t *testing.T) {
func TestAddCustomHTTPHeaderFail(t *testing.T) {
srv, close := mock.NewServer()
defer close()
const customHeader = "custom-header"
const customValue = "custom-value"
srv.AppendResponse(mock.WithPredicate(func(r *http.Request) bool {
return r.Header.Get(xMsClientRequestID) == customValue
return r.Header.Get(customHeader) == customValue
}), mock.WithStatusCode(http.StatusOK))
srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest))
// HTTP header policy is automatically added during pipeline construction
Expand All @@ -70,16 +71,17 @@ func TestAddCustomHTTPHeaderFail(t *testing.T) {
func TestAddCustomHTTPHeaderOverwrite(t *testing.T) {
srv, close := mock.NewServer()
defer close()
const customHeader = "custom-header"
const customValue = "custom-value"
srv.AppendResponse(mock.WithPredicate(func(r *http.Request) bool {
return r.Header.Get(xMsClientRequestID) == customValue
return r.Header.Get(customHeader) == customValue
}), mock.WithStatusCode(http.StatusOK))
srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest))
// HTTP header policy is automatically added during pipeline construction
pl := NewPipeline(srv)
// overwrite the request ID with our own value
req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{
xMsClientRequestID: []string{customValue},
customHeader: []string{customValue},
}), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand Down
63 changes: 21 additions & 42 deletions sdk/azcore/policy_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,34 @@ package azcore
import (
"bytes"
"fmt"
"net/url"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/internal/runtime"
)

// RequestLogOptions configures the logging policy's behavior.
type RequestLogOptions struct {
// LogOptions configures the logging policy's behavior.
type LogOptions struct {
// placeholder for future configuration options
}

type requestLogPolicy struct {
options RequestLogOptions
// DefaultLogOptions returns an instance of LogOptions initialized with default values.
func DefaultLogOptions() LogOptions {
return LogOptions{}
}

// NewRequestLogPolicy creates a RequestLogPolicy object configured using the specified options.
func NewRequestLogPolicy(o *RequestLogOptions) Policy {
return &requestLogPolicy{}
type logPolicy struct {
options LogOptions
}

// NewLogPolicy creates a RequestLogPolicy object configured using the specified options.
// Pass nil to accept the default values; this is the same as passing the result
// from a call to DefaultLogOptions().
func NewLogPolicy(o *LogOptions) Policy {
if o == nil {
def := DefaultLogOptions()
o = &def
}
return &logPolicy{options: *o}
}

// logPolicyOpValues is the struct containing the per-operation values
Expand All @@ -35,7 +44,7 @@ type logPolicyOpValues struct {
start time.Time
}

func (p *requestLogPolicy) Do(req *Request) (*Response, error) {
func (p *logPolicy) Do(req *Request) (*Response, error) {
// Get the per-operation values. These are saved in the Message's map so that they persist across each retry calling into this policy object.
var opValues logPolicyOpValues
if req.OperationValue(&opValues); opValues.start.IsZero() {
Expand All @@ -48,7 +57,7 @@ func (p *requestLogPolicy) Do(req *Request) (*Response, error) {
if Log().Should(LogRequest) {
b := &bytes.Buffer{}
fmt.Fprintf(b, "==> OUTGOING REQUEST (Try=%d)\n", opValues.try)
WriteRequestWithResponse(b, prepareRequestForLogging(req), nil, nil)
writeRequestWithResponse(b, req, nil, nil)
Log().Write(LogRequest, b.String())
}

Expand All @@ -69,7 +78,7 @@ func (p *requestLogPolicy) Do(req *Request) (*Response, error) {
fmt.Fprint(b, "RESPONSE RECEIVED\n")
}

WriteRequestWithResponse(b, prepareRequestForLogging(req), response, err)
writeRequestWithResponse(b, req, response, err)
if err != nil {
// skip frames runtime.Callers() and runtime.StackTrace()
b.WriteString(runtime.StackTrace(2, StackFrameCount))
Expand All @@ -78,33 +87,3 @@ func (p *requestLogPolicy) Do(req *Request) (*Response, error) {
}
return response, err
}

// RedactSigQueryParam redacts the 'sig' query parameter in URL's raw query to protect secret.
func RedactSigQueryParam(rawQuery string) (bool, string) {
rawQuery = strings.ToLower(rawQuery) // lowercase the string so we can look for ?sig= and &sig=
sigFound := strings.Contains(rawQuery, "?sig=")
if !sigFound {
sigFound = strings.Contains(rawQuery, "&sig=")
if !sigFound {
return sigFound, rawQuery // [?|&]sig= not found; return same rawQuery passed in (no memory allocation)
}
}
// [?|&]sig= found, redact its value
values, _ := url.ParseQuery(rawQuery)
for name := range values {
if strings.EqualFold(name, "sig") {
values[name] = []string{"REDACTED"}
}
}
return sigFound, values.Encode()
}

func prepareRequestForLogging(req *Request) *Request {
request := req
if sigFound, rawQuery := RedactSigQueryParam(request.URL.RawQuery); sigFound {
// Make copy so we don't destroy the query parameters we actually need to send in the request
request = req.copy()
request.URL.RawQuery = rawQuery
}
return request
}
7 changes: 2 additions & 5 deletions sdk/azcore/policy_logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestPolicyLoggingSuccess(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse()
pl := NewPipeline(srv, NewRequestLogPolicy(nil))
pl := NewPipeline(srv, NewLogPolicy(nil))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -43,9 +43,6 @@ func TestPolicyLoggingSuccess(t *testing.T) {
// Request ==> OUTGOING REQUEST (Try=1)
// GET http://127.0.0.1:49475?one=fish&sig=REDACTED
// (no headers)
if !strings.Contains(logReq, "sig=REDACTED") {
t.Fatal("missing redacted sig query param")
}
if !strings.Contains(logReq, "(no headers)") {
t.Fatal("missing (no headers)")
}
Expand Down Expand Up @@ -76,7 +73,7 @@ func TestPolicyLoggingError(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetError(errors.New("bogus error"))
pl := NewPipeline(srv, NewRequestLogPolicy(nil))
pl := NewPipeline(srv, NewLogPolicy(nil))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand Down
13 changes: 12 additions & 1 deletion sdk/azcore/policy_telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ type TelemetryOptions struct {
Disabled bool
}

// DefaultTelemetryOptions returns an instance of TelemetryOptions initialized with default values.
func DefaultTelemetryOptions() TelemetryOptions {
return TelemetryOptions{}
}

type telemetryPolicy struct {
telemetryValue string
}

// NewTelemetryPolicy creates a telemetry policy object that adds telemetry information to outgoing HTTP requests.
// The format is [<application_id> ]azsdk-<sdk_language>-<package_name>/<package_version> <platform_info> [<custom>].
func NewTelemetryPolicy(o TelemetryOptions) Policy {
// Pass nil to accept the default values; this is the same as passing the result
// from a call to DefaultTelemetryOptions().
func NewTelemetryPolicy(o *TelemetryOptions) Policy {
if o == nil {
def := DefaultTelemetryOptions()
o = &def
}
tp := telemetryPolicy{}
if o.Disabled {
return &tp
Expand Down
25 changes: 18 additions & 7 deletions sdk/azcore/policy_telemetry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestPolicyTelemetryDefault(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse()
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
pl := NewPipeline(srv, NewTelemetryPolicy(nil))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -37,7 +37,9 @@ func TestPolicyTelemetryWithCustomInfo(t *testing.T) {
defer close()
srv.SetResponse()
const testValue = "azcore_test"
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{Value: testValue}))
o := DefaultTelemetryOptions()
o.Value = testValue
pl := NewPipeline(srv, NewTelemetryPolicy(&o))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -55,7 +57,7 @@ func TestPolicyTelemetryPreserveExisting(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse()
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
pl := NewPipeline(srv, NewTelemetryPolicy(nil))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -76,7 +78,9 @@ func TestPolicyTelemetryWithAppID(t *testing.T) {
defer close()
srv.SetResponse()
const appID = "my_application"
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{ApplicationID: appID}))
o := DefaultTelemetryOptions()
o.ApplicationID = appID
pl := NewPipeline(srv, NewTelemetryPolicy(&o))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -95,7 +99,9 @@ func TestPolicyTelemetryWithAppIDSanitized(t *testing.T) {
defer close()
srv.SetResponse()
const appID = "This will get the spaces removed and truncated."
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{ApplicationID: appID}))
o := DefaultTelemetryOptions()
o.ApplicationID = appID
pl := NewPipeline(srv, NewTelemetryPolicy(&o))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -115,7 +121,9 @@ func TestPolicyTelemetryPreserveExistingWithAppID(t *testing.T) {
defer close()
srv.SetResponse()
const appID = "my_application"
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{ApplicationID: appID}))
o := DefaultTelemetryOptions()
o.ApplicationID = appID
pl := NewPipeline(srv, NewTelemetryPolicy(&o))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -136,7 +144,10 @@ func TestPolicyTelemetryDisabled(t *testing.T) {
defer close()
srv.SetResponse()
const appID = "my_application"
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{ApplicationID: appID, Disabled: true}))
o := DefaultTelemetryOptions()
o.ApplicationID = appID
o.Disabled = true
pl := NewPipeline(srv, NewTelemetryPolicy(&o))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand Down
24 changes: 0 additions & 24 deletions sdk/azcore/policy_unique_request_id.go

This file was deleted.

52 changes: 0 additions & 52 deletions sdk/azcore/policy_unique_request_id_test.go

This file was deleted.

2 changes: 1 addition & 1 deletion sdk/azcore/progress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestProgressReporting(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody(content))
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
pl := NewPipeline(srv)
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand Down
20 changes: 0 additions & 20 deletions sdk/azcore/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,26 +211,6 @@ func (req *Request) Close() error {
return req.Body.Close()
}

// copy returns a shallow copy of the request
func (req *Request) copy() *Request {
clonedURL := *req.URL
// Copy the values and immutable references
return &Request{
Request: &http.Request{
Method: req.Method,
URL: &clonedURL,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: req.Header.Clone(),
Host: req.URL.Host,
Body: req.Body, // shallow copy
ContentLength: req.ContentLength,
GetBody: req.GetBody,
},
}
}

// clone returns a deep copy of the request with its context changed to ctx
func (req *Request) clone(ctx context.Context) *Request {
r2 := Request{}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azcore/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ func RetryAfter(resp *http.Response) time.Duration {
return 0
}

// WriteRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are
// writeRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are
// not nil, then these are also written into the Buffer.
func WriteRequestWithResponse(b *bytes.Buffer, request *Request, response *Response, err error) {
func writeRequestWithResponse(b *bytes.Buffer, request *Request, response *Response, err error) {
// Write the request into the buffer.
fmt.Fprint(b, " "+request.Method+" "+request.URL.String()+"\n")
writeHeader(b, request.Header)
Expand Down

0 comments on commit 7f2ad15

Please sign in to comment.