Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flakey leak test for ACS Handler #3232

Merged
merged 1 commit into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ const (
acsProtocolVersion = 2
// numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across
// sessions. We use this to send pending acks, before agent initiates a disconnect to ACS.
// they are: refreshCredentialsHandler, taskManifestHandler, payloadHandler and heartbeatHandler
numOfHandlersSendingAcks = 4
// they are: refreshCredentialsHandler, taskManifestHandler, and payloadHandler
numOfHandlersSendingAcks = 3
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -358,12 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {

client.AddRequestHandler(payloadHandler.handlerFunc())

heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client, acsSession.doctor)
defer heartbeatHandler.clearAcks()
heartbeatHandler.start()
defer heartbeatHandler.stop()

client.AddRequestHandler(heartbeatHandler.handlerFunc())
client.AddRequestHandler(HeartbeatHandlerFunc(client, acsSession.doctor))

updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine)

Expand All @@ -377,7 +372,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
// Start a connection timer; agent will send pending acks and close its ACS websocket connection
// after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter,
&refreshCredsHandler, &taskManifestHandler, &payloadHandler, &heartbeatHandler)
&refreshCredsHandler, &taskManifestHandler, &payloadHandler)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
Expand Down Expand Up @@ -521,7 +516,6 @@ func newConnectionTimer(
refreshCredsHandler *refreshCredentialsHandler,
taskManifestHandler *taskManifestHandler,
payloadHandler *payloadRequestHandler,
heartbeatHandler *heartbeatHandler,
) ttime.Timer {
expiresAt := retry.AddJitter(connectionTime, connectionJitter)
timer := time.AfterFunc(expiresAt, func() {
Expand Down Expand Up @@ -549,12 +543,6 @@ func newConnectionTimer(
wg.Done()
}()

// send pending heartbeat acks to ACS
go func() {
heartbeatHandler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for acks from all the handlers above to be sent to ACS before closing the websocket connection.
// the methods used to read pending acks are non-blocking, so it is safe to wait here.
wg.Wait()
Expand Down
3 changes: 0 additions & 3 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,9 +957,6 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {
cancel()
<-ended

// The number of goroutines finishing in the MockACSServer will affect
// the result unless we wait here.
time.Sleep(2 * time.Second)
afterGoroutines := runtime.NumGoroutine()

t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines)
Expand Down
119 changes: 15 additions & 104 deletions agent/acs/handler/heartbeat_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,126 +14,37 @@
package handler

import (
"context"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/doctor"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
"github.com/aws/aws-sdk-go/aws"
"github.com/cihub/seelog"
)

// heartbeatHandler handles heartbeat messages from ACS
type heartbeatHandler struct {
heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage
heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest
ctx context.Context
cancel context.CancelFunc
acsClient wsclient.ClientServer
doctor *doctor.Doctor
}

// newHeartbeatHandler returns an instance of the heartbeatHandler struct
func newHeartbeatHandler(ctx context.Context, acsClient wsclient.ClientServer, heartbeatDoctor *doctor.Doctor) heartbeatHandler {
// Create a cancelable context from the parent context
derivedContext, cancel := context.WithCancel(ctx)
return heartbeatHandler{
heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage),
heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest),
ctx: derivedContext,
cancel: cancel,
acsClient: acsClient,
doctor: heartbeatDoctor,
}
}

// handlerFunc returns a function to enqueue requests onto the buffer
func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) {
func HeartbeatHandlerFunc(acsClient wsclient.ClientServer, doctor *doctor.Doctor) func(message *ecsacs.HeartbeatMessage) {
return func(message *ecsacs.HeartbeatMessage) {
heartbeatHandler.heartbeatMessageBuffer <- message
}
}

// start() invokes go routines to handle receive and respond to heartbeats
func (heartbeatHandler *heartbeatHandler) start() {
go heartbeatHandler.handleHeartbeatMessage()
go heartbeatHandler.sendHeartbeatAck()
}

func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() {
for {
select {
case message := <-heartbeatHandler.heartbeatMessageBuffer:
if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil {
seelog.Warnf("Unable to handle heartbeat message [%s]: %s", message.String(), err)
}
case <-heartbeatHandler.ctx.Done():
return
}
handleSingleHeartbeatMessage(acsClient, doctor, message)
}
}

func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error {
// TestHandlerDoesntLeakGoroutines unit test is failing because of this section
// To handle a Heartbeat Message the doctor health checks need to be run and
// an ACK needs to be sent back to ACS.

// This function is meant to be called from the ACS dispatcher and as such
// should not block in any way to prevent starvation of the message handler
func handleSingleHeartbeatMessage(acsClient wsclient.ClientServer, doctor *doctor.Doctor, message *ecsacs.HeartbeatMessage) {
// Agent will run healthchecks triggered by ACS heartbeat
// healthcheck results will be sent on to TACS, but for now just to debug logs.
go func() {
heartbeatHandler.doctor.RunHealthchecks()
}()
go doctor.RunHealthchecks()

// Agent will send simple ack to the heartbeatAckMessageBuffer
// Agent will send simple ack
ack := &ecsacs.HeartbeatAckRequest{
MessageId: message.MessageId,
}
go func() {
response := &ecsacs.HeartbeatAckRequest{
MessageId: message.MessageId,
err := acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err)
}
heartbeatHandler.heartbeatAckMessageBuffer <- response
}()
return nil
}

func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
case <-heartbeatHandler.ctx.Done():
return
}
}
}

