diff --git a/dkg/sync/client.go b/dkg/sync/client.go index faed0c6fe..5a425cd79 100644 --- a/dkg/sync/client.go +++ b/dkg/sync/client.go @@ -51,7 +51,7 @@ func (*Client) Shutdown() error { } // NewClient starts a goroutine that establishes a long lived connection to a p2p server and returns a new Client instance. -func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hash []byte, onFailure func(), ch chan *pb.MsgSyncResponse) Client { +func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig []byte, onFailure func(), ch chan *pb.MsgSyncResponse) Client { go func() { s, err := tcpNode.NewStream(ctx, server.ID, syncProtoID) if err != nil { @@ -61,7 +61,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hash []b msg := &pb.MsgSync{ Timestamp: timestamppb.Now(), - HashSignature: hash, + HashSignature: hashSig, Shutdown: false, } diff --git a/dkg/sync/server.go b/dkg/sync/server.go index 4b96d9854..df426d56c 100644 --- a/dkg/sync/server.go +++ b/dkg/sync/server.go @@ -17,7 +17,6 @@ package sync import ( "bufio" - "bytes" "context" "github.com/libp2p/go-libp2p-core/host" @@ -25,6 +24,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/obolnetwork/charon/app/log" + "github.com/obolnetwork/charon/app/z" pb "github.com/obolnetwork/charon/dkg/dkgpb/v1" "github.com/obolnetwork/charon/p2p" ) @@ -54,7 +54,7 @@ func (*Server) AwaitAllShutdown() error { } // NewServer registers a Stream Handler and returns a new Server instance. -func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, hash []byte, onFailure func()) *Server { +func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash []byte, onFailure func()) *Server { server := &Server{ ctx: ctx, tcpNode: tcpNode, @@ -90,12 +90,29 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, hash [] return } + pID := s.Conn().RemotePeer() + log.Debug(ctx, "Message received from client", z.Any("peer", p2p.PeerName(pID))) + + pubkey, err := pID.ExtractPublicKey() + if err != nil { + log.Error(ctx, "Get client public key", err) + err = s.Reset() + log.Error(ctx, "Stream reset", err) + } + + ok, err := pubkey.Verify(defHash, msg.HashSignature) + if err != nil { + log.Error(ctx, "Verify defHash signature", err) + err = s.Reset() + log.Error(ctx, "Stream reset", err) + } + resp := &pb.MsgSyncResponse{ SyncTimestamp: msg.Timestamp, Error: "", } - if !bytes.Equal(msg.HashSignature, hash) { + if !ok { resp.Error = "Invalid Signature" } diff --git a/dkg/sync/sync_internal_test.go b/dkg/sync/sync_internal_test.go index 7e81c3b8f..61b963ba1 100644 --- a/dkg/sync/sync_internal_test.go +++ b/dkg/sync/sync_internal_test.go @@ -40,10 +40,10 @@ func TestNaiveServerClient(t *testing.T) { ctx := context.Background() // Start Server - serverHost := newSyncHost(t, 0) + serverHost, _ := newSyncHost(t, 0) // Start Client - clientHost := newSyncHost(t, 1) + clientHost, key := newSyncHost(t, 1) require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String()) err := serverHost.Connect(ctx, peer.AddrInfo{ @@ -52,18 +52,22 @@ func TestNaiveServerClient(t *testing.T) { }) require.NoError(t, err) + hash := testutil.RandomBytes32() + hashSig, err := key.Sign(hash) + require.NoError(t, err) + serverCtx := log.WithTopic(ctx, "server") - hash := testutil.RandomCoreSignature() - ch := make(chan *pb.MsgSyncResponse) _ = NewServer(serverCtx, serverHost, nil, hash, nil) clientCtx := log.WithTopic(ctx, "client") - _ = NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hash, nil, ch) + ch := make(chan *pb.MsgSyncResponse) + _ = NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hashSig, nil, ch) + actual := <-ch require.Equal(t, "", actual.Error) } -func newSyncHost(t *testing.T, seed int64) host.Host { +func newSyncHost(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) { t.Helper() key, err := ecdsa.GenerateKey(crypto.S256(), rand.New(rand.NewSource(seed))) @@ -79,5 +83,5 @@ func newSyncHost(t *testing.T, seed int64) host.Host { host, err := libp2p.New(libp2p.ListenAddrs(multiAddr), libp2p.Identity(priv)) require.NoError(t, err) - return host + return host, priv }