From b0869444b859dbc984541fa9712a11912518a198 Mon Sep 17 00:00:00 2001 From: David Pait Date: Thu, 20 Aug 2020 10:45:23 -0400 Subject: [PATCH] move IMDSv2 401 retries to Request function. move Panic to main function. add tests for 401 retries. --- cmd/node-termination-handler.go | 19 ++++++++++-- pkg/ec2metadata/ec2metadata.go | 46 ++++++++++++++++++----------- pkg/ec2metadata/ec2metadata_test.go | 31 +++++++++++++++++++ 3 files changed, 75 insertions(+), 21 deletions(-) diff --git a/cmd/node-termination-handler.go b/cmd/node-termination-handler.go index d8e59b36..5b4582bd 100644 --- a/cmd/node-termination-handler.go +++ b/cmd/node-termination-handler.go @@ -34,9 +34,10 @@ import ( ) const ( - scheduledMaintenance = "Scheduled Maintenance" - spotITN = "Spot ITN" - timeFormat = "2006/01/02 15:04:05" + scheduledMaintenance = "Scheduled Maintenance" + spotITN = "Spot ITN" + timeFormat = "2006/01/02 15:04:05" + duplicateErrThreshold = 3 ) func main() { @@ -100,11 +101,23 @@ func main() { for _, fn := range monitoringFns { go func(monitor monitor.Monitor) { log.Log().Msgf("Started monitoring for %s events", monitor.Kind()) + var previousErr error + var duplicateErrCount int for range time.Tick(time.Second * 2) { err := monitor.Monitor() if err != nil { log.Log().Msgf("There was a problem monitoring for %s events: %v", monitor.Kind(), err) metrics.ErrorEventsInc(monitor.Kind()) + if err == previousErr { + duplicateErrCount++ + } else { + duplicateErrCount = 0 + previousErr = err + } + if duplicateErrCount > duplicateErrThreshold { + log.Log().Msg("Stopping NITH - Duplicate Error Threshold hit.") + panic(fmt.Sprintf("%v",err)) + } } } }(fn) diff --git a/pkg/ec2metadata/ec2metadata.go b/pkg/ec2metadata/ec2metadata.go index 38a991b1..935fafe7 100644 --- a/pkg/ec2metadata/ec2metadata.go +++ b/pkg/ec2metadata/ec2metadata.go @@ -52,6 +52,7 @@ const ( tokenRequestHeader = "X-aws-ec2-metadata-token" tokenTTL = 3600 // 1 hour secondsBeforeTTLRefresh = 15 + tokenRetryAttempts = 2 ) // Service is used to query the EC2 instance metadata service v1 and v2 @@ -181,28 +182,37 @@ func (e *Service) Request(contextPath string) (*http.Response, error) { if err != nil { return nil, fmt.Errorf("Unable to construct an http get request to IDMS for %s: %w", e.metadataURL+contextPath, err) } - if e.v2Token == "" || e.tokenTTL <= secondsBeforeTTLRefresh { - e.Lock() - token, ttl, err := e.getV2Token() + var resp *http.Response + for i := 0; i < tokenRetryAttempts; i++ { + if e.v2Token == "" || e.tokenTTL <= secondsBeforeTTLRefresh { + e.Lock() + token, ttl, err := e.getV2Token() + if err != nil { + e.v2Token = "" + e.tokenTTL = -1 + log.Log().Msgf("Unable to retrieve an IMDSv2 token, continuing with IMDSv1: %v", err) + } else { + e.v2Token = token + e.tokenTTL = ttl + } + e.Unlock() + } + if e.v2Token != "" { + req.Header.Add(tokenRequestHeader, e.v2Token) + } + httpReq := func() (*http.Response, error) { + return e.httpClient.Do(req) + } + resp, err = retry(e.tries, 2*time.Second, httpReq) if err != nil { + return nil, fmt.Errorf("Unable to get a response from IMDS: %w", err) + } + if resp != nil && resp.StatusCode == 401 { e.v2Token = "" - e.tokenTTL = -1 - log.Log().Msgf("Unable to retrieve an IMDSv2 token, continuing with IMDSv1: %v", err) + e.tokenTTL = 0 } else { - e.v2Token = token - e.tokenTTL = ttl + break } - e.Unlock() - } - if e.v2Token != "" { - req.Header.Add(tokenRequestHeader, e.v2Token) - } - httpReq := func() (*http.Response, error) { - return e.httpClient.Do(req) - } - resp, err := retry(e.tries, 2*time.Second, httpReq) - if err != nil { - return nil, fmt.Errorf("Unable to get a response from IMDS: %w", err) } ttl, err := ttlHeaderToInt(resp) if err == nil { diff --git a/pkg/ec2metadata/ec2metadata_test.go b/pkg/ec2metadata/ec2metadata_test.go index 62c930d6..988811db 100644 --- a/pkg/ec2metadata/ec2metadata_test.go +++ b/pkg/ec2metadata/ec2metadata_test.go @@ -114,6 +114,37 @@ func TestRequest500(t *testing.T) { h.Equals(t, 500, resp.StatusCode) } +func TestRequest401(t *testing.T) { + var requestPath string = "/some/path" + + tokenGenerationCounter := 0 + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Add("X-aws-ec2-metadata-token-ttl-seconds", "100") + if req.URL.String() == "/latest/api/token" { + rw.WriteHeader(200) + rw.Write([]byte(`token`)) + return + } + h.Equals(t, req.URL.String(), requestPath) + if tokenGenerationCounter < 1 { + rw.WriteHeader(401) + tokenGenerationCounter++ + } else { + rw.WriteHeader(200) + } + + })) + defer server.Close() + + // Use URL from our local test server + imds := ec2metadata.New(server.URL, 1) + + resp, err := imds.Request(requestPath) + h.Ok(t, err) + h.Equals(t, 200, resp.StatusCode) + h.Equals(t, 1, tokenGenerationCounter) +} + func TestRequestConstructFail(t *testing.T) { imds := ec2metadata.New("test", 0)