Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dkg: add nil checks #2088

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dkg/dkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ func Run(ctx context.Context, conf Config) (err error) {
}
peerMap[p.ID] = nodeIdx
}
tp := newFrostP2P(tcpNode, peerMap, key, def.Threshold)
tp, err := newFrostP2P(tcpNode, peerMap, key, def.Threshold)
if err != nil {
return err
}

log.Info(ctx, "Waiting to connect to all peers...")

Expand Down
63 changes: 53 additions & 10 deletions dkg/frostp2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ var (
)

// newFrostP2P returns a p2p frost transport implementation.
func newFrostP2P(tcpNode host.Host, peers map[peer.ID]cluster.NodeIdx, secret *k1.PrivateKey, threshold int) *frostP2P {
func newFrostP2P(tcpNode host.Host, peers map[peer.ID]cluster.NodeIdx, secret *k1.PrivateKey, threshold int) (*frostP2P, error) {
if secret == nil {
return nil, errors.New("secret cannot be nil")
}
dB2510 marked this conversation as resolved.
Show resolved Hide resolved

var (
round1CastsRecv = make(chan *pb.FrostRound1Casts, len(peers))
round1P2PRecv = make(chan *pb.FrostRound1P2P, len(peers))
Expand Down Expand Up @@ -122,6 +126,10 @@ func newFrostP2P(tcpNode host.Host, peers map[peer.ID]cluster.NodeIdx, secret *k
p2p.RegisterHandler("frost", tcpNode, round1P2PID,
func() proto.Message { return new(pb.FrostRound1P2P) },
func(ctx context.Context, pID peer.ID, req proto.Message) (proto.Message, bool, error) {
if req == nil {
dB2510 marked this conversation as resolved.
Show resolved Hide resolved
return nil, false, errors.New("req proto message cannot be nil")
}

mu.Lock()
defer mu.Unlock()

Expand Down Expand Up @@ -158,7 +166,7 @@ func newFrostP2P(tcpNode host.Host, peers map[peer.ID]cluster.NodeIdx, secret *k
round1CastsRecv: round1CastsRecv,
round1P2PRecv: round1P2PRecv,
round2CastsRecv: round2CastsRecv,
}
}, nil
}

// frostP2P implements frost transport.
Expand Down Expand Up @@ -293,7 +301,11 @@ func makeRound1Response(casts []*pb.FrostRound1Casts, p2ps []*pb.FrostRound1P2P)

for _, msg := range p2ps {
for _, sharePB := range msg.Shares {
key, share := shamirShareFromProto(sharePB)
key, share, err := shamirShareFromProto(sharePB)
if err != nil {
return nil, nil, err
}

p2pMap[key] = share
}
}
Expand Down Expand Up @@ -325,11 +337,20 @@ func shamirShareToProto(key msgKey, shamir sharing.ShamirShare) *pb.FrostRound1S
}
}

func shamirShareFromProto(shamir *pb.FrostRound1ShamirShare) (msgKey, sharing.ShamirShare) {
return keyFromProto(shamir.Key), sharing.ShamirShare{
func shamirShareFromProto(shamir *pb.FrostRound1ShamirShare) (msgKey, sharing.ShamirShare, error) {
if shamir == nil {
return msgKey{}, sharing.ShamirShare{}, errors.New("round 1 shamir share proto cannot be nil")
}

protoKey, err := keyFromProto(shamir.Key)
if err != nil {
return msgKey{}, sharing.ShamirShare{}, err
}

return protoKey, sharing.ShamirShare{
Id: shamir.Id,
Value: shamir.Value,
}
}, nil
}

func round1CastToProto(key msgKey, cast frost.Round1Bcast) *pb.FrostRound1Cast {
Expand All @@ -347,6 +368,10 @@ func round1CastToProto(key msgKey, cast frost.Round1Bcast) *pb.FrostRound1Cast {
}

func round1CastFromProto(cast *pb.FrostRound1Cast) (msgKey, frost.Round1Bcast, error) {
if cast == nil {
return msgKey{}, frost.Round1Bcast{}, errors.New("round 1 cast cannot be nil")
}

wi, err := curve.Scalar.SetBytes(cast.Wi)
if err != nil {
return msgKey{}, frost.Round1Bcast{}, errors.Wrap(err, "decode wi scalar")
Expand All @@ -366,7 +391,12 @@ func round1CastFromProto(cast *pb.FrostRound1Cast) (msgKey, frost.Round1Bcast, e
comms = append(comms, c)
}

return keyFromProto(cast.Key), frost.Round1Bcast{
key, err := keyFromProto(cast.Key)
if err != nil {
return msgKey{}, frost.Round1Bcast{}, err
}

return key, frost.Round1Bcast{
Wi: wi,
Ci: ci,
Verifiers: &sharing.FeldmanVerifier{Commitments: comms},
Expand All @@ -382,6 +412,10 @@ func round2CastToProto(key msgKey, cast frost.Round2Bcast) *pb.FrostRound2Cast {
}

func round2CastFromProto(cast *pb.FrostRound2Cast) (msgKey, frost.Round2Bcast, error) {
if cast == nil {
return msgKey{}, frost.Round2Bcast{}, errors.New("round 2 cast cannot be nil")
}

verificationKey, err := curve.Point.FromAffineCompressed(cast.VerificationKey)
if err != nil {
return msgKey{}, frost.Round2Bcast{}, errors.Wrap(err, "decode verification key scalar")
Expand All @@ -391,7 +425,12 @@ func round2CastFromProto(cast *pb.FrostRound2Cast) (msgKey, frost.Round2Bcast, e
return msgKey{}, frost.Round2Bcast{}, errors.Wrap(err, "decode c1 scalar")
}

return keyFromProto(cast.Key), frost.Round2Bcast{
key, err := keyFromProto(cast.Key)
if err != nil {
return msgKey{}, frost.Round2Bcast{}, err
}

return key, frost.Round2Bcast{
VerificationKey: verificationKey,
VkShare: vkShare,
}, nil
Expand All @@ -405,12 +444,16 @@ func keyToProto(key msgKey) *pb.FrostMsgKey {
}
}

func keyFromProto(key *pb.FrostMsgKey) msgKey {
func keyFromProto(key *pb.FrostMsgKey) (msgKey, error) {
if key == nil {
return msgKey{}, errors.New("frost msg key cannot be nil")
}

return msgKey{
ValIdx: key.ValIdx,
SourceID: key.SourceId,
TargetID: key.TargetId,
}
}, nil
}

// frostProtocol returns the frost protocol ID including the provided suffixes.
Expand Down