Skip to content

Commit

Permalink
aws/request: Fix bug where the host header would not reflect changes …
Browse files Browse the repository at this point in the history
…to the endpoint URL (#3102)
  • Loading branch information
skmcgrail committed Jan 30, 2020
1 parent 3acad12 commit 1edb62f
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 27 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
* Adds support for the EC2ThrottledException throttling exception code. The SDK will now treat this error code as throttling.

### SDK Bugs
* `aws/request`: Fixes an issue where the HTTP host header did not reflect changes to the endpoint URL ([#3102](https://github.com/aws/aws-sdk-go/pull/3102))
* Fixes [#3093](https://github.com/aws/aws-sdk-go/issues/3093)
3 changes: 3 additions & 0 deletions aws/corehandlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -290,6 +291,8 @@ func TestValidateReqSigHandler(t *testing.T) {
}

for i, c := range cases {
c.Req.HTTPRequest = &http.Request{URL: &url.URL{}}

resigned := false
c.Req.Handlers.Sign.PushBack(func(r *request.Request) {
resigned = true
Expand Down
4 changes: 2 additions & 2 deletions aws/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
}

SanitizeHostForHeader(httpReq)

r := &Request{
Config: cfg,
ClientInfo: clientInfo,
Expand Down Expand Up @@ -426,6 +424,8 @@ func (r *Request) Sign() error {
return r.Error
}

SanitizeHostForHeader(r.HTTPRequest)

r.Handlers.Sign.Run(r)
return r.Error
}
Expand Down
148 changes: 126 additions & 22 deletions aws/request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
Expand Down Expand Up @@ -938,25 +935,6 @@ func TestRequest_Presign(t *testing.T) {
}
}

func TestNew_EndpointWithDefaultPort(t *testing.T) {
endpoint := "https://estest.us-east-1.es.amazonaws.com:443"
expectedRequestHost := "estest.us-east-1.es.amazonaws.com"

r := request.New(
aws.Config{},
metadata.ClientInfo{Endpoint: endpoint},
defaults.Handlers(),
client.DefaultRetryer{},
&request.Operation{},
nil,
nil,
)

if h := r.HTTPRequest.Host; h != expectedRequestHost {
t.Errorf("expect %v host, got %q", expectedRequestHost, h)
}
}

func TestSanitizeHostForHeader(t *testing.T) {
cases := []struct {
url string
Expand Down Expand Up @@ -1150,6 +1128,132 @@ func TestRequestBodySeekFails(t *testing.T) {

}

func TestRequestEndpointWithDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

if e, a := "example.test", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://example.test:443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestRequestEndpointWithNonDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:8443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

// http.Request.Host should not be set for non-default ports
if e, a := "", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://example.test:8443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestRequestMarshaledEndpointWithDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Build.PushBack(func(r *request.Request) {
req := r.HTTPRequest
req.URL.Host = "foo." + req.URL.Host
})
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

if e, a := "foo.example.test", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://foo.example.test:443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestRequestMarshaledEndpointWithNonDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:8443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Build.PushBack(func(r *request.Request) {
req := r.HTTPRequest
req.URL.Host = "foo." + req.URL.Host
})
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

// http.Request.Host should not be set for non-default ports
if e, a := "", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://foo.example.test:8443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

type stubSeekFail struct {
Err error
}
Expand Down
5 changes: 4 additions & 1 deletion aws/request/timeout_read_closer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -61,7 +62,9 @@ func TestTimeoutReadCloserSameDuration(t *testing.T) {

func TestWithResponseReadTimeout(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
HTTPRequest: &http.Request{
URL: &url.URL{},
},
HTTPResponse: &http.Response{
Body: ioutil.NopCloser(bytes.NewReader(nil)),
},
Expand Down
64 changes: 62 additions & 2 deletions service/s3/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/aws/aws-sdk-go/awstesting/unit"
)

func TestEndpointARN(t *testing.T) {
func TestEndpoint(t *testing.T) {
cases := map[string]struct {
bucket string
config *aws.Config
Expand Down Expand Up @@ -193,6 +193,65 @@ func TestEndpointARN(t *testing.T) {
},
expectedErr: "client partition does not match provided ARN partition",
},
"bucket host-style": {
bucket: "mock-bucket",
config: &aws.Config{Region: aws.String("us-west-2")},
expectedEndpoint: "https://mock-bucket.s3.us-west-2.amazonaws.com",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket path-style": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
S3ForcePathStyle: aws.Bool(true),
},
expectedEndpoint: "https://s3.us-west-2.amazonaws.com",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket host-style endpoint with default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:443"),
},
expectedEndpoint: "https://mock-bucket.s3.us-west-2.amazonaws.com:443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket host-style endpoint with non-default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:8443"),
},
expectedEndpoint: "https://mock-bucket.s3.us-west-2.amazonaws.com:8443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket path-style endpoint with default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:443"),
S3ForcePathStyle: aws.Bool(true),
},
expectedEndpoint: "https://s3.us-west-2.amazonaws.com:443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket path-style endpoint with non-default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:8443"),
S3ForcePathStyle: aws.Bool(true),
},
expectedEndpoint: "https://s3.us-west-2.amazonaws.com:8443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
}

for name, c := range cases {
Expand Down Expand Up @@ -221,7 +280,8 @@ func TestEndpointARN(t *testing.T) {
if e, a := c.expectedEndpoint, endpoint; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := c.expectedSigningName, r.ClientInfo.SigningName; e != a {

if e, a := c.expectedSigningName, r.ClientInfo.SigningName; c.config.Endpoint == nil && e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := c.expectedSigningRegion, r.ClientInfo.SigningRegion; e != a {
Expand Down

0 comments on commit 1edb62f

Please sign in to comment.