-
Notifications
You must be signed in to change notification settings - Fork 669
/
validator.go
148 lines (128 loc) · 3.87 KB
/
validator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package warp
import (
"bytes"
"context"
"errors"
"fmt"
"golang.org/x/exp/maps"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/snow/validators"
"github.com/ava-labs/avalanchego/utils"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/math"
"github.com/ava-labs/avalanchego/utils/set"
)
var (
_ utils.Sortable[*Validator] = (*Validator)(nil)
ErrUnknownValidator = errors.New("unknown validator")
ErrWeightOverflow = errors.New("weight overflowed")
)
// ValidatorState defines the functions that must be implemented to get
// the canonical validator set for warp message validation.
type ValidatorState interface {
GetValidatorSet(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error)
}
type Validator struct {
PublicKey *bls.PublicKey
PublicKeyBytes []byte
Weight uint64
NodeIDs []ids.NodeID
}
func (v *Validator) Less(o *Validator) bool {
return bytes.Compare(v.PublicKeyBytes, o.PublicKeyBytes) < 0
}
// GetCanonicalValidatorSet returns the validator set of [subnetID] at
// [pChcainHeight] in a canonical ordering. Also returns the total weight on
// [subnetID].
func GetCanonicalValidatorSet(
ctx context.Context,
pChainState ValidatorState,
pChainHeight uint64,
subnetID ids.ID,
) ([]*Validator, uint64, error) {
// Get the validator set at the given height.
vdrSet, err := pChainState.GetValidatorSet(ctx, pChainHeight, subnetID)
if err != nil {
return nil, 0, fmt.Errorf("failed to fetch validator set (P-Chain Height: %d, SubnetID: %s): %w", pChainHeight, subnetID, err)
}
var (
vdrs = make(map[string]*Validator, len(vdrSet))
totalWeight uint64
)
for _, vdr := range vdrSet {
totalWeight, err = math.Add64(totalWeight, vdr.Weight)
if err != nil {
return nil, 0, fmt.Errorf("%w: %w", ErrWeightOverflow, err)
}
if vdr.PublicKey == nil {
continue
}
pkBytes := bls.SerializePublicKey(vdr.PublicKey)
uniqueVdr, ok := vdrs[string(pkBytes)]
if !ok {
uniqueVdr = &Validator{
PublicKey: vdr.PublicKey,
PublicKeyBytes: pkBytes,
}
vdrs[string(pkBytes)] = uniqueVdr
}
uniqueVdr.Weight += vdr.Weight // Impossible to overflow here
uniqueVdr.NodeIDs = append(uniqueVdr.NodeIDs, vdr.NodeID)
}
// Sort validators by public key
vdrList := maps.Values(vdrs)
utils.Sort(vdrList)
return vdrList, totalWeight, nil
}
// FilterValidators returns the validators in [vdrs] whose bit is set to 1 in
// [indices].
//
// Returns an error if [indices] references an unknown validator.
func FilterValidators(
indices set.Bits,
vdrs []*Validator,
) ([]*Validator, error) {
// Verify that all alleged signers exist
if indices.BitLen() > len(vdrs) {
return nil, fmt.Errorf(
"%w: NumIndices (%d) >= NumFilteredValidators (%d)",
ErrUnknownValidator,
indices.BitLen()-1, // -1 to convert from length to index
len(vdrs),
)
}
filteredVdrs := make([]*Validator, 0, len(vdrs))
for i, vdr := range vdrs {
if !indices.Contains(i) {
continue
}
filteredVdrs = append(filteredVdrs, vdr)
}
return filteredVdrs, nil
}
// SumWeight returns the total weight of the provided validators.
func SumWeight(vdrs []*Validator) (uint64, error) {
var (
weight uint64
err error
)
for _, vdr := range vdrs {
weight, err = math.Add64(weight, vdr.Weight)
if err != nil {
return 0, fmt.Errorf("%w: %w", ErrWeightOverflow, err)
}
}
return weight, nil
}
// AggregatePublicKeys returns the public key of the provided validators.
//
// Invariant: All of the public keys in [vdrs] are valid.
func AggregatePublicKeys(vdrs []*Validator) (*bls.PublicKey, error) {
pks := make([]*bls.PublicKey, len(vdrs))
for i, vdr := range vdrs {
pks[i] = vdr.PublicKey
}
return bls.AggregatePublicKeys(pks)
}