Skip to content

Commit

Permalink
Add support for marshalling/unmarshalling JSON (#6969)
Browse files Browse the repository at this point in the history
* Add support for marshalling/unmarshalling JSON

Removed Response.Payload field, replacing it with an internal
implementation nopClosingBytesReader.

* exit early when unmarshalling if there's no payload
  • Loading branch information
jhendrixMSFT committed Jan 16, 2020
1 parent d154b39 commit dfc6a17
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 16 deletions.
37 changes: 36 additions & 1 deletion sdk/azcore/policy_body_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azcore
import (
"context"
"fmt"
"io"
"io/ioutil"
)

Expand All @@ -22,11 +23,12 @@ func newBodyDownloadPolicy() Policy {
if req.OperationValue(&opValues); !opValues.skip && resp.Body != nil {
// Either bodyDownloadPolicyOpValues was not specified (so skip is false)
// or it was specified and skip is false: don't skip downloading the body
resp.Payload, err = ioutil.ReadAll(resp.Body)
b, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
err = fmt.Errorf("body download policy: %w", err)
}
resp.Body = &nopClosingBytesReader{s: b}
}
return resp, err
})
Expand All @@ -36,3 +38,36 @@ func newBodyDownloadPolicy() Policy {
type bodyDownloadPolicyOpValues struct {
skip bool
}

// nopClosingBytesReader is an io.ReadCloser around a byte slice.
// It also provides direct access to the byte slice.
type nopClosingBytesReader struct {
s []byte
i int64
}

// Bytes returns the underlying byte slice.
func (r *nopClosingBytesReader) Bytes() []byte {
return r.s
}

// Close implements the io.Closer interface.
func (*nopClosingBytesReader) Close() error {
return nil
}

// Read implements the io.Reader interface.
func (r *nopClosingBytesReader) Read(b []byte) (n int, err error) {
if r.i >= int64(len(r.s)) {
return 0, io.EOF
}
n = copy(b, r.s[r.i:])
r.i += int64(n)
return
}

// Set replaces the existing byte slice with the specified byte slice and resets the reader.
func (r *nopClosingBytesReader) Set(b []byte) {
r.s = b
r.i = 0
}
10 changes: 5 additions & 5 deletions sdk/azcore/policy_body_download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ func TestDownloadBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.Payload) == 0 {
if len(resp.payload()) == 0 {
t.Fatal("missing payload")
}
if string(resp.Payload) != message {
t.Fatalf("unexpected response: %s", string(resp.Payload))
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
}
}

Expand All @@ -45,7 +45,7 @@ func TestSkipBodyDownload(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.Payload) > 0 {
t.Fatalf("unexpected download: %s", string(resp.Payload))
if len(resp.payload()) > 0 {
t.Fatalf("unexpected download: %s", string(resp.payload()))
}
}
15 changes: 14 additions & 1 deletion sdk/azcore/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azcore
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"io"
Expand All @@ -18,7 +19,8 @@ import (
)

const (
contentTypeAppXML = "application/xml"
contentTypeAppJSON = "application/json"
contentTypeAppXML = "application/xml"
)

// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline.
Expand Down Expand Up @@ -85,6 +87,17 @@ func (req *Request) Next(ctx context.Context) (*Response, error) {
return nextPolicy.Do(ctx, &nextReq)
}

// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody.
// If json.Marshal fails a MarshalError is returned. Any error from SetBody is returned.
func (req *Request) MarshalAsJSON(v interface{}) error {
b, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("error marshalling type %s: %w", reflect.TypeOf(v).Name(), err)
}
req.Header.Set(HeaderContentType, contentTypeAppJSON)
return req.SetBody(NopCloser(bytes.NewReader(b)))
}

// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody.
// If xml.Marshal fails a MarshalError is returned. Any error from SetBody is returned.
func (req *Request) MarshalAsXML(v interface{}) error {
Expand Down
26 changes: 26 additions & 0 deletions sdk/azcore/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
"testing"
)

type testJSON struct {
SomeInt int
SomeString string
}

type testXML struct {
SomeInt int
SomeString string
Expand Down Expand Up @@ -53,3 +58,24 @@ func TestRequestEmptyPipeline(t *testing.T) {
t.Fatalf("expected ErrNoMorePolicies, got %v", err)
}
}

