diff --git a/dkg/dkg.go b/dkg/dkg.go index 1bc31bff9..73d1416fd 100644 --- a/dkg/dkg.go +++ b/dkg/dkg.go @@ -17,22 +17,24 @@ package dkg import ( "context" + "crypto/ecdsa" crand "crypto/rand" "encoding/base64" "fmt" - "time" "github.com/coinbase/kryptology/pkg/signatures/bls/bls_sig" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + libp2pcrypto "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/obolnetwork/charon/app/errors" "github.com/obolnetwork/charon/app/log" "github.com/obolnetwork/charon/app/z" "github.com/obolnetwork/charon/cluster" "github.com/obolnetwork/charon/core" + "github.com/obolnetwork/charon/dkg/sync" "github.com/obolnetwork/charon/eth2util/deposit" "github.com/obolnetwork/charon/p2p" "github.com/obolnetwork/charon/tbls" @@ -84,7 +86,12 @@ func Run(ctx context.Context, conf Config) (err error) { return err } - tcpNode, shutdown, err := setupP2P(ctx, conf.DataDir, conf.P2P, peers) + key, err := p2p.LoadPrivKey(conf.DataDir) + if err != nil { + return err + } + + tcpNode, shutdown, err := setupP2P(ctx, key, conf.P2P, peers) if err != nil { return err } @@ -108,6 +115,26 @@ func Run(ctx context.Context, conf Config) (err error) { ex := newExchanger(tcpNode, nodeIdx.PeerIdx, peerIds, def.NumValidators) + // Register Frost libp2p handlers + peerMap := make(map[uint32]peer.ID) + for _, p := range peers { + nodeIdx, err := def.NodeIdx(p.ID) + if err != nil { + return err + } + peerMap[uint32(nodeIdx.ShareIdx)] = p.ID + } + tp := newFrostP2P(ctx, tcpNode, peerMap, clusterID) + + log.Info(ctx, "Connecting to peers...", z.Str("definition_hash", clusterID)) + + stopSync, err := startSyncProtocol(ctx, tcpNode, key, defHash, peerIds, cancel) + if err != nil { + return err + } + + log.Info(ctx, "Starting DKG ceremony") + var shares []share switch def.DKGAlgorithm { case "default", "keycast": @@ -122,28 +149,6 @@ func Run(ctx context.Context, conf Config) (err error) { return err } case "frost": - // Construct peer map - peerMap := make(map[uint32]peer.ID) - for _, p := range peers { - nodeIdx, err := def.NodeIdx(p.ID) - if err != nil { - return err - } - peerMap[uint32(nodeIdx.ShareIdx)] = p.ID - } - - tp := newFrostP2P(ctx, tcpNode, peerMap, clusterID) - - log.Info(ctx, "Connecting to peers...", z.Str("definition_hash", clusterID)) - - ctx, cancel, err = waitPeers(ctx, tcpNode, peers) - if err != nil { - return err - } - defer cancel() - - log.Info(ctx, "Starting Frost DKG ceremony") - shares, err = runFrostParallel(ctx, tp, uint32(def.NumValidators), uint32(len(peerMap)), uint32(def.Threshold), uint32(nodeIdx.ShareIdx), clusterID) if err != nil { @@ -167,6 +172,10 @@ func Run(ctx context.Context, conf Config) (err error) { } log.Debug(ctx, "Aggregated deposit data signatures") + if err = stopSync(ctx); err != nil { + return errors.Wrap(err, "stop sync") + } + // Write keystores, deposit data and cluster lock files after exchange of partial signatures in order // to prevent partial data writes in case of peer connection lost @@ -191,12 +200,7 @@ func Run(ctx context.Context, conf Config) (err error) { } // setupP2P returns a started libp2p tcp node and a shutdown function. -func setupP2P(ctx context.Context, datadir string, p2pConf p2p.Config, peers []p2p.Peer) (host.Host, func(), error) { - key, err := p2p.LoadPrivKey(datadir) - if err != nil { - return nil, nil, err - } - +func setupP2P(ctx context.Context, key *ecdsa.PrivateKey, p2pConf p2p.Config, peers []p2p.Peer) (host.Host, func(), error) { localEnode, db, err := p2p.NewLocalEnode(p2pConf, key) if err != nil { return nil, nil, errors.Wrap(err, "failed to open enode") @@ -231,9 +235,6 @@ func setupP2P(ctx context.Context, datadir string, p2pConf p2p.Config, peers []p }(relay) } - // Register ping service handler - _ = ping.NewPingService(tcpNode) - return tcpNode, func() { db.Close() udpNode.Close() @@ -241,6 +242,69 @@ func setupP2P(ctx context.Context, datadir string, p2pConf p2p.Config, peers []p }, nil } +// startSyncProtocol sets up a sync protocol server and clients for each peer and returns a shutdown function +// when all peers are connected. +func startSyncProtocol(ctx context.Context, tcpNode host.Host, key *ecdsa.PrivateKey, defHash [32]byte, peerIDs []peer.ID, + onFailure func(), +) (func(context.Context) error, error) { + // Sign definition hash with charon-enr-private-key + priv, err := libp2pcrypto.UnmarshalSecp256k1PrivateKey(crypto.FromECDSA(key)) + if err != nil { + return nil, errors.Wrap(err, "convert key") + } + + hashSig, err := priv.Sign(defHash[:]) + if err != nil { + return nil, errors.Wrap(err, "sign definition hash") + } + + server := sync.NewServer(tcpNode, len(peerIDs)-1, defHash[:]) + server.Start(ctx) + + var clients []*sync.Client + for _, pID := range peerIDs { + if tcpNode.ID() == pID { + continue + } + + ctx := log.WithCtx(ctx, z.Str("peer", p2p.PeerName(pID))) + client := sync.NewClient(tcpNode, pID, hashSig) + clients = append(clients, client) + + go func() { + err := client.Run(ctx) + if err != nil { + log.Error(ctx, "Sync failed to peer", err) + onFailure() + } + }() + } + + for _, client := range clients { + err := client.AwaitConnected(ctx) + if err != nil { + return nil, err + } + } + + err = server.AwaitAllConnected(ctx) + if err != nil { + return nil, err + } + + // Shutdown function stops all clients and server + return func(ctx context.Context) error { + for _, client := range clients { + err := client.Shutdown(ctx) + if err != nil { + return err + } + } + + return server.AwaitAllShutdown(ctx) + }, nil +} + // signAndAggLockHash returns cluster lock file with aggregated signature after signing, exchange and aggregation of partial signatures. func signAndAggLockHash(ctx context.Context, shares []share, def cluster.Definition, nodeIdx cluster.NodeIdx, ex *exchanger) (cluster.Lock, error) { dvs, err := dvsFromShares(shares) @@ -504,93 +568,6 @@ func dvsFromShares(shares []share) ([]cluster.DistValidator, error) { return dvs, nil } -// waitPeers blocks until all peers are connected and returns a context that is cancelled when -// any connection is lost afterwards or when the parent context is cancelled. -func waitPeers(ctx context.Context, tcpNode host.Host, peers []p2p.Peer) (context.Context, context.CancelFunc, error) { - ctx, cancel := context.WithCancel(ctx) - - type tuple struct { - Peer peer.ID - RTT time.Duration - } - - var ( - tuples = make(chan tuple, len(peers)) - total int - ) - for _, p := range peers { - if tcpNode.ID() == p.ID { - continue // Do not connect to self. - } - total++ - go func(pID peer.ID) { - for { - results, rtt, ok := waitConnect(ctx, tcpNode, pID) - if ctx.Err() != nil { - return - } else if !ok { - continue - } - - // We are connected - tuples <- tuple{Peer: pID, RTT: rtt} - - // Wait for disconnect and cancel the context. - var err error - for result := range results { - if result.Error != nil { - err = result.Error - break - } - } - - if ctx.Err() == nil { - log.Error(ctx, "Peer connection lost", err, z.Str("peer", p2p.PeerName(pID))) - cancel() - } - - return - } - }(p.ID) - } - - var i int - for { - select { - case <-ctx.Done(): - return ctx, cancel, ctx.Err() - case <-time.After(time.Second * 30): - log.Info(ctx, fmt.Sprintf("Connected to %d of %d peers", i, total)) - case tuple := <-tuples: - i++ - log.Info(ctx, fmt.Sprintf("Connected to peer %d of %d", i, total), - z.Str("peer", p2p.PeerName(tuple.Peer)), - z.Str("rtt", tuple.RTT.String()), - ) - if i == total { - return ctx, cancel, nil - } - } - } -} - -// waitConnect blocks until a libp2p connection (ping) is established returning the ping result chan, with the peer or the context is cancelled. -func waitConnect(ctx context.Context, tcpNode host.Host, p peer.ID) (<-chan ping.Result, time.Duration, bool) { - resp := ping.Ping(ctx, tcpNode, p) - for result := range resp { - if result.Error == nil { - return resp, result.RTT, true - } else if ctx.Err() != nil { - return nil, 0, false - } - - log.Debug(ctx, "Failed connecting to peer (will retry)", z.Str("peer", p2p.PeerName(p)), z.Err(result.Error)) - time.Sleep(time.Second * 5) // TODO(corver): Improve backoff. - } - - return nil, 0, false -} - func forkVersionToNetwork(forkVersion string) (string, error) { switch forkVersion { case "0x00001020": diff --git a/dkg/sync/client.go b/dkg/sync/client.go index f37390128..bfc80db36 100644 --- a/dkg/sync/client.go +++ b/dkg/sync/client.go @@ -16,13 +16,13 @@ package sync import ( - "bufio" "context" + "sync" "time" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" - "google.golang.org/protobuf/proto" + "github.com/libp2p/go-libp2p-core/peer" "google.golang.org/protobuf/types/known/timestamppb" "github.com/obolnetwork/charon/app/errors" @@ -32,168 +32,193 @@ import ( "github.com/obolnetwork/charon/p2p" ) -type result struct { - rtt time.Duration - timestamp string - shutdown bool - error error +// NewClient returns a new Client instance. +func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte) *Client { + return &Client{ + tcpNode: tcpNode, + peer: peer, + hashSig: hashSig, + shutdown: make(chan struct{}), + done: make(chan struct{}), + } } +// Client is the client side of the sync protocol. It retries establishing a connection to a sync server, +// it sends period pings (including definition hash signatures), +// supports reestablishing on relay circuit recycling, and supports soft shutdown. type Client struct { - ctx context.Context - onFailure func() + mu sync.Mutex + connected bool + shutdown chan struct{} + done chan struct{} + hashSig []byte tcpNode host.Host - server p2p.Peer - results chan result - stream network.Stream + peer peer.ID } -// AwaitConnected blocks until the connection with the server has been established or returns an error. -func (c *Client) AwaitConnected() error { - for res := range c.results { - if errors.Is(res.error, errors.New(InvalidSig)) { - return errors.New("invalid cluster definition") - } else if res.error == nil { - // We are connected - break - } - } +// Run blocks while running the client-side sync protocol. It returns an error if the context is closed +// or if an established connection is dropped. It returns nil after successful Shutdown. +func (c *Client) Run(ctx context.Context) error { + defer close(c.done) + defer c.clearConnected() - log.Info(c.ctx, "Client connected to Server 🎉", z.Any("client", p2p.PeerName(c.tcpNode.ID()))) + ctx = log.WithCtx(ctx, z.Str("peer", p2p.PeerName(c.peer))) - return nil -} + for { + retry := !c.connected // Retry connecting if never connected. -// Shutdown sends a shutdown message to the server indicating it has successfully completed. -// It closes the connection and returns after receiving the subsequent MsgSyncResponse. -// It may only be called after AwaitConnected. -func (c *Client) Shutdown() error { - msg := &pb.MsgSync{ - Timestamp: timestamppb.Now(), - Shutdown: true, - } + stream, err := c.connect(ctx, retry) + if err != nil { + return err + } - _, err := c.send(msg) - if err != nil { - return err - } + log.Info(ctx, "Connected to peer (outbound)") + c.setConnected() - log.Info(c.ctx, "Closing stream with peer", z.Any("peer", p2p.PeerName(c.server.ID))) + reconnect, err := c.sendMsgs(ctx, stream) + if err != nil { + return err + } else if reconnect { + log.Debug(ctx, "Outgoing sync dropped, reconnecting") + continue + } - return c.stream.Close() + return nil + } } -// sendHashSignature sends MsgSync with signature of definition to server and receives response from server. -func (c *Client) sendHashSignature(hashSig []byte) result { - before := time.Now() - msg := &pb.MsgSync{ - Timestamp: timestamppb.Now(), - HashSignature: hashSig, - Shutdown: false, +// AwaitConnected blocks until the connection with the server has been established or returns a context error. +func (c *Client) AwaitConnected(ctx context.Context) error { + timer := time.NewTicker(time.Millisecond) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + if c.isConnected() { + return nil + } + } } +} - resp, err := c.send(msg) - if err != nil { - return result{error: err} - } +// Shutdown triggers the Run goroutine to shut down gracefully and returns nil after it has returned. +// It should be called after AwaitConnected and may only be called once. +func (c *Client) Shutdown(ctx context.Context) error { + close(c.shutdown) - return result{ - rtt: time.Since(before), - timestamp: resp.SyncTimestamp.String(), + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + return nil } } -func (c *Client) send(msg *pb.MsgSync) (*pb.MsgSyncResponse, error) { - wb, err := proto.Marshal(msg) - if err != nil { - return nil, errors.Wrap(err, "marshal msg") - } +// clearConnected sets the shared connected state. +func (c *Client) setConnected() { + c.mu.Lock() + defer c.mu.Unlock() - if _, err = c.stream.Write(wb); err != nil { - return nil, errors.Wrap(err, "write msg to stream") - } + c.connected = true +} - buf := bufio.NewReader(c.stream) - rb := make([]byte, MsgSize) - // n is the number of bytes read from buffer, if n < MsgSize the other bytes will be 0 - n, err := buf.Read(rb) - if err != nil { - return nil, errors.Wrap(err, "read server response") - } +// clearConnected clears the shared connected state. +func (c *Client) clearConnected() { + c.mu.Lock() + defer c.mu.Unlock() - // The first `n` bytes that are read are the most important - rb = rb[:n] + c.connected = false +} - resp := new(pb.MsgSyncResponse) - if err = proto.Unmarshal(rb, resp); err != nil { - return nil, errors.Wrap(err, "unmarshal server response") - } else if resp.Error != "" { - return nil, errors.New(resp.Error) - } +// isConnected returns the shared connected state. +func (c *Client) isConnected() bool { + c.mu.Lock() + defer c.mu.Unlock() - return resp, nil + return c.connected } -// NewClient starts a goroutine that establishes a long lived connection to a p2p server and returns a new Client instance. -// TODO(dhruv): call onFailure on permanent failure. -func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig []byte, onFailure func()) *Client { - s, err := tcpNode.NewStream(ctx, server.ID, syncProtoID) - if err != nil { - log.Error(ctx, "Open new stream with server", err) - ch := make(chan result, 1) - ch <- result{error: err} - close(ch) - - return &Client{ - ctx: ctx, - onFailure: onFailure, - tcpNode: tcpNode, - server: server, - results: ch, +// sendMsgs sends period sync protocol messages on the stream until error or shutdown. +func (c *Client) sendMsgs(ctx context.Context, stream network.Stream) (bool, error) { + timer := time.NewTicker(time.Second) + defer timer.Stop() + + first := make(chan struct{}, 1) + first <- struct{}{} + + var shutdown bool + + for { + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-c.shutdown: + shutdown = true + case <-first: + case <-timer.C: } + + resp, err := c.sendMsg(stream, shutdown) + if isRelayError(err) { + return true, nil // Reconnect on relay errors + } else if err != nil { + return false, err + } else if shutdown { + return false, nil + } else if resp.Error == errInvalidSig { + return false, errors.New("mismatching cluster definition hash with peer") + } else if resp.Error != "" { + return false, errors.New("peer responded with error", z.Str("error_message", resp.Error)) + } + + rtt := time.Since(resp.SyncTimestamp.AsTime()) + c.tcpNode.Peerstore().RecordLatency(c.peer, rtt) } +} - ctx, cancel := context.WithCancel(ctx) - out := make(chan result) +// sendMsg sends a sync message and returns the response. +func (c *Client) sendMsg(stream network.Stream, shutdown bool) (*pb.MsgSyncResponse, error) { + msg := &pb.MsgSync{ + Timestamp: timestamppb.Now(), + HashSignature: c.hashSig, + Shutdown: shutdown, + } - client := &Client{ - ctx: ctx, - onFailure: onFailure, - tcpNode: tcpNode, - server: server, - results: out, - stream: s, + if err := writeSizedProto(stream, msg); err != nil { + return nil, err } - go func() { - defer close(out) - defer cancel() + resp := new(pb.MsgSyncResponse) + if err := readSizedProto(stream, resp); err != nil { + return nil, err + } - for ctx.Err() == nil { - res := client.sendHashSignature(hashSig) + return resp, nil +} - if ctx.Err() != nil { - return +// connect returns an opened libp2p stream/connection, it will retry if instructed. +func (c *Client) connect(ctx context.Context, retry bool) (network.Stream, error) { + for { + s, err := c.tcpNode.NewStream(network.WithUseTransient(ctx, "sync"), c.peer, protocolID) + if ctx.Err() != nil { + return nil, ctx.Err() + } else if err != nil { + if retry { + continue } - if res.error == nil { - tcpNode.Peerstore().RecordLatency(server.ID, res.rtt) - } + return nil, errors.Wrap(err, "open connection") + } - log.Debug(ctx, "Server response", z.Any("response", res.timestamp)) + return s, nil + } +} - select { - case out <- res: - case <-ctx.Done(): - return - } - } - }() - go func() { - <-ctx.Done() - //nolint:errcheck - s.Reset() - }() - - return client +// isRelayError returns true if the error is due to temporary relay circuit recycling. +func isRelayError(_ error) bool { + // TODO(corver): Detect circuit relay connection errors + return false } diff --git a/dkg/sync/server.go b/dkg/sync/server.go index fb655fe31..bd3ef6136 100644 --- a/dkg/sync/server.go +++ b/dkg/sync/server.go @@ -18,8 +18,9 @@ package sync import ( - "bufio" "context" + "encoding/binary" + "io" "sync" "time" @@ -28,6 +29,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "google.golang.org/protobuf/proto" + "github.com/obolnetwork/charon/app/errors" "github.com/obolnetwork/charon/app/log" "github.com/obolnetwork/charon/app/z" pb "github.com/obolnetwork/charon/dkg/dkgpb/v1" @@ -35,184 +37,222 @@ import ( ) const ( - syncProtoID = "dkg_sync_v1.0" - MsgSize = 128 - InvalidSig = "Invalid Signature" + protocolID = "/charon/dkg/sync/1.0.0/" + errInvalidSig = "invalid signature" ) +// NewServer returns a new Server instance. +func NewServer(tcpNode host.Host, allCount int, defHash []byte) *Server { + return &Server{ + defHash: defHash, + tcpNode: tcpNode, + allCount: allCount, + shutdown: make(map[peer.ID]struct{}), + connected: make(map[peer.ID]struct{}), + } +} + +// Server implements the server side of the sync protocol. It accepts connections from clients, verifies +// definition hash signatures, and supports waiting for shutdown by all clients. type Server struct { - mu sync.Mutex - ctx context.Context - onFailure func() - tcpNode host.Host - peers []p2p.Peer - dedupResponse map[peer.ID]bool - receiveChan chan result - receiveShutdown chan result + mu sync.Mutex + shutdown map[peer.ID]struct{} + connected map[peer.ID]struct{} + defHash []byte + allCount int // Excluding self + tcpNode host.Host } // AwaitAllConnected blocks until all peers have established a connection with this server or returns an error. -func (s *Server) AwaitAllConnected() error { - var msgs []result - for len(msgs) < len(s.peers) { +func (s *Server) AwaitAllConnected(ctx context.Context) error { + timer := time.NewTicker(time.Millisecond) + defer timer.Stop() + + for { select { - case <-s.ctx.Done(): - return s.ctx.Err() - case msg := <-s.receiveChan: - msgs = append(msgs, msg) + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + if s.isAllConnected() { + return nil + } } } - - log.Info(s.ctx, "All Clients Connected 🎉", z.Any("clients", len(msgs))) - - return nil } // AwaitAllShutdown blocks until all peers have successfully shutdown or returns an error. // It may only be called after AwaitAllConnected. -func (s *Server) AwaitAllShutdown() error { - var msgs []result - for len(msgs) < len(s.peers) { +func (s *Server) AwaitAllShutdown(ctx context.Context) error { + timer := time.NewTicker(time.Millisecond) + defer timer.Stop() + + for { select { - case <-s.ctx.Done(): - return s.ctx.Err() - case msg := <-s.receiveShutdown: - msgs = append(msgs, msg) + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + if s.isAllShutdown() { + return nil + } } } +} - log.Info(s.ctx, "All clients shutdown successfully 🎉", z.Any("clients", len(msgs))) +// isConnected returns the shared connected state for the peer. +func (s *Server) isConnected(pID peer.ID) bool { + s.mu.Lock() + defer s.mu.Unlock() - return nil -} + _, ok := s.connected[pID] -// NewServer registers a Stream Handler and returns a new Server instance. -// TODO(dhruv): remove this nolint once we have everything in place -// nolint:gocognit -func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash []byte, onFailure func()) *Server { - server := &Server{ - ctx: ctx, - tcpNode: tcpNode, - peers: peers, - onFailure: onFailure, - dedupResponse: make(map[peer.ID]bool), - receiveChan: make(chan result, len(peers)), - receiveShutdown: make(chan result, len(peers)), - } + return ok +} - knownPeers := make(map[peer.ID]bool) - for _, peer := range peers { - knownPeers[peer.ID] = true - } +// setConnected sets the shared connected state for the peer. +func (s *Server) setConnected(pID peer.ID) { + s.mu.Lock() + defer s.mu.Unlock() - server.tcpNode.SetStreamHandler(syncProtoID, func(s network.Stream) { - defer s.Close() - - // TODO(dhruv): introduce timeout to break the loop - for { - before := time.Now() - pID := s.Conn().RemotePeer() - if !knownPeers[pID] { - // Ignoring unknown peer - log.Warn(ctx, "Ignoring unknown client", nil, z.Any("client", p2p.PeerName(pID))) - return - } + s.connected[pID] = struct{}{} +} - buf := bufio.NewReader(s) - b := make([]byte, MsgSize) - // n is the number of bytes read from buffer, if n < MsgSize the other bytes will be 0 - n, err := buf.Read(b) - if err != nil { - log.Error(ctx, "Read client msg from stream", err, z.Any("client", p2p.PeerName(pID))) - return - } +// clearConnected clears the shared connected state for the peer. +func (s *Server) clearConnected(pID peer.ID) { + s.mu.Lock() + defer s.mu.Unlock() - // The first `n` bytes that are read are the most important - b = b[:n] + delete(s.connected, pID) +} - msg := new(pb.MsgSync) - if err := proto.Unmarshal(b, msg); err != nil { - log.Error(ctx, "Unmarshal client msg", err) - return - } +// setShutdown sets the shared shutdown state for the peer. +func (s *Server) setShutdown(pID peer.ID) { + s.mu.Lock() + defer s.mu.Unlock() - log.Debug(ctx, "Message received from client", z.Any("client", p2p.PeerName(pID))) + s.shutdown[pID] = struct{}{} +} - if msg.Shutdown { - resp := &pb.MsgSyncResponse{ - SyncTimestamp: msg.Timestamp, - Error: "", - } +// isAllConnected returns if all expected peers are connected. +func (s *Server) isAllConnected() bool { + s.mu.Lock() + defer s.mu.Unlock() - resBytes, err := proto.Marshal(resp) - if err != nil { - log.Error(ctx, "Marshal server response", err) - return - } + return len(s.connected) == s.allCount +} - _, err = s.Write(resBytes) - if err != nil { - log.Error(ctx, "Send response to client", err, z.Any("client", p2p.PeerName(pID))) - return - } - server.receiveShutdown <- result{shutdown: true} +// isAllShutdown returns if all expected peers are shutdown. +func (s *Server) isAllShutdown() bool { + s.mu.Lock() + defer s.mu.Unlock() - continue - } + return len(s.shutdown) == s.allCount +} - pubkey, err := pID.ExtractPublicKey() - if err != nil { - log.Error(ctx, "Get client public key", err) - return - } +// handleStream serves a new long-lived client connection. +func (s *Server) handleStream(ctx context.Context, stream network.Stream) error { + defer stream.Close() - ok, err := pubkey.Verify(defHash, msg.HashSignature) - if err != nil { - log.Error(ctx, "Verify defHash signature", err) - return - } + pID := stream.Conn().RemotePeer() + pubkey, err := pID.ExtractPublicKey() + if err != nil { + return errors.Wrap(err, "extract pubkey") + } - resp := &pb.MsgSyncResponse{ - SyncTimestamp: msg.Timestamp, - Error: "", - } + defer s.clearConnected(pID) - if !ok { - resp.Error = InvalidSig - } + for { + // Read next sync message + msg := new(pb.MsgSync) + if err := readSizedProto(stream, msg); err != nil { + return err + } - resBytes, err := proto.Marshal(resp) - if err != nil { - log.Error(ctx, "Marshal server response", err) - return - } + if msg.Shutdown { + s.setShutdown(pID) + } - _, err = s.Write(resBytes) - if err != nil { - log.Error(ctx, "Send response to client", err, z.Any("client", p2p.PeerName(pID))) - return - } + // Prep response + resp := &pb.MsgSyncResponse{ + SyncTimestamp: msg.Timestamp, + } - if server.dedupResponse[pID] { - log.Debug(ctx, "Ignoring duplicate message", z.Any("client", p2p.PeerName(pID))) - continue - } + // Verify definition hash + ok, err := pubkey.Verify(s.defHash, msg.HashSignature) + if err != nil { + return errors.Wrap(err, "verify sig hash") + } else if !ok { + resp.Error = errInvalidSig + log.Error(ctx, "Received mismatching cluster definition hash from peer", nil) + } else if ok && !s.isConnected(pID) { + log.Info(ctx, "Connected to peer (inbound)") + s.setConnected(pID) + } - if resp.Error == "" && !server.dedupResponse[pID] { - // TODO(dhruv): This is temporary solution to avoid race condition of concurrent writes to map, figure out something permanent. - server.mu.Lock() - server.dedupResponse[pID] = true - server.mu.Unlock() + // Write response message + if err := writeSizedProto(stream, resp); err != nil { + return err + } - server.receiveChan <- result{ - rtt: time.Since(before), - timestamp: msg.Timestamp.String(), - } - } + if msg.Shutdown { + return nil + } + } +} - log.Debug(ctx, "Send response to client", z.Any("client", p2p.PeerName(pID))) +// Start registers sync protocol with the libp2p host. +func (s *Server) Start(ctx context.Context) { + s.tcpNode.SetStreamHandler(protocolID, func(stream network.Stream) { + ctx = log.WithCtx(ctx, z.Str("peer", p2p.PeerName(stream.Conn().RemotePeer()))) + err := s.handleStream(ctx, stream) + if err != nil { + log.Warn(ctx, "Error serving sync protocol", err) } }) +} + +// writeSizedProto writes a size prefixed proto message. +func writeSizedProto(writer io.Writer, msg proto.Message) error { + buf, err := proto.Marshal(msg) + if err != nil { + return errors.Wrap(err, "marshal proto") + } - return server + size := int64(len(buf)) + err = binary.Write(writer, binary.LittleEndian, size) + if err != nil { + return errors.Wrap(err, "read size") + } + + n, err := writer.Write(buf) + if err != nil { + return errors.Wrap(err, "write message") + } else if int64(n) != size { + return errors.New("unexpected message length") + } + + return nil +} + +// readSizedProto reads a size prefixed proto message. +func readSizedProto(reader io.Reader, msg proto.Message) error { + var size int64 + err := binary.Read(reader, binary.LittleEndian, &size) + if err != nil { + return errors.Wrap(err, "read size") + } + + buf := make([]byte, size) + n, err := reader.Read(buf) + if err != nil { + return errors.Wrap(err, "read buffer") + } else if int64(n) != size { + return errors.New("unexpected message length") + } + + err = proto.Unmarshal(buf, msg) + if err != nil { + return errors.Wrap(err, "unmarshal proto") + } + + return nil } diff --git a/dkg/sync/sync_internal_test.go b/dkg/sync/sync_internal_test.go deleted file mode 100644 index 605dd41a6..000000000 --- a/dkg/sync/sync_internal_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright © 2022 Obol Labs Inc. -// -// This program is free software: you can redistribute it and/or modify it -// under the terms of the GNU General Public License as published by the Free -// Software Foundation, either version 3 of the License, or (at your option) -// any later version. -// -// This program is distributed in the hope that it will be useful, but WITHOUT -// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for -// more details. -// -// You should have received a copy of the GNU General Public License along with -// this program. If not, see . - -package sync - -import ( - "context" - "crypto/ecdsa" - "fmt" - "math/rand" - "testing" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/libp2p/go-libp2p" - libp2pcrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" - "github.com/stretchr/testify/require" - - "github.com/obolnetwork/charon/app/log" - "github.com/obolnetwork/charon/p2p" - "github.com/obolnetwork/charon/testutil" -) - -func TestNaiveServerClient(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Start Server - serverHost, _ := newSyncHost(t, 0) - - // Start Client - clientHost, key := newSyncHost(t, 1) - require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String()) - - err := serverHost.Connect(ctx, peer.AddrInfo{ - ID: clientHost.ID(), - Addrs: clientHost.Addrs(), - }) - require.NoError(t, err) - - hash := testutil.RandomBytes32() - hashSig, err := key.Sign(hash) - require.NoError(t, err) - - serverCtx := log.WithTopic(ctx, "server") - _ = NewServer(serverCtx, serverHost, []p2p.Peer{{ID: clientHost.ID()}}, hash, nil) - - clientCtx := log.WithTopic(ctx, "client") - client := NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hashSig, nil) - - for i := 0; i < 5; i++ { - actual, ok := <-client.results - require.True(t, ok) - require.NoError(t, actual.error) - t.Log("rtt is: ", actual.rtt) - } -} - -func newSyncHost(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) { - t.Helper() - - key, err := ecdsa.GenerateKey(crypto.S256(), rand.New(rand.NewSource(seed))) - require.NoError(t, err) - - priv, err := libp2pcrypto.UnmarshalSecp256k1PrivateKey(crypto.FromECDSA(key)) - require.NoError(t, err) - - addr := testutil.AvailableAddr(t) - multiAddr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/ip4/%s/tcp/%d", addr.IP, addr.Port)) - require.NoError(t, err) - - host, err := libp2p.New(libp2p.ListenAddrs(multiAddr), libp2p.Identity(priv)) - require.NoError(t, err) - - return host, priv -} diff --git a/dkg/sync/sync_test.go b/dkg/sync/sync_test.go index 9c5b5de74..8c6fb6ec2 100644 --- a/dkg/sync/sync_test.go +++ b/dkg/sync/sync_test.go @@ -21,6 +21,7 @@ import ( "fmt" "math/rand" "testing" + "time" "github.com/ethereum/go-ethereum/crypto" "github.com/libp2p/go-libp2p" @@ -32,126 +33,120 @@ import ( "github.com/obolnetwork/charon/app/log" "github.com/obolnetwork/charon/dkg/sync" - "github.com/obolnetwork/charon/p2p" "github.com/obolnetwork/charon/testutil" ) -//go:generate go test . -run=TestAwaitAllConnected -race - -func TestAwaitConnected(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Start Server - serverHost, _ := newSyncHost(t, 0) - - // Start Client - clientHost, key := newSyncHost(t, 1) - require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String()) - - err := serverHost.Connect(ctx, peer.AddrInfo{ - ID: clientHost.ID(), - Addrs: clientHost.Addrs(), +func TestSyncProtocol(t *testing.T) { + t.Run("2", func(t *testing.T) { + testCluster(t, 2) }) - require.NoError(t, err) - - hash := testutil.RandomBytes32() - hashSig, err := key.Sign(hash) - require.NoError(t, err) - - serverCtx := log.WithTopic(ctx, "server") - _ = sync.NewServer(serverCtx, serverHost, []p2p.Peer{{ID: clientHost.ID()}}, hash, nil) - clientCtx := log.WithTopic(context.Background(), "client") - client := sync.NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hashSig, nil) + t.Run("3", func(t *testing.T) { + testCluster(t, 3) + }) - require.NoError(t, client.AwaitConnected()) + t.Run("5", func(t *testing.T) { + testCluster(t, 5) + }) } -func TestAwaitAllConnected(t *testing.T) { +func testCluster(t *testing.T, n int) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const numClients = 3 - server, clients := testGetServerAndClients(t, ctx, numClients) - - for i := 0; i < numClients; i++ { - require.NoError(t, clients[i].AwaitConnected()) - } - - require.NoError(t, server.AwaitAllConnected()) -} + hash := testutil.RandomBytes32() -func TestAwaitAllShutdown(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + var ( + tcpNodes []host.Host + servers []*sync.Server + clients []*sync.Client + keys []libp2pcrypto.PrivKey + ) + for i := 0; i < n; i++ { + tcpNode, key := newTCPNode(t, int64(i)) + tcpNodes = append(tcpNodes, tcpNode) + keys = append(keys, key) - const numClients = 3 - server, clients := testGetServerAndClients(t, ctx, numClients) + server := sync.NewServer(tcpNode, n-1, hash) + servers = append(servers, server) + } - for i := 0; i < numClients; i++ { - require.NoError(t, clients[i].Shutdown()) + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + if i == j { + continue + } + err := tcpNodes[i].Connect(ctx, peer.AddrInfo{ + ID: tcpNodes[j].ID(), + Addrs: tcpNodes[j].Addrs(), + }) + require.NoError(t, err) + + hashSig, err := keys[i].Sign(hash) + require.NoError(t, err) + + client := sync.NewClient(tcpNodes[i], tcpNodes[j].ID(), hashSig) + clients = append(clients, client) + + ctx := log.WithTopic(ctx, fmt.Sprintf("client%d_%d", i, j)) + go func() { + err := client.Run(ctx) + require.NoError(t, err) + }() + } } - require.NoError(t, server.AwaitAllShutdown()) -} + time.Sleep(time.Millisecond) // Wait a bit before starting servers -func testGetServerAndClients(t *testing.T, ctx context.Context, num int) (*sync.Server, []*sync.Client) { - t.Helper() + for i, server := range servers { + server.Start(log.WithTopic(ctx, fmt.Sprintf("server%d", i))) + } - seed := 0 - serverHost, _ := newSyncHost(t, int64(seed)) - var ( - peers []p2p.Peer - keys []libp2pcrypto.PrivKey - clientHosts []host.Host - ) - for i := 0; i < num; i++ { - seed++ - clientHost, key := newSyncHost(t, int64(seed)) - require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String()) - - err := serverHost.Connect(ctx, peer.AddrInfo{ - ID: clientHost.ID(), - Addrs: clientHost.Addrs(), - }) + t.Log("client.AwaitConnected") + for _, client := range clients { + err := client.AwaitConnected(ctx) require.NoError(t, err) - - clientHosts = append(clientHosts, clientHost) - keys = append(keys, key) - peers = append(peers, p2p.Peer{ID: clientHost.ID()}) } - hash := testutil.RandomBytes32() - server := sync.NewServer(log.WithTopic(ctx, "server"), serverHost, peers, hash, nil) - - var clients []*sync.Client - for i := 0; i < num; i++ { - hashSig, err := keys[i].Sign(hash) + t.Log("server.AwaitAllConnected") + for _, server := range servers { + err := server.AwaitAllConnected(ctx) require.NoError(t, err) - - client := sync.NewClient(log.WithTopic(context.Background(), "client"), clientHosts[i], p2p.Peer{ID: serverHost.ID()}, hashSig, nil) - clients = append(clients, client) } - return server, clients + go func() { + t.Log("client.Shutdown") + for _, client := range clients { + err := client.Shutdown(ctx) + require.NoError(t, err) + } + }() + + t.Log("server.AwaitAllShutdown") + for _, server := range servers { + err := server.AwaitAllShutdown(ctx) + require.NoError(t, err) + } } -func newSyncHost(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) { +func newTCPNode(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) { t.Helper() key, err := ecdsa.GenerateKey(crypto.S256(), rand.New(rand.NewSource(seed))) require.NoError(t, err) - priv, err := libp2pcrypto.UnmarshalSecp256k1PrivateKey(crypto.FromECDSA(key)) - require.NoError(t, err) - addr := testutil.AvailableAddr(t) multiAddr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/ip4/%s/tcp/%d", addr.IP, addr.Port)) require.NoError(t, err) - host, err := libp2p.New(libp2p.ListenAddrs(multiAddr), libp2p.Identity(priv)) + priv, err := libp2pcrypto.UnmarshalSecp256k1PrivateKey(crypto.FromECDSA(key)) + require.NoError(t, err) + + tcpNode, err := libp2p.New(libp2p.ListenAddrs(multiAddr), libp2p.Identity(priv)) + testutil.SkipIfBindErr(t, err) require.NoError(t, err) - return host, priv + return tcpNode, priv }