Skip to content

Commit

Permalink
fix(dot/network): Fix notification handshake and reuse stream. (#1545)
Browse files Browse the repository at this point in the history
  • Loading branch information
arijitAD committed Apr 30, 2021
1 parent 6fd2501 commit a632dc4
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 117 deletions.
3 changes: 2 additions & 1 deletion dot/network/block_announce.go
Expand Up @@ -222,14 +222,15 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err
// `createNotificationsMessageHandler` which locks the map beforehand.
data, ok := np.getHandshakeData(peer)
if !ok {
np.handshakeData.Store(peer, &handshakeData{
np.handshakeData.Store(peer, handshakeData{
received: true,
validated: true,
})
data, _ = np.getHandshakeData(peer)
}

data.handshake = hs
np.handshakeData.Store(peer, data)

// if peer has higher best block than us, begin syncing
latestHeader, err := s.blockState.BestBlockHeader()
Expand Down
2 changes: 1 addition & 1 deletion dot/network/block_announce_test.go
Expand Up @@ -120,7 +120,7 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
handshakeData: new(sync.Map),
}
testPeerID := peer.ID("noot")
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{})
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, handshakeData{})

err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
BestBlockNumber: 100,
Expand Down
2 changes: 1 addition & 1 deletion dot/network/gossip_test.go
Expand Up @@ -101,7 +101,7 @@ func TestGossip(t *testing.T) {
}
require.NoError(t, err)

err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage)
_, err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
Expand Down
42 changes: 18 additions & 24 deletions dot/network/host.go
Expand Up @@ -262,32 +262,26 @@ func (h *host) bootstrap() {
}
}

// send writes the given message to the outbound message stream for the given
// peer (gets the already opened outbound message stream or opens a new one).
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) {
// get outbound stream for given peer
s := h.getOutboundStream(p, pid)

// check if stream needs to be opened
if s == nil {
// open outbound stream with host protocol id
s, err = h.h.NewStream(h.ctx, p, pid)
if err != nil {
logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err)
return err
}

logger.Trace(
"Opened stream",
"host", h.id(),
"peer", p,
"protocol", pid,
)
// send creates a new outbound stream with the given peer and writes the message. It also returns
// the newly created stream.
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (libp2pnetwork.Stream, error) {
// open outbound stream with host protocol id
stream, err := h.h.NewStream(h.ctx, p, pid)
if err != nil {
logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err)
return nil, err
}

err = h.writeToStream(s, msg)
logger.Trace(
"Opened stream",
"host", h.id(),
"peer", p,
"protocol", pid,
)

err = h.writeToStream(stream, msg)
if err != nil {
return err
return nil, err
}

logger.Trace(
Expand All @@ -298,7 +292,7 @@ func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) {
"message", msg.String(),
)

return nil
return stream, nil
}

func (h *host) writeToStream(s libp2pnetwork.Stream, msg Message) error {
Expand Down
29 changes: 7 additions & 22 deletions dot/network/host_test.go
Expand Up @@ -218,7 +218,7 @@ func TestSend(t *testing.T) {
}
require.NoError(t, err)

err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
_, err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
Expand Down Expand Up @@ -273,44 +273,29 @@ func TestExistingStream(t *testing.T) {
}
require.NoError(t, err)

stream := nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
require.Nil(t, stream, "node A should not have an outbound stream")

// node A opens the stream to send the first message
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
stream, err := nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A")

stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node A should have an outbound stream")

// node A uses the stream to send a second message
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
err = nodeA.host.writeToStream(stream, testBlockRequestMessage)
require.NoError(t, err)
require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A")

stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node B should have an outbound stream")

// node B opens the stream to send the first message
err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
stream, err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B")

stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node B should have an outbound stream")

// node B uses the stream to send a second message
err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
err = nodeB.host.writeToStream(stream, testBlockRequestMessage)
require.NoError(t, err)
require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B")

stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node B should have an outbound stream")
}

func TestStreamCloseMetadataCleanup(t *testing.T) {
Expand Down Expand Up @@ -361,13 +346,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
}

// node A opens the stream to send the first message
err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
_, err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
require.NoError(t, err)

info := nodeA.notificationsProtocols[BlockAnnounceMsgType]

