Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tbls: revert changes to TSS struct #149

Merged
merged 2 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

eth2client "github.com/attestantio/go-eth2-client"
eth2p0 "github.com/attestantio/go-eth2-client/spec/phase0"
"github.com/coinbase/kryptology/pkg/signatures/bls/bls_sig"

"github.com/obolnetwork/charon/app/errors"
"github.com/obolnetwork/charon/app/log"
Expand All @@ -39,24 +40,24 @@ type eth2Provider interface {
eth2client.ProposerDutiesProvider
}

func New(manifest types.Manifest, eth2Svc eth2client.Service) (*Scheduler, error) {
func New(pubkeys []bls_sig.PublicKey, eth2Svc eth2client.Service) (*Scheduler, error) {
eth2Cl, ok := eth2Svc.(eth2Provider)
if !ok {
return nil, errors.New("invalid eth2 client service")
}

return &Scheduler{
eth2Cl: eth2Cl,
manifest: manifest,
quit: make(chan struct{}),
duties: make(map[types.Duty]types.DutyArgSet),
eth2Cl: eth2Cl,
pubkeys: pubkeys,
quit: make(chan struct{}),
duties: make(map[types.Duty]types.DutyArgSet),
}, nil
}

type Scheduler struct {
eth2Cl eth2Provider
manifest types.Manifest
quit chan struct{}
eth2Cl eth2Provider
pubkeys []bls_sig.PublicKey
quit chan struct{}

duties map[types.Duty]types.DutyArgSet
subs []func(context.Context, types.Duty, types.DutyArgSet) error
Expand Down Expand Up @@ -143,16 +144,13 @@ func (s *Scheduler) scheduleSlot(ctx context.Context, slot slot) error {
}

func (s *Scheduler) resolveDuties(ctx context.Context, slot slot) error {
// Overwrite slot, since we normally fetch for a future slot
ctx = log.WithCtx(ctx, z.I64("slot", slot.Slot))

dvs, indexes, err := resolveActiveDVs(ctx, s.eth2Cl, s.manifest, slot.Slot)
dvs, indexes, err := resolveActiveDVs(ctx, s.eth2Cl, s.pubkeys, slot.Slot)
if err != nil {
return err
}

if len(dvs) == 0 {
log.Debug(ctx, "No active DVs for slot")
log.Debug(ctx, "No active DVs for slot", z.I64("slot", slot.Slot))
return nil
}

Expand Down Expand Up @@ -180,12 +178,14 @@ func (s *Scheduler) resolveDuties(ctx context.Context, slot slot) error {
if !ok {
argSet = make(types.DutyArgSet)
}

argSet[types.VIdx(attDuty.ValidatorIndex)] = b
s.duties[duty] = argSet

log.Debug(ctx, "Resolved attester duty",
z.U64("epoch", uint64(slot.Epoch())),
z.U64("vidx", uint64(attDuty.ValidatorIndex)),
z.U64("slot", uint64(attDuty.Slot)),
z.U64("commidx", uint64(attDuty.CommitteeIndex)))
}
}
Expand Down Expand Up @@ -266,11 +266,11 @@ func newSlotTicker(ctx context.Context, eth2Cl eth2Provider) (<-chan slot, error

// resolveActiveDVs returns the active validators for the slot (in two different formats).
func resolveActiveDVs(ctx context.Context, eth2Cl eth2Provider,
manifest types.Manifest, slot int64,
pubkeys []bls_sig.PublicKey, slot int64,
) ([]types.VIdx, []eth2p0.ValidatorIndex, error) {
var pubkeys []eth2p0.BLSPubKey
for _, dv := range manifest.DVs {
b, err := dv.PublicKey.MarshalBinary()
var e2pks []eth2p0.BLSPubKey
for _, pubkey := range pubkeys {
b, err := pubkey.MarshalBinary()
if err != nil {
return nil, nil, errors.Wrap(err, "marshal pubkey")
}
Expand All @@ -281,15 +281,15 @@ func resolveActiveDVs(ctx context.Context, eth2Cl eth2Provider,
return nil, nil, errors.New("invalid pubkey")
}

pubkeys = append(pubkeys, e2pk)
e2pks = append(e2pks, e2pk)
}

state := fmt.Sprint(slot)
if slot == 0 {
state = "head"
}

vals, err := eth2Cl.ValidatorsByPubKey(ctx, state, pubkeys)
vals, err := eth2Cl.ValidatorsByPubKey(ctx, state, e2pks)
if err != nil {
return nil, nil, err
}
Expand Down
27 changes: 11 additions & 16 deletions scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/stretchr/testify/require"

"github.com/obolnetwork/charon/scheduler"
"github.com/obolnetwork/charon/tbls"
"github.com/obolnetwork/charon/types"
)

Expand Down Expand Up @@ -57,17 +56,15 @@ func TestIntegration(t *testing.T) {
require.NoError(t, err)

// Use random actual mainnet validators
manifest := types.Manifest{
DVs: []tbls.TSS{
pkFromHex(t, "0x914cff835a769156ba43ad50b931083c2dadd94e8359ce394bc7a3e06424d0214922ddf15f81640530b9c25c0bc0d490"),
pkFromHex(t, "0x8dae41352b69f2b3a1c0b05330c1bf65f03730c520273028864b11fcb94d8ce8f26d64f979a0ee3025467f45fd2241ea"),
pkFromHex(t, "0x8ee91545183c8c2db86633626f5074fd8ef93c4c9b7a2879ad1768f600c5b5906c3af20d47de42c3b032956fa8db1a76"),
pkFromHex(t, "0xa8785ecbb5c030e5da6cbbacc3e6cad39dffbc7bcf7f223a12844db8c1182603df99f673157f0d27912a53546e0f64fe"),
pkFromHex(t, "0xb790b322e1cce41c48e3c344cf8d752bdc3cfd51e8eeef44a4bdaac081bc92b53b73e823a9878b5d7a532eb9d9dce1e3"),
},
pubkeys := []bls_sig.PublicKey{
pkFromHex(t, "0x914cff835a769156ba43ad50b931083c2dadd94e8359ce394bc7a3e06424d0214922ddf15f81640530b9c25c0bc0d490"),
pkFromHex(t, "0x8dae41352b69f2b3a1c0b05330c1bf65f03730c520273028864b11fcb94d8ce8f26d64f979a0ee3025467f45fd2241ea"),
pkFromHex(t, "0x8ee91545183c8c2db86633626f5074fd8ef93c4c9b7a2879ad1768f600c5b5906c3af20d47de42c3b032956fa8db1a76"),
pkFromHex(t, "0xa8785ecbb5c030e5da6cbbacc3e6cad39dffbc7bcf7f223a12844db8c1182603df99f673157f0d27912a53546e0f64fe"),
pkFromHex(t, "0xb790b322e1cce41c48e3c344cf8d752bdc3cfd51e8eeef44a4bdaac081bc92b53b73e823a9878b5d7a532eb9d9dce1e3"),
}

s, err := scheduler.New(manifest, eth2Cl)
s, err := scheduler.New(pubkeys, eth2Cl)
require.NoError(t, err)

count := 10
Expand All @@ -91,19 +88,17 @@ func TestIntegration(t *testing.T) {
require.NoError(t, s.Run())
}

func pkFromHex(t *testing.T, pk string) tbls.TSS {
func pkFromHex(t *testing.T, pk string) bls_sig.PublicKey {
t.Helper()

pk = strings.TrimPrefix(pk, "0x")

b, err := hex.DecodeString(pk)
require.NoError(t, err)

pubkey := new(bls_sig.PublicKey)
err = pubkey.UnmarshalBinary(b)
var pubkey bls_sig.PublicKey
err = (&pubkey).UnmarshalBinary(b)
require.NoError(t, err)

return tbls.TSS{
PublicKey: pubkey,
}
return pubkey
}
41 changes: 29 additions & 12 deletions tbls/tss.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,31 @@ type PubShare struct {
// TSS (threshold signing scheme) wraps PubKey (PublicKey), Verifiers (the public shares corresponding to each secret share)
// and threshold (number of shares).
type TSS struct {
Verifier *share.FeldmanVerifier
NumShares int
verifier *share.FeldmanVerifier
numShares int

// PublicKey and Threshold are inferred from verifier commitments in NewTSS.
// publicKey inferred from verifier commitments in NewTSS.
publicKey *bls_sig.PublicKey
}
Comment on lines 40 to +46
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactoring to private fields, with getter methods. To make it explicit that you should use NewTSS for constructing.


// Verifier returns the feldman verifier containing the public shares of the threshold signature scheme.
func (t TSS) Verifier() *share.FeldmanVerifier {
return t.verifier
}

// NumShares returns the number of shares in the threshold signature scheme.
func (t TSS) NumShares() int {
return t.numShares
}

// PublicKey returns the threshold signature scheme's root public key.
func (t TSS) PublicKey() *bls_sig.PublicKey {
return t.publicKey
}

PublicKey *bls_sig.PublicKey
Threshold int
// Threshold returns the minimum number of partial signatures required to aggregate the threshold signature.
func (t TSS) Threshold() int {
return len(t.verifier.Commitments)
}

func NewTSS(verifier *share.FeldmanVerifier, numShares int) (TSS, error) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is good to still calculate the public key on construction:

  • Then we fail fast if we cannot calculate it, instead of later.
  • Reduces error handling in general.

Expand All @@ -55,10 +73,9 @@ func NewTSS(verifier *share.FeldmanVerifier, numShares int) (TSS, error) {
}

return TSS{
Verifier: verifier,
PublicKey: pk,
NumShares: numShares,
Threshold: len(verifier.Commitments),
verifier: verifier,
publicKey: pk,
numShares: numShares,
}, nil
}

Expand Down Expand Up @@ -88,7 +105,7 @@ func GenerateTSS(t, n int, reader io.Reader) (TSS, []*bls_sig.SecretKeyShare, er
// AggregateSignatures aggregates partial signatures over the given message.
// Returns aggregated signatures and slice of signers identifiers that had valid partial signatures.
func AggregateSignatures(tss TSS, partialSigs []*bls_sig.PartialSignature, msg []byte) (*bls_sig.Signature, []byte, error) {
if len(partialSigs) < tss.Threshold {
if len(partialSigs) < tss.Threshold() {
return nil, nil, errors.New("insufficient signatures")
}

Expand All @@ -99,7 +116,7 @@ func AggregateSignatures(tss TSS, partialSigs []*bls_sig.PartialSignature, msg [

for _, psig := range partialSigs {
// TODO(dhruv): add break condition if valid shares >= threshold
pubShare, err := getPubShare(uint32(psig.Identifier), tss.Verifier)
pubShare, err := getPubShare(uint32(psig.Identifier), tss.Verifier())
if err != nil {
return nil, nil, errors.Wrap(err, "get Public Share")
}
Expand All @@ -113,7 +130,7 @@ func AggregateSignatures(tss TSS, partialSigs []*bls_sig.PartialSignature, msg [
signers = append(signers, psig.Identifier)
}

if len(validShares) < tss.Threshold {
if len(validShares) < tss.Threshold() {
return nil, nil, errors.New("insufficient valid signatures")
}

Expand Down
6 changes: 3 additions & 3 deletions tbls/tss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func TestGenerateTSS(t *testing.T) {
require.NotNil(t, tss)
require.NotNil(t, secrets)

require.Equal(t, threshold, tss.Threshold)
require.Equal(t, shares, tss.NumShares)
require.Equal(t, threshold, tss.Threshold())
require.Equal(t, shares, tss.NumShares())
}

func TestAggregateSignatures(t *testing.T) {
Expand All @@ -58,7 +58,7 @@ func TestAggregateSignatures(t *testing.T) {
sig, _, err := tbls.AggregateSignatures(tss, partialSigs, msg)
require.NoError(t, err)

result, err := tbls.Verify(tss.PublicKey, msg, sig)
result, err := tbls.Verify(tss.PublicKey(), msg, sig)
require.NoError(t, err)
require.Equal(t, true, result)
}
20 changes: 16 additions & 4 deletions types/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/coinbase/kryptology/pkg/core/curves"
"github.com/coinbase/kryptology/pkg/sharing"
"github.com/coinbase/kryptology/pkg/signatures/bls/bls_sig"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
Expand Down Expand Up @@ -100,6 +101,17 @@ func (m Manifest) PeerIDs() []peer.ID {
return res
}

// PublicKeys is a convenience function that returns the DV root public keys.
func (m Manifest) PublicKeys() []*bls_sig.PublicKey {
res := make([]*bls_sig.PublicKey, 0, len(m.DVs))

for _, tss := range m.DVs {
res = append(res, tss.PublicKey())
}

return res
}

func (m Manifest) MarshalJSON() ([]byte, error) {
var enrs []string
for _, p := range m.Peers {
Expand All @@ -113,16 +125,16 @@ func (m Manifest) MarshalJSON() ([]byte, error) {

var dvs []dvJSON
for _, tss := range m.DVs {
if len(m.Peers) != tss.NumShares {
if len(m.Peers) != tss.NumShares() {
return nil, errors.New("dv shares and peers mismatch")
}

var verifiers [][]byte
for _, c := range tss.Verifier.Commitments {
for _, c := range tss.Verifier().Commitments {
verifiers = append(verifiers, c.ToAffineCompressed())
}

rawPK, err := tss.PublicKey.MarshalBinary()
rawPK, err := tss.PublicKey().MarshalBinary()
if err != nil {
return nil, errors.Wrap(err, "marshal pubkey")
}
Expand Down Expand Up @@ -257,7 +269,7 @@ func getDescription(m Manifest) string {

var threshold int
if dv > 0 {
threshold = m.DVs[0].Threshold
threshold = m.DVs[0].Threshold()
}

return fmt.Sprintf("dv/%d/threshold/%d/peer/%d", dv, threshold, peers)
Expand Down
6 changes: 3 additions & 3 deletions types/manifest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ func TestManifestJSON(t *testing.T) {
for i := 0; i < len(manifest.DVs); i++ {
tss1 := manifest.DVs[i]
tss2 := manifest2.DVs[i]
require.Equal(t, tss1.NumShares, tss2.NumShares)
require.Equal(t, tss1.Verifier, tss2.Verifier)
require.Equal(t, tss1.PublicKey, tss2.PublicKey)
require.Equal(t, tss1.NumShares(), tss2.NumShares())
require.Equal(t, tss1.Verifier(), tss2.Verifier())
require.Equal(t, tss1.PublicKey(), tss2.PublicKey())
}
}
}
Expand Down