diff --git a/telemetry/cnstelemetry.go b/telemetry/cnstelemetry.go index b5e15eff04..a554def7f7 100644 --- a/telemetry/cnstelemetry.go +++ b/telemetry/cnstelemetry.go @@ -80,6 +80,7 @@ CONNECT: if err == nil { // If write fails, try to re-establish connections as server/client if _, err = telemetryBuffer.Write(report); err != nil { + log.Printf("[CNS-Telemetry] Telemetry write failed: %v", err) telemetryBuffer.Cancel() goto CONNECT } diff --git a/telemetry/telemetry_test.go b/telemetry/telemetry_test.go index 0daf668f1e..a12703d107 100644 --- a/telemetry/telemetry_test.go +++ b/telemetry/telemetry_test.go @@ -175,6 +175,63 @@ func TestCloseTelemetryConnection(t *testing.T) { } } +func TestServerCloseTelemetryConnection(t *testing.T) { + // create server telemetrybuffer and start server + tb = NewTelemetryBuffer(hostAgentUrl) + err := tb.StartServer() + if err == nil { + go tb.BufferAndPushData(0) + } + + // create client telemetrybuffer and connect to server + tb1 := NewTelemetryBuffer(hostAgentUrl) + if err := tb1.Connect(); err != nil { + t.Errorf("connection to telemetry server failed %v", err) + } + + // Exit server thread and close server connection + tb.Cancel() + time.Sleep(300 * time.Millisecond) + + b := []byte("tamil") + if _, err := tb1.Write(b); err == nil { + t.Errorf("Client couldn't recognise server close") + } + + if len(tb.connections) != 0 { + t.Errorf("All connections not closed as expected") + } + + // Close client connection + tb1.Close() +} + +func TestClientCloseTelemetryConnection(t *testing.T) { + // create server telemetrybuffer and start server + tb = NewTelemetryBuffer(hostAgentUrl) + err := tb.StartServer() + if err == nil { + go tb.BufferAndPushData(0) + } + + // create client telemetrybuffer and connect to server + tb1 := NewTelemetryBuffer(hostAgentUrl) + if err := tb1.Connect(); err != nil { + t.Errorf("connection to telemetry server failed %v", err) + } + + // Close client connection + tb1.Close() + time.Sleep(300 * time.Millisecond) + + if len(tb.connections) != 0 { + t.Errorf("All connections not closed as expected") + } + + // Exit server thread and close server connection + tb.Cancel() +} + func TestSetReportState(t *testing.T) { err := reportManager.SetReportState("a.json") if err != nil { diff --git a/telemetry/telemetrybuffer.go b/telemetry/telemetrybuffer.go index cb5436a984..9ee8a6b948 100644 --- a/telemetry/telemetrybuffer.go +++ b/telemetry/telemetrybuffer.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "strings" + "sync" "time" "github.com/Azure/azure-container-networking/common" @@ -52,6 +53,7 @@ type TelemetryBuffer struct { Connected bool data chan interface{} cancel chan bool + mutex sync.Mutex } // Payload object holds the different types of reports @@ -87,8 +89,13 @@ func NewTelemetryBuffer(hostReportURL string) *TelemetryBuffer { } func remove(s []net.Conn, i int) []net.Conn { - s[i] = s[len(s)-1] - return s[:len(s)-1] + if len(s) > 0 && i < len(s) { + s[i] = s[len(s)-1] + return s[:len(s)-1] + } + + telemetryLogger.Printf("tb connections remove failed index %v len %v", i, len(s)) + return s } // Starts Telemetry server listening on unix domain socket @@ -107,7 +114,9 @@ func (tb *TelemetryBuffer) StartServer() error { // Spawn worker goroutines to communicate with client conn, err := tb.listener.Accept() if err == nil { + tb.mutex.Lock() tb.connections = append(tb.connections, conn) + tb.mutex.Unlock() go func() { for { reportStr, err := read(conn) @@ -132,18 +141,32 @@ func (tb *TelemetryBuffer) StartServer() error { tb.data <- cnsReport } } else { - telemetryLogger.Printf("Server closing client connection") - for index, value := range tb.connections { + var index int + var value net.Conn + var found bool + + tb.mutex.Lock() + defer tb.mutex.Unlock() + + for index, value = range tb.connections { if value == conn { + telemetryLogger.Printf("Server closing client connection") conn.Close() - tb.connections = remove(tb.connections, index) - return + found = true + break } } + + if found { + tb.connections = remove(tb.connections, index) + } + + return } } }() } else { + telemetryLogger.Printf("Telemetry Server accept error %v", err) return } } @@ -239,9 +262,12 @@ func (tb *TelemetryBuffer) Close() { tb.listener = nil } + tb.mutex.Lock() + defer tb.mutex.Unlock() + for _, conn := range tb.connections { if conn != nil { - telemetryLogger.Printf("connection close") + telemetryLogger.Printf("connection close as server closed") conn.Close() } }