Skip to content

Commit

Permalink
Refactor ACS refresh credentials message handling
Browse files Browse the repository at this point in the history
  • Loading branch information
danehlim committed Jul 28, 2023
1 parent 6ba7f1e commit 86db0d4
Show file tree
Hide file tree
Showing 22 changed files with 1,083 additions and 753 deletions.
27 changes: 9 additions & 18 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
"github.com/aws/amazon-ecs-agent/ecs-agent/eventstream"
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"
Expand Down Expand Up @@ -78,8 +79,8 @@ const (
acsProtocolVersion = 2
// numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across
// sessions. We use this to send pending acks, before agent initiates a disconnect to ACS.
// they are: refreshCredentialsHandler, taskManifestHandler, and payloadHandler
numOfHandlersSendingAcks = 3
// they are: taskManifestHandler, and payloadHandler
numOfHandlersSendingAcks = 2
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -250,13 +251,7 @@ func (acsSession *session) startSessionOnce() error {
func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
cfg := acsSession.agentConfig

refreshCredsHandler := newRefreshCredentialsHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN,
client, acsSession.credentialsManager, acsSession.taskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()
defer refreshCredsHandler.stop()

client.AddRequestHandler(refreshCredsHandler.handlerFunc())
credsMetadataSetter := &credentialsMetadataSetter{taskEngine: acsSession.taskEngine}

eniHandler := &eniHandler{
state: acsSession.state,
Expand All @@ -265,6 +260,8 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {

manifestMessageIDAccessor := &manifestMessageIDAccessor{}

metricsFactory := metrics.NewNopEntryFactory()

// Add TaskManifestHandler
taskManifestHandler := newTaskManifestHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN,
client, acsSession.dataClient, acsSession.taskEngine, acsSession.latestSeqNumTaskManifest,
Expand All @@ -286,7 +283,6 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
acsSession.containerInstanceARN,
client,
acsSession.dataClient,
refreshCredsHandler,
acsSession.credentialsManager,
acsSession.taskHandler, acsSession.latestSeqNumTaskManifest)
// Clear the acks channel on return because acks of messageids don't have any value across sessions
Expand All @@ -300,6 +296,8 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
return client.MakeRequest(response)
}
responders := []wsclient.RequestResponder{
acssession.NewRefreshCredentialsResponder(acsSession.credentialsManager, credsMetadataSetter, metricsFactory,
responseSender),
acssession.NewAttachTaskENIResponder(eniHandler, responseSender),
acssession.NewAttachInstanceENIResponder(eniHandler, responseSender),
acssession.NewHeartbeatResponder(acsSession.doctor, responseSender),
Expand All @@ -320,7 +318,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
// Start a connection timer; agent will send pending acks and close its ACS websocket connection
// after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter,
&refreshCredsHandler, &taskManifestHandler, &payloadHandler)
&taskManifestHandler, &payloadHandler)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
Expand Down Expand Up @@ -416,7 +414,6 @@ func newConnectionTimer(
client wsclient.ClientServer,
connectionTime time.Duration,
connectionJitter time.Duration,
refreshCredsHandler *refreshCredentialsHandler,
taskManifestHandler *taskManifestHandler,
payloadHandler *payloadRequestHandler,
) ttime.Timer {
Expand All @@ -427,12 +424,6 @@ func newConnectionTimer(
wg := sync.WaitGroup{}
wg.Add(numOfHandlersSendingAcks)

// send pending creds refresh acks to ACS
go func() {
refreshCredsHandler.sendPendingAcks()
wg.Done()
}()

// send pending task manifest acks and task stop verification acks to ACS
go func() {
taskManifestHandler.sendPendingTaskManifestMessageAck()
Expand Down
4 changes: 2 additions & 2 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1162,8 +1162,6 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) {
// Ensure that credentials manager interface methods are invoked in the
// correct order, with expected arguments
gomock.InOrder(
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true),
// The last invocation of SetCredentials is to update
// credentials when a refresh message is received by the handler
credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()).Do(func(creds *rolecredentials.TaskIAMRoleCredentials) {
Expand All @@ -1185,6 +1183,8 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) {
t.Errorf("Mismatch between expected and credentials expected: %v, added: %v", expectedCreds, updatedCredentials)
}
}).Return(nil),
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true),
)
serverIn <- sampleRefreshCredentialsMessage

