Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dot/network): implement a handshake timeout #1615

Merged
merged 14 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 62 additions & 25 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package network
import (
"errors"
"sync"
"time"
"unsafe"

libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
Expand All @@ -29,6 +30,7 @@ import (
var errCannotValidateHandshake = errors.New("failed to validate handshake")

const maxHandshakeSize = unsafe.Sizeof(BlockAnnounceHandshake{}) //nolint
const handshakeTimeout = time.Second * 10

// Handshake is the interface all handshakes for notifications protocols must implement
type Handshake interface {
Expand All @@ -53,6 +55,11 @@ type (
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) (propagate bool, err error)
)

type handshakeReader struct {
hs Handshake
err error
}

type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
Expand All @@ -63,16 +70,17 @@ type notificationsProtocol struct {
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID, inbound bool) (handshakeData, bool) {
if inbound {
data, has := n.inboundHandshakeData.Load(pid)
if !has {
return handshakeData{}, false
}
var (
data interface{}
has bool
)

return data.(handshakeData), true
if inbound {
data, has = n.inboundHandshakeData.Load(pid)
} else {
data, has = n.outboundHandshakeData.Load(pid)
}

data, has := n.outboundHandshakeData.Load(pid)
if !has {
return handshakeData{}, false
}
Expand Down Expand Up @@ -174,7 +182,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return nil
}

logger.Debug("received message on notifications sub-protocol", "protocol", info.protocolID,
logger.Trace("received message on notifications sub-protocol", "protocol", info.protocolID,
"message", msg,
"peer", stream.Conn().RemotePeer(),
)
Expand Down Expand Up @@ -226,14 +234,32 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc
return
}

hs, err := s.readHandshake(stream, decodeBlockAnnounceHandshake)
if err != nil {
logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err)
hsTicker := time.NewTicker(handshakeTimeout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User timer instead of ticker.

time.NewTimer(time.Second)

Timers are for when you want to do something once in the future - tickers are for when you want to do something repeatedly at regular intervals

var hs Handshake
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also do

defer hsTimer.Stop()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


select {
case <-hsTicker.C:
hsTicker.Stop()

logger.Warn("handshake timeout reached", "protocol", info.protocolID, "peer", peer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to trace

_ = stream.Close()
info.outboundHandshakeData.Delete(peer)
return
}

hsData.received = true
case hsResponse := <-s.readHandshake(stream, decodeBlockAnnounceHandshake):
hsTicker.Stop()

if hsResponse.err != nil {
logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err)
_ = stream.Close()

info.outboundHandshakeData.Delete(peer)
return
}

hs = hsResponse.hs
hsData.received = true
}

err = info.handshakeValidator(peer, hs)
if err != nil {
Expand Down Expand Up @@ -294,19 +320,30 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
}
}

func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) (Handshake, error) {
msgBytes := s.bufPool.get()
defer s.bufPool.put(&msgBytes)
func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) <-chan *handshakeReader {
hsC := make(chan *handshakeReader, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why is this buffered channel? The code will still work if it's an unbuffered channel.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the readHandshake function return the bytes of the handshake response at once I just fixed this in the channel capacity, but yes it works with an unbuffered channel (a simple example with an unbuffered channel: https://play.golang.org/p/Y19oZxihFxD)


tot, err := readStream(stream, msgBytes[:])
if err != nil {
return nil, err
}
go func() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we close hsC on errors?

msgBytes := s.bufPool.get()
defer func() {
s.bufPool.put(&msgBytes)
close(hsC)
}()

hs, err := decoder(msgBytes[:tot])
if err != nil {
return nil, err
}
tot, err := readStream(stream, msgBytes[:])
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

hs, err := decoder(msgBytes[:tot])
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

hsC <- &handshakeReader{hs: hs, err: nil}
}()

return hs, nil
return hsC
}
82 changes: 82 additions & 0 deletions dot/network/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package network

import (
"context"
"fmt"
"math/big"
"sync"
"testing"
Expand All @@ -25,7 +27,10 @@ import (
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/utils"
ma "github.com/multiformats/go-multiaddr"

"github.com/libp2p/go-libp2p"
libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -240,3 +245,80 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
require.True(t, data.received)
require.True(t, data.validated)
}

func Test_HandshakeTimeout(t *testing.T) {
// create service A
config := &Config{
BasePath: utils.NewTestBasePath(t, "nodeA"),
Port: 7001,
RandSeed: 1,
NoBootstrap: true,
NoMDNS: true,
}
ha := createTestService(t, config)

// create info and handler
info := &notificationsProtocol{
protocolID: ha.host.protocolID + blockAnnounceID,
getHandshake: ha.getBlockAnnounceHandshake,
handshakeValidator: ha.validateBlockAnnounceHandshake,
inboundHandshakeData: new(sync.Map),
outboundHandshakeData: new(sync.Map),
}

// creating host b with will never respond to a handshake
addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", 7002))
require.NoError(t, err)

hb, err := libp2p.New(
context.Background(), libp2p.ListenAddrs(addrB),
)
require.NoError(t, err)

testHandshakeMsg := &BlockAnnounceHandshake{
Roles: 4,
BestBlockNumber: 77,
BestBlockHash: common.Hash{1},
GenesisHash: common.Hash{2},
}

hb.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) {
fmt.Println("never respond a handshake message")
})

addrBInfo := peer.AddrInfo{
ID: hb.ID(),
Addrs: hb.Addrs(),
}

err = ha.host.connect(addrBInfo)
if failedToDial(err) {
time.Sleep(TestBackoffTimeout)
err = ha.host.connect(addrBInfo)
}
require.NoError(t, err)

go ha.sendData(hb.ID(), testHandshakeMsg, info, nil)

time.Sleep(handshakeTimeout / 2)
// peer should be stored in handshake data until timeout
_, ok := info.outboundHandshakeData.Load(hb.ID())
require.True(t, ok)

// a stream should be open until timeout
connAToB := ha.host.h.Network().ConnsToPeer(hb.ID())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 1)

// after the timeout
time.Sleep(handshakeTimeout)

// handshake data should be removed
_, ok = info.outboundHandshakeData.Load(hb.ID())
require.False(t, ok)

// stream should be closed
connAToB = ha.host.h.Network().ConnsToPeer(hb.ID())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 0)
}
4 changes: 2 additions & 2 deletions dot/network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
if err == io.EOF {
return 0, err
} else if err != nil {
return 0, err // TODO: return bytes read from readLEB128ToUint64
return int(length), err
}

if length == 0 {
Expand Down Expand Up @@ -220,7 +220,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
}

if tot != int(length) {
return tot, fmt.Errorf("failed to read entire message: expected %d bytes", length)
return tot, fmt.Errorf("failed to read entire message: expected %d bytes, received %d bytes", length, tot)
}

return tot, nil
Expand Down