Skip to content

Commit

Permalink
fix(dot/network): update notificationsProtocol handshakeData to sync.…
Browse files Browse the repository at this point in the history
…Map (#1492)
  • Loading branch information
noot committed Mar 25, 2021
1 parent 3b2ad8d commit 22f7269
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 53 deletions.
8 changes: 4 additions & 4 deletions dot/network/block_announce.go
Expand Up @@ -228,13 +228,13 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err

// don't need to lock here, since function is always called inside the func returned by
// `createNotificationsMessageHandler` which locks the map beforehand.
data, ok := np.handshakeData[peer]
data, ok := np.getHandshakeData(peer)
if !ok {
np.handshakeData[peer] = &handshakeData{
np.handshakeData.Store(peer, &handshakeData{
received: true,
validated: true,
}
data = np.handshakeData[peer]
})
data, _ = np.getHandshakeData(peer)
}

data.handshake = hs
Expand Down
5 changes: 3 additions & 2 deletions dot/network/block_announce_test.go
Expand Up @@ -18,6 +18,7 @@ package network

import (
"math/big"
"sync"
"testing"

"github.com/ChainSafe/gossamer/dot/types"
Expand Down Expand Up @@ -116,10 +117,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
nodeA := createTestService(t, configA)
nodeA.noGossip = true
nodeA.notificationsProtocols[BlockAnnounceMsgType] = &notificationsProtocol{
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
testPeerID := peer.ID("noot")
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData[testPeerID] = &handshakeData{}
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{})

err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
BestBlockNumber: 100,
Expand Down
8 changes: 4 additions & 4 deletions dot/network/host_test.go
Expand Up @@ -363,13 +363,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]

// Set handshake data to received
info.handshakeData[nodeB.host.id()] = &handshakeData{
info.handshakeData.Store(nodeB.host.id(), &handshakeData{
received: true,
validated: true,
}
})

// Verify that handshake data exists.
_, ok := info.handshakeData[nodeB.host.id()]
_, ok := info.getHandshakeData(nodeB.host.id())
require.True(t, ok)

time.Sleep(time.Second)
Expand All @@ -379,7 +379,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
time.Sleep(time.Second)

// Verify that handshake data is cleared.
_, ok = info.handshakeData[nodeB.host.id()]
_, ok = info.getHandshakeData(nodeB.host.id())
require.False(t, ok)
}

Expand Down
38 changes: 24 additions & 14 deletions dot/network/notifications.go
Expand Up @@ -54,10 +54,19 @@ type (
type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
handshakeData map[peer.ID]*handshakeData
handshakeData *sync.Map //map[peer.ID]*handshakeData
mapMu sync.RWMutex
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) {
data, has := n.handshakeData.Load(pid)
if !has {
return nil, false
}

return data.(*handshakeData), true
}

type handshakeData struct {
received bool
validated bool
Expand All @@ -72,7 +81,7 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode
info.mapMu.RLock()
defer info.mapMu.RUnlock()

if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
return handshakeDecoder(in)
}

Expand Down Expand Up @@ -112,12 +121,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
defer info.mapMu.Unlock()

