diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index 869701b9d6..ec3882d8a5 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -37,6 +37,7 @@ import ( rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime" "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" @@ -78,8 +79,8 @@ const ( acsProtocolVersion = 2 // numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across // sessions. We use this to send pending acks, before agent initiates a disconnect to ACS. - // they are: refreshCredentialsHandler, taskManifestHandler, and payloadHandler - numOfHandlersSendingAcks = 3 + // they are: taskManifestHandler, and payloadHandler + numOfHandlersSendingAcks = 2 ) // Session defines an interface for handler's long-lived connection with ACS. @@ -250,13 +251,7 @@ func (acsSession *session) startSessionOnce() error { func (acsSession *session) startACSSession(client wsclient.ClientServer) error { cfg := acsSession.agentConfig - refreshCredsHandler := newRefreshCredentialsHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN, - client, acsSession.credentialsManager, acsSession.taskEngine) - defer refreshCredsHandler.clearAcks() - refreshCredsHandler.start() - defer refreshCredsHandler.stop() - - client.AddRequestHandler(refreshCredsHandler.handlerFunc()) + credsMetadataSetter := &credentialsMetadataSetter{taskEngine: acsSession.taskEngine} eniHandler := &eniHandler{ state: acsSession.state, @@ -265,6 +260,8 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { manifestMessageIDAccessor := &manifestMessageIDAccessor{} + metricsFactory := metrics.NewNopEntryFactory() + // Add TaskManifestHandler taskManifestHandler := newTaskManifestHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN, client, acsSession.dataClient, acsSession.taskEngine, acsSession.latestSeqNumTaskManifest, @@ -286,7 +283,6 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { acsSession.containerInstanceARN, client, acsSession.dataClient, - refreshCredsHandler, acsSession.credentialsManager, acsSession.taskHandler, acsSession.latestSeqNumTaskManifest) // Clear the acks channel on return because acks of messageids don't have any value across sessions @@ -300,6 +296,8 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { return client.MakeRequest(response) } responders := []wsclient.RequestResponder{ + acssession.NewRefreshCredentialsResponder(acsSession.credentialsManager, credsMetadataSetter, metricsFactory, + responseSender), acssession.NewAttachTaskENIResponder(eniHandler, responseSender), acssession.NewAttachInstanceENIResponder(eniHandler, responseSender), acssession.NewHeartbeatResponder(acsSession.doctor, responseSender), @@ -320,7 +318,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { // Start a connection timer; agent will send pending acks and close its ACS websocket connection // after this timer expires connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter, - &refreshCredsHandler, &taskManifestHandler, &payloadHandler) + &taskManifestHandler, &payloadHandler) defer connectionTimer.Stop() // Start a heartbeat timer for closing the connection @@ -416,7 +414,6 @@ func newConnectionTimer( client wsclient.ClientServer, connectionTime time.Duration, connectionJitter time.Duration, - refreshCredsHandler *refreshCredentialsHandler, taskManifestHandler *taskManifestHandler, payloadHandler *payloadRequestHandler, ) ttime.Timer { @@ -427,12 +424,6 @@ func newConnectionTimer( wg := sync.WaitGroup{} wg.Add(numOfHandlersSendingAcks) - // send pending creds refresh acks to ACS - go func() { - refreshCredsHandler.sendPendingAcks() - wg.Done() - }() - // send pending task manifest acks and task stop verification acks to ACS go func() { taskManifestHandler.sendPendingTaskManifestMessageAck() diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index cd9f35f924..96300a8416 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -1162,8 +1162,6 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { // Ensure that credentials manager interface methods are invoked in the // correct order, with expected arguments gomock.InOrder( - // Return a task from the engine for GetTaskByArn - taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true), // The last invocation of SetCredentials is to update // credentials when a refresh message is received by the handler credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()).Do(func(creds *rolecredentials.TaskIAMRoleCredentials) { @@ -1185,6 +1183,8 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { t.Errorf("Mismatch between expected and credentials expected: %v, added: %v", expectedCreds, updatedCredentials) } }).Return(nil), + // Return a task from the engine for GetTaskByArn + taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true), ) serverIn <- sampleRefreshCredentialsMessage diff --git a/agent/acs/handler/attach_eni_handler_common_test.go b/agent/acs/handler/attach_eni_handler_common_test.go index 8694aae22f..6ac9fbaefc 100644 --- a/agent/acs/handler/attach_eni_handler_common_test.go +++ b/agent/acs/handler/attach_eni_handler_common_test.go @@ -54,7 +54,7 @@ func testENIAckTimeout(t *testing.T, attachmentType string) { expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis) eniAttachment := &apieni.ENIAttachment{ AttachmentInfo: attachmentinfo.AttachmentInfo{ - TaskARN: taskArn, + TaskARN: testconst.TaskARN, AttachmentARN: attachmentArn, ExpiresAt: expiresAt, AttachStatusSent: false, @@ -103,7 +103,7 @@ func testENIAckWithinTimeout(t *testing.T, attachmentType string) { expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis) eniAttachment := &apieni.ENIAttachment{ AttachmentInfo: attachmentinfo.AttachmentInfo{ - TaskARN: taskArn, + TaskARN: testconst.TaskARN, AttachmentARN: attachmentArn, ExpiresAt: expiresAt, AttachStatusSent: false, @@ -130,7 +130,7 @@ func testENIAckWithinTimeout(t *testing.T, attachmentType string) { // TestHandleENIAttachmentTaskENI tests handling a new task eni func TestHandleENIAttachmentTaskENI(t *testing.T) { - testHandleENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn) + testHandleENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, testconst.TaskARN) } // TestHandleENIAttachmentInstanceENI tests handling a new instance eni @@ -178,7 +178,7 @@ func testHandleENIAttachment(t *testing.T, attachmentType, taskArn string) { // TestHandleExpiredENIAttachmentTaskENI tests handling an expired task eni func TestHandleExpiredENIAttachmentTaskENI(t *testing.T) { - testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn) + testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, testconst.TaskARN) } // TestHandleExpiredENIAttachmentInstanceENI tests handling an expired instance eni diff --git a/agent/acs/handler/payload_handler.go b/agent/acs/handler/payload_handler.go index d606a5c71d..9afcfe58cb 100644 --- a/agent/acs/handler/payload_handler.go +++ b/agent/acs/handler/payload_handler.go @@ -51,7 +51,6 @@ type payloadRequestHandler struct { cluster string containerInstanceArn string acsClient wsclient.ClientServer - refreshHandler refreshCredentialsHandler credentialsManager credentials.Manager latestSeqNumberTaskManifest *int64 } @@ -65,7 +64,6 @@ func newPayloadRequestHandler( containerInstanceArn string, acsClient wsclient.ClientServer, dataClient data.Client, - refreshHandler refreshCredentialsHandler, credentialsManager credentials.Manager, taskHandler *eventhandler.TaskHandler, seqNumTaskManifest *int64) payloadRequestHandler { // Create a cancelable context from the parent context @@ -82,7 +80,6 @@ func newPayloadRequestHandler( cluster: cluster, containerInstanceArn: containerInstanceArn, acsClient: acsClient, - refreshHandler: refreshHandler, credentialsManager: credentialsManager, latestSeqNumberTaskManifest: seqNumTaskManifest, } @@ -187,7 +184,7 @@ func (payloadHandler *payloadRequestHandler) handleSingleMessage(payload *ecsacs go func() { // Throw the ack in async; it doesn't really matter all that much and this is blocking handling more tasks. for _, credentialsAck := range credentialsAcks { - payloadHandler.refreshHandler.ackMessage(credentialsAck) + payloadHandler.makeCredentialsAckRequest(credentialsAck) } payloadHandler.ackRequest <- *payload.MessageId }() @@ -195,6 +192,16 @@ func (payloadHandler *payloadRequestHandler) handleSingleMessage(payload *ecsacs return nil } +// makeCredentialsAckRequest sends an IAMRoleCredentialsAckRequest to the backend +func (payloadHandler *payloadRequestHandler) makeCredentialsAckRequest(ack *ecsacs.IAMRoleCredentialsAckRequest) { + seelog.Debugf("ACKing credentials associated with ACS payload message: %s", ack.String()) + err := payloadHandler.acsClient.MakeRequest(ack) + if err != nil { + seelog.Warnf("Error ACKing credentials with credentialsID '%s' associated with ACS payload message, error: %v", + aws.StringValue(ack.CredentialsId), err) + } +} + // addPayloadTasks does validation on each task and, for all valid ones, adds // it to the task engine. It returns a bool indicating if it could add every // task to the taskEngine and a slice of credential ack requests diff --git a/agent/acs/handler/payload_handler_test.go b/agent/acs/handler/payload_handler_test.go index ff9d69240a..ad5b88031c 100644 --- a/agent/acs/handler/payload_handler_test.go +++ b/agent/acs/handler/payload_handler_test.go @@ -80,7 +80,6 @@ func setup(t *testing.T) *testHelper { testconst.ContainerInstanceARN, mockWsClient, data.NewNoopClient(), - refreshCredentialsHandler{}, credentialsManager, taskHandler, &latestSeqNumberTaskManifest) @@ -307,11 +306,6 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) { }), ) - refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) - defer refreshCredsHandler.clearAcks() - refreshCredsHandler.start() - tester.payloadHandler.refreshHandler = refreshCredsHandler - go tester.payloadHandler.start() taskArn := "t1" @@ -338,8 +332,8 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) { }, }, MessageId: aws.String(payloadMessageId), - ClusterArn: aws.String(cluster), - ContainerInstanceArn: aws.String(containerInstance), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), } err := tester.payloadHandler.handleSingleMessage(payloadMessage) assert.NoError(t, err, "error handling payload message") @@ -496,11 +490,6 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) { }), ) - refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) - defer refreshCredsHandler.clearAcks() - refreshCredsHandler.start() - tester.payloadHandler.refreshHandler = refreshCredsHandler - go tester.payloadHandler.start() firstTaskArn := "t1" @@ -546,8 +535,8 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) { }, }, MessageId: aws.String(payloadMessageId), - ClusterArn: aws.String(cluster), - ContainerInstanceArn: aws.String(containerInstance), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), } // Wait till we get an ack @@ -618,11 +607,7 @@ func TestAddPayloadTaskAddsExecutionRoles(t *testing.T) { tester.cancel() }), ) - refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) - defer refreshCredsHandler.clearAcks() - refreshCredsHandler.start() - tester.payloadHandler.refreshHandler = refreshCredsHandler go tester.payloadHandler.start() taskArn := "t1" credentialsExpiration := "expiration" diff --git a/agent/acs/handler/refresh_credentials_handler.go b/agent/acs/handler/refresh_credentials_handler.go deleted file mode 100644 index 8236593655..0000000000 --- a/agent/acs/handler/refresh_credentials_handler.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -package handler - -import ( - "context" - "fmt" - - "github.com/aws/amazon-ecs-agent/agent/engine" - "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" - "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" - "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" - "github.com/aws/aws-sdk-go/aws" - "github.com/cihub/seelog" - - "github.com/pkg/errors" -) - -var ( - // For ease of unit testing - checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials -) - -// refreshCredentialsHandler represents the refresh credentials operation for the ACS client -type refreshCredentialsHandler struct { - // messageBuffer is used to process IAMRoleCredentialsMessages received from the server - messageBuffer chan *ecsacs.IAMRoleCredentialsMessage - // ackRequest is used to send acks to the backend - ackRequest chan *ecsacs.IAMRoleCredentialsAckRequest - ctx context.Context - // cancel is used to stop go routines started by start() method - cancel context.CancelFunc - cluster *string - containerInstance *string - acsClient wsclient.ClientServer - credentialsManager credentials.Manager - taskEngine engine.TaskEngine -} - -// newRefreshCredentialsHandler returns a new refreshCredentialsHandler object -func newRefreshCredentialsHandler(ctx context.Context, cluster string, containerInstanceArn string, acsClient wsclient.ClientServer, credentialsManager credentials.Manager, taskEngine engine.TaskEngine) refreshCredentialsHandler { - // Create a cancelable context from the parent context - derivedContext, cancel := context.WithCancel(ctx) - return refreshCredentialsHandler{ - messageBuffer: make(chan *ecsacs.IAMRoleCredentialsMessage), - ackRequest: make(chan *ecsacs.IAMRoleCredentialsAckRequest), - ctx: derivedContext, - cancel: cancel, - cluster: aws.String(cluster), - containerInstance: aws.String(containerInstanceArn), - acsClient: acsClient, - credentialsManager: credentialsManager, - taskEngine: taskEngine, - } -} - -// handlerFunc returns the request handler function for the ecsacs.IAMRoleCredentialsMessage -func (refreshHandler *refreshCredentialsHandler) handlerFunc() func(message *ecsacs.IAMRoleCredentialsMessage) { - // return a function that just enqueues IAMRoleCredentials messages into the message buffer - return func(message *ecsacs.IAMRoleCredentialsMessage) { - refreshHandler.messageBuffer <- message - } -} - -// start invokes go routines to: -// 1. handle messages in the refresh credentials message buffer -// 2. handle ack requests to be sent to ACS -func (refreshHandler *refreshCredentialsHandler) start() { - go refreshHandler.handleMessages() - go refreshHandler.sendAcks() -} - -// stop cancels the context being used by the refresh credentials handler. This is used -// to stop the go routines started by 'start()' -func (refreshHandler *refreshCredentialsHandler) stop() { - refreshHandler.cancel() -} - -// sendAcks sends ack requests to ACS -func (refreshHandler *refreshCredentialsHandler) sendAcks() { - for { - select { - case ack := <-refreshHandler.ackRequest: - refreshHandler.ackMessage(ack) - case <-refreshHandler.ctx.Done(): - return - } - } -} - -// sendPendingAcks sends pending acks to ACS before closing the connection -func (refreshHandler *refreshCredentialsHandler) sendPendingAcks() { - for { - select { - case ack := <-refreshHandler.ackRequest: - refreshHandler.ackMessage(ack) - default: - return - } - } -} - -// ackMessageId sends an IAMRoleCredentialsAckRequest to the backend -func (refreshHandler *refreshCredentialsHandler) ackMessage(ack *ecsacs.IAMRoleCredentialsAckRequest) { - err := refreshHandler.acsClient.MakeRequest(ack) - if err != nil { - seelog.Warnf("Error 'ack'ing request with messageID: %s, error: %v", aws.StringValue(ack.MessageId), err) - } - seelog.Debugf("Acking credentials message: %s", ack.String()) -} - -// handleMessages processes refresh credentials messages in the buffer in-order -func (refreshHandler *refreshCredentialsHandler) handleMessages() { - for { - select { - case message := <-refreshHandler.messageBuffer: - refreshHandler.handleSingleMessage(message) - case <-refreshHandler.ctx.Done(): - return - } - } -} - -// handleSingleMessage processes a single refresh credentials message. -func (refreshHandler *refreshCredentialsHandler) handleSingleMessage(message *ecsacs.IAMRoleCredentialsMessage) error { - // Validate fields in the message - err := validateIAMRoleCredentialsMessage(message) - if err != nil { - seelog.Errorf("Error validating credentials message: %v", err) - return err - } - taskArn := aws.StringValue(message.TaskArn) - messageId := aws.StringValue(message.MessageId) - task, ok := refreshHandler.taskEngine.GetTaskByArn(taskArn) - if !ok { - seelog.Errorf("Task not found in the engine for the arn in credentials message, arn: %s, messageId: %s", taskArn, messageId) - return fmt.Errorf("task not found in the engine for the arn in credentials message, arn: %s", taskArn) - } - - roleType := aws.StringValue(message.RoleType) - if !validRoleType(roleType) { - seelog.Errorf("Unknown RoleType for task in credentials message, roleType: %s arn: %s, messageId: %s", roleType, taskArn, messageId) - } else { - iamRoleCredentials := credentials.IAMRoleCredentialsFromACS(message.RoleCredentials, roleType) - err = refreshHandler.credentialsManager.SetTaskCredentials( - &(credentials.TaskIAMRoleCredentials{ - ARN: taskArn, - IAMRoleCredentials: iamRoleCredentials, - })) - if err != nil { - seelog.Errorf("Unable to update credentials for task, err: %v messageId: %s", err, messageId) - return fmt.Errorf("unable to update credentials %v", err) - } - - if roleType == credentials.ApplicationRoleType { - task.SetCredentialsID(aws.StringValue(message.RoleCredentials.CredentialsId)) - } - if roleType == credentials.ExecutionRoleType { - task.SetExecutionRoleCredentialsID(aws.StringValue(message.RoleCredentials.CredentialsId)) - // Refresh domainless gMSA plugin credentials if needed - err = checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl(iamRoleCredentials, task) - if err != nil { - seelog.Errorf("Unable to SetDomainlessGMSATaskExecutionRoleCredentials for task %s, err: %v messageId: %s", taskArn, err, messageId) - return errors.Wrap(err, "unable to SetDomainlessGMSATaskExecutionRoleCredentials") - } - } - } - - go func() { - response := &ecsacs.IAMRoleCredentialsAckRequest{ - Expiration: message.RoleCredentials.Expiration, - MessageId: message.MessageId, - CredentialsId: message.RoleCredentials.CredentialsId, - } - refreshHandler.ackRequest <- response - }() - return nil -} - -// validateIAMRoleCredentialsMessage validates fields in the IAMRoleCredentialsMessage -// It returns an error if any of the following fields are not set in the message: -// messageId, taskArn, roleCredentials -func validateIAMRoleCredentialsMessage(message *ecsacs.IAMRoleCredentialsMessage) error { - if message == nil { - return fmt.Errorf("empty credentials message") - } - - messageId := aws.StringValue(message.MessageId) - if messageId == "" { - return fmt.Errorf("message id not set in credentials message") - } - - if aws.StringValue(message.TaskArn) == "" { - return fmt.Errorf("task Arn not set in credentials message") - } - - if message.RoleCredentials == nil { - return fmt.Errorf("role Credentials not set in credentials message: messageId: %s", messageId) - } - - if aws.StringValue(message.RoleCredentials.CredentialsId) == "" { - return fmt.Errorf("role Credentials ID not set in credentials message: messageId: %s", messageId) - } - - return nil -} - -// clearAcks drains the ack request channel -func (refreshHandler *refreshCredentialsHandler) clearAcks() { - for { - select { - case <-refreshHandler.ackRequest: - default: - return - } - } -} - -// validRoleType returns false if the RoleType in the acs refresh payload is not -// one of the expected types. TaskApplication, TaskExecution -func validRoleType(roleType string) bool { - switch roleType { - case credentials.ApplicationRoleType: - return true - case credentials.ExecutionRoleType: - return true - default: - return false - } -} diff --git a/agent/acs/handler/refresh_credentials_handler_test.go b/agent/acs/handler/refresh_credentials_handler_test.go deleted file mode 100644 index 375bce36fb..0000000000 --- a/agent/acs/handler/refresh_credentials_handler_test.go +++ /dev/null @@ -1,460 +0,0 @@ -//go:build unit -// +build unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -package handler - -import ( - "context" - "fmt" - "reflect" - "sync" - "testing" - "time" - - apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - apitask "github.com/aws/amazon-ecs-agent/agent/api/task" - mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" - "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" - "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" - mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" - - "github.com/aws/aws-sdk-go/aws" - "github.com/golang/mock/gomock" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -const ( - messageId = "message1" - taskArn = "task1" - cluster = "default" - containerInstance = "instance" - expiration = "soon" - roleArn = "taskrole1" - accessKey = "akid" - secretKey = "secret" - sessionToken = "token" - credentialsId = "credsid" - roleType = "TaskExecution" -) - -var expectedAck = &ecsacs.IAMRoleCredentialsAckRequest{ - Expiration: aws.String(expiration), - MessageId: aws.String(messageId), - CredentialsId: aws.String(credentialsId), -} - -var expectedCredentials = credentials.TaskIAMRoleCredentials{ - ARN: taskArn, - IAMRoleCredentials: credentials.IAMRoleCredentials{ - RoleArn: roleArn, - AccessKeyID: accessKey, - SecretAccessKey: secretKey, - SessionToken: sessionToken, - Expiration: expiration, - CredentialsID: credentialsId, - RoleType: roleType, - }, -} - -var message = &ecsacs.IAMRoleCredentialsMessage{ - MessageId: aws.String(messageId), - TaskArn: aws.String(taskArn), - RoleType: aws.String(roleType), - RoleCredentials: &ecsacs.IAMRoleCredentials{ - RoleArn: aws.String(roleArn), - Expiration: aws.String(expiration), - AccessKeyId: aws.String(accessKey), - SecretAccessKey: aws.String(secretKey), - SessionToken: aws.String(sessionToken), - CredentialsId: aws.String(credentialsId), - }, -} - -// TestValidateRefreshMessageWithNilMessage tests if a validation error -// is returned while validating an empty credentials message -func TestValidateRefreshMessageWithNilMessage(t *testing.T) { - err := validateIAMRoleCredentialsMessage(nil) - if err == nil { - t.Error("Expected validation error validating an empty message") - } -} - -// TestValidateRefreshMessageWithNoMessageId tests if a validation error -// is returned while validating a credentials message with no message id -func TestValidateRefreshMessageWithNoMessageId(t *testing.T) { - message := &ecsacs.IAMRoleCredentialsMessage{} - err := validateIAMRoleCredentialsMessage(message) - if err == nil { - t.Error("Expected validation error validating a message with no message id") - } - message.MessageId = aws.String("") - err = validateIAMRoleCredentialsMessage(message) - if err == nil { - t.Error("Expected validation error validating a message with empty message id") - } -} - -// TestValidateRefreshMessageWithNoRoleCredentials tests if a validation error -// is returned while validating a credentials message with no role credentials -func TestValidateRefreshMessageWithNoRoleCredentials(t *testing.T) { - message := &ecsacs.IAMRoleCredentialsMessage{ - MessageId: aws.String(messageId), - } - err := validateIAMRoleCredentialsMessage(message) - if err == nil { - t.Error("Expected validation error validating a message with no role credentials") - } -} - -// TestValidateRefreshMessageWithNoCredentialsId tests if a valid error -// is returned while validating a credentials message with no credentials id -func TestValidateRefreshMessageWithNoCredentialsId(t *testing.T) { - message := &ecsacs.IAMRoleCredentialsMessage{ - MessageId: aws.String(messageId), - RoleCredentials: &ecsacs.IAMRoleCredentials{}, - } - err := validateIAMRoleCredentialsMessage(message) - if err == nil { - t.Error("Expected validation error validating a message with no credentials id") - } - message.RoleCredentials.CredentialsId = aws.String("") - err = validateIAMRoleCredentialsMessage(message) - if err == nil { - t.Error("Expected validation error validating a message with empty credentials id") - } -} - -// TestValidateRefreshMessageWithNoTaskArn tests if a validation error -// is returned while validating a credentials message with no task arn -func TestValidateRefreshMessageWithNoTaskArn(t *testing.T) { - message := &ecsacs.IAMRoleCredentialsMessage{ - MessageId: aws.String(messageId), - RoleCredentials: &ecsacs.IAMRoleCredentials{ - CredentialsId: aws.String("id"), - }, - } - err := validateIAMRoleCredentialsMessage(message) - if err == nil { - t.Error("Expected validation error validating a message with no task arn") - } -} - -// TestValidateRefreshMessageSuccess tests if a valid credentials message -// is validated without any errors -func TestValidateRefreshMessageSuccess(t *testing.T) { - message := &ecsacs.IAMRoleCredentialsMessage{ - MessageId: aws.String(messageId), - RoleCredentials: &ecsacs.IAMRoleCredentials{ - CredentialsId: aws.String("id"), - }, - TaskArn: aws.String(taskArn), - } - err := validateIAMRoleCredentialsMessage(message) - if err != nil { - t.Errorf("Error validating credentials message: %v", err) - } -} - -// TestInvalidCredentialsMessageNotAcked tests if invalid credential messages -// are not acked -func TestInvalidCredentialsMessageNotAcked(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - credentialsManager := credentials.NewManager() - - ctx, cancel := context.WithCancel(context.Background()) - handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, nil) - - // Start a goroutine to listen for acks. Cancelling the context stops the goroutine - go func() { - for { - select { - // We never expect the message to be acked - case <-handler.ackRequest: - t.Fatalf("Received ack when none expected") - case <-ctx.Done(): - return - } - } - }() - - // test adding a credentials message without the MessageId field - message := &ecsacs.IAMRoleCredentialsMessage{} - err := handler.handleSingleMessage(message) - if err == nil { - t.Error("Expected error updating credentials when the message contains no message id") - } - cancel() -} - -// TestCredentialsMessageNotAckedWhenTaskNotFound tests if credential messages -// are not acked when the task arn in the message is not found in the task -// engine -func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - credentialsManager := credentials.NewManager() - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - // Return task not found from the engine for GetTaskByArn - taskEngine.EXPECT().GetTaskByArn(taskArn).Return(nil, false) - - ctx, cancel := context.WithCancel(context.Background()) - handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine) - - // Start a goroutine to listen for acks. Cancelling the context stops the goroutine - go func() { - for { - select { - // We never expect the message to be acked - case <-handler.ackRequest: - t.Fatalf("Received ack when none expected") - case <-ctx.Done(): - return - } - } - }() - - // Test adding a credentials message without the MessageId field - err := handler.handleSingleMessage(message) - if err == nil { - t.Error("Expected error updating credentials when the message contains unexpected task arn") - } - cancel() -} - -// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message -// is ackd when the credentials are updated successfully and the domainless gMSA plugin credentials are updated successfully -func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) { - testCases := []struct { - name string - taskArn string - domainlessGMSATaskExpectedInput bool - containers []*apicontainer.Container - }{ - { - name: "EmptyTaskSucceeds", - taskArn: taskArn, - containers: []*apicontainer.Container{}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - credentialsManager := credentials.NewManager() - - ctx, cancel := context.WithCancel(context.Background()) - var ackRequested *ecsacs.IAMRoleCredentialsAckRequest - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { - ackRequested = ackRequest - cancel() - }).Times(1) - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - // Return a task from the engine for GetTaskByArn - taskEngine.EXPECT().GetTaskByArn(tc.taskArn).Return(&apitask.Task{Arn: tc.taskArn, Containers: tc.containers}, true) - - checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = func(iamRoleCredentials credentials.IAMRoleCredentials, task *apitask.Task) error { - if tc.taskArn != task.Arn { - return errors.New(fmt.Sprintf("Expected taskArnInput to be %s, instead got %s", tc.taskArn, task.Arn)) - } - - return nil - } - - defer func() { - checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials - }() - - handler := newRefreshCredentialsHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWsClient, credentialsManager, taskEngine) - go handler.sendAcks() - - // test adding a credentials message without the MessageId field - err := handler.handleSingleMessage(message) - if err != nil { - t.Errorf("Error updating credentials: %v", err) - } - - // Wait till we get an ack from the ackBuffer - select { - case <-ctx.Done(): - } - - if !reflect.DeepEqual(ackRequested, expectedAck) { - t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested) - } - - creds, exist := credentialsManager.GetTaskCredentials(credentialsId) - if !exist { - t.Errorf("Expected credentials to exist for the task") - } - if !reflect.DeepEqual(creds, expectedCredentials) { - t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds) - } - }) - } -} - -// TestCredentialsMessageNotAckedWhenDomainlessGMSACredentialsNotSet tests if credential messages -// are not acked when setting the domainlessGMSA Credentials fails -func TestCredentialsMessageNotAckedWhenDomainlessGMSACredentialsError(t *testing.T) { - testCases := []struct { - name string - taskArn string - containers []*apicontainer.Container - domainlessGMSATaskExpectedInput bool - setDomainlessGMSATaskExecutionRoleCredentialsImplError error - expectedErrorString string - }{ - { - name: "ErrDomainlessTask", - taskArn: taskArn, - containers: []*apicontainer.Container{{CredentialSpecs: []string{"credentialspecdomainless:file://gmsa_gmsa-acct.json"}}}, - domainlessGMSATaskExpectedInput: true, - setDomainlessGMSATaskExecutionRoleCredentialsImplError: errors.New("mock setDomainlessGMSATaskExecutionRoleCredentialsImplError"), - expectedErrorString: "unable to SetDomainlessGMSATaskExecutionRoleCredentials: mock setDomainlessGMSATaskExecutionRoleCredentialsImplError", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - credentialsManager := credentials.NewManager() - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - // Return a task from the engine for GetTaskByArn - taskEngine.EXPECT().GetTaskByArn(tc.taskArn).Return(&apitask.Task{Arn: tc.taskArn, Containers: tc.containers}, true) - - checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = func(iamRoleCredentials credentials.IAMRoleCredentials, task *apitask.Task) error { - if tc.taskArn != task.Arn { - return errors.New(fmt.Sprintf("Expected taskArnInput to be %s, instead got %s", tc.taskArn, task.Arn)) - } - - return tc.setDomainlessGMSATaskExecutionRoleCredentialsImplError - } - - defer func() { - checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials - }() - - ctx, cancel := context.WithCancel(context.Background()) - handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine) - - // Start a goroutine to listen for acks. Cancelling the context stops the goroutine - go func() { - for { - select { - // We never expect the message to be acked - case <-handler.ackRequest: - t.Fatalf("Received ack when none expected") - case <-ctx.Done(): - return - } - } - }() - - err := handler.handleSingleMessage(message) - assert.EqualError(t, err, tc.expectedErrorString) - cancel() - }) - } -} - -func TestRefreshCredentialsHandlerSendPendingAcks(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx := context.TODO() - credentialsManager := credentials.NewManager() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1) - - handler := newRefreshCredentialsHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, - credentialsManager, taskEngine) - - wg := sync.WaitGroup{} - wg.Add(2) - - // write a dummy ack into the ackRequest - go func() { - handler.ackRequest <- expectedAck - wg.Done() - }() - - // sleep here to ensure that the sending go routine above executes before the receiving one below. if not, then the - // receiving go routine will finish without receiving the ack msg since sendPendingAcks() is non-blocking. - time.Sleep(1 * time.Second) - - go func() { - handler.sendPendingAcks() - wg.Done() - }() - - // wait for both go routines above to finish before we verify that ack channel is empty and exit the test. - // this also ensures that the mock MakeRequest call happened as expected. - wg.Wait() - - // verify that the ackRequest channel is empty - assert.Equal(t, 0, len(handler.ackRequest)) -} - -// TestRefreshCredentialsHandler tests if a credential message is acked when -// the message is sent to the messageBuffer channel -func TestRefreshCredentialsHandler(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - credentialsManager := credentials.NewManager() - - ctx, cancel := context.WithCancel(context.Background()) - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - var ackRequested *ecsacs.IAMRoleCredentialsAckRequest - mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { - ackRequested = ackRequest - cancel() - }).Times(1) - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - // Return a task from the engine for GetTaskByArn - taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&apitask.Task{}, true) - - handler := newRefreshCredentialsHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWsClient, credentialsManager, taskEngine) - go handler.start() - - handler.messageBuffer <- message - // Wait till we get an ack - select { - case <-ctx.Done(): - } - - if !reflect.DeepEqual(ackRequested, expectedAck) { - t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested) - } - - creds, exist := credentialsManager.GetTaskCredentials(credentialsId) - if !exist { - t.Errorf("Expected credentials to exist for the task") - } - if !reflect.DeepEqual(creds, expectedCredentials) { - t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds) - } -} diff --git a/agent/acs/handler/refresh_credentials_responder.go b/agent/acs/handler/refresh_credentials_responder.go new file mode 100644 index 0000000000..d46978acd7 --- /dev/null +++ b/agent/acs/handler/refresh_credentials_responder.go @@ -0,0 +1,65 @@ +package handler + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/pkg/errors" + + apitask "github.com/aws/amazon-ecs-agent/agent/api/task" + "github.com/aws/amazon-ecs-agent/agent/engine" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" +) + +var ( + // For ease of unit testing + checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials +) + +// credentialsMetadataSetter implements CredentialsMetadataSetter +type credentialsMetadataSetter struct { + taskEngine engine.TaskEngine +} + +func (cmSetter *credentialsMetadataSetter) SetTaskRoleMetadata( + message *ecsacs.IAMRoleCredentialsMessage) error { + task, err := cmSetter.getCredentialsMessageTask(message) + if err != nil { + return err + } + task.SetCredentialsID(aws.StringValue(message.RoleCredentials.CredentialsId)) + return nil +} + +func (cmSetter *credentialsMetadataSetter) SetExecRoleMetadata( + message *ecsacs.IAMRoleCredentialsMessage) error { + task, err := cmSetter.getCredentialsMessageTask(message) + if err != nil { + return errors.Wrap(err, "unable to get credentials message's task") + } + task.SetExecutionRoleCredentialsID(aws.StringValue(message.RoleCredentials.CredentialsId)) + + // Refresh domainless gMSA plugin credentials if needed. + err = checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl(credentials.IAMRoleCredentialsFromACS( + message.RoleCredentials, aws.StringValue(message.RoleType)), task) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("unable to set %s for task with ARN %s", + "DomainlessGMSATaskExecutionRoleCredentials", aws.StringValue(message.TaskArn))) + } + + return nil +} + +func (cmSetter *credentialsMetadataSetter) getCredentialsMessageTask( + message *ecsacs.IAMRoleCredentialsMessage) (*apitask.Task, error) { + taskARN := aws.StringValue(message.TaskArn) + messageID := aws.StringValue(message.MessageId) + task, ok := cmSetter.taskEngine.GetTaskByArn(taskARN) + if !ok { + return nil, errors.Errorf( + "Task not found in the task engine for task ARN %s from credentials message with message ID %s", + taskARN, messageID) + } + return task, nil +} diff --git a/agent/acs/handler/refresh_credentials_handler_linux.go b/agent/acs/handler/refresh_credentials_responder_linux.go similarity index 100% rename from agent/acs/handler/refresh_credentials_handler_linux.go rename to agent/acs/handler/refresh_credentials_responder_linux.go diff --git a/agent/acs/handler/refresh_credentials_responder_test.go b/agent/acs/handler/refresh_credentials_responder_test.go new file mode 100644 index 0000000000..622001dc3e --- /dev/null +++ b/agent/acs/handler/refresh_credentials_responder_test.go @@ -0,0 +1,275 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package handler + +import ( + "fmt" + "testing" + + apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" + apitask "github.com/aws/amazon-ecs-agent/agent/api/task" + mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/aws-sdk-go/aws" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +const ( + expiration = "soon" + roleArn = "taskrole1" + accessKey = "akid" + secretKey = "secret" + sessionToken = "token" + roleType = "TaskExecution" +) + +var expectedCredentialsAck = &ecsacs.IAMRoleCredentialsAckRequest{ + Expiration: aws.String(expiration), + MessageId: aws.String(testconst.MessageID), + CredentialsId: aws.String(testconst.CredentialsID), +} + +var expectedCredentials = credentials.TaskIAMRoleCredentials{ + ARN: testconst.TaskARN, + IAMRoleCredentials: credentials.IAMRoleCredentials{ + RoleArn: roleArn, + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + SessionToken: sessionToken, + Expiration: expiration, + CredentialsID: testconst.CredentialsID, + RoleType: roleType, + }, +} + +var testRefreshCredentialsMessage = &ecsacs.IAMRoleCredentialsMessage{ + MessageId: aws.String(testconst.MessageID), + TaskArn: aws.String(testconst.TaskARN), + RoleType: aws.String(roleType), + RoleCredentials: &ecsacs.IAMRoleCredentials{ + RoleArn: aws.String(roleArn), + Expiration: aws.String(expiration), + AccessKeyId: aws.String(accessKey), + SecretAccessKey: aws.String(secretKey), + SessionToken: aws.String(sessionToken), + CredentialsId: aws.String(testconst.CredentialsID), + }, +} + +// TestInvalidCredentialsMessageNotAcked tests that invalid credential message +// is not ACKed. +func TestInvalidCredentialsMessageNotAcked(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ackSent := false + testResponseSender := func(response interface{}) error { + ackSent = true + return nil + } + testRefreshCredentialsResponder := acssession.NewRefreshCredentialsResponder(credentials.NewManager(), + &credentialsMetadataSetter{ + taskEngine: nil, + }, + metrics.NewNopEntryFactory(), + testResponseSender) + + handleCredentialsMessage := testRefreshCredentialsResponder.HandlerFunc().(func(*ecsacs.IAMRoleCredentialsMessage)) + + // Test handling a credentials message without any fields set. + message := &ecsacs.IAMRoleCredentialsMessage{} + handleCredentialsMessage(message) + assert.False(t, ackSent, + "Expected no ACK of invalid refresh credentials message when it is invalid") +} + +// TestCredentialsMessageNotAckedWhenTaskNotFound tests if credential messages +// are not ACKed when the task ARN in the message is not found in the task +// engine. +func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ackSent := false + mockTaskEngine := mock_engine.NewMockTaskEngine(ctrl) + testResponseSender := func(response interface{}) error { + ackSent = true + return nil + } + testRefreshCredentialsResponder := acssession.NewRefreshCredentialsResponder(credentials.NewManager(), + &credentialsMetadataSetter{ + taskEngine: mockTaskEngine, + }, + metrics.NewNopEntryFactory(), + testResponseSender) + + handleCredentialsMessage := testRefreshCredentialsResponder.HandlerFunc().(func(*ecsacs.IAMRoleCredentialsMessage)) + + // Test handling a credentials message with a task ARN that is not in the task engine. + mockTaskEngine.EXPECT().GetTaskByArn(testconst.TaskARN).Return(nil, false) + handleCredentialsMessage(testRefreshCredentialsMessage) + assert.False(t, ackSent, + "Expected no ACK of invalid refresh credentials message when its task ARN is not in task engine") +} + +// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message +// is ACKed when the credentials are updated successfully and the domainless gMSA plugin credentials +// are updated successfully. +func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) { + testCases := []struct { + name string + taskArn string + containers []*apicontainer.Container + }{ + { + name: "EmptyTaskSucceeds", + taskArn: testconst.TaskARN, + containers: []*apicontainer.Container{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ackSent := make(chan *ecsacs.IAMRoleCredentialsAckRequest) + credentialsManager := credentials.NewManager() + mockTaskEngine := mock_engine.NewMockTaskEngine(ctrl) + + testResponseSender := func(response interface{}) error { + resp := response.(*ecsacs.IAMRoleCredentialsAckRequest) + ackSent <- resp + return nil + } + testRefreshCredentialsResponder := acssession.NewRefreshCredentialsResponder(credentialsManager, + &credentialsMetadataSetter{ + taskEngine: mockTaskEngine, + }, + metrics.NewNopEntryFactory(), + testResponseSender) + + handleCredentialsMessage := + testRefreshCredentialsResponder.HandlerFunc().(func(*ecsacs.IAMRoleCredentialsMessage)) + + checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = func( + iamRoleCredentials credentials.IAMRoleCredentials, task *apitask.Task) error { + if tc.taskArn != task.Arn { + return errors.New(fmt.Sprintf("Expected taskArnInput to be %s, instead got %s", tc.taskArn, + task.Arn)) + } + + return nil + } + + defer func() { + checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = + checkAndSetDomainlessGMSATaskExecutionRoleCredentials + }() + + // Return a task from the engine for GetTaskByArn. + mockTaskEngine.EXPECT().GetTaskByArn(tc.taskArn).Return( + &apitask.Task{Arn: tc.taskArn, Containers: tc.containers}, true) + + go handleCredentialsMessage(testRefreshCredentialsMessage) + + refreshCredentialsAckSent := <-ackSent + assert.Equal(t, expectedCredentialsAck, refreshCredentialsAckSent) + + creds, exist := credentialsManager.GetTaskCredentials(testconst.CredentialsID) + assert.True(t, exist, "Expected credentials to exist for the task") + assert.Equal(t, expectedCredentials, creds) + }) + } +} + +// TestCredentialsMessageNotAckedWhenDomainlessGMSACredentialsNotSet tests that credential messages +// are not ACKed when setting the domainless GMSA Credentials fails. +func TestCredentialsMessageNotAckedWhenDomainlessGMSACredentialsError(t *testing.T) { + testCases := []struct { + name string + taskArn string + containers []*apicontainer.Container + setDomainlessGMSATaskExecutionRoleCredentialsImplError error + }{ + { + name: "ErrDomainlessTask", + taskArn: testconst.TaskARN, + containers: []*apicontainer.Container{ + { + CredentialSpecs: []string{"credentialspecdomainless:file://gmsa_gmsa-acct.json"}, + }, + }, + setDomainlessGMSATaskExecutionRoleCredentialsImplError: errors.New( + "mock setDomainlessGMSATaskExecutionRoleCredentialsImplError"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var ackSent, errorOnSetDomainlessGMSACreds bool + credentialsManager := credentials.NewManager() + mockTaskEngine := mock_engine.NewMockTaskEngine(ctrl) + + testResponseSender := func(response interface{}) error { + ackSent = true + return nil + } + testRefreshCredentialsResponder := acssession.NewRefreshCredentialsResponder(credentialsManager, + &credentialsMetadataSetter{ + taskEngine: mockTaskEngine, + }, + metrics.NewNopEntryFactory(), + testResponseSender) + + handleCredentialsMessage := + testRefreshCredentialsResponder.HandlerFunc().(func(*ecsacs.IAMRoleCredentialsMessage)) + + checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = func( + iamRoleCredentials credentials.IAMRoleCredentials, task *apitask.Task) error { + if tc.taskArn != task.Arn { + return errors.New(fmt.Sprintf("Expected taskArnInput to be %s, instead got %s", tc.taskArn, task.Arn)) + } + errorOnSetDomainlessGMSACreds = true + return tc.setDomainlessGMSATaskExecutionRoleCredentialsImplError + } + + defer func() { + checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = + checkAndSetDomainlessGMSATaskExecutionRoleCredentials + }() + + // Return a task from the engine for GetTaskByArn. + mockTaskEngine.EXPECT().GetTaskByArn(tc.taskArn).Return( + &apitask.Task{Arn: tc.taskArn, Containers: tc.containers}, true) + + handleCredentialsMessage(testRefreshCredentialsMessage) + assert.True(t, errorOnSetDomainlessGMSACreds, + "Expected error when setting the domainless GMSA Credentials") + assert.False(t, ackSent, + "Expected no ACK of refresh credentials message when setting domainless GMSA Credentials fails") + }) + } +} diff --git a/agent/acs/handler/refresh_credentials_handler_windows.go b/agent/acs/handler/refresh_credentials_responder_windows.go similarity index 100% rename from agent/acs/handler/refresh_credentials_handler_windows.go rename to agent/acs/handler/refresh_credentials_responder_windows.go diff --git a/agent/acs/handler/task_manifest_handler_test.go b/agent/acs/handler/task_manifest_handler_test.go index 08fd171e71..b59a019048 100644 --- a/agent/acs/handler/task_manifest_handler_test.go +++ b/agent/acs/handler/task_manifest_handler_test.go @@ -467,7 +467,7 @@ func TestManifestHandlerSequenceNumbers(t *testing.T) { mockWSClient := mock_wsclient.NewMockClientServer(ctrl) manifestMessageIDAccessor := &manifestMessageIDAccessor{} - newTaskManifest := newTaskManifestHandler(ctx, cluster, testconst.ContainerInstanceARN, mockWSClient, + newTaskManifest := newTaskManifestHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, data.NewNoopClient(), taskEngine, aws.Int64(tc.inputSequenceNumber), manifestMessageIDAccessor) taskList := []*task.Task{ @@ -562,7 +562,7 @@ func TestTaskManifestHandlerSendPendingTaskManifestMessageAck(t *testing.T) { mockWSClient := mock_wsclient.NewMockClientServer(ctrl) mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1) manifestMessageIDAccessor := &manifestMessageIDAccessor{} - handler := newTaskManifestHandler(ctx, cluster, testconst.ContainerInstanceARN, mockWSClient, + handler := newTaskManifestHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, data.NewNoopClient(), taskEngine, aws.Int64(testSeqNum), manifestMessageIDAccessor) wg := sync.WaitGroup{} @@ -599,7 +599,7 @@ func TestTaskManifestHandlerHandlePendingTaskStopVerificationAck(t *testing.T) { taskEngine := mock_engine.NewMockTaskEngine(ctrl) mockWSClient := mock_wsclient.NewMockClientServer(ctrl) manifestMessageIDAccessor := &manifestMessageIDAccessor{} - handler := newTaskManifestHandler(ctx, cluster, testconst.ContainerInstanceARN, mockWSClient, + handler := newTaskManifestHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, data.NewNoopClient(), taskEngine, aws.Int64(testSeqNum), manifestMessageIDAccessor) wg := sync.WaitGroup{} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/generate_mocks.go index eecfbf1591..661e9d8107 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/generate_mocks.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/generate_mocks.go @@ -11,6 +11,6 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. -//go:generate mockgen -destination=mocks/session_mock.go -copyright_file=../../../scripts/copyright_file . ENIHandler,ResourceHandler +//go:generate mockgen -destination=mocks/session_mock.go -copyright_file=../../../scripts/copyright_file . ENIHandler,ResourceHandler,CredentialsMetadataSetter package session diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/refresh_credentials_responder.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/refresh_credentials_responder.go new file mode 100644 index 0000000000..df0c1f5cc4 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/refresh_credentials_responder.go @@ -0,0 +1,199 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package session + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/pkg/errors" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" +) + +const ( + RefreshCredentialsMessageName = "IAMRoleCredentialsMessage" +) + +type CredentialsMetadataSetter interface { + SetTaskRoleMetadata(message *ecsacs.IAMRoleCredentialsMessage) error + SetExecRoleMetadata(message *ecsacs.IAMRoleCredentialsMessage) error +} + +// refreshCredentialsResponder implements the wsclient.RequestResponder interface for responding +// to ecsacs.IAMRoleCredentialsMessage messages sent by ACS. +type refreshCredentialsResponder struct { + credentialsManager credentials.Manager + credsMetadataSetter CredentialsMetadataSetter + metricsFactory metrics.EntryFactory + respond wsclient.RespondFunc +} + +// NewRefreshCredentialsResponder returns an instance of the refreshCredentialsResponder struct. +func NewRefreshCredentialsResponder(credentialsManager credentials.Manager, + credsMetadataSetter CredentialsMetadataSetter, + metricsFactory metrics.EntryFactory, + responseSender wsclient.RespondFunc) wsclient.RequestResponder { + r := &refreshCredentialsResponder{ + credentialsManager: credentialsManager, + credsMetadataSetter: credsMetadataSetter, + metricsFactory: metricsFactory, + } + r.respond = ResponseToACSSender(r.Name(), responseSender) + return r +} + +func (*refreshCredentialsResponder) Name() string { return "refresh credentials responder" } + +func (r *refreshCredentialsResponder) HandlerFunc() wsclient.RequestHandler { + return r.handleCredentialsMessage +} + +func (r *refreshCredentialsResponder) handleCredentialsMessage(message *ecsacs.IAMRoleCredentialsMessage) { + logger.Debug(fmt.Sprintf("Handling %s", RefreshCredentialsMessageName)) + messageID := aws.StringValue(message.MessageId) + taskARN := aws.StringValue(message.TaskArn) + metricFields := logger.Fields{ + field.MessageID: messageID, + field.TaskARN: taskARN, + } + + // Validate fields in the message. + err := validateIAMRoleCredentialsMessage(message) + if err != nil { + logger.Error(fmt.Sprintf("Error validating %s received from ECS", RefreshCredentialsMessageName), + logger.Fields{ + field.Error: err, + }) + err = errors.Wrap(err, "ACS refresh credentials message validation failed") + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + + // Handle credentials refresh. + err = r.credentialsManager.SetTaskCredentials(&credentials.TaskIAMRoleCredentials{ + ARN: taskARN, + IAMRoleCredentials: credentials.IAMRoleCredentialsFromACS(message.RoleCredentials, + aws.StringValue(message.RoleType)), + }) + if err != nil { + logger.Error(fmt.Sprintf("Unable to handle %s due to error in setting credentials", + RefreshCredentialsMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + err = errors.Wrap(err, "unable to set credentials in the credentials manager") + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + err = r.setCredentialsMetadata(message) + if err != nil { + logger.Error(fmt.Sprintf("Unable to handle %s due to error in setting credentials metadata", + RefreshCredentialsMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + + // Send ACK. + err = r.respond(&ecsacs.IAMRoleCredentialsAckRequest{ + Expiration: message.RoleCredentials.Expiration, + MessageId: message.MessageId, + CredentialsId: message.RoleCredentials.CredentialsId, + }) + if err != nil { + logger.Warn(fmt.Sprintf("Error acknowledging %s", RefreshCredentialsMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + err = errors.Wrapf(err, "unable to ACK task credentials for task with ARN %s", taskARN) + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + + r.metricsFactory.New(metrics.CredentialsRefreshSuccess).WithFields(metricFields).Done(nil) +} + +func (r *refreshCredentialsResponder) setCredentialsMetadata(message *ecsacs.IAMRoleCredentialsMessage) error { + roleType := aws.StringValue(message.RoleType) + switch roleType { + case credentials.ApplicationRoleType: + err := r.credsMetadataSetter.SetTaskRoleMetadata(message) + if err != nil { + return errors.Wrap(err, "failed to set task role metadata") + } + case credentials.ExecutionRoleType: + err := r.credsMetadataSetter.SetExecRoleMetadata(message) + if err != nil { + return errors.Wrap(err, "failed to set execution role metadata") + } + default: + return errors.Errorf("received credentials for unexpected roleType \"%s\"", roleType) + } + return nil +} + +// validateIAMRoleCredentialsMessage performs validation checks on the +// IAMRoleCredentialsMessage. +func validateIAMRoleCredentialsMessage(message *ecsacs.IAMRoleCredentialsMessage) error { + if message == nil { + return errors.Errorf("Message is empty") + } + + messageID := aws.StringValue(message.MessageId) + if messageID == "" { + return errors.Errorf("Message ID is not set") + } + + taskArn := aws.StringValue(message.TaskArn) + if taskArn == "" { + return errors.Errorf("taskArn is not set for message ID %s", messageID) + } + + if message.RoleCredentials == nil { + return errors.Errorf("roleCredentials is not set for message ID %s", messageID) + } + + if aws.StringValue(message.RoleCredentials.CredentialsId) == "" { + return errors.Errorf("roleCredentials ID not set for message ID %s", messageID) + } + + roleType := aws.StringValue(message.RoleType) + if !validRoleType(roleType) { + return errors.Errorf("roleType \"%s\" is invalid for message ID %s with taskArn %s", roleType, messageID, + taskArn) + } + + return nil +} + +// validRoleType returns false if the RoleType in the ACS refresh credentials message is not +// one of the expected types. Expected types: TaskApplication, TaskExecution +func validRoleType(roleType string) bool { + switch roleType { + case credentials.ApplicationRoleType: + return true + case credentials.ExecutionRoleType: + return true + default: + return false + } +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go index ec7894e2d9..310224e713 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go @@ -25,4 +25,5 @@ const ( InterfaceProtocol = "default" GatewayIPv4 = "192.168.1.1/24" IPv4Address = "ipv4" + CredentialsID = "credsid" ) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/metrics/constants.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/metrics/constants.go index 4979940c33..fed662215c 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/metrics/constants.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/metrics/constants.go @@ -10,6 +10,7 @@ // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. + package metrics const ( @@ -34,4 +35,9 @@ const ( // TaskStopVerificationACKResponder taskStopVerificationACKResponderNamespace = "TaskStopVeificationACKResponder" TaskStoppedMetricName = taskStopVerificationACKResponderNamespace + ".TaskStopped" + + // Credentials Refresh + credsRefreshNamespace = "CredentialsRefresh" + CredentialsRefreshFailure = credsRefreshNamespace + ".Failure" + CredentialsRefreshSuccess = credsRefreshNamespace + ".Success" ) diff --git a/ecs-agent/acs/session/generate_mocks.go b/ecs-agent/acs/session/generate_mocks.go index eecfbf1591..661e9d8107 100644 --- a/ecs-agent/acs/session/generate_mocks.go +++ b/ecs-agent/acs/session/generate_mocks.go @@ -11,6 +11,6 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. -//go:generate mockgen -destination=mocks/session_mock.go -copyright_file=../../../scripts/copyright_file . ENIHandler,ResourceHandler +//go:generate mockgen -destination=mocks/session_mock.go -copyright_file=../../../scripts/copyright_file . ENIHandler,ResourceHandler,CredentialsMetadataSetter package session diff --git a/ecs-agent/acs/session/mocks/session_mock.go b/ecs-agent/acs/session/mocks/session_mock.go index 9361679149..53cc54614a 100644 --- a/ecs-agent/acs/session/mocks/session_mock.go +++ b/ecs-agent/acs/session/mocks/session_mock.go @@ -13,7 +13,7 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/ecs-agent/acs/session (interfaces: ENIHandler,ResourceHandler) +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/acs/session (interfaces: ENIHandler,ResourceHandler,CredentialsMetadataSetter) // Package mock_session is a generated GoMock package. package mock_session @@ -21,6 +21,7 @@ package mock_session import ( reflect "reflect" + ecsacs "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" eni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" resource "github.com/aws/amazon-ecs-agent/ecs-agent/api/resource" gomock "github.com/golang/mock/gomock" @@ -97,3 +98,54 @@ func (mr *MockResourceHandlerMockRecorder) HandleResourceAttachment(arg0 interfa mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleResourceAttachment", reflect.TypeOf((*MockResourceHandler)(nil).HandleResourceAttachment), arg0) } + +// MockCredentialsMetadataSetter is a mock of CredentialsMetadataSetter interface. +type MockCredentialsMetadataSetter struct { + ctrl *gomock.Controller + recorder *MockCredentialsMetadataSetterMockRecorder +} + +// MockCredentialsMetadataSetterMockRecorder is the mock recorder for MockCredentialsMetadataSetter. +type MockCredentialsMetadataSetterMockRecorder struct { + mock *MockCredentialsMetadataSetter +} + +// NewMockCredentialsMetadataSetter creates a new mock instance. +func NewMockCredentialsMetadataSetter(ctrl *gomock.Controller) *MockCredentialsMetadataSetter { + mock := &MockCredentialsMetadataSetter{ctrl: ctrl} + mock.recorder = &MockCredentialsMetadataSetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCredentialsMetadataSetter) EXPECT() *MockCredentialsMetadataSetterMockRecorder { + return m.recorder +} + +// SetExecRoleMetadata mocks base method. +func (m *MockCredentialsMetadataSetter) SetExecRoleMetadata(arg0 *ecsacs.IAMRoleCredentialsMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetExecRoleMetadata", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetExecRoleMetadata indicates an expected call of SetExecRoleMetadata. +func (mr *MockCredentialsMetadataSetterMockRecorder) SetExecRoleMetadata(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetExecRoleMetadata", reflect.TypeOf((*MockCredentialsMetadataSetter)(nil).SetExecRoleMetadata), arg0) +} + +// SetTaskRoleMetadata mocks base method. +func (m *MockCredentialsMetadataSetter) SetTaskRoleMetadata(arg0 *ecsacs.IAMRoleCredentialsMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetTaskRoleMetadata", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetTaskRoleMetadata indicates an expected call of SetTaskRoleMetadata. +func (mr *MockCredentialsMetadataSetterMockRecorder) SetTaskRoleMetadata(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTaskRoleMetadata", reflect.TypeOf((*MockCredentialsMetadataSetter)(nil).SetTaskRoleMetadata), arg0) +} diff --git a/ecs-agent/acs/session/refresh_credentials_responder.go b/ecs-agent/acs/session/refresh_credentials_responder.go new file mode 100644 index 0000000000..df0c1f5cc4 --- /dev/null +++ b/ecs-agent/acs/session/refresh_credentials_responder.go @@ -0,0 +1,199 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package session + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/pkg/errors" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" +) + +const ( + RefreshCredentialsMessageName = "IAMRoleCredentialsMessage" +) + +type CredentialsMetadataSetter interface { + SetTaskRoleMetadata(message *ecsacs.IAMRoleCredentialsMessage) error + SetExecRoleMetadata(message *ecsacs.IAMRoleCredentialsMessage) error +} + +// refreshCredentialsResponder implements the wsclient.RequestResponder interface for responding +// to ecsacs.IAMRoleCredentialsMessage messages sent by ACS. +type refreshCredentialsResponder struct { + credentialsManager credentials.Manager + credsMetadataSetter CredentialsMetadataSetter + metricsFactory metrics.EntryFactory + respond wsclient.RespondFunc +} + +// NewRefreshCredentialsResponder returns an instance of the refreshCredentialsResponder struct. +func NewRefreshCredentialsResponder(credentialsManager credentials.Manager, + credsMetadataSetter CredentialsMetadataSetter, + metricsFactory metrics.EntryFactory, + responseSender wsclient.RespondFunc) wsclient.RequestResponder { + r := &refreshCredentialsResponder{ + credentialsManager: credentialsManager, + credsMetadataSetter: credsMetadataSetter, + metricsFactory: metricsFactory, + } + r.respond = ResponseToACSSender(r.Name(), responseSender) + return r +} + +func (*refreshCredentialsResponder) Name() string { return "refresh credentials responder" } + +func (r *refreshCredentialsResponder) HandlerFunc() wsclient.RequestHandler { + return r.handleCredentialsMessage +} + +func (r *refreshCredentialsResponder) handleCredentialsMessage(message *ecsacs.IAMRoleCredentialsMessage) { + logger.Debug(fmt.Sprintf("Handling %s", RefreshCredentialsMessageName)) + messageID := aws.StringValue(message.MessageId) + taskARN := aws.StringValue(message.TaskArn) + metricFields := logger.Fields{ + field.MessageID: messageID, + field.TaskARN: taskARN, + } + + // Validate fields in the message. + err := validateIAMRoleCredentialsMessage(message) + if err != nil { + logger.Error(fmt.Sprintf("Error validating %s received from ECS", RefreshCredentialsMessageName), + logger.Fields{ + field.Error: err, + }) + err = errors.Wrap(err, "ACS refresh credentials message validation failed") + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + + // Handle credentials refresh. + err = r.credentialsManager.SetTaskCredentials(&credentials.TaskIAMRoleCredentials{ + ARN: taskARN, + IAMRoleCredentials: credentials.IAMRoleCredentialsFromACS(message.RoleCredentials, + aws.StringValue(message.RoleType)), + }) + if err != nil { + logger.Error(fmt.Sprintf("Unable to handle %s due to error in setting credentials", + RefreshCredentialsMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + err = errors.Wrap(err, "unable to set credentials in the credentials manager") + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + err = r.setCredentialsMetadata(message) + if err != nil { + logger.Error(fmt.Sprintf("Unable to handle %s due to error in setting credentials metadata", + RefreshCredentialsMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + + // Send ACK. + err = r.respond(&ecsacs.IAMRoleCredentialsAckRequest{ + Expiration: message.RoleCredentials.Expiration, + MessageId: message.MessageId, + CredentialsId: message.RoleCredentials.CredentialsId, + }) + if err != nil { + logger.Warn(fmt.Sprintf("Error acknowledging %s", RefreshCredentialsMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + err = errors.Wrapf(err, "unable to ACK task credentials for task with ARN %s", taskARN) + r.metricsFactory.New(metrics.CredentialsRefreshFailure).WithFields(metricFields).Done(err) + return + } + + r.metricsFactory.New(metrics.CredentialsRefreshSuccess).WithFields(metricFields).Done(nil) +} + +func (r *refreshCredentialsResponder) setCredentialsMetadata(message *ecsacs.IAMRoleCredentialsMessage) error { + roleType := aws.StringValue(message.RoleType) + switch roleType { + case credentials.ApplicationRoleType: + err := r.credsMetadataSetter.SetTaskRoleMetadata(message) + if err != nil { + return errors.Wrap(err, "failed to set task role metadata") + } + case credentials.ExecutionRoleType: + err := r.credsMetadataSetter.SetExecRoleMetadata(message) + if err != nil { + return errors.Wrap(err, "failed to set execution role metadata") + } + default: + return errors.Errorf("received credentials for unexpected roleType \"%s\"", roleType) + } + return nil +} + +// validateIAMRoleCredentialsMessage performs validation checks on the +// IAMRoleCredentialsMessage. +func validateIAMRoleCredentialsMessage(message *ecsacs.IAMRoleCredentialsMessage) error { + if message == nil { + return errors.Errorf("Message is empty") + } + + messageID := aws.StringValue(message.MessageId) + if messageID == "" { + return errors.Errorf("Message ID is not set") + } + + taskArn := aws.StringValue(message.TaskArn) + if taskArn == "" { + return errors.Errorf("taskArn is not set for message ID %s", messageID) + } + + if message.RoleCredentials == nil { + return errors.Errorf("roleCredentials is not set for message ID %s", messageID) + } + + if aws.StringValue(message.RoleCredentials.CredentialsId) == "" { + return errors.Errorf("roleCredentials ID not set for message ID %s", messageID) + } + + roleType := aws.StringValue(message.RoleType) + if !validRoleType(roleType) { + return errors.Errorf("roleType \"%s\" is invalid for message ID %s with taskArn %s", roleType, messageID, + taskArn) + } + + return nil +} + +// validRoleType returns false if the RoleType in the ACS refresh credentials message is not +// one of the expected types. Expected types: TaskApplication, TaskExecution +func validRoleType(roleType string) bool { + switch roleType { + case credentials.ApplicationRoleType: + return true + case credentials.ExecutionRoleType: + return true + default: + return false + } +} diff --git a/ecs-agent/acs/session/refresh_credentials_responder_test.go b/ecs-agent/acs/session/refresh_credentials_responder_test.go new file mode 100644 index 0000000000..af095cdce8 --- /dev/null +++ b/ecs-agent/acs/session/refresh_credentials_responder_test.go @@ -0,0 +1,246 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package session + +import ( + "sync" + "testing" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + mock_session "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" + "github.com/aws/aws-sdk-go/aws" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +var testRefreshCredentialsMessage = &ecsacs.IAMRoleCredentialsMessage{ + MessageId: aws.String(testconst.MessageID), + TaskArn: aws.String(testconst.TaskARN), + RoleCredentials: &ecsacs.IAMRoleCredentials{ + CredentialsId: aws.String(testconst.CredentialsID), + }, + RoleType: aws.String(credentials.ApplicationRoleType), +} + +// TestValidateRefreshMessageWithNilMessage tests if a validation error +// is returned while validating an empty credentials message. +func TestValidateRefreshMessageWithNilMessage(t *testing.T) { + err := validateIAMRoleCredentialsMessage(nil) + assert.Error(t, err, "Expected validation error validating an empty message") +} + +// TestValidateRefreshMessageWithNoMessageId tests if a validation error +// is returned while validating a credentials message with no message ID. +func TestValidateRefreshMessageWithNoMessageId(t *testing.T) { + tempMessageId := testRefreshCredentialsMessage.MessageId + + testRefreshCredentialsMessage.MessageId = nil + err := validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with no message ID") + + testRefreshCredentialsMessage.MessageId = aws.String("") + err = validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with empty message ID") + + testRefreshCredentialsMessage.MessageId = tempMessageId +} + +// TestValidateRefreshMessageWithNoTaskArn tests if a validation error +// is returned while validating a credentials message with no task ARN. +func TestValidateRefreshMessageWithNoTaskArn(t *testing.T) { + tempTaskArn := testRefreshCredentialsMessage.TaskArn + + testRefreshCredentialsMessage.TaskArn = nil + err := validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with no task ARN") + + testRefreshCredentialsMessage.TaskArn = aws.String("") + err = validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with empty task ARN") + + testRefreshCredentialsMessage.TaskArn = tempTaskArn +} + +// TestValidateRefreshMessageWithNoRoleCredentials tests if a validation error +// is returned while validating a credentials message with no role credentials. +func TestValidateRefreshMessageWithNoRoleCredentials(t *testing.T) { + tempRoleCredentials := testRefreshCredentialsMessage.RoleCredentials + + testRefreshCredentialsMessage.RoleCredentials = nil + err := validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with no role credentials") + + testRefreshCredentialsMessage.RoleCredentials = tempRoleCredentials +} + +// TestValidateRefreshMessageWithInvalidRoleType tests if a valid error +// is returned while validating a credentials message with no credentials ID. +func TestValidateRefreshMessageWithNoCredentialsId(t *testing.T) { + tempRoleCredentials := testRefreshCredentialsMessage.RoleCredentials + + testRefreshCredentialsMessage.RoleCredentials = &ecsacs.IAMRoleCredentials{} + err := validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with no credentials ID") + + testRefreshCredentialsMessage.RoleCredentials = &ecsacs.IAMRoleCredentials{CredentialsId: aws.String("")} + err = validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with empty credentials ID") + + testRefreshCredentialsMessage.RoleCredentials = tempRoleCredentials +} + +// TestValidateRefreshMessageWithInvalidRoleType tests if a valid error +// is returned while validating a credentials message with an invalid role type. +func TestValidateRefreshMessageWithInvalidRoleType(t *testing.T) { + tempRoleType := testRefreshCredentialsMessage.RoleType + + testRefreshCredentialsMessage.RoleType = aws.String("not a valid role type") + err := validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.Error(t, err, "Expected validation error validating a message with an invalid role type") + + testRefreshCredentialsMessage.RoleType = tempRoleType +} + +// TestValidateRefreshMessageSuccess tests if a valid credentials message +// is validated without any errors. +func TestValidateRefreshMessageSuccess(t *testing.T) { + err := validateIAMRoleCredentialsMessage(testRefreshCredentialsMessage) + assert.NoError(t, err, "Error validating credentials message: %w", err) +} + +// TestRefreshCredentialsAckHappyPath tests the happy path for a typical IAMRoleCredentialsMessage and confirms expected +// ACK request is made. +func TestRefreshCredentialsAckHappyPath(t *testing.T) { + testCases := []struct { + name string + roleType string + }{ + { + name: "task role type", + roleType: credentials.ApplicationRoleType, + }, + { + name: "execution role type", + roleType: credentials.ExecutionRoleType, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // WaitGroup is necessary to wait for function to be called in separate goroutine before exiting the test. + wg := sync.WaitGroup{} + wg.Add(1) + + iamRoleCredentialsMessageCopy := *testRefreshCredentialsMessage + ackSent := make(chan *ecsacs.IAMRoleCredentialsAckRequest) + credentialsManager := credentials.NewManager() + mockCredsMetadataSetter := mock_session.NewMockCredentialsMetadataSetter(ctrl) + switch tc.roleType { + case credentials.ApplicationRoleType: + iamRoleCredentialsMessageCopy.RoleType = aws.String(credentials.ApplicationRoleType) + mockCredsMetadataSetter.EXPECT(). + SetTaskRoleMetadata(gomock.Any()). + Return(nil) + case credentials.ExecutionRoleType: + iamRoleCredentialsMessageCopy.RoleType = aws.String(credentials.ExecutionRoleType) + mockCredsMetadataSetter.EXPECT(). + SetExecRoleMetadata(gomock.Any()). + Return(nil) + default: + t.Fatal("invalid role type used in happy path test, role type should be valid for happy path") + return + } + mockMetricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + mockEntry := mock_metrics.NewMockEntry(ctrl) + mockEntry.EXPECT().WithFields(gomock.Any()).Return(mockEntry) + mockEntry.EXPECT().Done(nil) + mockMetricsFactory.EXPECT().New(metrics.CredentialsRefreshSuccess). + Do(func(arg0 interface{}) { + defer wg.Done() // decrement WaitGroup counter now that HandleResourceAttachment function has been called + }). + Return(mockEntry) + + testResponseSender := func(response interface{}) error { + resp := response.(*ecsacs.IAMRoleCredentialsAckRequest) + ackSent <- resp + return nil + } + testRefreshCredentialsResponder := NewRefreshCredentialsResponder(credentialsManager, + mockCredsMetadataSetter, + mockMetricsFactory, + testResponseSender) + + handleCredentialsMessage := + testRefreshCredentialsResponder.HandlerFunc().(func(*ecsacs.IAMRoleCredentialsMessage)) + + go handleCredentialsMessage(&iamRoleCredentialsMessageCopy) + + refreshCredentialsAckSent := <-ackSent + wg.Wait() + assert.Equal(t, aws.StringValue(iamRoleCredentialsMessageCopy.MessageId), + aws.StringValue(refreshCredentialsAckSent.MessageId)) + + creds, exist := credentialsManager.GetTaskCredentials(testconst.CredentialsID) + assert.True(t, exist, "Expected credentials to exist for the task") + assert.Equal(t, aws.StringValue(iamRoleCredentialsMessageCopy.RoleCredentials.CredentialsId), + creds.IAMRoleCredentials.CredentialsID) + }) + } +} + +// TestRefreshCredentialsWhenUnableToSetCredentialsMetadata tests the error case where the responder is not able to +// successfully set credentials metadata. +func TestRefreshCredentialsWhenUnableToSetCredentialsMetadata(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ackSent := false + credentialsManager := credentials.NewManager() + mockCredsMetadataSetter := mock_session.NewMockCredentialsMetadataSetter(ctrl) + mockCredsMetadataSetter.EXPECT(). + SetTaskRoleMetadata(gomock.Any()). + Return(errors.Errorf("unable to set credentials metadata")) + mockMetricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + mockEntry := mock_metrics.NewMockEntry(ctrl) + mockEntry.EXPECT().WithFields(gomock.Any()).Return(mockEntry) + mockEntry.EXPECT().Done(gomock.Any()) + mockMetricsFactory.EXPECT().New(metrics.CredentialsRefreshFailure).Return(mockEntry) + + testResponseSender := func(response interface{}) error { + ackSent = true + return nil + } + testRefreshCredentialsResponder := NewRefreshCredentialsResponder(credentialsManager, + mockCredsMetadataSetter, + mockMetricsFactory, + testResponseSender) + + handleCredentialsMessage := + testRefreshCredentialsResponder.HandlerFunc().(func(*ecsacs.IAMRoleCredentialsMessage)) + + handleCredentialsMessage(testRefreshCredentialsMessage) + assert.False(t, ackSent, + "Expected no ACK of refresh credentials message when unable to successfully set credentials metadata") +} diff --git a/ecs-agent/acs/session/testconst/test_const.go b/ecs-agent/acs/session/testconst/test_const.go index ec7894e2d9..310224e713 100644 --- a/ecs-agent/acs/session/testconst/test_const.go +++ b/ecs-agent/acs/session/testconst/test_const.go @@ -25,4 +25,5 @@ const ( InterfaceProtocol = "default" GatewayIPv4 = "192.168.1.1/24" IPv4Address = "ipv4" + CredentialsID = "credsid" ) diff --git a/ecs-agent/metrics/constants.go b/ecs-agent/metrics/constants.go index 4979940c33..fed662215c 100644 --- a/ecs-agent/metrics/constants.go +++ b/ecs-agent/metrics/constants.go @@ -10,6 +10,7 @@ // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. + package metrics const ( @@ -34,4 +35,9 @@ const ( // TaskStopVerificationACKResponder taskStopVerificationACKResponderNamespace = "TaskStopVeificationACKResponder" TaskStoppedMetricName = taskStopVerificationACKResponderNamespace + ".TaskStopped" + + // Credentials Refresh + credsRefreshNamespace = "CredentialsRefresh" + CredentialsRefreshFailure = credsRefreshNamespace + ".Failure" + CredentialsRefreshSuccess = credsRefreshNamespace + ".Success" )