// Set handshake data to received
info.handshakeData.Store(nodeB.host.id(), &handshakeData{
info.handshakeData.Store(nodeB.host.id(), handshakeData{
received: true,
validated: true,
})
Expand Down
115 changes: 74 additions & 41 deletions dot/network/notifications.go
Expand Up @@ -49,29 +49,33 @@ type (

// NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream.
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error

streamHandler = func(libp2pnetwork.Stream, peer.ID)
)

type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
handshakeData *sync.Map //map[peer.ID]*handshakeData
streamHandler streamHandler
mapMu sync.RWMutex
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) {
func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (handshakeData, bool) {
data, has := n.handshakeData.Load(pid)
if !has {
return nil, false
return handshakeData{}, false
}

return data.(*handshakeData), true
return data.(handshakeData), true
}

type handshakeData struct {
received bool
validated bool
handshake Handshake
outboundMsg NotificationsMessage
stream libp2pnetwork.Stream
}

func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecoder, messageDecoder MessageDecoder) messageDecoder {
Expand Down Expand Up @@ -123,19 +127,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
// if we are the receiver and haven't received the handshake already, validate it
if _, has := info.getHandshakeData(peer); !has {
logger.Trace("receiver: validating handshake", "protocol", info.protocolID)
info.handshakeData.Store(peer, &handshakeData{
hsData := handshakeData{
validated: false,
received: true,
})
stream: stream,
}
info.handshakeData.Store(peer, hsData)

err := handshakeValidator(peer, hs)
if err != nil {
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
return errCannotValidateHandshake
}

data, _ := info.getHandshakeData(peer)
data.validated = true
hsData.validated = true
info.handshakeData.Store(peer, hsData)

// once validated, send back a handshake
resp, err := info.getHandshake()
Expand All @@ -144,7 +150,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return err
}

err = s.host.writeToStream(stream, resp)
err = s.host.writeToStream(hsData.stream, resp)
if err != nil {
logger.Trace("failed to send handshake", "protocol", info.protocolID, "peer", peer, "error", err)
return err
Expand All @@ -160,20 +166,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
if err != nil {
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
hsData.validated = false
info.handshakeData.Store(peer, hsData)
return errCannotValidateHandshake
}

hsData.validated = true
hsData.received = true
info.handshakeData.Store(peer, hsData)

logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer)
} else if hsData.received {
return nil
}

// if we are the initiator, send the message
if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
logger.Trace("sender: sending message", "protocol", info.protocolID)
err := s.host.writeToStream(stream, hsData.outboundMsg)
err := s.host.writeToStream(hsData.stream, hsData.outboundMsg)
if err != nil {
logger.Debug("failed to send message", "protocol", info.protocolID, "peer", peer, "error", err)
return err
Expand Down Expand Up @@ -209,6 +216,61 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
}
}

func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) {
hsData, has := info.getHandshakeData(peer)
if !has || !hsData.received {
hsData = handshakeData{
validated: false,
received: false,
outboundMsg: msg,
}

info.handshakeData.Store(peer, hsData)
logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)

stream, err := s.host.send(peer, info.protocolID, hs)
if err != nil {
logger.Trace("failed to send message to peer", "peer", peer, "error", err)
return
}

hsData.stream = stream
info.handshakeData.Store(peer, hsData)

if info.streamHandler == nil {
return
}

go info.streamHandler(stream, peer)
return
}

if s.host.messageCache != nil {
added, err := s.host.messageCache.put(peer, msg)
if err != nil {
logger.Error("failed to add message to cache", "peer", peer, "error", err)
return
}

if !added {
return
}
}

if hsData.stream == nil {
logger.Error("trying to send data through empty stream", "protocol", info.protocolID, "peer", peer, "message", msg)
return
}

// we've already completed the handshake with the peer, send message directly
logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg)

err := s.host.writeToStream(hsData.stream, msg)
if err != nil {
logger.Trace("failed to send message to peer", "peer", peer, "error", err)
}
}

// gossipExcluding sends a message to each connected peer except the given peer
// Used for notifications sub-protocols to gossip a message
func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer.ID, msg NotificationsMessage) {
Expand All @@ -234,35 +296,6 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
continue
}

if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
info.handshakeData.Store(peer, &handshakeData{
validated: false,
outboundMsg: msg,
})

logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
err = s.host.send(peer, info.protocolID, hs)
} else {
if s.host.messageCache != nil {
var added bool
added, err = s.host.messageCache.put(peer, msg)
if err != nil {
logger.Error("failed to add message to cache", "peer", peer, "error", err)
continue
}

if !added {
continue
}
}

// we've already completed the handshake with the peer, send message directly
logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg)
err = s.host.send(peer, info.protocolID, msg)
}

if err != nil {
logger.Debug("failed to send message to peer", "peer", peer, "error", err)
}
go s.sendData(peer, hs, info, msg)
}
}
5 changes: 3 additions & 2 deletions dot/network/notifications_test.go
Expand Up @@ -53,7 +53,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {

// haven't received handshake from peer
testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ")
info.handshakeData.Store(testPeerID, &handshakeData{
info.handshakeData.Store(testPeerID, handshakeData{
received: false,
})

Expand Down Expand Up @@ -85,6 +85,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
// set handshake data to received
hsData, _ := info.getHandshakeData(testPeerID)
hsData.received = true
info.handshakeData.Store(testPeerID, hsData)
msg, err = decoder(enc, testPeerID)
require.NoError(t, err)
require.Equal(t, testBlockAnnounce, msg)
Expand Down Expand Up @@ -139,7 +140,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)

// set handshake data to received
info.handshakeData.Store(testPeerID, &handshakeData{
info.handshakeData.Store(testPeerID, handshakeData{
received: true,
validated: true,
})
Expand Down

0 comments on commit a632dc4

Please sign in to comment.