Skip to content

Commit

Permalink
Fix flakey leak test for ACS Handler
Browse files Browse the repository at this point in the history
The multi-channel multi-goroutine implementation of heartbeatHandling
created a scenario where goroutines would be routinely leaked in the
test which was testing for it.

The reality of handling Heartbeat messages is that there was not a
need to do multi-thread synchronization with channels at all.
  • Loading branch information
aws-gibbskt committed Mar 6, 2023
1 parent 5cf4103 commit b02f15a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 180 deletions.
20 changes: 4 additions & 16 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ const (
acsProtocolVersion = 2
// numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across
// sessions. We use this to send pending acks, before agent initiates a disconnect to ACS.
// they are: refreshCredentialsHandler, taskManifestHandler, payloadHandler and heartbeatHandler
numOfHandlersSendingAcks = 4
// they are: refreshCredentialsHandler, taskManifestHandler, and payloadHandler
numOfHandlersSendingAcks = 3
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -358,12 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {

client.AddRequestHandler(payloadHandler.handlerFunc())

heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client, acsSession.doctor)
defer heartbeatHandler.clearAcks()
heartbeatHandler.start()
defer heartbeatHandler.stop()

client.AddRequestHandler(heartbeatHandler.handlerFunc())
client.AddRequestHandler(HeartbeatHandlerFunc(client, acsSession.doctor))

updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine)

Expand All @@ -377,7 +372,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
// Start a connection timer; agent will send pending acks and close its ACS websocket connection
// after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter,
&refreshCredsHandler, &taskManifestHandler, &payloadHandler, &heartbeatHandler)
&refreshCredsHandler, &taskManifestHandler, &payloadHandler)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
Expand Down Expand Up @@ -521,7 +516,6 @@ func newConnectionTimer(
refreshCredsHandler *refreshCredentialsHandler,
taskManifestHandler *taskManifestHandler,
payloadHandler *payloadRequestHandler,
heartbeatHandler *heartbeatHandler,
) ttime.Timer {
expiresAt := retry.AddJitter(connectionTime, connectionJitter)
timer := time.AfterFunc(expiresAt, func() {
Expand Down Expand Up @@ -549,12 +543,6 @@ func newConnectionTimer(
wg.Done()
}()

// send pending heartbeat acks to ACS
go func() {
heartbeatHandler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for acks from all the handlers above to be sent to ACS before closing the websocket connection.
// the methods used to read pending acks are non-blocking, so it is safe to wait here.
wg.Wait()
Expand Down
3 changes: 0 additions & 3 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,9 +957,6 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {
cancel()
<-ended

// The number of goroutines finishing in the MockACSServer will affect
// the result unless we wait here.
time.Sleep(2 * time.Second)
afterGoroutines := runtime.NumGoroutine()

t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines)
Expand Down
119 changes: 15 additions & 104 deletions agent/acs/handler/heartbeat_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,126 +14,37 @@
package handler

import (
"context"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/doctor"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
"github.com/aws/aws-sdk-go/aws"
"github.com/cihub/seelog"
)

// heartbeatHandler handles heartbeat messages from ACS
type heartbeatHandler struct {
heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage
heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest
ctx context.Context
cancel context.CancelFunc
acsClient wsclient.ClientServer
doctor *doctor.Doctor
}

// newHeartbeatHandler returns an instance of the heartbeatHandler struct
func newHeartbeatHandler(ctx context.Context, acsClient wsclient.ClientServer, heartbeatDoctor *doctor.Doctor) heartbeatHandler {
// Create a cancelable context from the parent context
derivedContext, cancel := context.WithCancel(ctx)
return heartbeatHandler{
heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage),
heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest),
ctx: derivedContext,
cancel: cancel,
acsClient: acsClient,
doctor: heartbeatDoctor,
}
}

// handlerFunc returns a function to enqueue requests onto the buffer
func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) {
func HeartbeatHandlerFunc(acsClient wsclient.ClientServer, doctor *doctor.Doctor) func(message *ecsacs.HeartbeatMessage) {
return func(message *ecsacs.HeartbeatMessage) {
heartbeatHandler.heartbeatMessageBuffer <- message
}
}

// start() invokes go routines to handle receive and respond to heartbeats
func (heartbeatHandler *heartbeatHandler) start() {
go heartbeatHandler.handleHeartbeatMessage()
go heartbeatHandler.sendHeartbeatAck()
}

func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() {
for {
select {
case message := <-heartbeatHandler.heartbeatMessageBuffer:
if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil {
seelog.Warnf("Unable to handle heartbeat message [%s]: %s", message.String(), err)
}
case <-heartbeatHandler.ctx.Done():
return
}
handleSingleHeartbeatMessage(acsClient, doctor, message)
}
}

func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error {
// TestHandlerDoesntLeakGoroutines unit test is failing because of this section
// To handle a Heartbeat Message the doctor health checks need to be run and
// an ACK needs to be sent back to ACS.

// This function is meant to be called from the ACS dispatcher and as such
// should not block in any way to prevent starvation of the message handler
func handleSingleHeartbeatMessage(acsClient wsclient.ClientServer, doctor *doctor.Doctor, message *ecsacs.HeartbeatMessage) {
// Agent will run healthchecks triggered by ACS heartbeat
// healthcheck results will be sent on to TACS, but for now just to debug logs.
go func() {
heartbeatHandler.doctor.RunHealthchecks()
}()
go doctor.RunHealthchecks()

// Agent will send simple ack to the heartbeatAckMessageBuffer
// Agent will send simple ack
ack := &ecsacs.HeartbeatAckRequest{
MessageId: message.MessageId,
}
go func() {
response := &ecsacs.HeartbeatAckRequest{
MessageId: message.MessageId,
err := acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err)
}
heartbeatHandler.heartbeatAckMessageBuffer <- response
}()
return nil
}

func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
case <-heartbeatHandler.ctx.Done():
return
}
}
}