func TestRequestMarshalJSON(t *testing.T) {
u, err := url.Parse("https://contoso.com")
if err != nil {
panic(err)
}
req := NewRequest(http.MethodPost, *u)
err = req.MarshalAsJSON(testJSON{SomeInt: 1, SomeString: "s"})
if err != nil {
t.Fatalf("marshal failure: %v", err)
}
if ct := req.Header.Get(HeaderContentType); ct != contentTypeAppJSON {
t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON)
}
if req.Body == nil {
t.Fatal("unexpected nil request body")
}
if req.ContentLength == 0 {
t.Fatal("unexpected zero content length")
}
}
44 changes: 36 additions & 8 deletions sdk/azcore/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package azcore

import (
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"io"
Expand All @@ -22,17 +23,25 @@ import (
// Response represents the response from an HTTP request.
type Response struct {
*http.Response
}

// Payload contains the contents of the HTTP response body if available.
Payload []byte
func (r *Response) payload() []byte {
if r.Body == nil {
return nil
}
// r.Body won't be a nopClosingBytesReader if downloading was skipped
if buf, ok := r.Body.(*nopClosingBytesReader); ok {
return buf.Bytes()
}
return nil
}

// CheckStatusCode returns a RequestError if the Response's status code isn't one of the specified values.
func (r *Response) CheckStatusCode(statusCodes ...int) error {
if !r.HasStatusCode(statusCodes...) {
msg := r.Status
if len(r.Payload) > 0 {
msg = string(r.Payload)
if len(r.payload()) > 0 {
msg = string(r.payload())
}
return newRequestError(msg, r)
}
Expand All @@ -52,14 +61,30 @@ func (r *Response) HasStatusCode(statusCodes ...int) bool {
return false
}

// UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v.
// If no payload was received a RequestError is returned. If json.Unmarshal fails a UnmarshalError is returned.
func (r *Response) UnmarshalAsJSON(v interface{}) error {
// TODO: verify early exit is correct
if len(r.payload()) == 0 {
return nil
}
r.removeBOM()
err := json.Unmarshal(r.payload(), v)
if err != nil {
err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err)
}
return err
}

// UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v.
// If no payload was received a RequestError is returned. If xml.Unmarshal fails a UnmarshalError is returned.
func (r *Response) UnmarshalAsXML(v interface{}) error {
if len(r.Payload) == 0 {
return newRequestError("missing payload", r)
// TODO: verify early exit is correct
if len(r.payload()) == 0 {
return nil
}
r.removeBOM()
err := xml.Unmarshal(r.Payload, v)
err := xml.Unmarshal(r.payload(), v)
if err != nil {
err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err)
}
Expand All @@ -77,7 +102,10 @@ func (r *Response) Drain() {
// removeBOM removes any byte-order mark prefix from the payload if present.
func (r *Response) removeBOM() {
// UTF8
r.Payload = bytes.TrimPrefix(r.Payload, []byte("\xef\xbb\xbf"))
trimmed := bytes.TrimPrefix(r.payload(), []byte("\xef\xbb\xbf"))
if len(trimmed) < len(r.payload()) {
r.Body.(*nopClosingBytesReader).Set(trimmed)
}
}

// RetryAfter returns (non-zero, true) if the response contains a Retry-After header value
Expand Down
58 changes: 57 additions & 1 deletion sdk/azcore/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
func TestResponseUnmarshalXML(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody([]byte("<testXML><SomeInt>1</SomeInt><SomeString>s</SomeString></testXML>")))
// include UTF8 BOM
srv.SetResponse(mock.WithBody([]byte("\xef\xbb\xbf<testXML><SomeInt>1</SomeInt><SomeString>s</SomeString></testXML>")))
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
resp, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL()))
if err != nil {
Expand Down Expand Up @@ -54,3 +55,58 @@ func TestResponseFailureStatusCode(t *testing.T) {
t.Fatal("unexpected response")
}
}

func TestResponseUnmarshalJSON(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody([]byte(`{ "someInt": 1, "someString": "s" }`)))
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
resp, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL()))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
}
var tx testJSON
if err := resp.UnmarshalAsJSON(&tx); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
}
if tx.SomeInt != 1 || tx.SomeString != "s" {
t.Fatal("unexpected value")
}
}

func TestResponseUnmarshalJSONNoBody(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody([]byte{}))
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
resp, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL()))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
}
if err := resp.UnmarshalAsJSON(nil); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
}
}

func TestResponseUnmarshalXMLNoBody(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody([]byte{}))
pl := NewPipeline(srv, NewTelemetryPolicy(TelemetryOptions{}))
resp, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL()))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
}
if err := resp.UnmarshalAsXML(nil); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
}
}

0 comments on commit dfc6a17

Please sign in to comment.