Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dot/network): update notificationsProtocol handshakeData to sync.Map #1492

Merged
merged 4 commits into from Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions dot/network/block_announce.go
Expand Up @@ -228,13 +228,13 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err

// don't need to lock here, since function is always called inside the func returned by
// `createNotificationsMessageHandler` which locks the map beforehand.
data, ok := np.handshakeData[peer]
data, ok := np.getHandshakeData(peer)
if !ok {
np.handshakeData[peer] = &handshakeData{
np.handshakeData.Store(peer, &handshakeData{
received: true,
validated: true,
}
data = np.handshakeData[peer]
})
data, _ = np.getHandshakeData(peer)
}

data.handshake = hs
Expand Down
5 changes: 3 additions & 2 deletions dot/network/block_announce_test.go
Expand Up @@ -18,6 +18,7 @@ package network

import (
"math/big"
"sync"
"testing"

"github.com/ChainSafe/gossamer/dot/types"
Expand Down Expand Up @@ -116,10 +117,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
nodeA := createTestService(t, configA)
nodeA.noGossip = true
nodeA.notificationsProtocols[BlockAnnounceMsgType] = &notificationsProtocol{
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
testPeerID := peer.ID("noot")
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData[testPeerID] = &handshakeData{}
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{})

err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
BestBlockNumber: 100,
Expand Down
8 changes: 4 additions & 4 deletions dot/network/host_test.go
Expand Up @@ -363,13 +363,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]

// Set handshake data to received
info.handshakeData[nodeB.host.id()] = &handshakeData{
info.handshakeData.Store(nodeB.host.id(), &handshakeData{
received: true,
validated: true,
}
})

// Verify that handshake data exists.
_, ok := info.handshakeData[nodeB.host.id()]
_, ok := info.getHandshakeData(nodeB.host.id())
require.True(t, ok)

time.Sleep(time.Second)
Expand All @@ -379,7 +379,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
time.Sleep(time.Second)

// Verify that handshake data is cleared.
_, ok = info.handshakeData[nodeB.host.id()]
_, ok = info.getHandshakeData(nodeB.host.id())
require.False(t, ok)
}

Expand Down
38 changes: 24 additions & 14 deletions dot/network/notifications.go
Expand Up @@ -54,10 +54,19 @@ type (
type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
handshakeData map[peer.ID]*handshakeData
handshakeData *sync.Map //map[peer.ID]*handshakeData
mapMu sync.RWMutex
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) {
data, has := n.handshakeData.Load(pid)
if !has {
return nil, false
}

return data.(*handshakeData), true
}

type handshakeData struct {
received bool
validated bool
Expand All @@ -72,7 +81,7 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode
info.mapMu.RLock()
defer info.mapMu.RUnlock()

if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
return handshakeDecoder(in)
}

Expand Down Expand Up @@ -112,12 +121,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
defer info.mapMu.Unlock()

