diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index 0c090474c0..279a42a1c3 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -358,12 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { client.AddRequestHandler(payloadHandler.handlerFunc()) - heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client, acsSession.doctor) - defer heartbeatHandler.clearAcks() - heartbeatHandler.start() - defer heartbeatHandler.stop() - - client.AddRequestHandler(heartbeatHandler.handlerFunc()) + client.AddRequestHandler(HeartbeatHandlerFunc(client, acsSession.doctor)) updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine) diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 6a1ceab6f8..07e834122b 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -957,9 +957,6 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) { cancel() <-ended - // The number of goroutines finishing in the MockACSServer will affect - // the result unless we wait here. - time.Sleep(2 * time.Second) afterGoroutines := runtime.NumGoroutine() t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines) diff --git a/agent/acs/handler/heartbeat_handler.go b/agent/acs/handler/heartbeat_handler.go index 43b16c477b..c248de129d 100644 --- a/agent/acs/handler/heartbeat_handler.go +++ b/agent/acs/handler/heartbeat_handler.go @@ -14,8 +14,6 @@ package handler import ( - "context" - "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/agent/doctor" "github.com/aws/amazon-ecs-agent/agent/wsclient" @@ -23,117 +21,30 @@ import ( "github.com/cihub/seelog" ) -// heartbeatHandler handles heartbeat messages from ACS -type heartbeatHandler struct { - heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage - heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest - ctx context.Context - cancel context.CancelFunc - acsClient wsclient.ClientServer - doctor *doctor.Doctor -} - -// newHeartbeatHandler returns an instance of the heartbeatHandler struct -func newHeartbeatHandler(ctx context.Context, acsClient wsclient.ClientServer, heartbeatDoctor *doctor.Doctor) heartbeatHandler { - // Create a cancelable context from the parent context - derivedContext, cancel := context.WithCancel(ctx) - return heartbeatHandler{ - heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage), - heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest), - ctx: derivedContext, - cancel: cancel, - acsClient: acsClient, - doctor: heartbeatDoctor, - } -} - -// handlerFunc returns a function to enqueue requests onto the buffer -func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) { +func HeartbeatHandlerFunc(acsClient wsclient.ClientServer, doctor *doctor.Doctor) func(message *ecsacs.HeartbeatMessage) { return func(message *ecsacs.HeartbeatMessage) { - heartbeatHandler.heartbeatMessageBuffer <- message - } -} - -// start() invokes go routines to handle receive and respond to heartbeats -func (heartbeatHandler *heartbeatHandler) start() { - go heartbeatHandler.handleHeartbeatMessage() - go heartbeatHandler.sendHeartbeatAck() -} - -func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() { - for { - select { - case message := <-heartbeatHandler.heartbeatMessageBuffer: - if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil { - seelog.Warnf("Unable to handle heartbeat message [%s]: %s", message.String(), err) - } - case <-heartbeatHandler.ctx.Done(): - return - } + handleSingleHeartbeatMessage(acsClient, doctor, message) } } -func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error { - // TestHandlerDoesntLeakGoroutines unit test is failing because of this section +// To handle a Heartbeat Message the doctor health checks need to be run and +// an ACK needs to be sent back to ACS. +// This function is meant to be called from the ACS dispatcher and as such +// should not block in any way to prevent starvation of the message handler +func handleSingleHeartbeatMessage(acsClient wsclient.ClientServer, doctor *doctor.Doctor, message *ecsacs.HeartbeatMessage) { // Agent will run healthchecks triggered by ACS heartbeat // healthcheck results will be sent on to TACS, but for now just to debug logs. - go func() { - heartbeatHandler.doctor.RunHealthchecks() - }() + go doctor.RunHealthchecks() - // Agent will send simple ack to the heartbeatAckMessageBuffer + // Agent will send simple ack + ack := &ecsacs.HeartbeatAckRequest{ + MessageId: message.MessageId, + } go func() { - response := &ecsacs.HeartbeatAckRequest{ - MessageId: message.MessageId, + err := acsClient.MakeRequest(ack) + if err != nil { + seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err) } - heartbeatHandler.heartbeatAckMessageBuffer <- response }() - return nil -} - -func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() { - for { - select { - case ack := <-heartbeatHandler.heartbeatAckMessageBuffer: - heartbeatHandler.sendSingleHeartbeatAck(ack) - case <-heartbeatHandler.ctx.Done(): - return - } - } -} - -// sendPendingHeartbeatAck sends all pending heartbeat acks to ACS before closing the connection -func (heartbeatHandler *heartbeatHandler) sendPendingHeartbeatAck() { - for { - select { - case ack := <-heartbeatHandler.heartbeatAckMessageBuffer: - heartbeatHandler.sendSingleHeartbeatAck(ack) - default: - return - } - } -} - -func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) { - err := heartbeatHandler.acsClient.MakeRequest(ack) - if err != nil { - seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err) - } -} - -// stop() cancels the context being used by this handler, which stops the go routines started by 'start()' -func (heartbeatHandler *heartbeatHandler) stop() { - heartbeatHandler.cancel() -} - -// clearAcks drains the ack request channel -func (heartbeatHandler *heartbeatHandler) clearAcks() { - for { - select { - case <-heartbeatHandler.heartbeatAckMessageBuffer: - default: - return - } - } } diff --git a/agent/acs/handler/heartbeat_handler_test.go b/agent/acs/handler/heartbeat_handler_test.go index fb382e1f25..deba7762bd 100644 --- a/agent/acs/handler/heartbeat_handler_test.go +++ b/agent/acs/handler/heartbeat_handler_test.go @@ -17,13 +17,10 @@ package handler import ( - "context" - "sync" "testing" "time" "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" - mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks" "github.com/aws/amazon-ecs-agent/agent/doctor" mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" @@ -90,29 +87,21 @@ func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessa ctrl := gomock.NewController(t) defer ctrl.Finish() - ctx, cancel := context.WithCancel(context.Background()) - var heartbeatAckSent *ecsacs.HeartbeatAckRequest + ackSent := make(chan *ecsacs.HeartbeatAckRequest) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) { - heartbeatAckSent = message - cancel() + ackSent <- message + close(ackSent) }).Times(1) - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - dockerClient.EXPECT().SystemPing(gomock.Any(), gomock.Any()).AnyTimes() - emptyHealthchecksList := []doctor.Healthcheck{} emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "testCluster", "this:is:an:instance:arn") - handler := newHeartbeatHandler(ctx, mockWsClient, emptyDoctor) - - go handler.sendHeartbeatAck() - - handler.handleSingleHeartbeatMessage(heartbeatReceived) + handleSingleHeartbeatMessage(mockWsClient, emptyDoctor, heartbeatReceived) - // wait till we get an ack from heartbeatAckMessageBuffer - <-ctx.Done() + // wait till we send an + heartbeatAckSent := <-ackSent require.Equal(t, heartbeatAckExpected, heartbeatAckSent) }