Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws/request: Fix bug where the host header would not reflect changes to the endpoint URL. #3102

Merged
merged 4 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions aws/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
}

SanitizeHostForHeader(httpReq)

r := &Request{
Config: cfg,
ClientInfo: clientInfo,
Expand Down Expand Up @@ -426,6 +424,8 @@ func (r *Request) Sign() error {
return r.Error
}

SanitizeHostForHeader(r.HTTPRequest)

r.Handlers.Sign.Run(r)
return r.Error
}
Expand Down
148 changes: 126 additions & 22 deletions aws/request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
Expand Down Expand Up @@ -938,25 +935,6 @@ func TestRequest_Presign(t *testing.T) {
}
}

func TestNew_EndpointWithDefaultPort(t *testing.T) {
endpoint := "https://estest.us-east-1.es.amazonaws.com:443"
expectedRequestHost := "estest.us-east-1.es.amazonaws.com"

r := request.New(
aws.Config{},
metadata.ClientInfo{Endpoint: endpoint},
defaults.Handlers(),
client.DefaultRetryer{},
&request.Operation{},
nil,
nil,
)

if h := r.HTTPRequest.Host; h != expectedRequestHost {
t.Errorf("expect %v host, got %q", expectedRequestHost, h)
}
}

func TestSanitizeHostForHeader(t *testing.T) {
cases := []struct {
url string
Expand Down Expand Up @@ -1150,6 +1128,132 @@ func TestRequestBodySeekFails(t *testing.T) {

}

func TestRequestEndpointWithDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

if e, a := "example.test", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://example.test:443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestRequestEndpointWithNonDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:8443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

// http.Request.Host should not be set for non-default ports
if e, a := "", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://example.test:8443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestRequestMarshaledEndpointWithDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Build.PushBack(func(r *request.Request) {
req := r.HTTPRequest
req.URL.Host = "foo." + req.URL.Host
})
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

if e, a := "foo.example.test", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://foo.example.test:443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestRequestMarshaledEndpointWithNonDefaultPort(t *testing.T) {
s := awstesting.NewClient(&aws.Config{
Endpoint: aws.String("https://example.test:8443"),
})
r := s.NewRequest(&request.Operation{
Name: "FooBar",
HTTPMethod: "GET",
HTTPPath: "/",
}, nil, nil)
r.Handlers.Validate.Clear()
r.Handlers.ValidateResponse.Clear()
r.Handlers.Build.PushBack(func(r *request.Request) {
req := r.HTTPRequest
req.URL.Host = "foo." + req.URL.Host
})
r.Handlers.Send.Clear()
r.Handlers.Send.PushFront(func(r *request.Request) {
req := r.HTTPRequest

// http.Request.Host should not be set for non-default ports
if e, a := "", req.Host; e != a {
t.Errorf("expected %v, got %v", e, a)
}

if e, a := "https://foo.example.test:8443/", req.URL.String(); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
err := r.Send()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

type stubSeekFail struct {
Err error
}
Expand Down
5 changes: 4 additions & 1 deletion aws/request/timeout_read_closer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -61,7 +62,9 @@ func TestTimeoutReadCloserSameDuration(t *testing.T) {

func TestWithResponseReadTimeout(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
HTTPRequest: &http.Request{
URL: &url.URL{},
},
HTTPResponse: &http.Response{
Body: ioutil.NopCloser(bytes.NewReader(nil)),
},
Expand Down
64 changes: 62 additions & 2 deletions service/s3/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/aws/aws-sdk-go/awstesting/unit"
)

func TestEndpointARN(t *testing.T) {
func TestEndpoint(t *testing.T) {
cases := map[string]struct {
bucket string
config *aws.Config
Expand Down Expand Up @@ -193,6 +193,65 @@ func TestEndpointARN(t *testing.T) {
},
expectedErr: "client partition does not match provided ARN partition",
},
"bucket host-style": {
bucket: "mock-bucket",
config: &aws.Config{Region: aws.String("us-west-2")},
expectedEndpoint: "https://mock-bucket.s3.us-west-2.amazonaws.com",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket path-style": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
S3ForcePathStyle: aws.Bool(true),
},
expectedEndpoint: "https://s3.us-west-2.amazonaws.com",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket host-style endpoint with default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:443"),
},
expectedEndpoint: "https://mock-bucket.s3.us-west-2.amazonaws.com:443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket host-style endpoint with non-default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:8443"),
},
expectedEndpoint: "https://mock-bucket.s3.us-west-2.amazonaws.com:8443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket path-style endpoint with default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:443"),
S3ForcePathStyle: aws.Bool(true),
},
expectedEndpoint: "https://s3.us-west-2.amazonaws.com:443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
"bucket path-style endpoint with non-default port": {
bucket: "mock-bucket",
config: &aws.Config{
Region: aws.String("us-west-2"),
Endpoint: aws.String("https://s3.us-west-2.amazonaws.com:8443"),
S3ForcePathStyle: aws.Bool(true),
},
expectedEndpoint: "https://s3.us-west-2.amazonaws.com:8443",
expectedSigningName: "s3",
expectedSigningRegion: "us-west-2",
},
}

for name, c := range cases {
Expand Down Expand Up @@ -221,7 +280,8 @@ func TestEndpointARN(t *testing.T) {
if e, a := c.expectedEndpoint, endpoint; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := c.expectedSigningName, r.ClientInfo.SigningName; e != a {

if e, a := c.expectedSigningName, r.ClientInfo.SigningName; c.config.Endpoint == nil && e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := c.expectedSigningRegion, r.ClientInfo.SigningRegion; e != a {
Expand Down