Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify sampler interface #3026

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions network/ip_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ func (i *ipTracker) GetGossipableIPs(

uniform.Initialize(uint64(len(i.gossipableIPs)))
for len(ips) < maxNumIPs {
index, err := uniform.Next()
if err != nil {
index, hasNext := uniform.Next()
if !hasNext {
return ips
}

Expand Down
4 changes: 2 additions & 2 deletions network/p2p/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ func (v *Validators) Sample(ctx context.Context, limit int) []ids.NodeID {

uniform.Initialize(uint64(len(v.validatorList)))
for len(sampled) < limit {
i, err := uniform.Next()
if err != nil {
i, hasNext := uniform.Next()
if !hasNext {
break
}

Expand Down
4 changes: 2 additions & 2 deletions network/peer/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ func (s *peerSet) Sample(n int, precondition func(Peer) bool) []Peer {

peers := make([]Peer, 0, n)
for len(peers) < n {
index, err := sampler.Next()
if err != nil {
index, hasNext := sampler.Next()
if !hasNext {
// We have run out of peers to attempt to sample.
break
}
Expand Down
10 changes: 7 additions & 3 deletions snow/consensus/snowman/bootstrapper/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
package bootstrapper

import (
"errors"

"github.com/ava-labs/avalanchego/utils/math"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"
)

var errUnexpectedSamplerFailure = errors.New("unexpected sampler failure")

// Sample keys from [elements] uniformly by weight without replacement. The
// returned set will have size less than or equal to [maxSize]. This function
// will error if the sum of all weights overflows.
Expand Down Expand Up @@ -36,9 +40,9 @@ func Sample[T comparable](elements map[T]uint64, maxSize int) (set.Set[T], error
}

maxSize = int(min(uint64(maxSize), totalWeight))
indices, err := sampler.Sample(maxSize)
if err != nil {
return nil, err
indices, ok := sampler.Sample(maxSize)
if !ok {
return nil, errUnexpectedSamplerFailure
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen because of the maxSize calculation above.

}

sampledElements := set.NewSet[T](maxSize)
Expand Down
3 changes: 1 addition & 2 deletions snow/validators/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"

safemath "github.com/ava-labs/avalanchego/utils/math"
Expand Down Expand Up @@ -396,7 +395,7 @@ func TestSample(t *testing.T) {
require.Equal([]ids.NodeID{nodeID0}, sampled)

_, err = m.Sample(subnetID, 2)
require.ErrorIs(err, sampler.ErrOutOfRange)
require.ErrorIs(err, errInsufficientWeight)

nodeID1 := ids.GenerateTestNodeID()
require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, math.MaxInt64-1))
Expand Down
7 changes: 4 additions & 3 deletions snow/validators/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
errDuplicateValidator = errors.New("duplicate validator")
errMissingValidator = errors.New("missing validator")
errTotalWeightNotUint64 = errors.New("total weight is not a uint64")
errInsufficientWeight = errors.New("insufficient weight")
)

// newSet returns a new, empty set of validators.
Expand Down Expand Up @@ -257,9 +258,9 @@ func (s *vdrSet) sample(size int) ([]ids.NodeID, error) {
s.samplerInitialized = true
}

indices, err := s.sampler.Sample(size)
if err != nil {
return nil, err
indices, ok := s.sampler.Sample(size)
if !ok {
return nil, errInsufficientWeight
}

list := make([]ids.NodeID, size)
Expand Down
3 changes: 1 addition & 2 deletions snow/validators/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"

safemath "github.com/ava-labs/avalanchego/utils/math"
Expand Down Expand Up @@ -343,7 +342,7 @@ func TestSetSample(t *testing.T) {
require.Equal([]ids.NodeID{nodeID0}, sampled)

_, err = s.Sample(2)
require.ErrorIs(err, sampler.ErrOutOfRange)
require.ErrorIs(err, errInsufficientWeight)

nodeID1 := ids.GenerateTestNodeID()
require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1))
Expand Down
6 changes: 3 additions & 3 deletions utils/sampler/uniform.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ package sampler
type Uniform interface {
Initialize(sampleRange uint64)
// Sample returns length numbers in the range [0,sampleRange). If there
// aren't enough numbers in the range, an error is returned. If length is
// aren't enough numbers in the range, false returned. If length is
// negative the implementation may panic.
Sample(length int) ([]uint64, error)
Sample(length int) ([]uint64, bool)

Next() (uint64, bool)
Reset()
Next() (uint64, error)
}

// NewUniform returns a new sampler
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/uniform_best.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ samplerLoop:

start := s.clock.Time()
for i := 0; i < s.benchmarkIterations; i++ {
if _, err := sampler.Sample(sampleSize); err != nil {
if _, ok := sampler.Sample(sampleSize); !ok {
continue samplerLoop
}
}
Expand Down
16 changes: 8 additions & 8 deletions utils/sampler/uniform_replacer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,34 @@ func (s *uniformReplacer) Initialize(length uint64) {
s.drawsCount = 0
}

func (s *uniformReplacer) Sample(count int) ([]uint64, error) {
func (s *uniformReplacer) Sample(count int) ([]uint64, bool) {
s.Reset()

results := make([]uint64, count)
for i := 0; i < count; i++ {
ret, err := s.Next()
if err != nil {
return nil, err
ret, hasNext := s.Next()
if !hasNext {
return nil, false
}
results[i] = ret
}
return results, nil
return results, true
}

func (s *uniformReplacer) Reset() {
clear(s.drawn)
s.drawsCount = 0
}

func (s *uniformReplacer) Next() (uint64, error) {
func (s *uniformReplacer) Next() (uint64, bool) {
if s.drawsCount >= s.length {
return 0, ErrOutOfRange
return 0, false
}

draw := s.rng.Uint64Inclusive(s.length-1-s.drawsCount) + s.drawsCount
ret := s.drawn.get(draw, draw)
s.drawn[draw] = s.drawn.get(s.drawsCount, s.drawsCount)
s.drawsCount++

return ret, nil
return ret, true
}
16 changes: 8 additions & 8 deletions utils/sampler/uniform_resample.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,28 @@ func (s *uniformResample) Initialize(length uint64) {
s.drawn = make(map[uint64]struct{})
}

func (s *uniformResample) Sample(count int) ([]uint64, error) {
func (s *uniformResample) Sample(count int) ([]uint64, bool) {
s.Reset()

results := make([]uint64, count)
for i := 0; i < count; i++ {
ret, err := s.Next()
if err != nil {
return nil, err
ret, hasNext := s.Next()
if !hasNext {
return nil, false
}
results[i] = ret
}
return results, nil
return results, true
}

func (s *uniformResample) Reset() {
clear(s.drawn)
}

func (s *uniformResample) Next() (uint64, error) {
func (s *uniformResample) Next() (uint64, bool) {
i := uint64(len(s.drawn))
if i >= s.length {
return 0, ErrOutOfRange
return 0, false
}

for {
Expand All @@ -53,6 +53,6 @@ func (s *uniformResample) Next() (uint64, error) {
continue
}
s.drawn[draw] = struct{}{}
return draw, nil
return draw, true
}
}
32 changes: 16 additions & 16 deletions utils/sampler/uniform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ func UniformInitializeMaxUint64Test(t *testing.T, s Uniform) {
s.Initialize(math.MaxUint64)

for {
val, err := s.Next()
require.NoError(t, err)
val, hasNext := s.Next()
require.True(t, hasNext)

if val > math.MaxInt64 {
break
Expand All @@ -95,17 +95,17 @@ func UniformInitializeMaxUint64Test(t *testing.T, s Uniform) {
func UniformOutOfRangeTest(t *testing.T, s Uniform) {
s.Initialize(0)

_, err := s.Sample(1)
require.ErrorIs(t, err, ErrOutOfRange)
_, ok := s.Sample(1)
require.False(t, ok)
}

func UniformEmptyTest(t *testing.T, s Uniform) {
require := require.New(t)

s.Initialize(1)

val, err := s.Sample(0)
require.NoError(err)
val, ok := s.Sample(0)
require.True(ok)
require.Empty(val)
}

Expand All @@ -114,8 +114,8 @@ func UniformSingletonTest(t *testing.T, s Uniform) {

s.Initialize(1)

val, err := s.Sample(1)
require.NoError(err)
val, ok := s.Sample(1)
require.True(ok)
require.Equal([]uint64{0}, val)
}

Expand All @@ -124,8 +124,8 @@ func UniformDistributionTest(t *testing.T, s Uniform) {

s.Initialize(3)

val, err := s.Sample(3)
require.NoError(err)
val, ok := s.Sample(3)
require.True(ok)

slices.Sort(val)
require.Equal([]uint64{0, 1, 2}, val)
Expand All @@ -134,8 +134,8 @@ func UniformDistributionTest(t *testing.T, s Uniform) {
func UniformOverSampleTest(t *testing.T, s Uniform) {
s.Initialize(3)

_, err := s.Sample(4)
require.ErrorIs(t, err, ErrOutOfRange)
_, ok := s.Sample(4)
require.False(t, ok)
}

func UniformLazilySample(t *testing.T, s Uniform) {
Expand All @@ -146,15 +146,15 @@ func UniformLazilySample(t *testing.T, s Uniform) {
for j := 0; j < 2; j++ {
sampled := map[uint64]bool{}
for i := 0; i < 3; i++ {
val, err := s.Next()
require.NoError(err)
val, hasNext := s.Next()
require.True(hasNext)
require.False(sampled[val])

sampled[val] = true
}

_, err := s.Next()
require.ErrorIs(err, ErrOutOfRange)
_, hasNext := s.Next()
require.False(hasNext)

s.Reset()
}
Expand Down
6 changes: 1 addition & 5 deletions utils/sampler/weighted.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@

package sampler

import "errors"

var ErrOutOfRange = errors.New("out of range")

// Weighted defines how to sample a specified valued based on a provided
// weighted distribution
type Weighted interface {
Initialize(weights []uint64) error
Sample(sampleValue uint64) (int, error)
Sample(sampleValue uint64) (int, bool)
}

// NewWeighted returns a new sampler
Expand Down
6 changes: 3 additions & 3 deletions utils/sampler/weighted_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ func (s *weightedArray) Initialize(weights []uint64) error {
return nil
}

func (s *weightedArray) Sample(value uint64) (int, error) {
func (s *weightedArray) Sample(value uint64) (int, bool) {
if len(s.arr) == 0 || s.arr[len(s.arr)-1].cumulativeWeight <= value {
return 0, ErrOutOfRange
return 0, false
}
minIndex := 0
maxIndex := len(s.arr) - 1
Expand All @@ -98,7 +98,7 @@ func (s *weightedArray) Sample(value uint64) (int, error) {
currentElem := s.arr[index]
currentWeight := currentElem.cumulativeWeight
if previousWeight <= value && value < currentWeight {
return currentElem.index, nil
return currentElem.index, true
}

if value < previousWeight {
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_best.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ samplerLoop:

start := s.clock.Time()
for _, sample := range samples {
if _, err := sampler.Sample(sample); err != nil {
if _, ok := sampler.Sample(sample); !ok {
continue samplerLoop
}
}
Expand Down
6 changes: 3 additions & 3 deletions utils/sampler/weighted_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ func (s *weightedHeap) Initialize(weights []uint64) error {
return nil
}

func (s *weightedHeap) Sample(value uint64) (int, error) {
func (s *weightedHeap) Sample(value uint64) (int, bool) {
if len(s.heap) == 0 || s.heap[0].cumulativeWeight <= value {
return 0, ErrOutOfRange
return 0, false
}

index := 0
for {
currentElement := s.heap[index]
currentWeight := currentElement.weight
if value < currentWeight {
return currentElement.index, nil
return currentElement.index, true
}
value -= currentWeight

Expand Down
Loading
Loading