// if we are the receiver and haven't received the handshake already, validate it
if _, has := info.handshakeData[peer]; !has {
if _, has := info.getHandshakeData(peer); !has {
logger.Trace("receiver: validating handshake", "protocol", info.protocolID)
info.handshakeData[peer] = &handshakeData{
info.handshakeData.Store(peer, &handshakeData{
validated: false,
received: true,
}
})

err := handshakeValidator(peer, hs)
if err != nil {
Expand All @@ -126,7 +135,8 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return errCannotValidateHandshake
}

info.handshakeData[peer].validated = true
data, _ := info.getHandshakeData(peer)
data.validated = true

// once validated, send back a handshake
resp, err := info.getHandshake()
Expand All @@ -145,25 +155,25 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
}

// if we are the initiator and haven't received the handshake already, validate it
if hsData, has := info.handshakeData[peer]; has && !hsData.validated {
if hsData, has := info.getHandshakeData(peer); has && !hsData.validated {
logger.Trace("sender: validating handshake")
err := handshakeValidator(peer, hs)
if err != nil {
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
info.handshakeData[peer].validated = false
hsData.validated = false
_ = stream.Conn().Close()
return errCannotValidateHandshake
}

info.handshakeData[peer].validated = true
info.handshakeData[peer].received = true
hsData.validated = true
hsData.received = true
logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer)
} else if hsData.received {
return nil
}

// if we are the initiator, send the message
if hsData, has := info.handshakeData[peer]; has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
logger.Trace("sender: sending message", "protocol", info.protocolID)
err := s.host.send(peer, info.protocolID, hsData.outboundMsg)
if err != nil {
Expand Down Expand Up @@ -223,11 +233,11 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
info.mapMu.RLock()
defer info.mapMu.RUnlock()

if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
info.handshakeData[peer] = &handshakeData{
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
info.handshakeData.Store(peer, &handshakeData{
validated: false,
outboundMsg: msg,
}
})

logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
err = s.host.send(peer, info.protocolID, hs)
Expand Down
30 changes: 18 additions & 12 deletions dot/network/notifications_test.go
Expand Up @@ -18,6 +18,7 @@ package network

import (
"math/big"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -46,15 +47,15 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
info := &notificationsProtocol{
protocolID: s.host.protocolID + blockAnnounceID,
getHandshake: s.getBlockAnnounceHandshake,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
decoder := createDecoder(info, decodeBlockAnnounceHandshake, decodeBlockAnnounceMessage)

// haven't received handshake from peer
testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ")
info.handshakeData[testPeerID] = &handshakeData{
info.handshakeData.Store(testPeerID, &handshakeData{
received: false,
}
})

testHandshake := &BlockAnnounceHandshake{
Roles: 4,
Expand Down Expand Up @@ -82,7 +83,8 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
require.NoError(t, err)

// set handshake data to received
info.handshakeData[testPeerID].received = true
hsData, _ := info.getHandshakeData(testPeerID)
hsData.received = true
msg, err = decoder(enc, testPeerID)
require.NoError(t, err)
require.Equal(t, testBlockAnnounce, msg)
Expand Down Expand Up @@ -132,15 +134,15 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
info := &notificationsProtocol{
protocolID: s.host.protocolID + blockAnnounceID,
getHandshake: s.getBlockAnnounceHandshake,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)

// set handshake data to received
info.handshakeData[testPeerID] = &handshakeData{
info.handshakeData.Store(testPeerID, &handshakeData{
received: true,
validated: true,
}
})
msg := &BlockAnnounceMessage{
Number: big.NewInt(10),
}
Expand All @@ -164,7 +166,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
info := &notificationsProtocol{
protocolID: s.host.protocolID + blockAnnounceID,
getHandshake: s.getBlockAnnounceHandshake,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)

Expand Down Expand Up @@ -205,8 +207,10 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)

err = handler(stream, testHandshake)
require.Equal(t, errCannotValidateHandshake, err)
require.True(t, info.handshakeData[testPeerID].received)
require.False(t, info.handshakeData[testPeerID].validated)
data, has := info.getHandshakeData(testPeerID)
require.True(t, has)
require.True(t, data.received)
require.False(t, data.validated)

// try valid handshake
testHandshake = &BlockAnnounceHandshake{
Expand All @@ -218,6 +222,8 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)

err = handler(stream, testHandshake)
require.NoError(t, err)
require.True(t, info.handshakeData[testPeerID].received)
require.True(t, info.handshakeData[testPeerID].validated)
data, has = info.getHandshakeData(testPeerID)
require.True(t, has)
require.True(t, data.received)
require.True(t, data.validated)
}
27 changes: 12 additions & 15 deletions dot/network/service.go
Expand Up @@ -304,10 +304,10 @@ func (s *Service) handleConn(conn libp2pnetwork.Conn) {
defer info.mapMu.RUnlock()

peer := conn.RemotePeer()
if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
info.handshakeData[peer] = &handshakeData{
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
info.handshakeData.Store(peer, &handshakeData{
validated: false,
}
})

logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
err = s.host.send(peer, info.protocolID, hs)
Expand Down Expand Up @@ -407,7 +407,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID,
np := &notificationsProtocol{
protocolID: protocolID,
getHandshake: handshakeGetter,
handshakeData: make(map[peer.ID]*handshakeData),
handshakeData: new(sync.Map),
}
s.notificationsProtocols[messageID] = np

Expand All @@ -416,13 +416,13 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID,
np.mapMu.Lock()
defer np.mapMu.Unlock()

if _, ok := np.handshakeData[peerID]; ok {
if _, ok := np.getHandshakeData(peerID); ok {
logger.Trace(
"Cleaning up handshake data",
"peer", peerID,
"protocol", protocolID,
)
delete(np.handshakeData, peerID)
np.handshakeData.Delete(peerID)
}
})

Expand Down Expand Up @@ -625,31 +625,28 @@ func (s *Service) Peers() []common.PeerInfo {
peers := []common.PeerInfo{}

s.notificationsMu.RLock()
defer s.notificationsMu.RUnlock()
np := s.notificationsProtocols[BlockAnnounceMsgType]
s.notificationsMu.RUnlock()

for _, p := range s.host.peers() {
if s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p] == nil {
data, has := np.getHandshakeData(p)
if !has || data.handshake == nil {
peers = append(peers, common.PeerInfo{
PeerID: p.String(),
})

continue
}
peerHandshakeMessage := s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p].handshake
if peerHandshakeMessage == nil {
peers = append(peers, common.PeerInfo{
PeerID: p.String(),
})
continue
}

peerHandshakeMessage := data.handshake
peers = append(peers, common.PeerInfo{
PeerID: p.String(),
Roles: peerHandshakeMessage.(*BlockAnnounceHandshake).Roles,
BestHash: peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockHash,
BestNumber: uint64(peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockNumber),
})
}

return peers
}

Expand Down
4 changes: 2 additions & 2 deletions dot/network/service_test.go
Expand Up @@ -371,13 +371,13 @@ func TestHandleConn(t *testing.T) {
require.Equal(t, 1, aScore)

infoA := nodeA.notificationsProtocols[BlockAnnounceMsgType]
hsDataB, has := infoA.handshakeData[nodeB.host.id()]
hsDataB, has := infoA.getHandshakeData(nodeB.host.id())
require.True(t, has)
require.True(t, hsDataB.received)
require.True(t, hsDataB.validated)

infoB := nodeB.notificationsProtocols[BlockAnnounceMsgType]
hsDataA, has := infoB.handshakeData[nodeA.host.id()]
hsDataA, has := infoB.getHandshakeData(nodeA.host.id())
require.True(t, has)
require.True(t, hsDataA.received)
require.True(t, hsDataA.validated)
Expand Down

0 comments on commit 22f7269

Please sign in to comment.