diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index c18be68b2a..63d02e4479 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -220,18 +220,14 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err // don't need to lock here, since function is always called inside the func returned by // `createNotificationsMessageHandler` which locks the map beforehand. - data, ok := np.getHandshakeData(peer) - if !ok { - np.handshakeData.Store(peer, handshakeData{ - received: true, - validated: true, - }) - data, _ = np.getHandshakeData(peer) + data, ok := np.getHandshakeData(peer, true) + if ok { + data.handshake = hs + // TODO: since this is used only for rpc system_peers only, + // we can just set the inbound handshake and use that in Peers() + np.inboundHandshakeData.Store(peer, data) } - data.handshake = hs - np.handshakeData.Store(peer, data) - // if peer has higher best block than us, begin syncing latestHeader, err := s.blockState.BestBlockHeader() if err != nil { diff --git a/dot/network/block_announce_test.go b/dot/network/block_announce_test.go index 79c01f2b4c..0eac5903cf 100644 --- a/dot/network/block_announce_test.go +++ b/dot/network/block_announce_test.go @@ -117,10 +117,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) { nodeA := createTestService(t, configA) nodeA.noGossip = true nodeA.notificationsProtocols[BlockAnnounceMsgType] = ¬ificationsProtocol{ - handshakeData: new(sync.Map), + inboundHandshakeData: new(sync.Map), } testPeerID := peer.ID("noot") - nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, handshakeData{}) + nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, handshakeData{}) err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{ BestBlockNumber: 100, diff --git a/dot/network/host_test.go b/dot/network/host_test.go index fd32789817..26090f8cb0 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -352,13 +352,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { info := nodeA.notificationsProtocols[BlockAnnounceMsgType] // Set handshake data to received - info.handshakeData.Store(nodeB.host.id(), handshakeData{ + info.inboundHandshakeData.Store(nodeB.host.id(), handshakeData{ received: true, validated: true, }) // Verify that handshake data exists. - _, ok := info.getHandshakeData(nodeB.host.id()) + _, ok := info.getHandshakeData(nodeB.host.id(), true) require.True(t, ok) time.Sleep(time.Second) @@ -368,7 +368,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { time.Sleep(time.Second) // Verify that handshake data is cleared. - _, ok = info.getHandshakeData(nodeB.host.id()) + _, ok = info.getHandshakeData(nodeB.host.id(), true) require.False(t, ok) } diff --git a/dot/network/light_test.go b/dot/network/light_test.go index 50a5572aa2..bc9b2fd862 100644 --- a/dot/network/light_test.go +++ b/dot/network/light_test.go @@ -22,7 +22,7 @@ func TestDecodeLightMessage(t *testing.T) { reqEnc, err := testLightRequest.Encode() require.NoError(t, err) - msg, err := s.decodeLightMessage(reqEnc, testPeer) + msg, err := s.decodeLightMessage(reqEnc, testPeer, true) require.NoError(t, err) req, ok := msg.(*LightRequest) @@ -36,7 +36,7 @@ func TestDecodeLightMessage(t *testing.T) { respEnc, err := testLightResponse.Encode() require.NoError(t, err) - msg, err = s.decodeLightMessage(respEnc, testPeer) + msg, err = s.decodeLightMessage(respEnc, testPeer, true) require.NoError(t, err) resp, ok := msg.(*LightResponse) require.True(t, ok) diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 1ed7654b83..a6c2fcd2b7 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -18,8 +18,8 @@ package network import ( "errors" - "math/rand" "sync" + "unsafe" libp2pnetwork "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -28,6 +28,8 @@ import ( var errCannotValidateHandshake = errors.New("failed to validate handshake") +var maxHandshakeSize = unsafe.Sizeof(BlockAnnounceHandshake{}) //nolint + // Handshake is the interface all handshakes for notifications protocols must implement type Handshake interface { NotificationsMessage @@ -49,20 +51,28 @@ type ( // NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream. NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error - - streamHandler = func(libp2pnetwork.Stream, peer.ID) ) type notificationsProtocol struct { - protocolID protocol.ID - getHandshake HandshakeGetter - handshakeData *sync.Map //map[peer.ID]*handshakeData - streamHandler streamHandler - mapMu sync.RWMutex + protocolID protocol.ID + getHandshake HandshakeGetter + handshakeValidator HandshakeValidator + + inboundHandshakeData *sync.Map //map[peer.ID]*handshakeData + outboundHandshakeData *sync.Map //map[peer.ID]*handshakeData } -func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (handshakeData, bool) { - data, has := n.handshakeData.Load(pid) +func (n *notificationsProtocol) getHandshakeData(pid peer.ID, inbound bool) (handshakeData, bool) { + if inbound { + data, has := n.inboundHandshakeData.Load(pid) + if !has { + return handshakeData{}, false + } + + return data.(handshakeData), true + } + + data, has := n.outboundHandshakeData.Load(pid) if !has { return handshakeData{}, false } @@ -71,21 +81,27 @@ func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (handshakeData, bo } type handshakeData struct { - received bool - validated bool - handshake Handshake - outboundMsg NotificationsMessage - stream libp2pnetwork.Stream + received bool + validated bool + handshake Handshake + stream libp2pnetwork.Stream + *sync.Mutex +} + +func newHandshakeData(received, validated bool, stream libp2pnetwork.Stream) handshakeData { + return handshakeData{ + received: received, + validated: validated, + stream: stream, + Mutex: new(sync.Mutex), + } } func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecoder, messageDecoder MessageDecoder) messageDecoder { - return func(in []byte, peer peer.ID) (Message, error) { + return func(in []byte, peer peer.ID, inbound bool) (Message, error) { // if we don't have handshake data on this peer, or we haven't received the handshake from them already, // assume we are receiving the handshake - info.mapMu.RLock() - defer info.mapMu.RUnlock() - - if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { + if hsData, has := info.getHandshakeData(peer, inbound); !has || !hsData.received { return handshakeDecoder(in) } @@ -94,9 +110,9 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode } } -func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, handshakeValidator HandshakeValidator, messageHandler NotificationsMessageHandler) messageHandler { +func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, messageHandler NotificationsMessageHandler) messageHandler { return func(stream libp2pnetwork.Stream, m Message) error { - if m == nil || info == nil || handshakeValidator == nil || messageHandler == nil { + if m == nil || info == nil || info.handshakeValidator == nil || messageHandler == nil { return nil } @@ -121,27 +137,24 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return errors.New("failed to convert message to Handshake") } - info.mapMu.Lock() - defer info.mapMu.Unlock() - // if we are the receiver and haven't received the handshake already, validate it - if _, has := info.getHandshakeData(peer); !has { + // note: if this function is being called, it's being called via SetStreamHandler, + // ie it is an inbound stream and we only send the handshake over it. + // we do not send any other data over this stream, we would need to open a new outbound stream. + if _, has := info.getHandshakeData(peer, true); !has { logger.Trace("receiver: validating handshake", "protocol", info.protocolID) - hsData := handshakeData{ - validated: false, - received: true, - stream: stream, - } - info.handshakeData.Store(peer, hsData) - err := handshakeValidator(peer, hs) + hsData := newHandshakeData(true, false, stream) + info.inboundHandshakeData.Store(peer, hsData) + + err := info.handshakeValidator(peer, hs) if err != nil { logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err) return errCannotValidateHandshake } hsData.validated = true - info.handshakeData.Store(peer, hsData) + info.inboundHandshakeData.Store(peer, hsData) // once validated, send back a handshake resp, err := info.getHandshake() @@ -150,7 +163,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return err } - err = s.host.writeToStream(hsData.stream, resp) + err = s.host.writeToStream(stream, resp) if err != nil { logger.Trace("failed to send handshake", "protocol", info.protocolID, "peer", peer, "error", err) return err @@ -159,35 +172,6 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return nil } - // if we are the initiator and haven't received the handshake already, validate it - if hsData, has := info.getHandshakeData(peer); has && !hsData.validated { - logger.Trace("sender: validating handshake") - err := handshakeValidator(peer, hs) - if err != nil { - logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err) - hsData.validated = false - info.handshakeData.Store(peer, hsData) - return errCannotValidateHandshake - } - - hsData.validated = true - hsData.received = true - info.handshakeData.Store(peer, hsData) - - logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer) - } - - // if we are the initiator, send the message - if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil { - logger.Trace("sender: sending message", "protocol", info.protocolID) - err := s.host.writeToStream(hsData.stream, hsData.outboundMsg) - if err != nil { - logger.Debug("failed to send message", "protocol", info.protocolID, "peer", peer, "error", err) - return err - } - return nil - } - return nil } @@ -201,14 +185,17 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return err } - // TODO: improve this by keeping track of who you've received/sent messages from if s.noGossip { return nil } + // TODO: we don't want to rebroadcast neighbour messages, so ignore all consensus messages for now + if _, isConsensus := msg.(*ConsensusMessage); isConsensus { + return nil + } + seen := s.gossip.hasSeen(msg) if !seen { - // TODO: update this to write to stream w/ handshake established s.broadcastExcluding(info, peer, msg) } @@ -217,17 +204,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, } func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) { - hsData, has := info.getHandshakeData(peer) - if !has || !hsData.received { - hsData = handshakeData{ - validated: false, - received: false, - outboundMsg: msg, + hsData, has := info.getHandshakeData(peer, false) + if has && !hsData.validated { + // peer has sent us an invalid handshake in the past, ignore + return + } + + if !has || !hsData.received || hsData.stream == nil { + if !has { + hsData = newHandshakeData(false, false, nil) } - info.handshakeData.Store(peer, hsData) - logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs) + hsData.Lock() + defer hsData.Unlock() + logger.Trace("sending outbound handshake", "protocol", info.protocolID, "peer", peer, "message", hs) stream, err := s.host.send(peer, info.protocolID, hs) if err != nil { logger.Trace("failed to send message to peer", "peer", peer, "error", err) @@ -235,14 +226,32 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc } hsData.stream = stream - info.handshakeData.Store(peer, hsData) + info.outboundHandshakeData.Store(peer, hsData) - if info.streamHandler == nil { + if info.handshakeValidator == nil { return } - go info.streamHandler(stream, peer) - return + hs, err := readHandshake(stream, decodeBlockAnnounceHandshake) + if err != nil { + logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err) + _ = stream.Close() + return + } + + hsData.received = true + + err = info.handshakeValidator(peer, hs) + if err != nil { + logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err) + hsData.validated = false + info.outboundHandshakeData.Store(peer, hsData) + return + } + + hsData.validated = true + info.outboundHandshakeData.Store(peer, hsData) + logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer) } if s.host.messageCache != nil { @@ -257,12 +266,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc } } - if hsData.stream == nil { - logger.Error("trying to send data through empty stream", "protocol", info.protocolID, "peer", peer, "message", msg) - return - } - - // we've already completed the handshake with the peer, send message directly + // we've completed the handshake with the peer, send message directly logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg) err := s.host.writeToStream(hsData.stream, msg) @@ -271,8 +275,9 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc } } -// gossipExcluding sends a message to each connected peer except the given peer -// Used for notifications sub-protocols to gossip a message +// broadcastExcluding sends a message to each connected peer except the given peer, +// and peers that have previously sent us the message or who we have already sent the message to. +// used for notifications sub-protocols to gossip a message func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer.ID, msg NotificationsMessage) { logger.Trace( "broadcasting message from notifications sub-protocol", @@ -286,12 +291,7 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer } peers := s.host.peers() - rand.Shuffle(len(peers), func(i, j int) { peers[i], peers[j] = peers[j], peers[i] }) - - info.mapMu.RLock() - defer info.mapMu.RUnlock() - - for _, peer := range peers { // TODO: check if stream is open, if not, open and send handshake + for _, peer := range peers { if peer == excluding { continue } @@ -299,3 +299,18 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer go s.sendData(peer, hs, info, msg) } } + +func readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) (Handshake, error) { + msgBytes := make([]byte, maxHandshakeSize) + tot, err := readStream(stream, msgBytes) + if err != nil { + return nil, err + } + + hs, err := decoder(msgBytes[:tot]) + if err != nil { + return nil, err + } + + return hs, nil +} diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index 0a4b8c1dc4..83926edaef 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -30,6 +30,10 @@ import ( "github.com/stretchr/testify/require" ) +func TestHandshake_SizeOf(t *testing.T) { + require.Equal(t, uint32(maxHandshakeSize), uint32(72)) +} + func TestCreateDecoder_BlockAnnounce(t *testing.T) { basePath := utils.NewTestBasePath(t, "nodeA") @@ -45,15 +49,17 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { // create info and decoder info := ¬ificationsProtocol{ - protocolID: s.host.protocolID + blockAnnounceID, - getHandshake: s.getBlockAnnounceHandshake, - handshakeData: new(sync.Map), + protocolID: s.host.protocolID + blockAnnounceID, + getHandshake: s.getBlockAnnounceHandshake, + handshakeValidator: s.validateBlockAnnounceHandshake, + inboundHandshakeData: new(sync.Map), + outboundHandshakeData: new(sync.Map), } decoder := createDecoder(info, decodeBlockAnnounceHandshake, decodeBlockAnnounceMessage) // haven't received handshake from peer testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") - info.handshakeData.Store(testPeerID, handshakeData{ + info.inboundHandshakeData.Store(testPeerID, handshakeData{ received: false, }) @@ -67,7 +73,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { enc, err := testHandshake.Encode() require.NoError(t, err) - msg, err := decoder(enc, testPeerID) + msg, err := decoder(enc, testPeerID, true) require.NoError(t, err) require.Equal(t, testHandshake, msg) @@ -83,10 +89,10 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { require.NoError(t, err) // set handshake data to received - hsData, _ := info.getHandshakeData(testPeerID) + hsData, _ := info.getHandshakeData(testPeerID, true) hsData.received = true - info.handshakeData.Store(testPeerID, hsData) - msg, err = decoder(enc, testPeerID) + info.inboundHandshakeData.Store(testPeerID, hsData) + msg, err = decoder(enc, testPeerID, true) require.NoError(t, err) require.Equal(t, testBlockAnnounce, msg) } @@ -133,14 +139,16 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) { // create info and handler info := ¬ificationsProtocol{ - protocolID: s.host.protocolID + blockAnnounceID, - getHandshake: s.getBlockAnnounceHandshake, - handshakeData: new(sync.Map), + protocolID: s.host.protocolID + blockAnnounceID, + getHandshake: s.getBlockAnnounceHandshake, + handshakeValidator: s.validateBlockAnnounceHandshake, + inboundHandshakeData: new(sync.Map), + outboundHandshakeData: new(sync.Map), } - handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage) + handler := s.createNotificationsMessageHandler(info, s.handleBlockAnnounceMessage) // set handshake data to received - info.handshakeData.Store(testPeerID, handshakeData{ + info.inboundHandshakeData.Store(testPeerID, handshakeData{ received: true, validated: true, }) @@ -165,11 +173,13 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) // create info and handler info := ¬ificationsProtocol{ - protocolID: s.host.protocolID + blockAnnounceID, - getHandshake: s.getBlockAnnounceHandshake, - handshakeData: new(sync.Map), + protocolID: s.host.protocolID + blockAnnounceID, + getHandshake: s.getBlockAnnounceHandshake, + handshakeValidator: s.validateBlockAnnounceHandshake, + inboundHandshakeData: new(sync.Map), + outboundHandshakeData: new(sync.Map), } - handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage) + handler := s.createNotificationsMessageHandler(info, s.handleBlockAnnounceMessage) configB := &Config{ BasePath: utils.NewTestBasePath(t, "nodeB"), @@ -208,7 +218,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) err = handler(stream, testHandshake) require.Equal(t, errCannotValidateHandshake, err) - data, has := info.getHandshakeData(testPeerID) + data, has := info.getHandshakeData(testPeerID, true) require.True(t, has) require.True(t, data.received) require.False(t, data.validated) @@ -221,9 +231,11 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) GenesisHash: s.blockState.GenesisHash(), } + info.inboundHandshakeData.Delete(testPeerID) + err = handler(stream, testHandshake) require.NoError(t, err) - data, has = info.getHandshakeData(testPeerID) + data, has = info.getHandshakeData(testPeerID, true) require.True(t, has) require.True(t, data.received) require.True(t, data.validated) diff --git a/dot/network/service.go b/dot/network/service.go index 747193bace..d048e237db 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -57,7 +57,7 @@ var ( type ( // messageDecoder is passed on readStream to decode the data from the stream into a message. // since messages are decoded based on context, this is different for every sub-protocol. - messageDecoder = func([]byte, peer.ID) (Message, error) + messageDecoder = func([]byte, peer.ID, bool) (Message, error) // messageHandler is passed on readStream to handle the resulting message. it should return an error only if the stream is to be closed messageHandler = func(stream libp2pnetwork.Stream, msg Message) error ) @@ -145,7 +145,6 @@ func NewService(cfg *Config) (*Service, error) { } network.syncQueue = newSyncQueue(network) - network.noGossip = true // TODO: remove once duplicate message sending is merged return network, err } @@ -373,35 +372,39 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, } np := ¬ificationsProtocol{ - protocolID: protocolID, - getHandshake: handshakeGetter, - handshakeData: new(sync.Map), + protocolID: protocolID, + getHandshake: handshakeGetter, + handshakeValidator: handshakeValidator, + inboundHandshakeData: new(sync.Map), + outboundHandshakeData: new(sync.Map), } s.notificationsProtocols[messageID] = np connMgr := s.host.h.ConnManager().(*ConnManager) connMgr.registerCloseHandler(protocolID, func(peerID peer.ID) { - np.mapMu.Lock() - defer np.mapMu.Unlock() + if _, ok := np.getHandshakeData(peerID, true); ok { + logger.Trace( + "Cleaning up inbound handshake data", + "peer", peerID, + "protocol", protocolID, + ) + np.inboundHandshakeData.Delete(peerID) + } - if _, ok := np.getHandshakeData(peerID); ok { + if _, ok := np.getHandshakeData(peerID, false); ok { logger.Trace( - "Cleaning up handshake data", + "Cleaning up outbound handshake data", "peer", peerID, "protocol", protocolID, ) - np.handshakeData.Delete(peerID) + np.outboundHandshakeData.Delete(peerID) } }) info := s.notificationsProtocols[messageID] decoder := createDecoder(info, handshakeDecoder, messageDecoder) - handlerWithValidate := s.createNotificationsMessageHandler(info, handshakeValidator, messageHandler) - streamHandler := func(stream libp2pnetwork.Stream, peerID peer.ID) { - s.readStream(stream, peerID, decoder, handlerWithValidate) - } - np.streamHandler = streamHandler + handlerWithValidate := s.createNotificationsMessageHandler(info, messageHandler) s.host.registerStreamHandlerWithOverwrite(sub, overwriteProtocol, func(stream libp2pnetwork.Stream) { logger.Trace("received stream", "sub-protocol", sub) @@ -411,8 +414,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, return } - p := conn.RemotePeer() - streamHandler(stream, p) + s.readStream(stream, decoder, handlerWithValidate) }) logger.Info("registered notifications sub-protocol", "protocol", protocolID) @@ -460,18 +462,10 @@ func (s *Service) SendMessage(msg NotificationsMessage) { // handleLightStream handles streams with the /light/2 protocol ID func (s *Service) handleLightStream(stream libp2pnetwork.Stream) { - conn := stream.Conn() - if conn == nil { - logger.Error("Failed to get connection from stream") - _ = stream.Close() - return - } - - peer := conn.RemotePeer() - s.readStream(stream, peer, s.decodeLightMessage, s.handleLightMsg) + s.readStream(stream, s.decodeLightMessage, s.handleLightMsg) } -func (s *Service) decodeLightMessage(in []byte, peer peer.ID) (Message, error) { +func (s *Service) decodeLightMessage(in []byte, peer peer.ID, _ bool) (Message, error) { s.lightRequestMu.RLock() defer s.lightRequestMu.RUnlock() @@ -489,10 +483,15 @@ func (s *Service) decodeLightMessage(in []byte, peer peer.ID) (Message, error) { return msg, err } -func (s *Service) readStream(stream libp2pnetwork.Stream, peer peer.ID, decoder messageDecoder, handler messageHandler) { +func isInbound(stream libp2pnetwork.Stream) bool { + return stream.Stat().Direction == libp2pnetwork.DirInbound +} + +func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) { var ( maxMessageSize uint64 = maxBlockResponseSize // TODO: determine actual max message size msgBytes = make([]byte, maxMessageSize) + peer = stream.Conn().RemotePeer() ) for { @@ -506,7 +505,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, peer peer.ID, decoder } // decode message based on message type - msg, err := decoder(msgBytes[:tot], peer) + msg, err := decoder(msgBytes[:tot], peer, isInbound(stream)) if err != nil { logger.Trace("failed to decode message from peer", "protocol", stream.Protocol(), "err", err) continue @@ -597,7 +596,7 @@ func (s *Service) Peers() []common.PeerInfo { s.notificationsMu.RUnlock() for _, p := range s.host.peers() { - data, has := np.getHandshakeData(p) + data, has := np.getHandshakeData(p, true) if !has || data.handshake == nil { peers = append(peers, common.PeerInfo{ PeerID: p.String(), diff --git a/dot/network/service_test.go b/dot/network/service_test.go index 31e88f37a1..20a23e1a2a 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -210,7 +210,7 @@ func TestBroadcastDuplicateMessage(t *testing.T) { require.NotNil(t, stream) protocol := nodeA.notificationsProtocols[BlockAnnounceMsgType] - protocol.handshakeData.Store(nodeB.host.id(), handshakeData{ + protocol.outboundHandshakeData.Store(nodeB.host.id(), handshakeData{ received: true, validated: true, stream: stream, diff --git a/dot/network/sync.go b/dot/network/sync.go index d50f1f5614..350843e061 100644 --- a/dot/network/sync.go +++ b/dot/network/sync.go @@ -42,18 +42,10 @@ func (s *Service) handleSyncStream(stream libp2pnetwork.Stream) { return } - conn := stream.Conn() - if conn == nil { - logger.Error("Failed to get connection from stream") - _ = stream.Close() - return - } - - peer := conn.RemotePeer() - s.readStream(stream, peer, s.decodeSyncMessage, s.handleSyncMessage) + s.readStream(stream, s.decodeSyncMessage, s.handleSyncMessage) } -func (s *Service) decodeSyncMessage(in []byte, peer peer.ID) (Message, error) { +func (s *Service) decodeSyncMessage(in []byte, peer peer.ID, inbound bool) (Message, error) { msg := new(BlockRequestMessage) err := msg.Decode(in) return msg, err @@ -180,6 +172,7 @@ func (q *syncQueue) syncAtHead() { } q.s.syncer.SetSyncing(true) + q.s.noGossip = true // don't gossip messages until we're at the head for { select { @@ -197,10 +190,12 @@ func (q *syncQueue) syncAtHead() { // we aren't at the head yet, sleep if curr.Number.Int64() < q.goal && curr.Number.Cmp(prev.Number) > 0 { prev = curr + q.s.noGossip = true continue } q.s.syncer.SetSyncing(false) + q.s.noGossip = false // we have received new blocks since the last check, sleep if prev.Number.Int64() < curr.Number.Int64() { diff --git a/dot/network/sync_test.go b/dot/network/sync_test.go index fba9b11e05..812ec04ca0 100644 --- a/dot/network/sync_test.go +++ b/dot/network/sync_test.go @@ -68,7 +68,7 @@ func TestDecodeSyncMessage(t *testing.T) { reqEnc, err := testBlockRequestMessage.Encode() require.NoError(t, err) - msg, err := s.decodeSyncMessage(reqEnc, testPeer) + msg, err := s.decodeSyncMessage(reqEnc, testPeer, true) require.NoError(t, err) req, ok := msg.(*BlockRequestMessage) diff --git a/dot/network/test_helpers.go b/dot/network/test_helpers.go index d185ab7dd8..65d1855997 100644 --- a/dot/network/test_helpers.go +++ b/dot/network/test_helpers.go @@ -133,7 +133,7 @@ func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream, peer peer.ID } // decode message based on message type - msg, err := decoder(msgBytes[:tot], peer) + msg, err := decoder(msgBytes[:tot], peer, isInbound(stream)) if err != nil { logger.Error("Failed to decode message from peer", "peer", peer, "err", err) continue @@ -159,7 +159,7 @@ var testBlockRequestMessage = &BlockRequestMessage{ Max: optional.NewUint32(true, 1), } -func testBlockRequestMessageDecoder(in []byte, _ peer.ID) (Message, error) { +func testBlockRequestMessageDecoder(in []byte, _ peer.ID, _ bool) (Message, error) { msg := new(BlockRequestMessage) err := msg.Decode(in) return msg, err @@ -173,13 +173,13 @@ var testBlockAnnounceHandshake = &BlockAnnounceHandshake{ BestBlockNumber: 0, } -func testBlockAnnounceMessageDecoder(in []byte, _ peer.ID) (Message, error) { +func testBlockAnnounceMessageDecoder(in []byte, _ peer.ID, _ bool) (Message, error) { msg := new(BlockAnnounceMessage) err := msg.Decode(in) return msg, err } -func testBlockAnnounceHandshakeDecoder(in []byte, _ peer.ID) (Message, error) { +func testBlockAnnounceHandshakeDecoder(in []byte, _ peer.ID, _ bool) (Message, error) { msg := new(BlockAnnounceHandshake) err := msg.Decode(in) return msg, err