Expand Down
8 changes: 4 additions & 4 deletions agent/acs/handler/attach_eni_handler_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func testENIAckTimeout(t *testing.T, attachmentType string) {
expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis)
eniAttachment := &apieni.ENIAttachment{
AttachmentInfo: attachmentinfo.AttachmentInfo{
TaskARN: taskArn,
TaskARN: testconst.TaskARN,
AttachmentARN: attachmentArn,
ExpiresAt: expiresAt,
AttachStatusSent: false,
Expand Down Expand Up @@ -103,7 +103,7 @@ func testENIAckWithinTimeout(t *testing.T, attachmentType string) {
expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis)
eniAttachment := &apieni.ENIAttachment{
AttachmentInfo: attachmentinfo.AttachmentInfo{
TaskARN: taskArn,
TaskARN: testconst.TaskARN,
AttachmentARN: attachmentArn,
ExpiresAt: expiresAt,
AttachStatusSent: false,
Expand All @@ -130,7 +130,7 @@ func testENIAckWithinTimeout(t *testing.T, attachmentType string) {

// TestHandleENIAttachmentTaskENI tests handling a new task eni
func TestHandleENIAttachmentTaskENI(t *testing.T) {
testHandleENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn)
testHandleENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, testconst.TaskARN)
}

// TestHandleENIAttachmentInstanceENI tests handling a new instance eni
Expand Down Expand Up @@ -178,7 +178,7 @@ func testHandleENIAttachment(t *testing.T, attachmentType, taskArn string) {

// TestHandleExpiredENIAttachmentTaskENI tests handling an expired task eni
func TestHandleExpiredENIAttachmentTaskENI(t *testing.T) {
testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn)
testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, testconst.TaskARN)
}

// TestHandleExpiredENIAttachmentInstanceENI tests handling an expired instance eni
Expand Down
15 changes: 11 additions & 4 deletions agent/acs/handler/payload_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ type payloadRequestHandler struct {
cluster string
containerInstanceArn string
acsClient wsclient.ClientServer
refreshHandler refreshCredentialsHandler
credentialsManager credentials.Manager
latestSeqNumberTaskManifest *int64
}
Expand All @@ -65,7 +64,6 @@ func newPayloadRequestHandler(
containerInstanceArn string,
acsClient wsclient.ClientServer,
dataClient data.Client,
refreshHandler refreshCredentialsHandler,
credentialsManager credentials.Manager,
taskHandler *eventhandler.TaskHandler, seqNumTaskManifest *int64) payloadRequestHandler {
// Create a cancelable context from the parent context
Expand All @@ -82,7 +80,6 @@ func newPayloadRequestHandler(
cluster: cluster,
containerInstanceArn: containerInstanceArn,
acsClient: acsClient,
refreshHandler: refreshHandler,
credentialsManager: credentialsManager,
latestSeqNumberTaskManifest: seqNumTaskManifest,
}
Expand Down Expand Up @@ -187,14 +184,24 @@ func (payloadHandler *payloadRequestHandler) handleSingleMessage(payload *ecsacs
go func() {
// Throw the ack in async; it doesn't really matter all that much and this is blocking handling more tasks.
for _, credentialsAck := range credentialsAcks {
payloadHandler.refreshHandler.ackMessage(credentialsAck)
payloadHandler.makeCredentialsAckRequest(credentialsAck)
}
payloadHandler.ackRequest <- *payload.MessageId
}()

return nil
}

// makeCredentialsAckRequest sends an IAMRoleCredentialsAckRequest to the backend
func (payloadHandler *payloadRequestHandler) makeCredentialsAckRequest(ack *ecsacs.IAMRoleCredentialsAckRequest) {
seelog.Debugf("ACKing credentials associated with ACS payload message: %s", ack.String())
err := payloadHandler.acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error ACKing credentials with credentialsID '%s' associated with ACS payload message, error: %v",
aws.StringValue(ack.CredentialsId), err)
}
}

// addPayloadTasks does validation on each task and, for all valid ones, adds
// it to the task engine. It returns a bool indicating if it could add every
// task to the taskEngine and a slice of credential ack requests
Expand Down
23 changes: 4 additions & 19 deletions agent/acs/handler/payload_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ func setup(t *testing.T) *testHelper {
testconst.ContainerInstanceARN,
mockWsClient,
data.NewNoopClient(),
refreshCredentialsHandler{},
credentialsManager,
taskHandler, &latestSeqNumberTaskManifest)

Expand Down Expand Up @@ -307,11 +306,6 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
}),
)

refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()
tester.payloadHandler.refreshHandler = refreshCredsHandler

go tester.payloadHandler.start()

taskArn := "t1"
Expand All @@ -338,8 +332,8 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
},
},
MessageId: aws.String(payloadMessageId),
ClusterArn: aws.String(cluster),
ContainerInstanceArn: aws.String(containerInstance),
ClusterArn: aws.String(testconst.ClusterName),
ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN),
}
err := tester.payloadHandler.handleSingleMessage(payloadMessage)
assert.NoError(t, err, "error handling payload message")
Expand Down Expand Up @@ -496,11 +490,6 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
}),
)

refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()
tester.payloadHandler.refreshHandler = refreshCredsHandler

go tester.payloadHandler.start()

firstTaskArn := "t1"
Expand Down Expand Up @@ -546,8 +535,8 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
},
},
MessageId: aws.String(payloadMessageId),
ClusterArn: aws.String(cluster),
ContainerInstanceArn: aws.String(containerInstance),
ClusterArn: aws.String(testconst.ClusterName),
ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN),
}

// Wait till we get an ack
Expand Down Expand Up @@ -618,11 +607,7 @@ func TestAddPayloadTaskAddsExecutionRoles(t *testing.T) {
tester.cancel()
}),
)
refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()

tester.payloadHandler.refreshHandler = refreshCredsHandler
go tester.payloadHandler.start()
taskArn := "t1"
credentialsExpiration := "expiration"
Expand Down

0 comments on commit 86db0d4

Please sign in to comment.