diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index 0c090474c0..f601682ac2 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -78,8 +78,8 @@ const ( acsProtocolVersion = 2 // numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across // sessions. We use this to send pending acks, before agent initiates a disconnect to ACS. - // they are: refreshCredentialsHandler, taskManifestHandler, payloadHandler and heartbeatHandler - numOfHandlersSendingAcks = 4 + // they are: refreshCredentialsHandler, taskManifestHandler, and payloadHandler + numOfHandlersSendingAcks = 3 ) // Session defines an interface for handler's long-lived connection with ACS. @@ -358,12 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { client.AddRequestHandler(payloadHandler.handlerFunc()) - heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client, acsSession.doctor) - defer heartbeatHandler.clearAcks() - heartbeatHandler.start() - defer heartbeatHandler.stop() - - client.AddRequestHandler(heartbeatHandler.handlerFunc()) + client.AddRequestHandler(HeartbeatHandlerFunc(client, acsSession.doctor)) updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine) @@ -377,7 +372,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { // Start a connection timer; agent will send pending acks and close its ACS websocket connection // after this timer expires connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter, - &refreshCredsHandler, &taskManifestHandler, &payloadHandler, &heartbeatHandler) + &refreshCredsHandler, &taskManifestHandler, &payloadHandler) defer connectionTimer.Stop() // Start a heartbeat timer for closing the connection @@ -521,7 +516,6 @@ func newConnectionTimer( refreshCredsHandler *refreshCredentialsHandler, taskManifestHandler *taskManifestHandler, payloadHandler *payloadRequestHandler, - heartbeatHandler *heartbeatHandler, ) ttime.Timer { expiresAt := retry.AddJitter(connectionTime, connectionJitter) timer := time.AfterFunc(expiresAt, func() { @@ -549,12 +543,6 @@ func newConnectionTimer( wg.Done() }() - // send pending heartbeat acks to ACS - go func() { - heartbeatHandler.sendPendingHeartbeatAck() - wg.Done() - }() - // wait for acks from all the handlers above to be sent to ACS before closing the websocket connection. // the methods used to read pending acks are non-blocking, so it is safe to wait here. wg.Wait() diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 6a1ceab6f8..07e834122b 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -957,9 +957,6 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) { cancel() <-ended - // The number of goroutines finishing in the MockACSServer will affect - // the result unless we wait here. - time.Sleep(2 * time.Second) afterGoroutines := runtime.NumGoroutine() t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines) diff --git a/agent/acs/handler/heartbeat_handler.go b/agent/acs/handler/heartbeat_handler.go index 43b16c477b..c248de129d 100644 --- a/agent/acs/handler/heartbeat_handler.go +++ b/agent/acs/handler/heartbeat_handler.go @@ -14,8 +14,6 @@ package handler import ( - "context" - "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/agent/doctor" "github.com/aws/amazon-ecs-agent/agent/wsclient" @@ -23,117 +21,30 @@ import ( "github.com/cihub/seelog" ) -// heartbeatHandler handles heartbeat messages from ACS -type heartbeatHandler struct { - heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage - heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest - ctx context.Context - cancel context.CancelFunc - acsClient wsclient.ClientServer - doctor *doctor.Doctor -} - -// newHeartbeatHandler returns an instance of the heartbeatHandler struct -func newHeartbeatHandler(ctx context.Context, acsClient wsclient.ClientServer, heartbeatDoctor *doctor.Doctor) heartbeatHandler { - // Create a cancelable context from the parent context - derivedContext, cancel := context.WithCancel(ctx) - return heartbeatHandler{ - heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage), - heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest), - ctx: derivedContext, - cancel: cancel, - acsClient: acsClient, - doctor: heartbeatDoctor, - } -} - -// handlerFunc returns a function to enqueue requests onto the buffer -func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) { +func HeartbeatHandlerFunc(acsClient wsclient.ClientServer, doctor *doctor.Doctor) func(message *ecsacs.HeartbeatMessage) { return func(message *ecsacs.HeartbeatMessage) { - heartbeatHandler.heartbeatMessageBuffer <- message - } -} - -// start() invokes go routines to handle receive and respond to heartbeats -func (heartbeatHandler *heartbeatHandler) start() { - go heartbeatHandler.handleHeartbeatMessage() - go heartbeatHandler.sendHeartbeatAck() -} - -func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() { - for { - select { - case message := <-heartbeatHandler.heartbeatMessageBuffer: - if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil { - seelog.Warnf("Unable to handle heartbeat message [%s]: %s", message.String(), err) - } - case <-heartbeatHandler.ctx.Done(): - return - } + handleSingleHeartbeatMessage(acsClient, doctor, message) } } -func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error { - // TestHandlerDoesntLeakGoroutines unit test is failing because of this section +// To handle a Heartbeat Message the doctor health checks need to be run and +// an ACK needs to be sent back to ACS. +// This function is meant to be called from the ACS dispatcher and as such +// should not block in any way to prevent starvation of the message handler +func handleSingleHeartbeatMessage(acsClient wsclient.ClientServer, doctor *doctor.Doctor, message *ecsacs.HeartbeatMessage) { // Agent will run healthchecks triggered by ACS heartbeat // healthcheck results will be sent on to TACS, but for now just to debug logs. - go func() { - heartbeatHandler.doctor.RunHealthchecks() - }() + go doctor.RunHealthchecks() - // Agent will send simple ack to the heartbeatAckMessageBuffer + // Agent will send simple ack + ack := &ecsacs.HeartbeatAckRequest{ + MessageId: message.MessageId, + } go func() { - response := &ecsacs.HeartbeatAckRequest{ - MessageId: message.MessageId, + err := acsClient.MakeRequest(ack) + if err != nil { + seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err) } - heartbeatHandler.heartbeatAckMessageBuffer <- response }() - return nil -} - -func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() { - for { - select { - case ack := <-heartbeatHandler.heartbeatAckMessageBuffer: - heartbeatHandler.sendSingleHeartbeatAck(ack) - case <-heartbeatHandler.ctx.Done(): - return - } - } -} - -// sendPendingHeartbeatAck sends all pending heartbeat acks to ACS before closing the connection -func (heartbeatHandler *heartbeatHandler) sendPendingHeartbeatAck() { - for { - select { - case ack := <-heartbeatHandler.heartbeatAckMessageBuffer: - heartbeatHandler.sendSingleHeartbeatAck(ack) - default: - return - } - } -} - -func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) { - err := heartbeatHandler.acsClient.MakeRequest(ack) - if err != nil { - seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err) - } -} - -// stop() cancels the context being used by this handler, which stops the go routines started by 'start()' -func (heartbeatHandler *heartbeatHandler) stop() { - heartbeatHandler.cancel() -} - -// clearAcks drains the ack request channel -func (heartbeatHandler *heartbeatHandler) clearAcks() { - for { - select { - case <-heartbeatHandler.heartbeatAckMessageBuffer: - default: - return - } - } } diff --git a/agent/acs/handler/heartbeat_handler_test.go b/agent/acs/handler/heartbeat_handler_test.go index fb382e1f25..84316760c7 100644 --- a/agent/acs/handler/heartbeat_handler_test.go +++ b/agent/acs/handler/heartbeat_handler_test.go @@ -17,19 +17,14 @@ package handler import ( - "context" - "sync" "testing" - "time" "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" - mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks" "github.com/aws/amazon-ecs-agent/agent/doctor" mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -90,67 +85,21 @@ func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessa ctrl := gomock.NewController(t) defer ctrl.Finish() - ctx, cancel := context.WithCancel(context.Background()) - var heartbeatAckSent *ecsacs.HeartbeatAckRequest + ackSent := make(chan *ecsacs.HeartbeatAckRequest) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) { - heartbeatAckSent = message - cancel() + ackSent <- message + close(ackSent) }).Times(1) - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - dockerClient.EXPECT().SystemPing(gomock.Any(), gomock.Any()).AnyTimes() - emptyHealthchecksList := []doctor.Healthcheck{} emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "testCluster", "this:is:an:instance:arn") - handler := newHeartbeatHandler(ctx, mockWsClient, emptyDoctor) - - go handler.sendHeartbeatAck() + handleSingleHeartbeatMessage(mockWsClient, emptyDoctor, heartbeatReceived) - handler.handleSingleHeartbeatMessage(heartbeatReceived) - - // wait till we get an ack from heartbeatAckMessageBuffer - <-ctx.Done() + // wait till we send an + heartbeatAckSent := <-ackSent require.Equal(t, heartbeatAckExpected, heartbeatAckSent) } - -func TestHeartbeatHandler(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx := context.TODO() - emptyHealthCheckList := []doctor.Healthcheck{} - emptyDoctor, _ := doctor.NewDoctor(emptyHealthCheckList, "testCluster", - "this:is:an:instance:arn") - mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1) - handler := newHeartbeatHandler(ctx, mockWSClient, emptyDoctor) - - wg := sync.WaitGroup{} - wg.Add(2) - - // write a dummy ack into the heartbeatAckMessageBuffer - go func() { - handler.heartbeatAckMessageBuffer <- &ecsacs.HeartbeatAckRequest{} - wg.Done() - }() - - // sleep here to ensure that the sending go routine executes before the receiving one below. if not, then the - // receiving go routine will finish without receiving the ack since sendPendingHeartbeatAck() is non-blocking. - time.Sleep(1 * time.Second) - - go func() { - handler.sendPendingHeartbeatAck() - wg.Done() - }() - - // wait for both go routines above to finish before we verify that ack channel is empty and exit the test. - // this also ensures that the mock MakeRequest call happened as expected. - wg.Wait() - - // verify that the heartbeatAckMessageBuffer channel is empty - assert.Equal(t, 0, len(handler.heartbeatAckMessageBuffer)) -}