Skip to content

Commit

Permalink
Merge pull request #4817 from aws/fix-restjson-errors
Browse files Browse the repository at this point in the history
fix: restjson error deserialization when no body is present
  • Loading branch information
lucix-aws committed May 8, 2023
2 parents 6eef808 + 35072aa commit fc3c2d6
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 156 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
### SDK Enhancements

### SDK Bugs
* `restjson`: Correct failure to deserialize errors.
* Deserialize generic error information when no response body is present.
135 changes: 78 additions & 57 deletions private/protocol/restjson/unmarshal_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package restjson

import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -40,54 +41,30 @@ func (u *UnmarshalTypedError) UnmarshalError(
resp *http.Response,
respMeta protocol.ResponseMetadata,
) (error, error) {

code := resp.Header.Get(errorTypeHeader)
msg := resp.Header.Get(errorMessageHeader)

body := resp.Body
if len(code) == 0 || len(msg) == 0 {
// If unable to get code from HTTP headers have to parse JSON message
// to determine what kind of exception this will be.
var buf bytes.Buffer
var jsonErr jsonErrorResponse
teeReader := io.TeeReader(resp.Body, &buf)
err := jsonutil.UnmarshalJSONError(&jsonErr, teeReader)
if err != nil {
return nil, err
}

body = ioutil.NopCloser(&buf)
if len(code) == 0 {
code = jsonErr.Code
}
msg = jsonErr.Message
code, msg, err := unmarshalErrorInfo(resp)
if err != nil {
return nil, err
}

// If code has colon separators remove them so can compare against modeled
// exception names.
code = strings.SplitN(code, ":", 2)[0]

if fn, ok := u.exceptions[code]; ok {
// If exception code is know, use associated constructor to get a value
// for the exception that the JSON body can be unmarshaled into.
v := fn(respMeta)
if err := jsonutil.UnmarshalJSONCaseInsensitive(v, body); err != nil {
return nil, err
}
fn, ok := u.exceptions[code]
if !ok {
return awserr.NewRequestFailure(
awserr.New(code, msg, nil),
respMeta.StatusCode,
respMeta.RequestID,
), nil
}

if err := rest.UnmarshalResponse(resp, v, true); err != nil {
return nil, err
}
v := fn(respMeta)
if err := jsonutil.UnmarshalJSONCaseInsensitive(v, resp.Body); err != nil {
return nil, err
}

return v, nil
if err := rest.UnmarshalResponse(resp, v, true); err != nil {
return nil, err
}

// fallback to unmodeled generic exceptions
return awserr.NewRequestFailure(
awserr.New(code, msg, nil),
respMeta.StatusCode,
respMeta.RequestID,
), nil
return v, nil
}

// UnmarshalErrorHandler is a named request handler for unmarshaling restjson
Expand All @@ -101,36 +78,80 @@ var UnmarshalErrorHandler = request.NamedHandler{
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()

var jsonErr jsonErrorResponse
err := jsonutil.UnmarshalJSONError(&jsonErr, r.HTTPResponse.Body)
code, msg, err := unmarshalErrorInfo(r.HTTPResponse)
if err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to unmarshal response error", err),
awserr.New(request.ErrCodeSerialization, "failed to unmarshal response error", err),
r.HTTPResponse.StatusCode,
r.RequestID,
)
return
}

code := r.HTTPResponse.Header.Get(errorTypeHeader)
if code == "" {
code = jsonErr.Code
}
msg := r.HTTPResponse.Header.Get(errorMessageHeader)
if msg == "" {
msg = jsonErr.Message
}

code = strings.SplitN(code, ":", 2)[0]
r.Error = awserr.NewRequestFailure(
awserr.New(code, jsonErr.Message, nil),
awserr.New(code, msg, nil),
r.HTTPResponse.StatusCode,
r.RequestID,
)
}

type jsonErrorResponse struct {
Type string `json:"__type"`
Code string `json:"code"`
Message string `json:"message"`
}

func (j *jsonErrorResponse) SanitizedCode() string {
code := j.Code
if len(j.Type) > 0 {
code = j.Type
}
return sanitizeCode(code)
}

// Remove superfluous components from a restJson error code.
// - If a : character is present, then take only the contents before the
// first : character in the value.
// - If a # character is present, then take only the contents after the first
// # character in the value.
//
// All of the following error values resolve to FooError:
// - FooError
// - FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/
// - aws.protocoltests.restjson#FooError
// - aws.protocoltests.restjson#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/
func sanitizeCode(code string) string {
noColon := strings.SplitN(code, ":", 2)[0]
hashSplit := strings.SplitN(noColon, "#", 2)
return hashSplit[len(hashSplit)-1]
}

// attempt to garner error details from the response, preferring header values
// when present
func unmarshalErrorInfo(resp *http.Response) (code string, msg string, err error) {
code = sanitizeCode(resp.Header.Get(errorTypeHeader))
msg = resp.Header.Get(errorMessageHeader)
if len(code) > 0 && len(msg) > 0 {
return
}

// a modeled error will have to be re-deserialized later, so the body must
// be preserved
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
defer func() { resp.Body = ioutil.NopCloser(&buf) }()

var jsonErr jsonErrorResponse
if decodeErr := json.NewDecoder(tee).Decode(&jsonErr); decodeErr != nil && decodeErr != io.EOF {
err = awserr.NewUnmarshalError(decodeErr, "failed to decode response body", buf.Bytes())
return
}

if len(code) == 0 {
code = jsonErr.SanitizedCode()
}
if len(msg) == 0 {
msg = jsonErr.Message
}
return
}
Loading

0 comments on commit fc3c2d6

Please sign in to comment.