// if we are the receiver and haven't received the handshake already, validate it
if _, has := info.handshakeData[peer]; !has {
if _, has := info.getHandshakeData(peer); !has {
logger.Trace("receiver: validating handshake", "protocol", info.protocolID)
info.handshakeData[peer] = &handshakeData{
info.handshakeData.Store(peer, &handshakeData{
validated: false,
received: true,
}
})

err := handshakeValidator(peer, hs)
if err != nil {
Expand All @@ -126,7 +135,8 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return errCannotValidateHandshake
}

info.handshakeData[peer].validated = true
data, _ := info.getHandshakeData(peer)
data.validated = true

// once validated, send back a handshake
resp, err := info.getHandshake()
Expand All @@ -145,25 +155,25 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
}

// if we are the initiator and haven't received the handshake already, validate it
if hsData, has := info.handshakeData[peer]; has && !hsData.validated {
if hsData, has := info.getHandshakeData(peer); has && !hsData.validated {
logger.Trace("sender: validating handshake")
err := handshakeValidator(peer, hs)
if err != nil {
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
info.handshakeData[peer].validated = false
hsData.validated = false
_ = stream.Conn().Close()
return errCannotValidateHandshake
}

info.handshakeData[peer].validated = true
info.handshakeData[peer].received = true
hsData.validated = true
hsData.received = true
logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer)
} else if hsData.received {
return nil
}

// if we are the initiator, send the message
if hsData, has := info.handshakeData[peer]; has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
logger.Trace("sender: sending message", "protocol", info.protocolID)
err := s.host.send(peer, info.protocolID, hsData.outboundMsg)
if err != nil {
Expand Down Expand Up @@ -223,11 +233,11 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
info.mapMu.RLock()
defer info.mapMu.RUnlock()

if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
info.handshakeData[peer] = &handshakeData{
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
info.handshakeData.Store(peer, &handshakeData{
validated: false,
outboundMsg: msg,
}
})

logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
err = s.host.send(peer, info.protocolID, hs)
Expand Down
30 changes: 18 additions & 12 deletions dot/network/notifications_test.go
Expand Up @@ -18,6 +18,7 @@ package network

import (
"math/big"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -46,15 +47,15 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
info := &notificationsProtocol{
protocolID: s.host.protocolID + blockAnnounceID,
getHandshake: s.getBlockAnnounceHandshake,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
decoder := createDecoder(info, decodeBlockAnnounceHandshake, decodeBlockAnnounceMessage)

// haven't received handshake from peer
testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ")
info.handshakeData[testPeerID] = &handshakeData{
info.handshakeData.Store(testPeerID, &handshakeData{
received: false,
}
})

testHandshake := &BlockAnnounceHandshake{
Roles: 4,
Expand Down Expand Up @@ -82,7 +83,8 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
require.NoError(t, err)

// set handshake data to received
info.handshakeData[testPeerID].received = true
hsData, _ := info.getHandshakeData(testPeerID)
hsData.received = true
msg, err = decoder(enc, testPeerID)
require.NoError(t, err)
require.Equal(t, testBlockAnnounce, msg)
Expand Down Expand Up @@ -132,15 +134,15 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
info := &notificationsProtocol{
protocolID: s.host.protocolID + blockAnnounceID,
getHandshake: s.getBlockAnnounceHandshake,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)

// set handshake data to received
info.handshakeData[testPeerID] = &handshakeData{
info.handshakeData.Store(testPeerID, &handshakeData{
received: true,
validated: true,
}
})
msg := &BlockAnnounceMessage{
Number: big.NewInt(10),
}
Expand All @@ -164,7 +166,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
info := &notificationsProtocol{
protocolID: s.host.protocolID + blockAnnounceID,
getHandshake: s.getBlockAnnounceHandshake,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)

Expand Down Expand Up @@ -205,8 +207,10 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)

err = handler(stream, testHandshake)
require.Equal(t, errCannotValidateHandshake, err)
require.True(t, info.handshakeData[testPeerID].received)
require.False(t, info.handshakeData[testPeerID].validated)
data, has := info.getHandshakeData(testPeerID)
require.True(t, has)
require.True(t, data.received)
require.False(t, data.validated)

// try valid handshake
testHandshake = &BlockAnnounceHandshake{
Expand All @@ -218,6 +222,8 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)

err = handler(stream, testHandshake)
require.NoError(t, err)
require.True(t, info.handshakeData[testPeerID].received)
require.True(t, info.handshakeData[testPeerID].validated)
data, has = info.getHandshakeData(testPeerID)
require.True(t, has)
require.True(t, data.received)
require.True(t, data.validated)
}
27 changes: 12 additions & 15 deletions dot/network/service.go
Expand Up @@ -304,10 +304,10 @@ func (s *Service) handleConn(conn libp2pnetwork.Conn) {
defer info.mapMu.RUnlock()

peer := conn.RemotePeer()
if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
info.handshakeData[peer] = &handshakeData{
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
info.handshakeData.Store(peer, &handshakeData{
validated: false,
}
})

logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
err = s.host.send(peer, info.protocolID, hs)
Expand Down Expand Up @@ -407,7 +407,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID,
np := &notificationsProtocol{
protocolID: protocolID,
getHandshake: handshakeGetter,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
s.notificationsProtocols[messageID] = np

Expand All @@ -416,13 +416,13 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID,
np.mapMu.Lock()
defer np.mapMu.Unlock()

if _, ok := np.handshakeData[peerID]; ok {
if _, ok := np.getHandshakeData(peerID); ok {
logger.Trace(
"Cleaning up handshake data",
"peer", peerID,
"protocol", protocolID,
)
delete(np.handshakeData, peerID)
np.handshakeData.Delete(peerID)
}
})

Expand Down Expand Up @@ -625,31 +625,28 @@ func (s *Service) Peers() []common.PeerInfo {
peers := []common.PeerInfo{}

s.notificationsMu.RLock()
defer s.notificationsMu.RUnlock()
np := s.notificationsProtocols[BlockAnnounceMsgType]
s.notificationsMu.RUnlock()

for _, p := range s.host.peers() {
if s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p] == nil {
data, has := np.getHandshakeData(p)
if !has || data.handshake == nil {
peers = append(peers, common.PeerInfo{
PeerID: p.String(),
})

continue
}
peerHandshakeMessage := s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p].handshake
if peerHandshakeMessage == nil {
peers = append(peers, common.PeerInfo{
PeerID: p.String(),
})
continue
}

peerHandshakeMessage := data.handshake
peers = append(peers, common.PeerInfo{
PeerID: p.String(),
Roles: peerHandshakeMessage.(*BlockAnnounceHandshake).Roles,
BestHash: peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockHash,
BestNumber: uint64(peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockNumber),
})
}

return peers
}

Expand Down
4 changes: 2 additions & 2 deletions dot/network/service_test.go
Expand Up @@ -371,13 +371,13 @@ func TestHandleConn(t *testing.T) {
require.Equal(t, 1, aScore)

infoA := nodeA.notificationsProtocols[BlockAnnounceMsgType]
hsDataB, has := infoA.handshakeData[nodeB.host.id()]
hsDataB, has := infoA.getHandshakeData(nodeB.host.id())
require.True(t, has)
require.True(t, hsDataB.received)
require.True(t, hsDataB.validated)

infoB := nodeB.notificationsProtocols[BlockAnnounceMsgType]
hsDataA, has := infoB.handshakeData[nodeA.host.id()]
hsDataA, has := infoB.getHandshakeData(nodeA.host.id())
require.True(t, has)
require.True(t, hsDataA.received)
require.True(t, hsDataA.validated)
Expand Down