Skip to content

Commit

Permalink
feat: add flag to disable IMDSv1 fallback (#4748)
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd committed Mar 13, 2023
1 parent f9c4a37 commit 6c34cf0
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 42 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

### SDK Enhancements

* `aws/ec2metadata`: Added an option to disable fallback to IMDSv1.
* When set the SDK will no longer fallback to IMDSv1 when fetching a token fails. Use `aws.WithEC2MetadataDisableFallback` to enable.

### SDK Bugs
64 changes: 46 additions & 18 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// }))
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// }))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
type Config struct {
// Enables verbose error printing of all credential chain errors.
// Should be used when wanting to see all errors while attempting to
Expand Down Expand Up @@ -192,6 +192,23 @@ type Config struct {
//
EC2MetadataDisableTimeoutOverride *bool

// Set this to `false` to disable EC2Metadata client from falling back to IMDSv1.
// By default, EC2 role credentials will fall back to IMDSv1 as needed for backwards compatibility.
// You can disable this behavior by explicitly setting this flag to `false`. When false, the EC2Metadata
// client will return any errors encountered from attempting to fetch a token instead of silently
// using the insecure data flow of IMDSv1.
//
// Example:
// sess := session.Must(session.NewSession(aws.NewConfig()
// .WithEC2MetadataEnableFallback(false)))
//
// svc := s3.New(sess)
//
// See [configuring IMDS] for more information.
//
// [configuring IMDS]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
EC2MetadataEnableFallback *bool

// Instructs the endpoint to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
Expand Down Expand Up @@ -283,16 +300,16 @@ type Config struct {
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
//
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// ))
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// ))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
func NewConfig() *Config {
return &Config{}
}
Expand Down Expand Up @@ -432,6 +449,13 @@ func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config {
return c
}

// WithEC2MetadataEnableFallback sets a config EC2MetadataEnableFallback value
// returning a Config pointer for chaining.
func (c *Config) WithEC2MetadataEnableFallback(v bool) *Config {
c.EC2MetadataEnableFallback = &v
return c
}

// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
Expand Down Expand Up @@ -576,6 +600,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
}

if other.EC2MetadataEnableFallback != nil {
dst.EC2MetadataEnableFallback = other.EC2MetadataEnableFallback
}

if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
Expand Down
42 changes: 34 additions & 8 deletions aws/ec2metadata/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ const (
NotFoundRequestTestType
InvalidTokenRequestTestType
ServerErrorForTokenTestType
pageNotFoundForTokenTestType
pageNotFoundWith401TestType
PageNotFoundForTokenTestType
PageNotFoundWith401TestType
ThrottleErrorForTokenNoFallbackTestType
)

type testServer struct {
Expand Down Expand Up @@ -126,12 +127,15 @@ func newTestServer(t *testing.T, testType testType, testServer *testServer) *htt
case ServerErrorForTokenTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.serverErrorGetTokenHandler))
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case pageNotFoundForTokenTestType:
case PageNotFoundForTokenTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler))
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case pageNotFoundWith401TestType:
case PageNotFoundWith401TestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler))
mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler)
case ThrottleErrorForTokenNoFallbackTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.throtleErrorGetTokenHandler))
mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler)

}

Expand Down Expand Up @@ -213,6 +217,10 @@ func (s *testServer) unauthorizedGetLatestHandler(w http.ResponseWriter, r *http
http.Error(w, "", 401)
}

func (s *testServer) throtleErrorGetTokenHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", 429)
}

