diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index 541e7dbf3a..ab31795702 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -228,13 +228,13 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err // don't need to lock here, since function is always called inside the func returned by // `createNotificationsMessageHandler` which locks the map beforehand. - data, ok := np.handshakeData[peer] + data, ok := np.getHandshakeData(peer) if !ok { - np.handshakeData[peer] = &handshakeData{ + np.handshakeData.Store(peer, &handshakeData{ received: true, validated: true, - } - data = np.handshakeData[peer] + }) + data, _ = np.getHandshakeData(peer) } data.handshake = hs diff --git a/dot/network/block_announce_test.go b/dot/network/block_announce_test.go index 8e01dca514..ea381dad6d 100644 --- a/dot/network/block_announce_test.go +++ b/dot/network/block_announce_test.go @@ -18,6 +18,7 @@ package network import ( "math/big" + "sync" "testing" "github.com/ChainSafe/gossamer/dot/types" @@ -116,10 +117,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) { nodeA := createTestService(t, configA) nodeA.noGossip = true nodeA.notificationsProtocols[BlockAnnounceMsgType] = ¬ificationsProtocol{ - handshakeData: make(map[peer.ID]*handshakeData), + handshakeData: new(sync.Map), } testPeerID := peer.ID("noot") - nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData[testPeerID] = &handshakeData{} + nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{}) err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{ BestBlockNumber: 100, diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 0a9f010b57..56430d94cb 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -363,13 +363,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { info := nodeA.notificationsProtocols[BlockAnnounceMsgType] // Set handshake data to received - info.handshakeData[nodeB.host.id()] = &handshakeData{ + info.handshakeData.Store(nodeB.host.id(), &handshakeData{ received: true, validated: true, - } + }) // Verify that handshake data exists. - _, ok := info.handshakeData[nodeB.host.id()] + _, ok := info.getHandshakeData(nodeB.host.id()) require.True(t, ok) time.Sleep(time.Second) @@ -379,7 +379,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { time.Sleep(time.Second) // Verify that handshake data is cleared. - _, ok = info.handshakeData[nodeB.host.id()] + _, ok = info.getHandshakeData(nodeB.host.id()) require.False(t, ok) } diff --git a/dot/network/notifications.go b/dot/network/notifications.go index bfdff68df9..a733d06b78 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -54,10 +54,19 @@ type ( type notificationsProtocol struct { protocolID protocol.ID getHandshake HandshakeGetter - handshakeData map[peer.ID]*handshakeData + handshakeData *sync.Map //map[peer.ID]*handshakeData mapMu sync.RWMutex } +func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) { + data, has := n.handshakeData.Load(pid) + if !has { + return nil, false + } + + return data.(*handshakeData), true +} + type handshakeData struct { received bool validated bool @@ -72,7 +81,7 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode info.mapMu.RLock() defer info.mapMu.RUnlock() - if hsData, has := info.handshakeData[peer]; !has || !hsData.received { + if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { return handshakeDecoder(in) } @@ -112,12 +121,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, defer info.mapMu.Unlock() // if we are the receiver and haven't received the handshake already, validate it - if _, has := info.handshakeData[peer]; !has { + if _, has := info.getHandshakeData(peer); !has { logger.Trace("receiver: validating handshake", "protocol", info.protocolID) - info.handshakeData[peer] = &handshakeData{ + info.handshakeData.Store(peer, &handshakeData{ validated: false, received: true, - } + }) err := handshakeValidator(peer, hs) if err != nil { @@ -126,7 +135,8 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return errCannotValidateHandshake } - info.handshakeData[peer].validated = true + data, _ := info.getHandshakeData(peer) + data.validated = true // once validated, send back a handshake resp, err := info.getHandshake() @@ -145,25 +155,25 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, } // if we are the initiator and haven't received the handshake already, validate it - if hsData, has := info.handshakeData[peer]; has && !hsData.validated { + if hsData, has := info.getHandshakeData(peer); has && !hsData.validated { logger.Trace("sender: validating handshake") err := handshakeValidator(peer, hs) if err != nil { logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err) - info.handshakeData[peer].validated = false + hsData.validated = false _ = stream.Conn().Close() return errCannotValidateHandshake } - info.handshakeData[peer].validated = true - info.handshakeData[peer].received = true + hsData.validated = true + hsData.received = true logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer) } else if hsData.received { return nil } // if we are the initiator, send the message - if hsData, has := info.handshakeData[peer]; has && hsData.validated && hsData.received && hsData.outboundMsg != nil { + if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil { logger.Trace("sender: sending message", "protocol", info.protocolID) err := s.host.send(peer, info.protocolID, hsData.outboundMsg) if err != nil { @@ -223,11 +233,11 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer info.mapMu.RLock() defer info.mapMu.RUnlock() - if hsData, has := info.handshakeData[peer]; !has || !hsData.received { - info.handshakeData[peer] = &handshakeData{ + if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { + info.handshakeData.Store(peer, &handshakeData{ validated: false, outboundMsg: msg, - } + }) logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs) err = s.host.send(peer, info.protocolID, hs) diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index 2eeeadcb6f..f0f65f793e 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -18,6 +18,7 @@ package network import ( "math/big" + "sync" "testing" "time" @@ -46,15 +47,15 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { info := ¬ificationsProtocol{ protocolID: s.host.protocolID + blockAnnounceID, getHandshake: s.getBlockAnnounceHandshake, - handshakeData: make(map[peer.ID]*handshakeData), + handshakeData: new(sync.Map), } decoder := createDecoder(info, decodeBlockAnnounceHandshake, decodeBlockAnnounceMessage) // haven't received handshake from peer testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") - info.handshakeData[testPeerID] = &handshakeData{ + info.handshakeData.Store(testPeerID, &handshakeData{ received: false, - } + }) testHandshake := &BlockAnnounceHandshake{ Roles: 4, @@ -82,7 +83,8 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { require.NoError(t, err) // set handshake data to received - info.handshakeData[testPeerID].received = true + hsData, _ := info.getHandshakeData(testPeerID) + hsData.received = true msg, err = decoder(enc, testPeerID) require.NoError(t, err) require.Equal(t, testBlockAnnounce, msg) @@ -132,15 +134,15 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) { info := ¬ificationsProtocol{ protocolID: s.host.protocolID + blockAnnounceID, getHandshake: s.getBlockAnnounceHandshake, - handshakeData: make(map[peer.ID]*handshakeData), + handshakeData: new(sync.Map), } handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage) // set handshake data to received - info.handshakeData[testPeerID] = &handshakeData{ + info.handshakeData.Store(testPeerID, &handshakeData{ received: true, validated: true, - } + }) msg := &BlockAnnounceMessage{ Number: big.NewInt(10), } @@ -164,7 +166,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) info := ¬ificationsProtocol{ protocolID: s.host.protocolID + blockAnnounceID, getHandshake: s.getBlockAnnounceHandshake, - handshakeData: make(map[peer.ID]*handshakeData), + handshakeData: new(sync.Map), } handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage) @@ -205,8 +207,10 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) err = handler(stream, testHandshake) require.Equal(t, errCannotValidateHandshake, err) - require.True(t, info.handshakeData[testPeerID].received) - require.False(t, info.handshakeData[testPeerID].validated) + data, has := info.getHandshakeData(testPeerID) + require.True(t, has) + require.True(t, data.received) + require.False(t, data.validated) // try valid handshake testHandshake = &BlockAnnounceHandshake{ @@ -218,6 +222,8 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) err = handler(stream, testHandshake) require.NoError(t, err) - require.True(t, info.handshakeData[testPeerID].received) - require.True(t, info.handshakeData[testPeerID].validated) + data, has = info.getHandshakeData(testPeerID) + require.True(t, has) + require.True(t, data.received) + require.True(t, data.validated) } diff --git a/dot/network/service.go b/dot/network/service.go index 6ce82d8be0..a970812985 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -304,10 +304,10 @@ func (s *Service) handleConn(conn libp2pnetwork.Conn) { defer info.mapMu.RUnlock() peer := conn.RemotePeer() - if hsData, has := info.handshakeData[peer]; !has || !hsData.received { - info.handshakeData[peer] = &handshakeData{ + if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { + info.handshakeData.Store(peer, &handshakeData{ validated: false, - } + }) logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs) err = s.host.send(peer, info.protocolID, hs) @@ -407,7 +407,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, np := ¬ificationsProtocol{ protocolID: protocolID, getHandshake: handshakeGetter, - handshakeData: make(map[peer.ID]*handshakeData), + handshakeData: new(sync.Map), } s.notificationsProtocols[messageID] = np @@ -416,13 +416,13 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, np.mapMu.Lock() defer np.mapMu.Unlock() - if _, ok := np.handshakeData[peerID]; ok { + if _, ok := np.getHandshakeData(peerID); ok { logger.Trace( "Cleaning up handshake data", "peer", peerID, "protocol", protocolID, ) - delete(np.handshakeData, peerID) + np.handshakeData.Delete(peerID) } }) @@ -625,24 +625,20 @@ func (s *Service) Peers() []common.PeerInfo { peers := []common.PeerInfo{} s.notificationsMu.RLock() - defer s.notificationsMu.RUnlock() + np := s.notificationsProtocols[BlockAnnounceMsgType] + s.notificationsMu.RUnlock() for _, p := range s.host.peers() { - if s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p] == nil { + data, has := np.getHandshakeData(p) + if !has || data.handshake == nil { peers = append(peers, common.PeerInfo{ PeerID: p.String(), }) continue } - peerHandshakeMessage := s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p].handshake - if peerHandshakeMessage == nil { - peers = append(peers, common.PeerInfo{ - PeerID: p.String(), - }) - continue - } + peerHandshakeMessage := data.handshake peers = append(peers, common.PeerInfo{ PeerID: p.String(), Roles: peerHandshakeMessage.(*BlockAnnounceHandshake).Roles, @@ -650,6 +646,7 @@ func (s *Service) Peers() []common.PeerInfo { BestNumber: uint64(peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockNumber), }) } + return peers } diff --git a/dot/network/service_test.go b/dot/network/service_test.go index dab4ccb26c..4f9743fc1f 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -371,13 +371,13 @@ func TestHandleConn(t *testing.T) { require.Equal(t, 1, aScore) infoA := nodeA.notificationsProtocols[BlockAnnounceMsgType] - hsDataB, has := infoA.handshakeData[nodeB.host.id()] + hsDataB, has := infoA.getHandshakeData(nodeB.host.id()) require.True(t, has) require.True(t, hsDataB.received) require.True(t, hsDataB.validated) infoB := nodeB.notificationsProtocols[BlockAnnounceMsgType] - hsDataA, has := infoB.handshakeData[nodeA.host.id()] + hsDataA, has := infoB.getHandshakeData(nodeA.host.id()) require.True(t, has) require.True(t, hsDataA.received) require.True(t, hsDataA.validated)