diff --git a/node/overridden_manager.go b/node/overridden_manager.go index d34860c2045c..4d4b881dcf93 100644 --- a/node/overridden_manager.go +++ b/node/overridden_manager.go @@ -75,6 +75,10 @@ func (o *overriddenManager) Sample(_ ids.ID, size int) ([]ids.NodeID, error) { return o.manager.Sample(o.subnetID, size) } +func (o *overriddenManager) GetAllMaps() map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput { + return o.manager.GetAllMaps() +} + func (o *overriddenManager) GetMap(ids.ID) map[ids.NodeID]*validators.GetValidatorOutput { return o.manager.GetMap(o.subnetID) } diff --git a/snow/validators/manager.go b/snow/validators/manager.go index 8616f850817d..3d9ec6a61515 100644 --- a/snow/validators/manager.go +++ b/snow/validators/manager.go @@ -94,6 +94,9 @@ type Manager interface { // If sampling the requested size isn't possible, an error will be returned. Sample(subnetID ids.ID, size int) ([]ids.NodeID, error) + // GetAllMaps returns a copy of all validators of all subnets + GetAllMaps() map[ids.ID]map[ids.NodeID]*GetValidatorOutput + // Map of the validators in this subnet GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput @@ -257,6 +260,17 @@ func (m *manager) Sample(subnetID ids.ID, size int) ([]ids.NodeID, error) { return set.Sample(size) } +func (m *manager) GetAllMaps() map[ids.ID]map[ids.NodeID]*GetValidatorOutput { + m.lock.RLock() + defer m.lock.RUnlock() + + set := make(map[ids.ID]map[ids.NodeID]*GetValidatorOutput, len(m.subnetToVdrs)) + for subnetID, vdrs := range m.subnetToVdrs { + set[subnetID] = vdrs.Map() + } + return set +} + func (m *manager) GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput { m.lock.RLock() set, exists := m.subnetToVdrs[subnetID] diff --git a/snow/validators/manager_test.go b/snow/validators/manager_test.go index 65dcceeb65a3..778c9bcd9658 100644 --- a/snow/validators/manager_test.go +++ b/snow/validators/manager_test.go @@ -375,6 +375,100 @@ func TestGetMap(t *testing.T) { require.Empty(m.GetMap(subnetID)) } +func TestGetAllMaps(t *testing.T) { + require := require.New(t) + + m := NewManager() + subnetID0 := ids.GenerateTestID() + subnetID1 := ids.GenerateTestID() + + maps := m.GetAllMaps() + require.Empty(maps) + + sk, err := localsigner.New() + require.NoError(err) + + pk := sk.PublicKey() + nodeID0 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID0, nodeID0, pk, ids.Empty, 2)) + + maps = m.GetAllMaps() + require.Len(maps, 1) + map0, ok := maps[subnetID0] + require.True(ok) + require.Len(map0, 1) + require.Contains(map0, nodeID0) + + node0 := map0[nodeID0] + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(2), node0.Weight) + + nodeID1 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID1, nodeID1, nil, ids.Empty, 1)) + + maps = m.GetAllMaps() + require.Len(maps, 2) + map0, ok = maps[subnetID0] + require.True(ok) + require.Contains(map0, nodeID0) + map1, ok := maps[subnetID1] + require.True(ok) + require.Contains(map1, nodeID1) + + node0 = map0[nodeID0] + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(2), node0.Weight) + + node1 := map1[nodeID1] + require.Equal(nodeID1, node1.NodeID) + require.Nil(node1.PublicKey) + require.Equal(uint64(1), node1.Weight) + + require.NoError(m.RemoveWeight(subnetID0, nodeID0, 1)) + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(2), node0.Weight) + + maps = m.GetAllMaps() + require.Len(maps, 2) + map0, ok = maps[subnetID0] + require.True(ok) + map1, ok = maps[subnetID1] + require.True(ok) + require.Contains(map0, nodeID0) + require.Contains(map1, nodeID1) + + node0 = map0[nodeID0] + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(1), node0.Weight) + + node1 = map1[nodeID1] + require.Equal(nodeID1, node1.NodeID) + require.Nil(node1.PublicKey) + require.Equal(uint64(1), node1.Weight) + + require.NoError(m.RemoveWeight(subnetID0, nodeID0, 1)) + + maps = m.GetAllMaps() + require.Len(maps, 1) + require.NotContains(maps, subnetID0) + map1, ok = maps[subnetID1] + require.True(ok) + require.Contains(map1, nodeID1) + + node1 = map1[nodeID1] + require.Equal(nodeID1, node1.NodeID) + require.Nil(node1.PublicKey) + require.Equal(uint64(1), node1.Weight) + + require.NoError(m.RemoveWeight(subnetID1, nodeID1, 1)) + + require.Empty(m.GetAllMaps()) +} + func TestWeight(t *testing.T) { require := require.New(t) diff --git a/vms/platformvm/state/disk_staker_diff_iterator.go b/vms/platformvm/state/disk_staker_diff_iterator.go index dc55806b4c04..2bb17e28315e 100644 --- a/vms/platformvm/state/disk_staker_diff_iterator.go +++ b/vms/platformvm/state/disk_staker_diff_iterator.go @@ -28,18 +28,28 @@ var ( errUnexpectedWeightValueLength = fmt.Errorf("expected weight value length %d", weightValueLength) ) -// marshalStartDiffKey is used to determine the starting key when iterating. +// marshalStartDiffKeyBySubnetID is used to determine the starting key when iterating. // -// Invariant: the result is a prefix of [marshalDiffKey] when called with the +// Invariant: the result is a prefix of [marshalDiffKeyBySubnetID] when called with the // same arguments. -func marshalStartDiffKey(subnetID ids.ID, height uint64) []byte { +func marshalStartDiffKeyBySubnetID(subnetID ids.ID, height uint64) []byte { key := make([]byte, startDiffKeyLength) copy(key, subnetID[:]) packIterableHeight(key[ids.IDLen:], height) return key } -func marshalDiffKey(subnetID ids.ID, height uint64, nodeID ids.NodeID) []byte { +// marshalStartDiffKeyByHeight is used to determine the starting key when iterating. +// +// Invariant: the result is a prefix of [marshalDiffKeyByHeight] when called with the +// same arguments. +func marshalStartDiffKeyByHeight(height uint64) []byte { + key := make([]byte, database.Uint64Size) + packIterableHeight(key, height) + return key +} + +func marshalDiffKeyBySubnetID(subnetID ids.ID, height uint64, nodeID ids.NodeID) []byte { key := make([]byte, diffKeyLength) copy(key, subnetID[:]) packIterableHeight(key[ids.IDLen:], height) @@ -47,7 +57,15 @@ func marshalDiffKey(subnetID ids.ID, height uint64, nodeID ids.NodeID) []byte { return key } -func unmarshalDiffKey(key []byte) (ids.ID, uint64, ids.NodeID, error) { +func marshalDiffKeyByHeight(height uint64, subnetID ids.ID, nodeID ids.NodeID) []byte { + key := make([]byte, diffKeyLength) + packIterableHeight(key, height) + copy(key[database.Uint64Size:], subnetID[:]) + copy(key[diffKeyNodeIDOffset:], nodeID.Bytes()) + return key +} + +func unmarshalDiffKeyBySubnetID(key []byte) (ids.ID, uint64, ids.NodeID, error) { if len(key) != diffKeyLength { return ids.Empty, 0, ids.EmptyNodeID, errUnexpectedDiffKeyLength } @@ -61,6 +79,20 @@ func unmarshalDiffKey(key []byte) (ids.ID, uint64, ids.NodeID, error) { return subnetID, height, nodeID, nil } +func unmarshalDiffKeyByHeight(key []byte) (uint64, ids.ID, ids.NodeID, error) { + if len(key) != diffKeyLength { + return 0, ids.Empty, ids.EmptyNodeID, errUnexpectedDiffKeyLength + } + var ( + subnetID ids.ID + nodeID ids.NodeID + ) + height := unpackIterableHeight(key) + copy(subnetID[:], key[database.Uint64Size:]) + copy(nodeID[:], key[diffKeyNodeIDOffset:]) + return height, subnetID, nodeID, nil +} + func marshalWeightDiff(diff *ValidatorWeightDiff) []byte { value := make([]byte, weightValueLength) if diff.Decrease { diff --git a/vms/platformvm/state/disk_staker_diff_iterator_test.go b/vms/platformvm/state/disk_staker_diff_iterator_test.go index e47932d6c0ad..f4c68b39e224 100644 --- a/vms/platformvm/state/disk_staker_diff_iterator_test.go +++ b/vms/platformvm/state/disk_staker_diff_iterator_test.go @@ -13,7 +13,7 @@ import ( "github.com/ava-labs/avalanchego/ids" ) -func FuzzMarshalDiffKey(f *testing.F) { +func FuzzMarshalDiffKeyBySubnetID(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { require := require.New(t) @@ -25,8 +25,8 @@ func FuzzMarshalDiffKey(f *testing.F) { fz := fuzzer.NewFuzzer(data) fz.Fill(&subnetID, &height, &nodeID) - key := marshalDiffKey(subnetID, height, nodeID) - parsedSubnetID, parsedHeight, parsedNodeID, err := unmarshalDiffKey(key) + key := marshalDiffKeyBySubnetID(subnetID, height, nodeID) + parsedSubnetID, parsedHeight, parsedNodeID, err := unmarshalDiffKeyBySubnetID(key) require.NoError(err) require.Equal(subnetID, parsedSubnetID) require.Equal(height, parsedHeight) @@ -34,22 +34,58 @@ func FuzzMarshalDiffKey(f *testing.F) { }) } -func FuzzUnmarshalDiffKey(f *testing.F) { +func FuzzMarshalDiffKeyByHeight(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + require := require.New(t) + + var ( + subnetID ids.ID + height uint64 + nodeID ids.NodeID + ) + fz := fuzzer.NewFuzzer(data) + fz.Fill(&height, &subnetID, &nodeID) + + key := marshalDiffKeyByHeight(height, subnetID, nodeID) + parsedHeight, parsedSubnetID, parsedNodeID, err := unmarshalDiffKeyByHeight(key) + require.NoError(err) + require.Equal(subnetID, parsedSubnetID) + require.Equal(height, parsedHeight) + require.Equal(nodeID, parsedNodeID) + }) +} + +func FuzzUnmarshalDiffKeyBySubnetID(f *testing.F) { f.Fuzz(func(t *testing.T, key []byte) { require := require.New(t) - subnetID, height, nodeID, err := unmarshalDiffKey(key) + subnetID, height, nodeID, err := unmarshalDiffKeyBySubnetID(key) if err != nil { require.ErrorIs(err, errUnexpectedDiffKeyLength) return } - formattedKey := marshalDiffKey(subnetID, height, nodeID) + formattedKey := marshalDiffKeyBySubnetID(subnetID, height, nodeID) require.Equal(key, formattedKey) }) } -func TestDiffIteration(t *testing.T) { +func FuzzUnmarshalDiffKeyByHeight(f *testing.F) { + f.Fuzz(func(t *testing.T, key []byte) { + require := require.New(t) + + subnetID, height, nodeID, err := unmarshalDiffKeyByHeight(key) + if err != nil { + require.ErrorIs(err, errUnexpectedDiffKeyLength) + return + } + + formattedKey := marshalDiffKeyByHeight(subnetID, height, nodeID) + require.Equal(key, formattedKey) + }) +} + +func TestDiffIterationBySubnetID(t *testing.T) { require := require.New(t) db := memdb.New() @@ -60,13 +96,13 @@ func TestDiffIteration(t *testing.T) { nodeID0 := ids.BuildTestNodeID([]byte{0x00}) nodeID1 := ids.BuildTestNodeID([]byte{0x01}) - subnetID0Height0NodeID0 := marshalDiffKey(subnetID0, 0, nodeID0) - subnetID0Height1NodeID0 := marshalDiffKey(subnetID0, 1, nodeID0) - subnetID0Height1NodeID1 := marshalDiffKey(subnetID0, 1, nodeID1) + subnetID0Height0NodeID0 := marshalDiffKeyBySubnetID(subnetID0, 0, nodeID0) + subnetID0Height1NodeID0 := marshalDiffKeyBySubnetID(subnetID0, 1, nodeID0) + subnetID0Height1NodeID1 := marshalDiffKeyBySubnetID(subnetID0, 1, nodeID1) - subnetID1Height0NodeID0 := marshalDiffKey(subnetID1, 0, nodeID0) - subnetID1Height1NodeID0 := marshalDiffKey(subnetID1, 1, nodeID0) - subnetID1Height1NodeID1 := marshalDiffKey(subnetID1, 1, nodeID1) + subnetID1Height0NodeID0 := marshalDiffKeyBySubnetID(subnetID1, 0, nodeID0) + subnetID1Height1NodeID0 := marshalDiffKeyBySubnetID(subnetID1, 1, nodeID0) + subnetID1Height1NodeID1 := marshalDiffKeyBySubnetID(subnetID1, 1, nodeID1) require.NoError(db.Put(subnetID0Height0NodeID0, nil)) require.NoError(db.Put(subnetID0Height1NodeID0, nil)) @@ -76,7 +112,7 @@ func TestDiffIteration(t *testing.T) { require.NoError(db.Put(subnetID1Height1NodeID1, nil)) { - it := db.NewIteratorWithStartAndPrefix(marshalStartDiffKey(subnetID0, 0), subnetID0[:]) + it := db.NewIteratorWithStartAndPrefix(marshalStartDiffKeyBySubnetID(subnetID0, 0), subnetID0[:]) defer it.Release() expectedKeys := [][]byte{ @@ -91,7 +127,7 @@ func TestDiffIteration(t *testing.T) { } { - it := db.NewIteratorWithStartAndPrefix(marshalStartDiffKey(subnetID0, 1), subnetID0[:]) + it := db.NewIteratorWithStartAndPrefix(marshalStartDiffKeyBySubnetID(subnetID0, 1), subnetID0[:]) defer it.Release() expectedKeys := [][]byte{ @@ -107,3 +143,66 @@ func TestDiffIteration(t *testing.T) { require.NoError(it.Error()) } } + +func TestDiffIterationByHeight(t *testing.T) { + require := require.New(t) + + db := memdb.New() + + subnetID0 := ids.ID{0x00} + subnetID1 := ids.ID{0x01} + + nodeID0 := ids.BuildTestNodeID([]byte{0x00}) + nodeID1 := ids.BuildTestNodeID([]byte{0x01}) + + height0SubnetID0NodeID0 := marshalDiffKeyByHeight(0, subnetID0, nodeID0) + height1SubnetID0NodeID0 := marshalDiffKeyByHeight(1, subnetID0, nodeID0) + height1SubnetID0NodeID1 := marshalDiffKeyByHeight(1, subnetID0, nodeID1) + + height0SubnetID1NodeID0 := marshalDiffKeyByHeight(0, subnetID1, nodeID0) + height1SubnetID1NodeID0 := marshalDiffKeyByHeight(1, subnetID1, nodeID0) + height1SubnetID1NodeID1 := marshalDiffKeyByHeight(1, subnetID1, nodeID1) + + require.NoError(db.Put(height0SubnetID0NodeID0, nil)) + require.NoError(db.Put(height1SubnetID0NodeID0, nil)) + require.NoError(db.Put(height1SubnetID0NodeID1, nil)) + require.NoError(db.Put(height0SubnetID1NodeID0, nil)) + require.NoError(db.Put(height1SubnetID1NodeID0, nil)) + require.NoError(db.Put(height1SubnetID1NodeID1, nil)) + + { + it := db.NewIteratorWithStart(marshalStartDiffKeyByHeight(0)) + defer it.Release() + + expectedKeys := [][]byte{ + height0SubnetID0NodeID0, + height0SubnetID1NodeID0, + } + for _, expectedKey := range expectedKeys { + require.True(it.Next()) + require.Equal(expectedKey, it.Key()) + } + require.False(it.Next()) + require.NoError(it.Error()) + } + + { + it := db.NewIteratorWithStart(marshalStartDiffKeyByHeight(1)) + defer it.Release() + + expectedKeys := [][]byte{ + height1SubnetID0NodeID0, + height1SubnetID0NodeID1, + height1SubnetID1NodeID0, + height1SubnetID1NodeID1, + height0SubnetID0NodeID0, + height0SubnetID1NodeID0, + } + for _, expectedKey := range expectedKeys { + require.True(it.Next()) + require.Equal(expectedKey, it.Key()) + } + require.False(it.Next()) + require.NoError(it.Error()) + } +} diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index 1dd577d1ebdf..ed7a447d63c9 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -149,6 +149,34 @@ func (mr *MockStateMockRecorder) AddUTXO(utxo any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUTXO", reflect.TypeOf((*MockState)(nil).AddUTXO), utxo) } +// ApplyAllValidatorPublicKeyDiffs mocks base method. +func (m *MockState) ApplyAllValidatorPublicKeyDiffs(ctx context.Context, validators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyAllValidatorPublicKeyDiffs", ctx, validators, startHeight, endHeight) + ret0, _ := ret[0].(error) + return ret0 +} + +// ApplyAllValidatorPublicKeyDiffs indicates an expected call of ApplyAllValidatorPublicKeyDiffs. +func (mr *MockStateMockRecorder) ApplyAllValidatorPublicKeyDiffs(ctx, validators, startHeight, endHeight any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyAllValidatorPublicKeyDiffs", reflect.TypeOf((*MockState)(nil).ApplyAllValidatorPublicKeyDiffs), ctx, validators, startHeight, endHeight) +} + +// ApplyAllValidatorWeightDiffs mocks base method. +func (m *MockState) ApplyAllValidatorWeightDiffs(ctx context.Context, validators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyAllValidatorWeightDiffs", ctx, validators, startHeight, endHeight) + ret0, _ := ret[0].(error) + return ret0 +} + +// ApplyAllValidatorWeightDiffs indicates an expected call of ApplyAllValidatorWeightDiffs. +func (mr *MockStateMockRecorder) ApplyAllValidatorWeightDiffs(ctx, validators, startHeight, endHeight any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyAllValidatorWeightDiffs", reflect.TypeOf((*MockState)(nil).ApplyAllValidatorWeightDiffs), ctx, validators, startHeight, endHeight) +} + // ApplyValidatorPublicKeyDiffs mocks base method. func (m *MockState) ApplyValidatorPublicKeyDiffs(ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64, subnetID ids.ID) error { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index e05d35fae42e..a12d472b6c47 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -67,33 +67,35 @@ var ( errIsNotSubnet = errors.New("is not a subnet") errMissingPrimaryNetworkValidator = errors.New("missing primary network validator") - BlockIDPrefix = []byte("blockID") - BlockPrefix = []byte("block") - ValidatorsPrefix = []byte("validators") - CurrentPrefix = []byte("current") - PendingPrefix = []byte("pending") - ValidatorPrefix = []byte("validator") - DelegatorPrefix = []byte("delegator") - SubnetValidatorPrefix = []byte("subnetValidator") - SubnetDelegatorPrefix = []byte("subnetDelegator") - ValidatorWeightDiffsPrefix = []byte("flatValidatorDiffs") - ValidatorPublicKeyDiffsPrefix = []byte("flatPublicKeyDiffs") - TxPrefix = []byte("tx") - RewardUTXOsPrefix = []byte("rewardUTXOs") - UTXOPrefix = []byte("utxo") - SubnetPrefix = []byte("subnet") - SubnetOwnerPrefix = []byte("subnetOwner") - SubnetToL1ConversionPrefix = []byte("subnetToL1Conversion") - TransformedSubnetPrefix = []byte("transformedSubnet") - SupplyPrefix = []byte("supply") - ChainPrefix = []byte("chain") - ExpiryReplayProtectionPrefix = []byte("expiryReplayProtection") - L1Prefix = []byte("l1") - WeightsPrefix = []byte("weights") - SubnetIDNodeIDPrefix = []byte("subnetIDNodeID") - ActivePrefix = []byte("active") - InactivePrefix = []byte("inactive") - SingletonPrefix = []byte("singleton") + BlockIDPrefix = []byte("blockID") + BlockPrefix = []byte("block") + ValidatorsPrefix = []byte("validators") + CurrentPrefix = []byte("current") + PendingPrefix = []byte("pending") + ValidatorPrefix = []byte("validator") + DelegatorPrefix = []byte("delegator") + SubnetValidatorPrefix = []byte("subnetValidator") + SubnetDelegatorPrefix = []byte("subnetDelegator") + ValidatorWeightDiffsBySubnetIDPrefix = []byte("flatValidatorDiffs") + ValidatorWeightDiffsByHeightPrefix = []byte("flatValidatorDiffsByHeight") + ValidatorPublicKeyDiffsBySubnetIDPrefix = []byte("flatPublicKeyDiffs") + ValidatorPublicKeyDiffsByHeightPrefix = []byte("flatPublicKeyDiffsByHeight") + TxPrefix = []byte("tx") + RewardUTXOsPrefix = []byte("rewardUTXOs") + UTXOPrefix = []byte("utxo") + SubnetPrefix = []byte("subnet") + SubnetOwnerPrefix = []byte("subnetOwner") + SubnetToL1ConversionPrefix = []byte("subnetToL1Conversion") + TransformedSubnetPrefix = []byte("transformedSubnet") + SupplyPrefix = []byte("supply") + ChainPrefix = []byte("chain") + ExpiryReplayProtectionPrefix = []byte("expiryReplayProtection") + L1Prefix = []byte("l1") + WeightsPrefix = []byte("weights") + SubnetIDNodeIDPrefix = []byte("subnetIDNodeID") + ActivePrefix = []byte("active") + InactivePrefix = []byte("inactive") + SingletonPrefix = []byte("singleton") TimestampKey = []byte("timestamp") FeeStateKey = []byte("fee state") @@ -209,6 +211,42 @@ type State interface { subnetID ids.ID, ) error + // ApplyAllValidatorWeightDiffs iterates from [startHeight] towards the genesis + // block until it has applied all of the diffs up to and including + // [endHeight]. Applying the diffs modifies [validators]. + // + // Invariant: If attempting to generate the validator set for + // [endHeight - 1], [validators] must initially contain the validator + // weights for [startHeight]. + // + // Note: Because this function iterates towards the genesis, [startHeight] + // will typically be greater than or equal to [endHeight]. If [startHeight] + // is less than [endHeight], no diffs will be applied. + ApplyAllValidatorWeightDiffs( + ctx context.Context, + validators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, + startHeight uint64, + endHeight uint64, + ) error + + // ApplyAllValidatorPublicKeyDiffs iterates from [startHeight] towards the + // genesis block until it has applied all of the diffs up to and including + // [endHeight]. Applying the diffs modifies [validators]. + // + // Invariant: If attempting to generate the validator set for + // [endHeight - 1], [validators] must initially contain the validator + // weights for [startHeight]. + // + // Note: Because this function iterates towards the genesis, [startHeight] + // will typically be greater than or equal to [endHeight]. If [startHeight] + // is less than [endHeight], no diffs will be applied. + ApplyAllValidatorPublicKeyDiffs( + ctx context.Context, + validators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, + startHeight uint64, + endHeight uint64, + ) error + SetHeight(height uint64) // GetCurrentValidators returns subnet and L1 validators for the given @@ -290,10 +328,14 @@ type stateBlk struct { * | | | '-- validationID -> l1Validator * | | '-. inactive * | | '-- validationID -> l1Validator - * | |-. weight diffs + * | |-. weight diffs by subnet ID * | | '-- subnet+height+nodeID -> weightChange - * | '-. pub key diffs + * | |-. weight diffs by height + * | | '-- height+subnet+nodeID -> weightChange + * | '-. pub key diffs by subnet ID * | '-- subnet+height+nodeID -> uncompressed public key or nil + * | '-. pub key diffs by height + * | '-- height+subnet+nodeID -> uncompressed public key or nil * |-. blockIDs * | '-- height -> blockID * |-. blocks @@ -389,8 +431,10 @@ type state struct { pendingSubnetDelegatorBaseDB database.Database pendingSubnetDelegatorList linkeddb.LinkedDB - validatorWeightDiffsDB database.Database - validatorPublicKeyDiffsDB database.Database + validatorWeightDiffsBySubnetIDDB database.Database + validatorWeightDiffsByHeightDB database.Database + validatorPublicKeyDiffsBySubnetIDDB database.Database + validatorPublicKeyDiffsByHeightDB database.Database addedTxs map[ids.ID]*txAndStatus // map of txID -> {*txs.Tx, Status} txCache cache.Cacher[ids.ID, *txAndStatus] // txID -> {*txs.Tx, Status}; if the entry is nil, it is not in the database @@ -571,8 +615,10 @@ func New( l1ValidatorsDB := prefixdb.New(L1Prefix, validatorsDB) - validatorWeightDiffsDB := prefixdb.New(ValidatorWeightDiffsPrefix, validatorsDB) - validatorPublicKeyDiffsDB := prefixdb.New(ValidatorPublicKeyDiffsPrefix, validatorsDB) + validatorWeightDiffsBySubnetIDDB := prefixdb.New(ValidatorWeightDiffsBySubnetIDPrefix, validatorsDB) + validatorWeightDiffsByHeightDB := prefixdb.New(ValidatorWeightDiffsByHeightPrefix, validatorsDB) + validatorPublicKeyDiffsBySubnetIDDB := prefixdb.New(ValidatorPublicKeyDiffsBySubnetIDPrefix, validatorsDB) + validatorPublicKeyDiffsByHeightDB := prefixdb.New(ValidatorPublicKeyDiffsByHeightPrefix, validatorsDB) weightsCache, err := metercacher.New( "l1_validator_weights_cache", @@ -743,27 +789,29 @@ func New( currentStakers: newBaseStakers(), pendingStakers: newBaseStakers(), - validatorsDB: validatorsDB, - currentValidatorsDB: currentValidatorsDB, - currentValidatorBaseDB: currentValidatorBaseDB, - currentValidatorList: linkeddb.NewDefault(currentValidatorBaseDB), - currentDelegatorBaseDB: currentDelegatorBaseDB, - currentDelegatorList: linkeddb.NewDefault(currentDelegatorBaseDB), - currentSubnetValidatorBaseDB: currentSubnetValidatorBaseDB, - currentSubnetValidatorList: linkeddb.NewDefault(currentSubnetValidatorBaseDB), - currentSubnetDelegatorBaseDB: currentSubnetDelegatorBaseDB, - currentSubnetDelegatorList: linkeddb.NewDefault(currentSubnetDelegatorBaseDB), - pendingValidatorsDB: pendingValidatorsDB, - pendingValidatorBaseDB: pendingValidatorBaseDB, - pendingValidatorList: linkeddb.NewDefault(pendingValidatorBaseDB), - pendingDelegatorBaseDB: pendingDelegatorBaseDB, - pendingDelegatorList: linkeddb.NewDefault(pendingDelegatorBaseDB), - pendingSubnetValidatorBaseDB: pendingSubnetValidatorBaseDB, - pendingSubnetValidatorList: linkeddb.NewDefault(pendingSubnetValidatorBaseDB), - pendingSubnetDelegatorBaseDB: pendingSubnetDelegatorBaseDB, - pendingSubnetDelegatorList: linkeddb.NewDefault(pendingSubnetDelegatorBaseDB), - validatorWeightDiffsDB: validatorWeightDiffsDB, - validatorPublicKeyDiffsDB: validatorPublicKeyDiffsDB, + validatorsDB: validatorsDB, + currentValidatorsDB: currentValidatorsDB, + currentValidatorBaseDB: currentValidatorBaseDB, + currentValidatorList: linkeddb.NewDefault(currentValidatorBaseDB), + currentDelegatorBaseDB: currentDelegatorBaseDB, + currentDelegatorList: linkeddb.NewDefault(currentDelegatorBaseDB), + currentSubnetValidatorBaseDB: currentSubnetValidatorBaseDB, + currentSubnetValidatorList: linkeddb.NewDefault(currentSubnetValidatorBaseDB), + currentSubnetDelegatorBaseDB: currentSubnetDelegatorBaseDB, + currentSubnetDelegatorList: linkeddb.NewDefault(currentSubnetDelegatorBaseDB), + pendingValidatorsDB: pendingValidatorsDB, + pendingValidatorBaseDB: pendingValidatorBaseDB, + pendingValidatorList: linkeddb.NewDefault(pendingValidatorBaseDB), + pendingDelegatorBaseDB: pendingDelegatorBaseDB, + pendingDelegatorList: linkeddb.NewDefault(pendingDelegatorBaseDB), + pendingSubnetValidatorBaseDB: pendingSubnetValidatorBaseDB, + pendingSubnetValidatorList: linkeddb.NewDefault(pendingSubnetValidatorBaseDB), + pendingSubnetDelegatorBaseDB: pendingSubnetDelegatorBaseDB, + pendingSubnetDelegatorList: linkeddb.NewDefault(pendingSubnetDelegatorBaseDB), + validatorWeightDiffsBySubnetIDDB: validatorWeightDiffsBySubnetIDDB, + validatorWeightDiffsByHeightDB: validatorWeightDiffsByHeightDB, + validatorPublicKeyDiffsBySubnetIDDB: validatorPublicKeyDiffsBySubnetIDDB, + validatorPublicKeyDiffsByHeightDB: validatorPublicKeyDiffsByHeightDB, addedTxs: make(map[ids.ID]*txAndStatus), txDB: prefixdb.New(TxPrefix, baseDB), @@ -1385,6 +1433,71 @@ func (s *state) SetCurrentSupply(subnetID ids.ID, cs uint64) { } } +func (s *state) ApplyAllValidatorWeightDiffs( + ctx context.Context, + allValidators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, + startHeight uint64, + endHeight uint64, +) error { + diffIter := s.validatorWeightDiffsByHeightDB.NewIteratorWithStart( + marshalStartDiffKeyByHeight(startHeight), + ) + defer diffIter.Release() + + prevHeight := startHeight + 1 + for diffIter.Next() { + if err := ctx.Err(); err != nil { + return err + } + + parsedHeight, subnetID, nodeID, err := unmarshalDiffKeyByHeight(diffIter.Key()) + if err != nil { + return err + } + + if parsedHeight > prevHeight { + s.ctx.Log.Error("unexpected parsed height", + zap.Stringer("subnetID", subnetID), + zap.Uint64("parsedHeight", parsedHeight), + zap.Stringer("nodeID", nodeID), + zap.Uint64("prevHeight", prevHeight), + zap.Uint64("startHeight", startHeight), + zap.Uint64("endHeight", endHeight), + ) + } + + // If the parsedHeight is less than our target endHeight, then we have + // fully processed the diffs from startHeight through endHeight. + if parsedHeight < endHeight { + return diffIter.Error() + } + + prevHeight = parsedHeight + + weightDiff, err := unmarshalWeightDiff(diffIter.Value()) + if err != nil { + return err + } + + vdrs, ok := allValidators[subnetID] + if !ok { + // If this subnet previously had no validators, add the map back + vdrs = make(map[ids.NodeID]*validators.GetValidatorOutput) + allValidators[subnetID] = vdrs + } + + if err := applyWeightDiff(vdrs, nodeID, weightDiff); err != nil { + return err + } + + if len(vdrs) == 0 { + // If the subnet has no validators, delete from the map + delete(allValidators, subnetID) + } + } + return diffIter.Error() +} + func (s *state) ApplyValidatorWeightDiffs( ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, @@ -1392,8 +1505,8 @@ func (s *state) ApplyValidatorWeightDiffs( endHeight uint64, subnetID ids.ID, ) error { - diffIter := s.validatorWeightDiffsDB.NewIteratorWithStartAndPrefix( - marshalStartDiffKey(subnetID, startHeight), + diffIter := s.validatorWeightDiffsBySubnetIDDB.NewIteratorWithStartAndPrefix( + marshalStartDiffKeyBySubnetID(subnetID, startHeight), subnetID[:], ) defer diffIter.Release() @@ -1404,7 +1517,7 @@ func (s *state) ApplyValidatorWeightDiffs( return err } - _, parsedHeight, nodeID, err := unmarshalDiffKey(diffIter.Key()) + _, parsedHeight, nodeID, err := unmarshalDiffKeyBySubnetID(diffIter.Key()) if err != nil { return err } @@ -1477,6 +1590,51 @@ func applyWeightDiff( return nil } +func (s *state) ApplyAllValidatorPublicKeyDiffs( + ctx context.Context, + allValidators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, + startHeight uint64, + endHeight uint64, +) error { + diffIter := s.validatorPublicKeyDiffsByHeightDB.NewIteratorWithStart( + marshalStartDiffKeyByHeight(startHeight), + ) + defer diffIter.Release() + + for diffIter.Next() { + if err := ctx.Err(); err != nil { + return err + } + + parsedHeight, subnetID, nodeID, err := unmarshalDiffKeyByHeight(diffIter.Key()) + if err != nil { + return err + } + // If the parsedHeight is less than our target endHeight, then we have + // fully processed the diffs from startHeight through endHeight. + if parsedHeight < endHeight { + break + } + + vdr, ok := allValidators[subnetID][nodeID] + if !ok { + // A validator that is eventually removed may have a key diff before it was removed + continue + } + + pkBytes := diffIter.Value() + if len(pkBytes) == 0 { + vdr.PublicKey = nil + } else { + vdr.PublicKey = bls.PublicKeyFromValidUncompressedBytes(pkBytes) + } + } + + // Nodes may see inconsistent public keys for heights before the new public + // key index was populated. + return diffIter.Error() +} + func (s *state) ApplyValidatorPublicKeyDiffs( ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, @@ -1484,8 +1642,8 @@ func (s *state) ApplyValidatorPublicKeyDiffs( endHeight uint64, subnetID ids.ID, ) error { - diffIter := s.validatorPublicKeyDiffsDB.NewIteratorWithStartAndPrefix( - marshalStartDiffKey(subnetID, startHeight), + diffIter := s.validatorPublicKeyDiffsBySubnetIDDB.NewIteratorWithStartAndPrefix( + marshalStartDiffKeyBySubnetID(subnetID, startHeight), subnetID[:], ) defer diffIter.Release() @@ -1495,7 +1653,7 @@ func (s *state) ApplyValidatorPublicKeyDiffs( return err } - _, parsedHeight, nodeID, err := unmarshalDiffKey(diffIter.Key()) + _, parsedHeight, nodeID, err := unmarshalDiffKeyBySubnetID(diffIter.Key()) if err != nil { return err } @@ -2586,19 +2744,35 @@ func (s *state) writeValidatorDiffs(height uint64) error { // Write the changes to the database for subnetIDNodeID, diff := range changes { - diffKey := marshalDiffKey(subnetIDNodeID.subnetID, height, subnetIDNodeID.nodeID) + diffKeyBySubnetID := marshalDiffKeyBySubnetID(subnetIDNodeID.subnetID, height, subnetIDNodeID.nodeID) + diffKeyByHeight := marshalDiffKeyByHeight(height, subnetIDNodeID.subnetID, subnetIDNodeID.nodeID) if diff.weightDiff.Amount != 0 { - err := s.validatorWeightDiffsDB.Put( - diffKey, - marshalWeightDiff(&diff.weightDiff), + weightDiff := marshalWeightDiff(&diff.weightDiff) + err := s.validatorWeightDiffsBySubnetIDDB.Put( + diffKeyBySubnetID, + weightDiff, + ) + if err != nil { + return err + } + err = s.validatorWeightDiffsByHeightDB.Put( + diffKeyByHeight, + weightDiff, ) if err != nil { return err } } if !bytes.Equal(diff.prevPublicKey, diff.newPublicKey) { - err := s.validatorPublicKeyDiffsDB.Put( - diffKey, + err := s.validatorPublicKeyDiffsBySubnetIDDB.Put( + diffKeyBySubnetID, + diff.prevPublicKey, + ) + if err != nil { + return err + } + err = s.validatorPublicKeyDiffsByHeightDB.Put( + diffKeyByHeight, diff.prevPublicKey, ) if err != nil { diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 3bf069a93dc6..bd77f4266390 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -514,26 +514,43 @@ func TestState_writeStakers(t *testing.T) { ) for subnetIDNodeID, expectedDiff := range test.expectedValidatorDiffs { - diffKey := marshalDiffKey(subnetIDNodeID.subnetID, 1, subnetIDNodeID.nodeID) - weightDiffBytes, err := state.validatorWeightDiffsDB.Get(diffKey) - if expectedDiff.weightDiff.Amount == 0 { - require.ErrorIs(err, database.ErrNotFound) - } else { - require.NoError(err) - - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - require.NoError(err) - require.Equal(&expectedDiff.weightDiff, weightDiff) - } + requireValidDiff := func( + diffKey []byte, + weightDiffs database.Database, + publicKeyDiffs database.Database, + ) { + t.Helper() + + weightDiffBytes, err := weightDiffs.Get(diffKey) + if expectedDiff.weightDiff.Amount == 0 { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) + + weightDiff, err := unmarshalWeightDiff(weightDiffBytes) + require.NoError(err) + require.Equal(&expectedDiff.weightDiff, weightDiff) + } - publicKeyDiffBytes, err := state.validatorPublicKeyDiffsDB.Get(diffKey) - if bytes.Equal(expectedDiff.prevPublicKey, expectedDiff.newPublicKey) { - require.ErrorIs(err, database.ErrNotFound) - } else { - require.NoError(err) + publicKeyDiffBytes, err := publicKeyDiffs.Get(diffKey) + if bytes.Equal(expectedDiff.prevPublicKey, expectedDiff.newPublicKey) { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) - require.Equal(expectedDiff.prevPublicKey, publicKeyDiffBytes) + require.Equal(expectedDiff.prevPublicKey, publicKeyDiffBytes) + } } + requireValidDiff( + marshalDiffKeyBySubnetID(subnetIDNodeID.subnetID, 1, subnetIDNodeID.nodeID), + state.validatorWeightDiffsBySubnetIDDB, + state.validatorPublicKeyDiffsBySubnetIDDB, + ) + requireValidDiff( + marshalDiffKeyByHeight(1, subnetIDNodeID.subnetID, subnetIDNodeID.nodeID), + state.validatorWeightDiffsByHeightDB, + state.validatorPublicKeyDiffsByHeightDB, + ) } // re-load the state from disk for the second iteration @@ -1086,6 +1103,38 @@ func TestState_ApplyValidatorDiffs(t *testing.T) { )) require.Equal(prevDiff.expectedSubnetValidatorSet, subnetValidatorSet) } + + // Checks applying diffs to all validator sets using height-based indices + { + allValidatorSets := make(map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput) + if len(diff.expectedPrimaryValidatorSet) != 0 { + allValidatorSets[constants.PrimaryNetworkID] = copyValidatorSet(diff.expectedPrimaryValidatorSet) + } + if len(diff.expectedSubnetValidatorSet) != 0 { + allValidatorSets[subnetID] = copyValidatorSet(diff.expectedSubnetValidatorSet) + } + require.NoError(state.ApplyAllValidatorWeightDiffs( + context.Background(), + allValidatorSets, + currentHeight, + prevHeight+1, + )) + require.NoError(state.ApplyAllValidatorPublicKeyDiffs( + context.Background(), + allValidatorSets, + currentHeight, + prevHeight+1, + )) + + expectedAllValidatorSets := make(map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput) + if len(prevDiff.expectedPrimaryValidatorSet) != 0 { + expectedAllValidatorSets[constants.PrimaryNetworkID] = prevDiff.expectedPrimaryValidatorSet + } + if len(prevDiff.expectedSubnetValidatorSet) != 0 { + expectedAllValidatorSets[subnetID] = prevDiff.expectedSubnetValidatorSet + } + require.Equal(expectedAllValidatorSets, allValidatorSets) + } } } } diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index ddb63e7abe08..041eeabc723c 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -81,6 +81,23 @@ type State interface { subnetID ids.ID, ) error + // ApplyAllValidatorWeightDiffs iterates from [startHeight] towards the genesis + // block until it has applied all of the diffs up to and including + // [endHeight]. Applying the diffs modifies [validators]. + // + // Invariant: If attempting to generate the validator set for + // [endHeight - 1], [validators] must initially contain the validator + // weights for [startHeight]. + // + // Note: Because this function iterates towards the genesis, [startHeight] + // should normally be greater than or equal to [endHeight]. + ApplyAllValidatorWeightDiffs( + ctx context.Context, + validators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, + startHeight uint64, + endHeight uint64, + ) error + // ApplyValidatorPublicKeyDiffs iterates from [startHeight] towards the // genesis block until it has applied all of the diffs up to and including // [endHeight]. Applying the diffs modifies [validators]. @@ -99,6 +116,23 @@ type State interface { subnetID ids.ID, ) error + // ApplyAllValidatorPublicKeyDiffs iterates from [startHeight] towards the + // genesis block until it has applied all of the diffs up to and including + // [endHeight]. Applying the diffs modifies [validators]. + // + // Invariant: If attempting to generate the validator set for + // [endHeight - 1], [validators] must initially contain the validator + // weights for [startHeight]. + // + // Note: Because this function iterates towards the genesis, [startHeight] + // should normally be greater than or equal to [endHeight]. + ApplyAllValidatorPublicKeyDiffs( + ctx context.Context, + validators map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, + startHeight uint64, + endHeight uint64, + ) error + GetCurrentValidators(ctx context.Context, subnetID ids.ID) ([]*state.Staker, []state.L1Validator, uint64, error) } @@ -196,6 +230,14 @@ func (m *manager) getCurrentHeight(context.Context) (uint64, error) { return lastAccepted.Height(), nil } +func (m *manager) GetAllValidatorSets( + ctx context.Context, + targetHeight uint64, +) (map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, error) { + // TODO: cache all validator sets + return m.makeAllValidatorSets(ctx, targetHeight) +} + func (m *manager) GetValidatorSet( ctx context.Context, targetHeight uint64, @@ -242,6 +284,48 @@ func (m *manager) getValidatorSetCache(subnetID ids.ID) cache.Cacher[uint64, map return validatorSetsCache } +func (m *manager) makeAllValidatorSets( + ctx context.Context, + targetHeight uint64, +) (map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, error) { + allValidators, currentHeight, err := m.getAllCurrentValidatorSets(ctx) + if err != nil { + return nil, err + } + if currentHeight < targetHeight { + return nil, fmt.Errorf("%w: current P-chain height (%d) < requested P-Chain height (%d)", + errUnfinalizedHeight, + currentHeight, + targetHeight, + ) + } + + // Rebuild subnet validators at [targetHeight] + // + // Note: Since we are attempting to generate the validator set at + // [targetHeight], we want to apply the diffs from + // (targetHeight, currentHeight]. Because the state interface is implemented + // to be inclusive, we apply diffs in [targetHeight + 1, currentHeight]. + lastDiffHeight := targetHeight + 1 + err = m.state.ApplyAllValidatorWeightDiffs( + ctx, + allValidators, + currentHeight, + lastDiffHeight, + ) + if err != nil { + return nil, err + } + + err = m.state.ApplyAllValidatorPublicKeyDiffs( + ctx, + allValidators, + currentHeight, + lastDiffHeight, + ) + return allValidators, err +} + func (m *manager) makeValidatorSet( ctx context.Context, targetHeight uint64, @@ -288,6 +372,14 @@ func (m *manager) makeValidatorSet( return validatorSet, currentHeight, err } +func (m *manager) getAllCurrentValidatorSets( + ctx context.Context, +) (map[ids.ID]map[ids.NodeID]*validators.GetValidatorOutput, uint64, error) { + subnetsMap := m.cfg.Validators.GetAllMaps() + currentHeight, err := m.getCurrentHeight(ctx) + return subnetsMap, currentHeight, err +} + func (m *manager) getCurrentValidatorSet( ctx context.Context, subnetID ids.ID,