Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dot/network): implement a handshake timeout #1615

Merged
merged 14 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package network
import (
"errors"
"sync"
"time"
"unsafe"

libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
Expand All @@ -29,6 +30,7 @@ import (
var errCannotValidateHandshake = errors.New("failed to validate handshake")

const maxHandshakeSize = unsafe.Sizeof(BlockAnnounceHandshake{}) //nolint
const handshakeTimeout = time.Second * 10

// Handshake is the interface all handshakes for notifications protocols must implement
type Handshake interface {
Expand All @@ -53,6 +55,11 @@ type (
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) (propagate bool, err error)
)

type handshakeReader struct {
hs Handshake
err error
}

type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
Expand All @@ -63,16 +70,17 @@ type notificationsProtocol struct {
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID, inbound bool) (handshakeData, bool) {
if inbound {
data, has := n.inboundHandshakeData.Load(pid)
if !has {
return handshakeData{}, false
}
var (
data interface{}
has bool
)

return data.(handshakeData), true
if inbound {
data, has = n.inboundHandshakeData.Load(pid)
} else {
data, has = n.outboundHandshakeData.Load(pid)
}

data, has := n.outboundHandshakeData.Load(pid)
if !has {
return handshakeData{}, false
}
Expand Down Expand Up @@ -174,7 +182,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return nil
}

logger.Debug("received message on notifications sub-protocol", "protocol", info.protocolID,
logger.Trace("received message on notifications sub-protocol", "protocol", info.protocolID,
"message", msg,
"peer", stream.Conn().RemotePeer(),
)
Expand Down Expand Up @@ -226,14 +234,29 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc
return
}

hs, err := s.readHandshake(stream, decodeBlockAnnounceHandshake)
if err != nil {
logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err)
hsTimer := time.NewTimer(handshakeTimeout)

var hs Handshake
select {
case <-hsTimer.C:
logger.Trace("handshake timeout reached", "protocol", info.protocolID, "peer", peer)
_ = stream.Close()
info.outboundHandshakeData.Delete(peer)
return
}

hsData.received = true
case hsResponse := <-s.readHandshake(stream, decodeBlockAnnounceHandshake):
hsTimer.Stop()
if hsResponse.err != nil {
logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err)
_ = stream.Close()

info.outboundHandshakeData.Delete(peer)
return
}

hs = hsResponse.hs
hsData.received = true
}

err = info.handshakeValidator(peer, hs)
if err != nil {
Expand Down Expand Up @@ -294,19 +317,30 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
}
}

func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) (Handshake, error) {
msgBytes := s.bufPool.get()
defer s.bufPool.put(&msgBytes)
func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) <-chan *handshakeReader {
hsC := make(chan *handshakeReader)

tot, err := readStream(stream, msgBytes[:])
if err != nil {
return nil, err
}
go func() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we close hsC on errors?

msgBytes := s.bufPool.get()
defer func() {
s.bufPool.put(&msgBytes)
close(hsC)
}()

hs, err := decoder(msgBytes[:tot])
if err != nil {
return nil, err
}
tot, err := readStream(stream, msgBytes[:])
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

hs, err := decoder(msgBytes[:tot])
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

hsC <- &handshakeReader{hs: hs, err: nil}
}()

return hs, nil
return hsC
}
82 changes: 82 additions & 0 deletions dot/network/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package network

import (
"context"
"fmt"
"math/big"
"sync"
"testing"
Expand All @@ -25,7 +27,10 @@ import (
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/utils"
ma "github.com/multiformats/go-multiaddr"

"github.com/libp2p/go-libp2p"
libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -240,3 +245,80 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
require.True(t, data.received)
require.True(t, data.validated)
}

func Test_HandshakeTimeout(t *testing.T) {
// create service A
config := &Config{
BasePath: utils.NewTestBasePath(t, "nodeA"),
Port: 7001,
RandSeed: 1,
NoBootstrap: true,
NoMDNS: true,
}
ha := createTestService(t, config)

// create info and handler
info := &notificationsProtocol{
protocolID: ha.host.protocolID + blockAnnounceID,
getHandshake: ha.getBlockAnnounceHandshake,
handshakeValidator: ha.validateBlockAnnounceHandshake,
inboundHandshakeData: new(sync.Map),
outboundHandshakeData: new(sync.Map),
}

// creating host b with will never respond to a handshake
addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", 7002))
require.NoError(t, err)

hb, err := libp2p.New(
context.Background(), libp2p.ListenAddrs(addrB),
)
require.NoError(t, err)

testHandshakeMsg := &BlockAnnounceHandshake{
Roles: 4,
BestBlockNumber: 77,
BestBlockHash: common.Hash{1},
GenesisHash: common.Hash{2},
}

hb.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) {
fmt.Println("never respond a handshake message")
})

addrBInfo := peer.AddrInfo{
ID: hb.ID(),
Addrs: hb.Addrs(),
}

err = ha.host.connect(addrBInfo)
if failedToDial(err) {
time.Sleep(TestBackoffTimeout)
err = ha.host.connect(addrBInfo)
}
require.NoError(t, err)

go ha.sendData(hb.ID(), testHandshakeMsg, info, nil)

time.Sleep(handshakeTimeout / 2)
// peer should be stored in handshake data until timeout
_, ok := info.outboundHandshakeData.Load(hb.ID())
require.True(t, ok)

// a stream should be open until timeout
connAToB := ha.host.h.Network().ConnsToPeer(hb.ID())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 1)

// after the timeout
time.Sleep(handshakeTimeout)

// handshake data should be removed
_, ok = info.outboundHandshakeData.Load(hb.ID())
require.False(t, ok)

// stream should be closed
connAToB = ha.host.h.Network().ConnsToPeer(hb.ID())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 0)
}
2 changes: 1 addition & 1 deletion dot/network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
}

if tot != int(length) {
return tot, fmt.Errorf("failed to read entire message: expected %d bytes", length)
return tot, fmt.Errorf("failed to read entire message: expected %d bytes, received %d bytes", length, tot)
}

return tot, nil
Expand Down