From 85e31ea17cdd502983cfe74bf71e29c6e2027fc0 Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Thu, 6 Apr 2023 15:53:39 -0400 Subject: [PATCH 1/8] Add maximum accepted header size on outgoing websocket connection --- go.mod | 4 +- go.sum | 4 +- network/wsNetwork.go | 9 +++ network/wsNetwork_test.go | 150 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 158 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index a63e13e5c2..47e6bcd8f2 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/algorand/graphtrace v0.1.0 github.com/algorand/msgp v1.1.53 github.com/algorand/oapi-codegen v1.12.0-algorand.0 - github.com/algorand/websocket v1.4.5 + github.com/algorand/websocket v1.4.6 github.com/aws/aws-sdk-go v1.33.0 github.com/consensys/gnark-crypto v0.7.0 github.com/davidlazar/go-crypto v0.0.0-20170701192655-dcfb0a7ac018 @@ -36,6 +36,7 @@ require ( golang.org/x/sys v0.1.0 golang.org/x/text v0.4.0 gopkg.in/sohlich/elogrus.v3 v3.0.0-20180410122755-1fa29e2f2009 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -71,5 +72,4 @@ require ( golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0bf38663bd..21406525ea 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ github.com/algorand/msgp v1.1.53 h1:D6HKLyvLE6ltfsf8Apsrc+kqYb/CcOZEAfh1DpkPrNg= github.com/algorand/msgp v1.1.53/go.mod h1:5K3d58/poT5fPmtiwuQft6GjgSrVEM46KoXdLrID8ZU= github.com/algorand/oapi-codegen v1.12.0-algorand.0 h1:W9PvED+wAJc+9EeXPONnA+0zE9UhynEqoDs4OgAxKhk= github.com/algorand/oapi-codegen v1.12.0-algorand.0/go.mod h1:tIWJ9K/qrLDVDt5A1p82UmxZIEGxv2X+uoujdhEAL48= -github.com/algorand/websocket v1.4.5 h1:Cs6UTaCReAl02evYxmN8k57cNHmBILRcspfSxYg4AJE= -github.com/algorand/websocket v1.4.5/go.mod h1:79n6FSZY08yQagHzE/YWZqTPBYfY5wc3IS+UTZe1W5c= +github.com/algorand/websocket v1.4.6 h1:I0kV4EYwatuUrKtNiwzYYgojgwh6pksDmlqntKG2Woc= +github.com/algorand/websocket v1.4.6/go.mod h1:HJmdGzFtnlUQ4nTzZP6WrT29oGYf1t6Ybi64vROcT+M= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/aws/aws-sdk-go v1.33.0 h1:Bq5Y6VTLbfnJp1IV8EL/qUU5qO1DYHda/zis/sqevkY= diff --git a/network/wsNetwork.go b/network/wsNetwork.go index 7339bbde61..824b1f9660 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -105,6 +105,9 @@ const unprintableCharacterGlyph = "▯" // PublicAddress (which will match HTTP Listener's Address) in tests only. const testingPublicAddress = "testing" +// Maximum number of bytes to read from a header when trying to establish a websocket connection. +const wsMaxHeaderBytes = 4096 + var networkIncomingConnections = metrics.MakeGauge(metrics.NetworkIncomingConnections) var networkOutgoingConnections = metrics.MakeGauge(metrics.NetworkOutgoingConnections) @@ -396,6 +399,9 @@ type WebsocketNetwork struct { // outgoingMessagesBufferSize is the size used for outgoing messages. outgoingMessagesBufferSize int + // maxHeaderSize is the maximum accepted size of the header prior to upgrading to websocket connection. + wsMaxHeaderBytes int64 + // slowWritingPeerMonitorInterval defines the interval between two consecutive tests for slow peer writing slowWritingPeerMonitorInterval time.Duration @@ -758,6 +764,8 @@ func (wn *WebsocketNetwork) setup() { config.Consensus[protocol.ConsensusCurrentVersion].DownCommitteeSize), ) + wn.wsMaxHeaderBytes = wsMaxHeaderBytes + wn.identityTracker = NewIdentityTracker() wn.broadcastQueueHighPrio = make(chan broadcastRequest, wn.outgoingMessagesBufferSize) @@ -2193,6 +2201,7 @@ func (wn *WebsocketNetwork) tryConnect(addr, gossipAddr string) { EnableCompression: false, NetDialContext: wn.dialer.DialContext, NetDial: wn.dialer.Dial, + MaxHeaderSize: wn.wsMaxHeaderBytes, } conn, response, err := websocketDialer.DialContext(wn.ctx, gossipAddr, requestHeader) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 4f7d01c95b..324ee2b601 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -27,6 +27,7 @@ import ( "math/rand" "net" "net/http" + "net/http/httptest" "net/url" "os" "runtime" @@ -41,6 +42,7 @@ import ( "github.com/stretchr/testify/require" "github.com/algorand/go-deadlock" + "github.com/algorand/websocket" "github.com/algorand/go-algorand/config" "github.com/algorand/go-algorand/crypto" @@ -54,6 +56,8 @@ import ( const sendBufferLength = 1000 +const genesisID = "go-test-network-genesis" + func init() { // this allows test code to use out-of-protocol message tags and have them go through allowCustomTags = true @@ -127,7 +131,7 @@ func makeTestWebsocketNodeWithConfig(t testing.TB, conf config.Local, opts ...te log: log, config: conf, phonebook: MakePhonebook(1, 1*time.Millisecond), - GenesisID: "go-test-network-genesis", + GenesisID: genesisID, NetworkID: config.Devtestnet, } // apply options to newly-created WebsocketNetwork, if provided @@ -990,7 +994,7 @@ func makeTestFilterWebsocketNode(t *testing.T, nodename string) *WebsocketNetwor log: logging.TestingLog(t).With("node", nodename), config: dc, phonebook: MakePhonebook(1, 1*time.Millisecond), - GenesisID: "go-test-network-genesis", + GenesisID: genesisID, NetworkID: config.Devtestnet, } require.True(t, wn.config.EnableIncomingMessageFilter) @@ -2462,7 +2466,7 @@ func TestSlowPeerDisconnection(t *testing.T) { log: log, config: defaultConfig, phonebook: MakePhonebook(1, 1*time.Millisecond), - GenesisID: "go-test-network-genesis", + GenesisID: genesisID, NetworkID: config.Devtestnet, slowWritingPeerMonitorInterval: time.Millisecond * 50, } @@ -2537,7 +2541,7 @@ func TestForceMessageRelaying(t *testing.T) { log: log, config: defaultConfig, phonebook: MakePhonebook(1, 1*time.Millisecond), - GenesisID: "go-test-network-genesis", + GenesisID: genesisID, NetworkID: config.Devtestnet, } wn.setup() @@ -2631,7 +2635,7 @@ func TestCheckProtocolVersionMatch(t *testing.T) { log: log, config: defaultConfig, phonebook: MakePhonebook(1, 1*time.Millisecond), - GenesisID: "go-test-network-genesis", + GenesisID: genesisID, NetworkID: config.Devtestnet, } wn.setup() @@ -3757,3 +3761,139 @@ func TestWebsocketNetworkTelemetryTCP(t *testing.T) { t.Log("closed detailsA", string(pcdA)) t.Log("closed detailsB", string(pcdB)) } + +type mockServer struct { + *httptest.Server + URL string + t *testing.T + + waitForClientClose bool + sendClose bool + sendCloseWC bool + + gotClientClose chan struct{} +} + +type mockHandler struct { + *testing.T + s *mockServer +} + +var mockUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + http.Error(w, reason.Error(), status) + }, +} + +func (t mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Set the required headers to successfully establish a connection + responseHeader := http.Header{} + responseHeader.Add(ProtocolVersionHeader, ProtocolVersion) + responseHeader.Add(GenesisHeader, genesisID) + responseHeader.Add(NodeRandomHeader, "randomHeader") + ws, err := mockUpgrader.Upgrade(w, r, responseHeader) + if err != nil { + t.Logf("Upgrade: %v", err) + return + } + defer ws.Close() + + for true { + // echo a message back to the client + op, rd, err := ws.NextReader() + if err != nil { + if _, ok := err.(*websocket.CloseError); ok && t.s.waitForClientClose { + t.Log("got client close") + close(t.s.gotClientClose) + return + } + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } + t.Log("sent message") + if !t.s.waitForClientClose { + break + } + } + if t.s.sendClose { + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + t.Logf("WriteMessage(CloseMessage): %v", err) + return + } + t.Log("sent close") + } else if t.s.sendCloseWC { + err = ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(5*time.Second)) + if err != nil { + t.Logf("WriteControl(CloseMessage): %v", err) + return + } + t.Log("sent close") + } +} + +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + +func newServer(t *testing.T) *mockServer { + var s mockServer + s.Server = httptest.NewServer(mockHandler{t, &s}) + s.Server.URL += "" + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func TestMaxHeaderSize(t *testing.T) { + partitiontest.PartitionTest(t) + + netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"}) + netA.config.GossipFanout = 1 + + s := newServer(t) + s.waitForClientClose = true + defer s.Close() + + netA.Start() + defer netA.Stop() + + // First make sure that the regular connection with default max header size works + netA.wsMaxHeaderBytes = wsMaxHeaderBytes + netA.wg.Add(1) + netA.tryConnect(s.URL, s.URL) + time.Sleep(250 * time.Millisecond) + assert.Equal(t, 1, len(netA.peers)) + + netA.removePeer(netA.peers[0], disconnectReasonNone) + assert.Zero(t, len(netA.peers)) + + // Now try to connect with a max header size that is too small + netA.wsMaxHeaderBytes = 64 + netA.wg.Add(1) + netA.tryConnect(s.URL, s.URL) + time.Sleep(250 * time.Millisecond) + assert.Zero(t, len(netA.peers)) + + // Test that setting 0 disables the max header size check + netA.wsMaxHeaderBytes = 0 + netA.wg.Add(1) + netA.tryConnect(s.URL, s.URL) + time.Sleep(250 * time.Millisecond) + assert.Equal(t, 1, len(netA.peers)) +} From 660d6825184e4f9ebe1550a677553300831f8f8d Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Thu, 6 Apr 2023 16:08:24 -0400 Subject: [PATCH 2/8] remove need for manually mocking websocket server --- network/wsNetwork_test.go | 120 +++++--------------------------------- 1 file changed, 13 insertions(+), 107 deletions(-) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 324ee2b601..9274790293 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -27,7 +27,6 @@ import ( "math/rand" "net" "net/http" - "net/http/httptest" "net/url" "os" "runtime" @@ -42,7 +41,6 @@ import ( "github.com/stretchr/testify/require" "github.com/algorand/go-deadlock" - "github.com/algorand/websocket" "github.com/algorand/go-algorand/config" "github.com/algorand/go-algorand/crypto" @@ -3762,121 +3760,29 @@ func TestWebsocketNetworkTelemetryTCP(t *testing.T) { t.Log("closed detailsB", string(pcdB)) } -type mockServer struct { - *httptest.Server - URL string - t *testing.T - - waitForClientClose bool - sendClose bool - sendCloseWC bool - - gotClientClose chan struct{} -} - -type mockHandler struct { - *testing.T - s *mockServer -} - -var mockUpgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - EnableCompression: true, - Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { - http.Error(w, reason.Error(), status) - }, -} - -func (t mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Set the required headers to successfully establish a connection - responseHeader := http.Header{} - responseHeader.Add(ProtocolVersionHeader, ProtocolVersion) - responseHeader.Add(GenesisHeader, genesisID) - responseHeader.Add(NodeRandomHeader, "randomHeader") - ws, err := mockUpgrader.Upgrade(w, r, responseHeader) - if err != nil { - t.Logf("Upgrade: %v", err) - return - } - defer ws.Close() - - for true { - // echo a message back to the client - op, rd, err := ws.NextReader() - if err != nil { - if _, ok := err.(*websocket.CloseError); ok && t.s.waitForClientClose { - t.Log("got client close") - close(t.s.gotClientClose) - return - } - t.Logf("NextReader: %v", err) - return - } - wr, err := ws.NextWriter(op) - if err != nil { - t.Logf("NextWriter: %v", err) - return - } - if _, err = io.Copy(wr, rd); err != nil { - t.Logf("NextWriter: %v", err) - return - } - if err := wr.Close(); err != nil { - t.Logf("Close: %v", err) - return - } - t.Log("sent message") - if !t.s.waitForClientClose { - break - } - } - if t.s.sendClose { - err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - t.Logf("WriteMessage(CloseMessage): %v", err) - return - } - t.Log("sent close") - } else if t.s.sendCloseWC { - err = ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(5*time.Second)) - if err != nil { - t.Logf("WriteControl(CloseMessage): %v", err) - return - } - t.Log("sent close") - } -} - -func makeWsProto(s string) string { - return "ws" + strings.TrimPrefix(s, "http") -} - -func newServer(t *testing.T) *mockServer { - var s mockServer - s.Server = httptest.NewServer(mockHandler{t, &s}) - s.Server.URL += "" - s.URL = makeWsProto(s.Server.URL) - return &s -} - func TestMaxHeaderSize(t *testing.T) { partitiontest.PartitionTest(t) netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"}) netA.config.GossipFanout = 1 - s := newServer(t) - s.waitForClientClose = true - defer s.Close() + netB := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netB"}) + netB.config.GossipFanout = 1 netA.Start() defer netA.Stop() + netB.Start() + defer netB.Stop() + + addrB, ok := netB.Address() + require.True(t, ok) + gossipB, err := netB.addrToGossipAddr(addrB) + require.NoError(t, err) // First make sure that the regular connection with default max header size works netA.wsMaxHeaderBytes = wsMaxHeaderBytes netA.wg.Add(1) - netA.tryConnect(s.URL, s.URL) + netA.tryConnect(addrB, gossipB) time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, len(netA.peers)) @@ -3884,16 +3790,16 @@ func TestMaxHeaderSize(t *testing.T) { assert.Zero(t, len(netA.peers)) // Now try to connect with a max header size that is too small - netA.wsMaxHeaderBytes = 64 + netA.wsMaxHeaderBytes = 128 netA.wg.Add(1) - netA.tryConnect(s.URL, s.URL) + netA.tryConnect(addrB, gossipB) time.Sleep(250 * time.Millisecond) assert.Zero(t, len(netA.peers)) // Test that setting 0 disables the max header size check netA.wsMaxHeaderBytes = 0 netA.wg.Add(1) - netA.tryConnect(s.URL, s.URL) + netA.tryConnect(addrB, gossipB) time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, len(netA.peers)) } From 67da0801ec086623646d6753750ef1f95d7ec19c Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Fri, 7 Apr 2023 09:57:37 -0400 Subject: [PATCH 3/8] Revert "remove need for manually mocking websocket server" This reverts commit 660d6825184e4f9ebe1550a677553300831f8f8d. --- network/wsNetwork_test.go | 120 +++++++++++++++++++++++++++++++++----- 1 file changed, 107 insertions(+), 13 deletions(-) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 9274790293..324ee2b601 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -27,6 +27,7 @@ import ( "math/rand" "net" "net/http" + "net/http/httptest" "net/url" "os" "runtime" @@ -41,6 +42,7 @@ import ( "github.com/stretchr/testify/require" "github.com/algorand/go-deadlock" + "github.com/algorand/websocket" "github.com/algorand/go-algorand/config" "github.com/algorand/go-algorand/crypto" @@ -3760,29 +3762,121 @@ func TestWebsocketNetworkTelemetryTCP(t *testing.T) { t.Log("closed detailsB", string(pcdB)) } +type mockServer struct { + *httptest.Server + URL string + t *testing.T + + waitForClientClose bool + sendClose bool + sendCloseWC bool + + gotClientClose chan struct{} +} + +type mockHandler struct { + *testing.T + s *mockServer +} + +var mockUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + http.Error(w, reason.Error(), status) + }, +} + +func (t mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Set the required headers to successfully establish a connection + responseHeader := http.Header{} + responseHeader.Add(ProtocolVersionHeader, ProtocolVersion) + responseHeader.Add(GenesisHeader, genesisID) + responseHeader.Add(NodeRandomHeader, "randomHeader") + ws, err := mockUpgrader.Upgrade(w, r, responseHeader) + if err != nil { + t.Logf("Upgrade: %v", err) + return + } + defer ws.Close() + + for true { + // echo a message back to the client + op, rd, err := ws.NextReader() + if err != nil { + if _, ok := err.(*websocket.CloseError); ok && t.s.waitForClientClose { + t.Log("got client close") + close(t.s.gotClientClose) + return + } + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } + t.Log("sent message") + if !t.s.waitForClientClose { + break + } + } + if t.s.sendClose { + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + t.Logf("WriteMessage(CloseMessage): %v", err) + return + } + t.Log("sent close") + } else if t.s.sendCloseWC { + err = ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(5*time.Second)) + if err != nil { + t.Logf("WriteControl(CloseMessage): %v", err) + return + } + t.Log("sent close") + } +} + +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + +func newServer(t *testing.T) *mockServer { + var s mockServer + s.Server = httptest.NewServer(mockHandler{t, &s}) + s.Server.URL += "" + s.URL = makeWsProto(s.Server.URL) + return &s +} + func TestMaxHeaderSize(t *testing.T) { partitiontest.PartitionTest(t) netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"}) netA.config.GossipFanout = 1 - netB := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netB"}) - netB.config.GossipFanout = 1 + s := newServer(t) + s.waitForClientClose = true + defer s.Close() netA.Start() defer netA.Stop() - netB.Start() - defer netB.Stop() - - addrB, ok := netB.Address() - require.True(t, ok) - gossipB, err := netB.addrToGossipAddr(addrB) - require.NoError(t, err) // First make sure that the regular connection with default max header size works netA.wsMaxHeaderBytes = wsMaxHeaderBytes netA.wg.Add(1) - netA.tryConnect(addrB, gossipB) + netA.tryConnect(s.URL, s.URL) time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, len(netA.peers)) @@ -3790,16 +3884,16 @@ func TestMaxHeaderSize(t *testing.T) { assert.Zero(t, len(netA.peers)) // Now try to connect with a max header size that is too small - netA.wsMaxHeaderBytes = 128 + netA.wsMaxHeaderBytes = 64 netA.wg.Add(1) - netA.tryConnect(addrB, gossipB) + netA.tryConnect(s.URL, s.URL) time.Sleep(250 * time.Millisecond) assert.Zero(t, len(netA.peers)) // Test that setting 0 disables the max header size check netA.wsMaxHeaderBytes = 0 netA.wg.Add(1) - netA.tryConnect(addrB, gossipB) + netA.tryConnect(s.URL, s.URL) time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, len(netA.peers)) } From 83e3b87c4ce06b9e274817432f6a97945c9f15a6 Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Mon, 10 Apr 2023 15:20:06 -0400 Subject: [PATCH 4/8] Add additional requested tests --- network/wsNetwork.go | 1 + network/wsNetwork_test.go | 134 +++++++++++++++++++++++--------------- 2 files changed, 84 insertions(+), 51 deletions(-) diff --git a/network/wsNetwork.go b/network/wsNetwork.go index 824b1f9660..1aa6265798 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -2205,6 +2205,7 @@ func (wn *WebsocketNetwork) tryConnect(addr, gossipAddr string) { } conn, response, err := websocketDialer.DialContext(wn.ctx, gossipAddr, requestHeader) + if err != nil { if err == websocket.ErrBadHandshake { // reading here from ioutil is safe only because it came from DialContext above, which already finished reading all the data from the network diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 324ee2b601..0c7599cb49 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -3768,10 +3768,6 @@ type mockServer struct { t *testing.T waitForClientClose bool - sendClose bool - sendCloseWC bool - - gotClientClose chan struct{} } type mockHandler struct { @@ -3788,63 +3784,48 @@ var mockUpgrader = websocket.Upgrader{ }, } +func buildWsResponseHeader() http.Header { + h := http.Header{} + h.Add(ProtocolVersionHeader, ProtocolVersion) + h.Add(GenesisHeader, genesisID) + h.Add(NodeRandomHeader, "randomHeader") + return h +} + func (t mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Set the required headers to successfully establish a connection - responseHeader := http.Header{} - responseHeader.Add(ProtocolVersionHeader, ProtocolVersion) - responseHeader.Add(GenesisHeader, genesisID) - responseHeader.Add(NodeRandomHeader, "randomHeader") - ws, err := mockUpgrader.Upgrade(w, r, responseHeader) + ws, err := mockUpgrader.Upgrade(w, r, buildWsResponseHeader()) if err != nil { t.Logf("Upgrade: %v", err) return } defer ws.Close() + // Send a message of interest immediately after the connection is established + wr, err := ws.NextWriter(websocket.BinaryMessage) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + + bytes := MarshallMessageOfInterest([]protocol.Tag{protocol.AgreementVoteTag}) + msgBytes := append([]byte(protocol.MsgOfInterestTag), bytes...) + _, err = wr.Write(msgBytes) + if err != nil { + t.Logf("Error writing MessageOfInterest: %v", err) + return + } + wr.Close() for true { // echo a message back to the client - op, rd, err := ws.NextReader() + _, _, err := ws.NextReader() if err != nil { if _, ok := err.(*websocket.CloseError); ok && t.s.waitForClientClose { t.Log("got client close") - close(t.s.gotClientClose) return } - t.Logf("NextReader: %v", err) - return - } - wr, err := ws.NextWriter(op) - if err != nil { - t.Logf("NextWriter: %v", err) - return - } - if _, err = io.Copy(wr, rd); err != nil { - t.Logf("NextWriter: %v", err) - return - } - if err := wr.Close(); err != nil { - t.Logf("Close: %v", err) return } - t.Log("sent message") - if !t.s.waitForClientClose { - break - } - } - if t.s.sendClose { - err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - t.Logf("WriteMessage(CloseMessage): %v", err) - return - } - t.Log("sent close") - } else if t.s.sendCloseWC { - err = ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(5*time.Second)) - if err != nil { - t.Logf("WriteControl(CloseMessage): %v", err) - return - } - t.Log("sent close") } } @@ -3866,17 +3847,23 @@ func TestMaxHeaderSize(t *testing.T) { netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"}) netA.config.GossipFanout = 1 - s := newServer(t) - s.waitForClientClose = true - defer s.Close() + netB := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netB"}) + netB.config.GossipFanout = 1 netA.Start() defer netA.Stop() + netB.Start() + defer netB.Stop() + + addrB, ok := netB.Address() + require.True(t, ok) + gossipB, err := netB.addrToGossipAddr(addrB) + require.NoError(t, err) // First make sure that the regular connection with default max header size works netA.wsMaxHeaderBytes = wsMaxHeaderBytes netA.wg.Add(1) - netA.tryConnect(s.URL, s.URL) + netA.tryConnect(addrB, gossipB) time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, len(netA.peers)) @@ -3884,16 +3871,61 @@ func TestMaxHeaderSize(t *testing.T) { assert.Zero(t, len(netA.peers)) // Now try to connect with a max header size that is too small - netA.wsMaxHeaderBytes = 64 + netA.wsMaxHeaderBytes = 128 netA.wg.Add(1) - netA.tryConnect(s.URL, s.URL) + netA.tryConnect(addrB, gossipB) time.Sleep(250 * time.Millisecond) assert.Zero(t, len(netA.peers)) // Test that setting 0 disables the max header size check netA.wsMaxHeaderBytes = 0 netA.wg.Add(1) - netA.tryConnect(s.URL, s.URL) + netA.tryConnect(addrB, gossipB) time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, len(netA.peers)) } + +func TestTryConnectEarlyWrite(t *testing.T) { + partitiontest.PartitionTest(t) + + netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"}) + netA.config.GossipFanout = 1 + + s := newServer(t) + s.waitForClientClose = true + defer s.Close() + + netA.Start() + defer netA.Stop() + + dialer := websocket.Dialer{} + mconn, resp, _ := dialer.Dial(s.URL, nil) + expectedHeader := buildWsResponseHeader() + for k, v := range expectedHeader { + assert.Equal(t, v[0], resp.Header.Get(k)) + } + + headerSize := 36 // Fixed overhead of the full status line "HTTP/1.1 101 Switching Protocols" + 4 + for k, v := range resp.Header { + headerSize += len(k) + len(v[0]) + 4 + } + mconn.Close() + + // Setting the max header size to 1 byte less than the minimum header size should fail + netA.wsMaxHeaderBytes = int64(headerSize) - 1 + netA.wg.Add(1) + netA.tryConnect(s.URL, s.URL) + time.Sleep(250 * time.Millisecond) + assert.Len(t, netA.peers, 0) + + // Now set the max header size to the minimum header size and it should succeed + netA.wsMaxHeaderBytes = int64(headerSize) + netA.wg.Add(1) + netA.tryConnect(s.URL, s.URL) + time.Sleep(250 * time.Millisecond) + + // Confirm that we successfuly received a message of interest + assert.Len(t, netA.peers, 1) + fmt.Printf("MI Message Count: %v\n", netA.peers[0].miMessageCount) + assert.Equal(t, uint64(1), netA.peers[0].miMessageCount) +} From b58155260d7c16144f5cf9435aaf3b5ecd4e36e2 Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Thu, 13 Apr 2023 15:07:55 -0400 Subject: [PATCH 5/8] Fix race in the test --- network/wsNetwork_test.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 0c7599cb49..c9b212ea2b 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -3922,7 +3922,15 @@ func TestTryConnectEarlyWrite(t *testing.T) { netA.wsMaxHeaderBytes = int64(headerSize) netA.wg.Add(1) netA.tryConnect(s.URL, s.URL) - time.Sleep(250 * time.Millisecond) + p := netA.peers[0] + var messageCount uint64 + for x := 0; x < 1000; x++ { + messageCount = atomic.LoadUint64(&p.miMessageCount) + if messageCount == 1 { + break + } + time.Sleep(2 * time.Millisecond) + } // Confirm that we successfuly received a message of interest assert.Len(t, netA.peers, 1) From 23f4c77f0b86428da04b2efb81e01766f6db5227 Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Thu, 13 Apr 2023 15:08:39 -0400 Subject: [PATCH 6/8] Update network/wsNetwork.go Co-authored-by: Shant Karakashian <55754073+algonautshant@users.noreply.github.com> --- network/wsNetwork.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/wsNetwork.go b/network/wsNetwork.go index 1aa6265798..9196553524 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -399,7 +399,7 @@ type WebsocketNetwork struct { // outgoingMessagesBufferSize is the size used for outgoing messages. outgoingMessagesBufferSize int - // maxHeaderSize is the maximum accepted size of the header prior to upgrading to websocket connection. + // wsMaxHeaderBytes is the maximum accepted size of the header prior to upgrading to websocket connection. wsMaxHeaderBytes int64 // slowWritingPeerMonitorInterval defines the interval between two consecutive tests for slow peer writing From b855fd6f8f88f4909aacd0523a83f3c01778f921 Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Tue, 18 Apr 2023 09:57:13 -0400 Subject: [PATCH 7/8] address review feedback --- network/wsNetwork_test.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index c9b212ea2b..5edde45cd9 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -3871,10 +3871,16 @@ func TestMaxHeaderSize(t *testing.T) { assert.Zero(t, len(netA.peers)) // Now try to connect with a max header size that is too small + logBuffer := bytes.NewBuffer(nil) + netA.log.SetOutput(logBuffer) + netA.wsMaxHeaderBytes = 128 netA.wg.Add(1) netA.tryConnect(addrB, gossipB) + lg := logBuffer.String() + logBuffer.Reset() time.Sleep(250 * time.Millisecond) + assert.Contains(t, lg, fmt.Sprintf("ws connect(%s) fail:", gossipB)) assert.Zero(t, len(netA.peers)) // Test that setting 0 disables the max header size check @@ -3905,21 +3911,21 @@ func TestTryConnectEarlyWrite(t *testing.T) { assert.Equal(t, v[0], resp.Header.Get(k)) } - headerSize := 36 // Fixed overhead of the full status line "HTTP/1.1 101 Switching Protocols" + 4 + minValidHeaderSize := 36 // Fixed overhead of the full status line "HTTP/1.1 101 Switching Protocols" + 4 for k, v := range resp.Header { - headerSize += len(k) + len(v[0]) + 4 + minValidHeaderSize += len(k) + len(v[0]) + 4 } mconn.Close() // Setting the max header size to 1 byte less than the minimum header size should fail - netA.wsMaxHeaderBytes = int64(headerSize) - 1 + netA.wsMaxHeaderBytes = int64(minValidHeaderSize) - 1 netA.wg.Add(1) netA.tryConnect(s.URL, s.URL) time.Sleep(250 * time.Millisecond) assert.Len(t, netA.peers, 0) // Now set the max header size to the minimum header size and it should succeed - netA.wsMaxHeaderBytes = int64(headerSize) + netA.wsMaxHeaderBytes = int64(minValidHeaderSize) netA.wg.Add(1) netA.tryConnect(s.URL, s.URL) p := netA.peers[0] From 157c7debe0aa518e9dfb46315466202b490c4fa7 Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Tue, 18 Apr 2023 12:58:07 -0400 Subject: [PATCH 8/8] amend comments to explain headersize calculation --- network/wsNetwork_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 5edde45cd9..3229e2ae20 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -3911,9 +3911,11 @@ func TestTryConnectEarlyWrite(t *testing.T) { assert.Equal(t, v[0], resp.Header.Get(k)) } - minValidHeaderSize := 36 // Fixed overhead of the full status line "HTTP/1.1 101 Switching Protocols" + 4 + // Fixed overhead of the full status line "HTTP/1.1 101 Switching Protocols" (32) + 4 bytes for two instance of CRLF + // one after the status line and one to separate headers from the body + minValidHeaderSize := 36 for k, v := range resp.Header { - minValidHeaderSize += len(k) + len(v[0]) + 4 + minValidHeaderSize += len(k) + len(v[0]) + 4 // + 4 is for the ": " and CRLF } mconn.Close()