// sendPendingHeartbeatAck sends all pending heartbeat acks to ACS before closing the connection
func (heartbeatHandler *heartbeatHandler) sendPendingHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
default:
return
}
}
}

func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
err := heartbeatHandler.acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err)
}
}

// stop() cancels the context being used by this handler, which stops the go routines started by 'start()'
func (heartbeatHandler *heartbeatHandler) stop() {
heartbeatHandler.cancel()
}

// clearAcks drains the ack request channel
func (heartbeatHandler *heartbeatHandler) clearAcks() {
for {
select {
case <-heartbeatHandler.heartbeatAckMessageBuffer:
default:
return
}
}
}
63 changes: 6 additions & 57 deletions agent/acs/handler/heartbeat_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@
package handler

import (
"context"
"sync"
"testing"
"time"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks"
"github.com/aws/amazon-ecs-agent/agent/doctor"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"

"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -90,67 +85,21 @@ func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessa
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
var heartbeatAckSent *ecsacs.HeartbeatAckRequest
ackSent := make(chan *ecsacs.HeartbeatAckRequest)

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) {
heartbeatAckSent = message
cancel()
ackSent <- message
close(ackSent)
}).Times(1)

dockerClient := mock_dockerapi.NewMockDockerClient(ctrl)
dockerClient.EXPECT().SystemPing(gomock.Any(), gomock.Any()).AnyTimes()

emptyHealthchecksList := []doctor.Healthcheck{}
emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "testCluster", "this:is:an:instance:arn")

handler := newHeartbeatHandler(ctx, mockWsClient, emptyDoctor)

go handler.sendHeartbeatAck()
handleSingleHeartbeatMessage(mockWsClient, emptyDoctor, heartbeatReceived)

handler.handleSingleHeartbeatMessage(heartbeatReceived)

// wait till we get an ack from heartbeatAckMessageBuffer
<-ctx.Done()
// wait till we send an
heartbeatAckSent := <-ackSent

require.Equal(t, heartbeatAckExpected, heartbeatAckSent)
}

func TestHeartbeatHandler(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.TODO()
emptyHealthCheckList := []doctor.Healthcheck{}
emptyDoctor, _ := doctor.NewDoctor(emptyHealthCheckList, "testCluster",
"this:is:an:instance:arn")
mockWSClient := mock_wsclient.NewMockClientServer(ctrl)
mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1)
handler := newHeartbeatHandler(ctx, mockWSClient, emptyDoctor)

wg := sync.WaitGroup{}
wg.Add(2)

// write a dummy ack into the heartbeatAckMessageBuffer
go func() {
handler.heartbeatAckMessageBuffer <- &ecsacs.HeartbeatAckRequest{}
wg.Done()
}()

// sleep here to ensure that the sending go routine executes before the receiving one below. if not, then the
// receiving go routine will finish without receiving the ack since sendPendingHeartbeatAck() is non-blocking.
time.Sleep(1 * time.Second)

go func() {
handler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for both go routines above to finish before we verify that ack channel is empty and exit the test.
// this also ensures that the mock MakeRequest call happened as expected.
wg.Wait()

// verify that the heartbeatAckMessageBuffer channel is empty
assert.Equal(t, 0, len(handler.heartbeatAckMessageBuffer))
}

0 comments on commit b02f15a

Please sign in to comment.