// sendPendingHeartbeatAck sends all pending heartbeat acks to ACS before closing the connection
func (heartbeatHandler *heartbeatHandler) sendPendingHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
default:
return
}
}
}

func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
err := heartbeatHandler.acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err)
}
}

// stop() cancels the context being used by this handler, which stops the go routines started by 'start()'
func (heartbeatHandler *heartbeatHandler) stop() {
heartbeatHandler.cancel()
}

// clearAcks drains the ack request channel
func (heartbeatHandler *heartbeatHandler) clearAcks() {
for {
select {
case <-heartbeatHandler.heartbeatAckMessageBuffer:
default:
return
}
}
}
63 changes: 6 additions & 57 deletions agent/acs/handler/heartbeat_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@
package handler

import (
"context"
"sync"
"testing"
"time"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks"
"github.com/aws/amazon-ecs-agent/agent/doctor"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"

"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -90,67 +85,21 @@ func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessa
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
var heartbeatAckSent *ecsacs.HeartbeatAckRequest
ackSent := make(chan *ecsacs.HeartbeatAckRequest)

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) {
heartbeatAckSent = message
cancel()
ackSent <- message
close(ackSent)
}).Times(1)

dockerClient := mock_dockerapi.NewMockDockerClient(ctrl)
dockerClient.EXPECT().SystemPing(gomock.Any(), gomock.Any()).AnyTimes()

emptyHealthchecksList := []doctor.Healthcheck{}
emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "testCluster", "this:is:an:instance:arn")

handler := newHeartbeatHandler(ctx, mockWsClient, emptyDoctor)

go handler.sendHeartbeatAck()
handleSingleHeartbeatMessage(mockWsClient, emptyDoctor, heartbeatReceived)

handler.handleSingleHeartbeatMessage(heartbeatReceived)

// wait till we get an ack from heartbeatAckMessageBuffer
<-ctx.Done()
// wait till we send an
heartbeatAckSent := <-ackSent

require.Equal(t, heartbeatAckExpected, heartbeatAckSent)
}

func TestHeartbeatHandler(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.TODO()
emptyHealthCheckList := []doctor.Healthcheck{}
emptyDoctor, _ := doctor.NewDoctor(emptyHealthCheckList, "testCluster",
"this:is:an:instance:arn")
mockWSClient := mock_wsclient.NewMockClientServer(ctrl)
mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1)
handler := newHeartbeatHandler(ctx, mockWSClient, emptyDoctor)

wg := sync.WaitGroup{}
wg.Add(2)

// write a dummy ack into the heartbeatAckMessageBuffer
go func() {
handler.heartbeatAckMessageBuffer <- &ecsacs.HeartbeatAckRequest{}
wg.Done()
}()

// sleep here to ensure that the sending go routine executes before the receiving one below. if not, then the
// receiving go routine will finish without receiving the ack since sendPendingHeartbeatAck() is non-blocking.
time.Sleep(1 * time.Second)

go func() {
handler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for both go routines above to finish before we verify that ack channel is empty and exit the test.
// this also ensures that the mock MakeRequest call happened as expected.
wg.Wait()

// verify that the heartbeatAckMessageBuffer channel is empty
assert.Equal(t, 0, len(handler.heartbeatAckMessageBuffer))
}