/
maxcover.go
300 lines (271 loc) · 9.58 KB
/
maxcover.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
package attestations
import (
"sort"
"github.com/pkg/errors"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/v4/crypto/bls"
ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1/attestation/aggregation"
)
// MaxCoverAttestationAggregation relies on Maximum Coverage greedy algorithm for aggregation.
// Aggregation occurs in many rounds, up until no more aggregation is possible (all attestations
// are overlapping).
// See https://hackmd.io/@farazdagi/in-place-attagg for design and rationale.
func MaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Attestation, error) {
if len(atts) < 2 {
return atts, nil
}
if err := attList(atts).validate(); err != nil {
return nil, err
}
// In the future this conversion will be redundant, as attestation bitlist will be of a Bitlist64
// type, so incoming `atts` parameters can be used as candidates list directly.
candidates := make([]*bitfield.Bitlist64, len(atts))
for i := 0; i < len(atts); i++ {
var err error
candidates[i], err = atts[i].AggregationBits.ToBitlist64()
if err != nil {
return nil, err
}
}
coveredBitsSoFar := bitfield.NewBitlist64(candidates[0].Len())
// In order not to re-allocate anything we rely on the very same underlying array, which
// can only shrink (while the `aggregated` slice length can increase).
// The `aggregated` slice grows by combining individual attestations and appending to that slice.
// Both aggregated and non-aggregated slices operate on the very same underlying array.
aggregated := atts[:0]
unaggregated := atts
// Aggregation over n/2 rounds is enough to find all aggregatable items (exits earlier if there
// are many items that can be aggregated).
for i := 0; i < len(atts)/2; i++ {
if len(unaggregated) < 2 {
break
}
// Find maximum non-overlapping coverage for subset of still non-processed candidates.
roundCandidates := candidates[len(aggregated) : len(aggregated)+len(unaggregated)]
selectedKeys, coverage, err := aggregation.MaxCover(
roundCandidates, len(roundCandidates), false /* allowOverlaps */)
if err != nil {
// Return aggregated attestations, and attestations that couldn't be aggregated.
return append(aggregated, unaggregated...), err
}
// Exit earlier, if possible cover does not allow aggregation (less than two items).
if selectedKeys.Count() < 2 {
break
}
// Pad selected key indexes, as `roundCandidates` is a subset of `candidates`.
keys := padSelectedKeys(selectedKeys.BitIndices(), len(aggregated))
// Create aggregated attestation and update solution lists. Process aggregates only if they
// feature at least one unknown bit i.e. can increase the overall coverage.
xc, err := coveredBitsSoFar.XorCount(coverage)
if err != nil {
return nil, err
}
if xc > 0 {
aggIdx, err := aggregateAttestations(atts, keys, coverage)
if err != nil {
return append(aggregated, unaggregated...), err
}
// Unless we are already at the right position, swap aggregation and the first non-aggregated item.
idx0 := len(aggregated)
if idx0 < aggIdx {
atts[idx0], atts[aggIdx] = atts[aggIdx], atts[idx0]
candidates[idx0], candidates[aggIdx] = candidates[aggIdx], candidates[idx0]
}
// Expand to the newly created aggregate.
aggregated = atts[:idx0+1]
// Shift the starting point of the slice to the right.
unaggregated = unaggregated[1:]
// Update covered bits map.
if err := coveredBitsSoFar.NoAllocOr(coverage, coveredBitsSoFar); err != nil {
return nil, err
}
keys = keys[1:]
}
// Remove processed attestations.
rearrangeProcessedAttestations(atts, candidates, keys)
unaggregated = unaggregated[:len(unaggregated)-len(keys)]
}
filtered, err := attList(unaggregated).filterContained()
if err != nil {
return nil, err
}
return append(aggregated, filtered...), nil
}
// NewMaxCover returns initialized Maximum Coverage problem for attestations aggregation.
func NewMaxCover(atts []*ethpb.Attestation) *aggregation.MaxCoverProblem {
candidates := make([]*aggregation.MaxCoverCandidate, len(atts))
for i := 0; i < len(atts); i++ {
candidates[i] = aggregation.NewMaxCoverCandidate(i, &atts[i].AggregationBits)
}
return &aggregation.MaxCoverProblem{Candidates: candidates}
}
// aggregate returns list as an aggregated attestation.
func (al attList) aggregate(coverage bitfield.Bitlist) (*ethpb.Attestation, error) {
if len(al) < 2 {
return nil, errors.Wrap(ErrInvalidAttestationCount, "cannot aggregate")
}
signs := make([]bls.Signature, len(al))
for i := 0; i < len(al); i++ {
sig, err := signatureFromBytes(al[i].Signature)
if err != nil {
return nil, err
}
signs[i] = sig
}
return ðpb.Attestation{
AggregationBits: coverage,
Data: ethpb.CopyAttestationData(al[0].Data),
Signature: aggregateSignatures(signs).Marshal(),
}, nil
}
// padSelectedKeys adds additional value to every key.
func padSelectedKeys(keys []int, pad int) []int {
for i, key := range keys {
keys[i] = key + pad
}
return keys
}
// aggregateAttestations combines signatures of selected attestations into a single aggregate attestation, and
// pushes that aggregated attestation into the position of the first of selected attestations.
func aggregateAttestations(atts []*ethpb.Attestation, keys []int, coverage *bitfield.Bitlist64) (targetIdx int, err error) {
if len(keys) < 2 || atts == nil || len(atts) < 2 {
return targetIdx, errors.Wrap(ErrInvalidAttestationCount, "cannot aggregate")
}
if coverage == nil || coverage.Count() == 0 {
return targetIdx, errors.New("invalid or empty coverage")
}
var data *ethpb.AttestationData
signs := make([]bls.Signature, 0, len(keys))
for i, idx := range keys {
sig, err := signatureFromBytes(atts[idx].Signature)
if err != nil {
return targetIdx, err
}
signs = append(signs, sig)
if i == 0 {
data = ethpb.CopyAttestationData(atts[idx].Data)
targetIdx = idx
}
}
// Put aggregated attestation at a position of the first selected attestation.
atts[targetIdx] = ðpb.Attestation{
// Append size byte, which will be unnecessary on switch to Bitlist64.
AggregationBits: coverage.ToBitlist(),
Data: data,
Signature: aggregateSignatures(signs).Marshal(),
}
return
}
// rearrangeProcessedAttestations pushes processed attestations to the end of the slice, returning
// the number of items re-arranged (so that caller can cut the slice, and allow processed items to be
// garbage collected).
func rearrangeProcessedAttestations(atts []*ethpb.Attestation, candidates []*bitfield.Bitlist64, processedKeys []int) {
if atts == nil || candidates == nil || processedKeys == nil {
return
}
// Set all selected keys to nil.
for _, idx := range processedKeys {
atts[idx] = nil
candidates[idx] = nil
}
// Re-arrange nil items, move them to end of slice.
sort.Ints(processedKeys)
lastIdx := len(atts) - 1
for _, idx0 := range processedKeys {
// Make sure that nil items are swapped for non-nil items only.
for lastIdx > idx0 && atts[lastIdx] == nil {
lastIdx--
}
if idx0 == lastIdx {
break
}
atts[idx0], atts[lastIdx] = atts[lastIdx], atts[idx0]
candidates[idx0], candidates[lastIdx] = candidates[lastIdx], candidates[idx0]
}
}
// merge combines two attestation lists into one.
func (al attList) merge(al1 attList) attList {
return append(al, al1...)
}
// selectUsingKeys returns only items with specified keys.
func (al attList) selectUsingKeys(keys []int) attList {
filtered := make([]*ethpb.Attestation, len(keys))
for i, key := range keys {
filtered[i] = al[key]
}
return filtered
}
// selectComplementUsingKeys returns only items with keys that are NOT specified.
func (al attList) selectComplementUsingKeys(keys []int) attList {
foundInKeys := func(key int) bool {
for i := 0; i < len(keys); i++ {
if keys[i] == key {
keys[i] = keys[len(keys)-1]
keys = keys[:len(keys)-1]
return true
}
}
return false
}
filtered := al[:0]
for i, att := range al {
if !foundInKeys(i) {
filtered = append(filtered, att)
}
}
return filtered
}
// hasCoverage returns true if a given coverage is found in attestations list.
func (al attList) hasCoverage(coverage bitfield.Bitlist) (bool, error) {
for _, att := range al {
x, err := att.AggregationBits.Xor(coverage)
if err != nil {
return false, err
}
if x.Count() == 0 {
return true, nil
}
}
return false, nil
}
// filterContained removes attestations that are contained within other attestations.
func (al attList) filterContained() (attList, error) {
if len(al) < 2 {
return al, nil
}
sort.Slice(al, func(i, j int) bool {
return al[i].AggregationBits.Count() > al[j].AggregationBits.Count()
})
filtered := al[:0]
filtered = append(filtered, al[0])
for i := 1; i < len(al); i++ {
c, err := filtered[len(filtered)-1].AggregationBits.Contains(al[i].AggregationBits)
if err != nil {
return nil, err
}
if c {
continue
}
filtered = append(filtered, al[i])
}
return filtered, nil
}
// validate checks attestation list for validity (equal bitlength, non-nil bitlist etc).
func (al attList) validate() error {
if al == nil {
return errors.New("nil list")
}
if len(al) == 0 {
return errors.Wrap(aggregation.ErrInvalidMaxCoverProblem, "empty list")
}
if al[0].AggregationBits == nil || al[0].AggregationBits.Len() == 0 {
return errors.Wrap(aggregation.ErrInvalidMaxCoverProblem, "bitlist cannot be nil or empty")
}
for i := 1; i < len(al); i++ {
if al[i].AggregationBits == nil || al[i].AggregationBits.Len() == 0 {
return errors.Wrap(aggregation.ErrInvalidMaxCoverProblem, "bitlist cannot be nil or empty")
}
}
return nil
}