func (opListProvider *operationListProvider) addToOperationPerformedList(r *request.Request) {
opListProvider.operationsPerformed = append(opListProvider.operationsPerformed, r.Operation.Name)
}
Expand Down Expand Up @@ -241,6 +249,7 @@ func TestGetMetadata(t *testing.T) {
expectedData string
expectedError string
expectedOperationsAttempted []string
enableImdsFallback *bool
}{
"Insecure server success case": {
NewServer: func(t *testing.T, tokens []string) *httptest.Server {
Expand Down Expand Up @@ -325,6 +334,21 @@ func TestGetMetadata(t *testing.T) {
expectedData: "IMDSProfileForGoSDK",
expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetMetadata"},
},
"No fallback to IMDSv1": {
NewServer: func(t *testing.T, tokens []string) *httptest.Server {
testType := ThrottleErrorForTokenNoFallbackTestType
Ts := &testServer{
t: t,
tokens: []string{},
data: "IMDSProfileForGoSDK",
}
return newTestServer(t, testType, Ts)
},
expectedError: "failed to get IMDSv2 token and fallback to IMDSv1 is disabled",
// 2 attempts + 2 retries per/attempt
expectedOperationsAttempted: []string{"GetToken", "GetToken", "GetToken", "GetToken", "GetToken", "GetToken"},
enableImdsFallback: aws.Bool(false),
},
}

for name, x := range cases {
Expand All @@ -336,8 +360,10 @@ func TestGetMetadata(t *testing.T) {
op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
Endpoint: aws.String(server.URL),
EC2MetadataEnableFallback: x.enableImdsFallback,
})

c.Handlers.CompleteAttempt.PushBack(op.addToOperationPerformedList)

tokenCounter := -1
Expand Down Expand Up @@ -953,7 +979,7 @@ func TestExhaustiveRetryToFetchToken(t *testing.T) {
data: "IMDSProfileForSDKGo",
}

server := newTestServer(t, pageNotFoundForTokenTestType, ts)
server := newTestServer(t, PageNotFoundForTokenTestType, ts)
defer server.Close()

op := &operationListProvider{}
Expand Down Expand Up @@ -1007,7 +1033,7 @@ func TestExhaustiveRetryWith401(t *testing.T) {
data: "IMDSProfileForSDKGo",
}

server := newTestServer(t, pageNotFoundWith401TestType, ts)
server := newTestServer(t, PageNotFoundWith401TestType, ts)
defer server.Close()

op := &operationListProvider{}
Expand Down Expand Up @@ -1117,7 +1143,7 @@ func TestRequestTimeOut(t *testing.T) {
t.Fatalf("Expected no error, got %v", err)
}

expectedOperationsPerformed = []string{"GetToken", "GetMetadata", "GetMetadata"}
expectedOperationsPerformed = []string{"GetToken", "GetMetadata", "GetToken", "GetMetadata"}
if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) {
t.Fatalf("expect %v operations, got %v", e, a)
}
Expand Down
10 changes: 5 additions & 5 deletions aws/ec2metadata/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ type EC2Metadata struct {
// New creates a new instance of the EC2Metadata client with a session.
// This client is safe to use across multiple goroutines.
//
//
// Example:
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
//
// // Create a EC2Metadata client with additional configuration
// svc := ec2metadata.New(mySession, aws.NewConfig().WithLogLevel(aws.LogDebugHTTPBody))
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
//
// // Create a EC2Metadata client with additional configuration
// svc := ec2metadata.New(mySession, aws.NewConfig().WithLogLevel(aws.LogDebugHTTPBody))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
c := p.ClientConfig(ServiceName, cfgs...)
return NewClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
Expand Down
25 changes: 14 additions & 11 deletions aws/ec2metadata/token_provider.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ec2metadata

import (
"fmt"
"net/http"
"sync/atomic"
"time"
Expand Down Expand Up @@ -33,11 +34,15 @@ func newTokenProvider(c *EC2Metadata, duration time.Duration) *tokenProvider {
return &tokenProvider{client: c, configuredTTL: duration}
}

// check if fallback is enabled
func (t *tokenProvider) fallbackEnabled() bool {
return t.client.Config.EC2MetadataEnableFallback == nil || *t.client.Config.EC2MetadataEnableFallback
}

// fetchTokenHandler fetches token for EC2Metadata service client by default.
func (t *tokenProvider) fetchTokenHandler(r *request.Request) {

// short-circuits to insecure data flow if tokenProvider is disabled.
if v := atomic.LoadUint32(&t.disabled); v == 1 {
if v := atomic.LoadUint32(&t.disabled); v == 1 && t.fallbackEnabled() {
return
}

Expand All @@ -49,23 +54,21 @@ func (t *tokenProvider) fetchTokenHandler(r *request.Request) {
output, err := t.client.getToken(r.Context(), t.configuredTTL)

if err != nil {
// only attempt fallback to insecure data flow if IMDSv1 is enabled
if !t.fallbackEnabled() {
r.Error = awserr.New("EC2MetadataError", "failed to get IMDSv2 token and fallback to IMDSv1 is disabled", err)
return
}

// change the disabled flag on token provider to true,
// when error is request timeout error.
// change the disabled flag on token provider to true and fallback
if requestFailureError, ok := err.(awserr.RequestFailure); ok {
switch requestFailureError.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed:
atomic.StoreUint32(&t.disabled, 1)
t.client.Config.Logger.Log(fmt.Sprintf("WARN: failed to get session token, falling back to IMDSv1: %v", requestFailureError))
case http.StatusBadRequest:
r.Error = requestFailureError
}

// Check if request timed out while waiting for response
if e, ok := requestFailureError.OrigErr().(awserr.Error); ok {
if e.Code() == request.ErrCodeRequestError {
atomic.StoreUint32(&t.disabled, 1)
}
}
}
return
}
Expand Down

0 comments on commit 6c34cf0

Please sign in to comment.