Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for EC2 IMDS endpoint from environment variable #3504

Merged
merged 3 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
### SDK Enhancements

### SDK Bugs
* `aws/ec2metadata`: Add support for EC2 IMDS endpoint from environment variable ([#3504](https://github.com/aws/aws-sdk-go/pull/3504))
* Adds support for specifying a custom EC2 IMDS endpoint from the environment variable, `AWS_EC2_METADATA_SERVICE_ENDPOINT`.
* The `aws/session#Options` struct also has a new field, `EC2IMDSEndpoint`. This field can be used to configure the custom endpoint of the EC2 IMDS client. The option only applies to EC2 IMDS clients created after the Session with `EC2IMDSEndpoint` is specified.
2 changes: 1 addition & 1 deletion aws/defaults/defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestDefaultEC2RoleProvider(t *testing.T) {
if ec2Provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.169.254/latest", ec2Provider.Client.Endpoint; e != a {
if e, a := "http://169.254.169.254", ec2Provider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}
8 changes: 4 additions & 4 deletions aws/ec2metadata/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (c *EC2Metadata) getToken(ctx aws.Context, duration time.Duration) (tokenOu
op := &request.Operation{
Name: "GetToken",
HTTPMethod: "PUT",
HTTPPath: "/api/token",
HTTPPath: "/latest/api/token",
}

var output tokenOutput
Expand Down Expand Up @@ -62,7 +62,7 @@ func (c *EC2Metadata) GetMetadataWithContext(ctx aws.Context, p string) (string,
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/meta-data", p),
HTTPPath: sdkuri.PathJoin("/latest/meta-data", p),
}
output := &metadataOutput{}

Expand All @@ -88,7 +88,7 @@ func (c *EC2Metadata) GetUserDataWithContext(ctx aws.Context) (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
HTTPPath: "/user-data",
HTTPPath: "/latest/user-data",
}

output := &metadataOutput{}
Expand All @@ -113,7 +113,7 @@ func (c *EC2Metadata) GetDynamicDataWithContext(ctx aws.Context, p string) (stri
op := &request.Operation{
Name: "GetDynamicData",
HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/dynamic", p),
HTTPPath: sdkuri.PathJoin("/latest/dynamic", p),
}

output := &metadataOutput{}
Expand Down
79 changes: 54 additions & 25 deletions aws/ec2metadata/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/internal/sdktesting"
)

const instanceIdentityDocument = `{
Expand Down Expand Up @@ -106,22 +107,22 @@ func newTestServer(t *testing.T, testType testType, testServer *testServer) *htt
switch testType {
case SecureTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.secureGetTokenHandler))
mux.HandleFunc("/latest/", testServer.secureGetLatestHandler)
mux.HandleFunc("/", testServer.secureGetLatestHandler)
case InsecureTestType:
mux.HandleFunc("/latest/api/token", testServer.insecureGetTokenHandler)
mux.HandleFunc("/latest/", testServer.insecureGetLatestHandler)
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case BadRequestTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.badRequestGetTokenHandler))
mux.HandleFunc("/latest/", testServer.badRequestGetLatestHandler)
mux.HandleFunc("/", testServer.badRequestGetLatestHandler)
case ServerErrorForTokenTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.serverErrorGetTokenHandler))
mux.HandleFunc("/latest/", testServer.insecureGetLatestHandler)
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case pageNotFoundForTokenTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler))
mux.HandleFunc("/latest/", testServer.insecureGetLatestHandler)
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case pageNotFoundWith401TestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler))
mux.HandleFunc("/latest/", testServer.unauthorizedGetLatestHandler)
mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler)

}

Expand Down Expand Up @@ -204,17 +205,17 @@ func (opListProvider *operationListProvider) addToOperationPerformedList(r *requ
}

func TestEndpoint(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()

c := ec2metadata.New(unit.Session)
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", "testpath"),
HTTPPath: path.Join("/latest", "meta-data", "testpath"),
}

req := c.NewRequest(op, nil, nil)
if e, a := "http://169.254.169.254/latest", req.ClientInfo.Endpoint; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
Expand Down Expand Up @@ -289,7 +290,9 @@ func TestGetMetadata(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

resp, err := c.GetMetadata("some/path")
Expand Down Expand Up @@ -340,7 +343,9 @@ func TestGetUserData_Error(t *testing.T) {
}))

defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})

resp, err := c.GetUserData()
if err == nil {
Expand Down Expand Up @@ -425,7 +430,9 @@ func TestGetRegion(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

resp, err := c.Region()
Expand Down Expand Up @@ -494,7 +501,9 @@ func TestMetadataIAMInfo_success(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

iamInfo, err := c.IAMInfo()
Expand Down Expand Up @@ -570,7 +579,9 @@ func TestMetadataIAMInfo_failure(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

iamInfo, err := c.IAMInfo()
Expand Down Expand Up @@ -675,7 +686,9 @@ func TestEC2RoleProviderInstanceIdentity(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)
doc, err := c.GetInstanceIdentityDocument()

Expand Down Expand Up @@ -719,7 +732,9 @@ func TestEC2MetadataRetryFailure(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})

c.Handlers.AfterRetry.PushBack(func(i *request.Request) {
t.Logf("%v received, retrying operation %v", i.HTTPResponse.StatusCode, i.Operation.Name)
Expand Down Expand Up @@ -774,7 +789,9 @@ func TestEC2MetadataRetryOnce(t *testing.T) {

server := httptest.NewServer(mux)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})

// Handler on client that logs if retried
c.Handlers.AfterRetry.PushBack(func(i *request.Request) {
Expand Down Expand Up @@ -807,7 +824,9 @@ func TestEC2Metadata_Concurrency(t *testing.T) {
server := newTestServer(t, SecureTestType, ts)
defer server.Close()

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})

var wg sync.WaitGroup
wg.Add(10)
Expand Down Expand Up @@ -838,11 +857,13 @@ func TestRequestOnMetadata(t *testing.T) {
server := newTestServer(t, SecureTestType, ts)
defer server.Close()

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
req := c.NewRequest(&request.Operation{
Name: "Ec2Metadata request",
HTTPMethod: "GET",
HTTPPath: "/latest",
HTTPPath: "/latest/foo",
Paginator: nil,
BeforePresignFn: nil,
}, nil, nil)
Expand Down Expand Up @@ -878,7 +899,9 @@ func TestExhaustiveRetryToFetchToken(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

resp, err := c.GetMetadata("/some/path")
Expand Down Expand Up @@ -930,7 +953,9 @@ func TestExhaustiveRetryWith401(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

resp, err := c.GetMetadata("/some/path")
Expand Down Expand Up @@ -991,7 +1016,9 @@ func TestRequestTimeOut(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
// for test, change the timeout to 100 ms
c.Config.HTTPClient.Timeout = 100 * time.Millisecond

Expand Down Expand Up @@ -1068,7 +1095,9 @@ func TestTokenExpiredBehavior(t *testing.T) {

op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
})
c.Handlers.Complete.PushBack(op.addToOperationPerformedList)

resp, err := c.GetMetadata("/some/path")
Expand Down
17 changes: 17 additions & 0 deletions aws/ec2metadata/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environment variable is set to true, (case insensitive).
//
// The endpoint of the EC2 IMDS client can be configured via the environment
// variable, AWS_EC2_METADATA_SERVICE_ENDPOINT when creating the client with a
// Session. See aws/session#Options.EC2IMDSEndpoint for more details.
package ec2metadata

import (
"bytes"
"errors"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -69,6 +74,9 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
// a client when not using a session. Generally using just New with a session
// is preferred.
//
// Will remove the URL path from the endpoint provided to ensure the EC2 IMDS
// client is able to communicate with the EC2 IMDS API.
//
// If an unmodified HTTP client is provided from the stdlib default, or no client
// the EC2RoleProvider's EC2Metadata HTTP client's timeout will be shortened.
// To disable this set Config.EC2MetadataDisableTimeoutOverride to false. Enabled by default.
Expand All @@ -86,6 +94,15 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg.MaxRetries = aws.Int(2)
}

if u, err := url.Parse(endpoint); err == nil {
// Remove path from the endpoint since it will be added by requests.
// This is an artifact of the SDK adding `/latest` to the endpoint for
// EC2 IMDS, but this is now moved to the operation definition.
u.Path = ""
u.RawPath = ""
endpoint = u.String()
}

svc := &EC2Metadata{
Client: client.New(
cfg,
Expand Down
37 changes: 34 additions & 3 deletions aws/ec2metadata/service_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// +build go1.7

package ec2metadata_test

import (
Expand Down Expand Up @@ -89,9 +91,7 @@ func TestClientDisableIMDS(t *testing.T) {

os.Setenv("AWS_EC2_METADATA_DISABLED", "true")

svc := ec2metadata.New(unit.Session, &aws.Config{
LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody),
})
svc := ec2metadata.New(unit.Session)
resp, err := svc.GetUserData()
if err == nil {
t.Fatalf("expect error, got none")
Expand All @@ -109,6 +109,37 @@ func TestClientDisableIMDS(t *testing.T) {
}
}

func TestClientStripPath(t *testing.T) {
cases := map[string]struct {
Endpoint string
Expect string
}{
"no change": {
Endpoint: "http://example.aws",
Expect: "http://example.aws",
},
"strip path": {
Endpoint: "http://example.aws/foo",
Expect: "http://example.aws",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()

svc := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(c.Endpoint),
})

if e, a := c.Expect, svc.ClientInfo.Endpoint; e != a {
t.Errorf("expect %v endpoint, got %v", e, a)
}
})
}
}

func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
Expand Down