diff --git a/.github/packer/roles/public-ami/defaults/main.yml b/.github/packer/roles/public-ami/defaults/main.yml index 3de12954683e..1d6f3d938fdc 100644 --- a/.github/packer/roles/public-ami/defaults/main.yml +++ b/.github/packer/roles/public-ami/defaults/main.yml @@ -5,6 +5,5 @@ network: mainnet db_dir: /data/avalanchego log_dir: /var/log/avalanchego config_dir: /etc/avalanchego -plugin_dir: /usr/local/lib/avalanchego/plugins repo_url: https://github.com/ava-labs/avalanchego repo_folder: /tmp/avalanchego diff --git a/.github/packer/roles/public-ami/tasks/main.yml b/.github/packer/roles/public-ami/tasks/main.yml index 6c3695f7ddbc..2388d1501f4a 100644 --- a/.github/packer/roles/public-ami/tasks/main.yml +++ b/.github/packer/roles/public-ami/tasks/main.yml @@ -67,14 +67,6 @@ group: "{{ ava_group }}" state: directory -- name: Create avalanche plugins directory - file: - path: "{{ plugin_dir }}" - owner: "{{ ava_user }}" - group: "{{ ava_group }}" - state: directory - recurse: yes - - name: Build avalanchego command: ./scripts/build.sh args: @@ -85,12 +77,6 @@ args: chdir: "{{ repo_folder }}" -- name: Copy evm binaries to the correct location - command: cp build/plugins/evm {{ plugin_dir }} - args: - chdir: "{{ repo_folder }}" - - - name: Configure avalanche template: src: templates/conf.json.j2 diff --git a/.github/packer/ubuntu-focal-x86_64-public-ami.json b/.github/packer/ubuntu-focal-x86_64-public-ami.json index 2110760deaea..0c2a639c4286 100644 --- a/.github/packer/ubuntu-focal-x86_64-public-ami.json +++ b/.github/packer/ubuntu-focal-x86_64-public-ami.json @@ -45,6 +45,7 @@ "type": "ansible", "playbook_file": ".github/packer/create_public_ami.yml", "roles_path": ".github/packer/roles/", + "use_proxy": false, "extra_arguments": ["-e", "component=public-ami build=packer os_release=focal tag={{user `tag`}}"] }, { diff --git a/.github/workflows/build-rpm-pkg.sh b/.github/workflows/build-rpm-pkg.sh deleted file mode 100755 index d1850f3fe0b3..000000000000 --- a/.github/workflows/build-rpm-pkg.sh +++ /dev/null @@ -1,25 +0,0 @@ -PKG_ROOT=/tmp/avalanchego -RPM_BASE_DIR=$PKG_ROOT/yum -AVALANCHE_BUILD_BIN_DIR=$RPM_BASE_DIR/usr/local/bin -AVALANCHE_LIB_DIR=$RPM_BASE_DIR/usr/local/lib/avalanchego - -mkdir -p $RPM_BASE_DIR -mkdir -p $AVALANCHE_BUILD_BIN_DIR -mkdir -p $AVALANCHE_LIB_DIR - -OK=`cp ./build/avalanchego $AVALANCHE_BUILD_BIN_DIR` -if [[ $OK -ne 0 ]]; then - exit $OK; -fi -OK=`cp ./build/plugins/evm $AVALANCHE_LIB_DIR` -if [[ $OK -ne 0 ]]; then - exit $OK; -fi - -echo "Build rpm package..." -VER=$(echo $TAG | gawk -F- '{print$1}' | tr -d 'v' ) -REL=$(echo $TAG | gawk -F- '{print$2}') -[ -z "$REL" ] && REL=0 -echo "Tag: $VER" -rpmbuild --bb --define "version $VER" --define "release $REL" --buildroot $RPM_BASE_DIR .github/workflows/yum/specfile/avalanchego.spec -aws s3 cp ~/rpmbuild/RPMS/x86_64/avalanchego-*.rpm s3://$BUCKET/linux/rpm/ diff --git a/.github/workflows/build-tgz-pkg.sh b/.github/workflows/build-tgz-pkg.sh index 756ab0637148..413d9fce802f 100755 --- a/.github/workflows/build-tgz-pkg.sh +++ b/.github/workflows/build-tgz-pkg.sh @@ -8,10 +8,6 @@ OK=`cp ./build/avalanchego $AVALANCHE_ROOT` if [[ $OK -ne 0 ]]; then exit $OK; fi -OK=`cp -r ./build/plugins $AVALANCHE_ROOT` -if [[ $OK -ne 0 ]]; then - exit $OK; -fi echo "Build tgz package..." diff --git a/.github/workflows/update-ami.py b/.github/workflows/update-ami.py index 5f5bcb633c90..87434ff6d298 100755 --- a/.github/workflows/update-ami.py +++ b/.github/workflows/update-ami.py @@ -10,6 +10,7 @@ file = '.github/workflows/amichange.json' packerfile = ".github/packer/ubuntu-focal-x86_64-public-ami.json" +update_marketplace = True product_id = os.getenv('PRODUCT_ID') role_arn = os.getenv('ROLE_ARN') vtag = os.getenv('TAG') @@ -18,14 +19,12 @@ for var in variables: if var is None: - print("A Variable is not set correctly or this is not the right repo. Exiting.") - exit(0) + print("A Variable is not set correctly or this is not the right repo. Only validating packer.") + update_marketplace = False if 'rc' in tag: - print("This is a release candidate. Nothing to do.") - exit(0) - -client = boto3.client('marketplace-catalog',region_name='us-east-1') + print("This is a release candidate. Only validating packer.") + update_marketplace = False def packer_build(packerfile): p = packer.Packer(packerfile) @@ -44,24 +43,28 @@ def parse_amichange(object): amiid=packer_build(packerfile) -try: - response = client.start_change_set( - Catalog='AWSMarketplace', - ChangeSet=[ - { - 'ChangeType': 'AddDeliveryOptions', - 'Entity': { - 'Type': 'AmiProduct@1.0', - 'Identifier': product_id - }, - 'Details': parse_amichange(file), - 'ChangeName': 'Update' - }, - ], - ChangeSetName='AvalancheGo Update ' + tag, - ClientRequestToken=uid - ) - print(response) -except client.exceptions.ResourceInUseException: - print("The product is currently blocked by Amazon. Please check the product site for more details") +if update_marketplace: + + client = boto3.client('marketplace-catalog',region_name='us-east-1') + + try: + response = client.start_change_set( + Catalog='AWSMarketplace', + ChangeSet=[ + { + 'ChangeType': 'AddDeliveryOptions', + 'Entity': { + 'Type': 'AmiProduct@1.0', + 'Identifier': product_id + }, + 'Details': parse_amichange(file), + 'ChangeName': 'Update' + }, + ], + ChangeSetName='AvalancheGo Update ' + tag, + ClientRequestToken=uid + ) + print(response) + except client.exceptions.ResourceInUseException: + print("The product is currently blocked by Amazon. Please check the product site for more details") diff --git a/.github/workflows/yum/specfile/avalanchego.spec b/.github/workflows/yum/specfile/avalanchego.spec deleted file mode 100644 index e61ede93dbcf..000000000000 --- a/.github/workflows/yum/specfile/avalanchego.spec +++ /dev/null @@ -1,22 +0,0 @@ -%define _build_id_links none - -Name: avalanchego -Version: %{version} -Release: %{release} -Summary: The Avalanche platform binaries -URL: https://github.com/ava-labs/%{name} -License: BSD-3 -AutoReqProv: no - -%description -Avalanche is an incredibly lightweight protocol, so the minimum computer requirements are quite modest. - -%files -/usr/local/bin/avalanchego -/usr/local/lib/avalanchego -/usr/local/lib/avalanchego/evm - -%changelog -* Mon Oct 26 2020 Charlie Wyse -- First creation of package - diff --git a/api/health/health_test.go b/api/health/health_test.go index b506429e7d0f..8b737b1a5b04 100644 --- a/api/health/health_test.go +++ b/api/health/health_test.go @@ -192,9 +192,9 @@ func TestPassingChecks(t *testing.T) { func TestPassingThenFailingChecks(t *testing.T) { require := require.New(t) - var shouldCheckErr utils.AtomicBool + var shouldCheckErr utils.Atomic[bool] check := CheckerFunc(func(context.Context) (interface{}, error) { - if shouldCheckErr.GetValue() { + if shouldCheckErr.Get() { return errUnhealthy.Error(), errUnhealthy } return "", nil @@ -228,7 +228,7 @@ func TestPassingThenFailingChecks(t *testing.T) { require.True(liveness) } - shouldCheckErr.SetValue(true) + shouldCheckErr.Set(true) awaitHealthy(h, false) awaitLiveness(h, false) diff --git a/api/health/worker.go b/api/health/worker.go index 38e4c9e42a51..4bcdb3a22acc 100644 --- a/api/health/worker.go +++ b/api/health/worker.go @@ -60,16 +60,16 @@ func (w *worker) RegisterCheck(name string, checker Checker) error { } func (w *worker) RegisterMonotonicCheck(name string, checker Checker) error { - var result utils.AtomicInterface - return w.RegisterCheck(name, CheckerFunc(func(ctx context.Context) (interface{}, error) { - details := result.GetValue() + var result utils.Atomic[any] + return w.RegisterCheck(name, CheckerFunc(func(ctx context.Context) (any, error) { + details := result.Get() if details != nil { return details, nil } details, err := checker.HealthCheck(ctx) if err == nil { - result.SetValue(details) + result.Set(details) } return details, err })) diff --git a/api/server/mock_server.go b/api/server/mock_server.go index 64d915a4819d..9c2b9bf7d9f4 100644 --- a/api/server/mock_server.go +++ b/api/server/mock_server.go @@ -11,6 +11,7 @@ import ( reflect "reflect" sync "sync" + snow "github.com/ava-labs/avalanchego/snow" common "github.com/ava-labs/avalanchego/snow/engine/common" gomock "github.com/golang/mock/gomock" ) @@ -133,15 +134,15 @@ func (mr *MockServerMockRecorder) DispatchTLS(arg0, arg1 interface{}) *gomock.Ca } // RegisterChain mocks base method. -func (m *MockServer) RegisterChain(arg0 string, arg1 common.Engine) { +func (m *MockServer) RegisterChain(arg0 string, arg1 *snow.ConsensusContext, arg2 common.VM) { m.ctrl.T.Helper() - m.ctrl.Call(m, "RegisterChain", arg0, arg1) + m.ctrl.Call(m, "RegisterChain", arg0, arg1, arg2) } // RegisterChain indicates an expected call of RegisterChain. -func (mr *MockServerMockRecorder) RegisterChain(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockServerMockRecorder) RegisterChain(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterChain", reflect.TypeOf((*MockServer)(nil).RegisterChain), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterChain", reflect.TypeOf((*MockServer)(nil).RegisterChain), arg0, arg1, arg2) } // Shutdown mocks base method. diff --git a/api/server/server.go b/api/server/server.go index e908e41a747b..78fa1cdd2262 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -71,14 +71,10 @@ type Server interface { Dispatch() error // DispatchTLS starts the API server with the provided TLS certificate DispatchTLS(certBytes, keyBytes []byte) error - // RegisterChain registers the API endpoints associated with this chain. That is, - // add pairs to server so that API calls can be made to the VM. - // This method runs in a goroutine to avoid a deadlock in the event that the caller - // holds the engine's context lock. Namely, this could happen when the P-Chain is - // creating a new chain and holds the P-Chain's lock when this function is held, - // and at the same time the server's lock is held due to an API call and is trying - // to grab the P-Chain's lock. - RegisterChain(chainName string, engine common.Engine) + // RegisterChain registers the API endpoints associated with this chain. + // That is, add pairs to server so that API calls can be + // made to the VM. + RegisterChain(chainName string, ctx *snow.ConsensusContext, vm common.VM) // Shutdown this server Shutdown() error } @@ -217,19 +213,14 @@ func (s *server) DispatchTLS(certBytes, keyBytes []byte) error { return s.srv.Serve(listener) } -func (s *server) RegisterChain(chainName string, engine common.Engine) { - go s.registerChain(chainName, engine) -} - -func (s *server) registerChain(chainName string, engine common.Engine) { +func (s *server) RegisterChain(chainName string, ctx *snow.ConsensusContext, vm common.VM) { var ( handlers map[string]*common.HTTPHandler err error ) - ctx := engine.Context() ctx.Lock.Lock() - handlers, err = engine.GetVM().CreateHandlers(context.TODO()) + handlers, err = vm.CreateHandlers(context.TODO()) ctx.Lock.Unlock() if err != nil { s.log.Error("failed to create handlers", @@ -377,7 +368,7 @@ func lockMiddleware( // not done state-syncing/bootstrapping, writes back an error. func rejectMiddleware(handler http.Handler, ctx *snow.ConsensusContext) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If chain isn't done bootstrapping, ignore API calls - if ctx.GetState() != snow.NormalOp { + if ctx.State.Get().State != snow.NormalOp { w.WriteHeader(http.StatusServiceUnavailable) // Doesn't matter if there's an error while writing. They'll get the StatusServiceUnavailable code. _, _ = w.Write([]byte("API call rejected because chain is not done bootstrapping")) diff --git a/cache/cache.go b/cache/cache.go index f69dca5e2324..82ddfb70a8de 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -3,37 +3,34 @@ package cache -// Cacher acts as a best effort key value store. Keys must be comparable, as -// defined by https://golang.org/ref/spec#Comparison_operators. -type Cacher interface { - // Put inserts an element into the cache. If spaced is required, elements will +// Cacher acts as a best effort key value store. +type Cacher[K comparable, V any] interface { + // Put inserts an element into the cache. If space is required, elements will // be evicted. - Put(key, value interface{}) + Put(key K, value V) // Get returns the entry in the cache with the key specified, if no value // exists, false is returned. - Get(key interface{}) (interface{}, bool) + Get(key K) (V, bool) // Evict removes the specified entry from the cache - Evict(key interface{}) + Evict(key K) // Flush removes all entries from the cache Flush() } // Evictable allows the object to be notified when it is evicted -type Evictable interface { - // Key must return a comparable value as defined by - // https://golang.org/ref/spec#Comparison_operators. - Key() interface{} +type Evictable[K comparable] interface { + Key() K Evict() } // Deduplicator acts as a best effort deduplication service -type Deduplicator interface { +type Deduplicator[K comparable, V Evictable[K]] interface { // Deduplicate returns either the provided value, or a previously provided // value with the same ID that hasn't yet been evicted - Deduplicate(Evictable) Evictable + Deduplicate(V) V // Flush removes all entries from the cache Flush() diff --git a/cache/lru_cache.go b/cache/lru_cache.go index b186df864405..dfe6cecab09b 100644 --- a/cache/lru_cache.go +++ b/cache/lru_cache.go @@ -6,58 +6,60 @@ package cache import ( "container/list" "sync" + + "github.com/ava-labs/avalanchego/utils" ) const minCacheSize = 32 -var _ Cacher = (*LRU)(nil) +var _ Cacher[struct{}, struct{}] = (*LRU[struct{}, struct{}])(nil) -type entry struct { - Key interface{} - Value interface{} +type entry[K comparable, V any] struct { + Key K + Value V } // LRU is a key value store with bounded size. If the size is attempted to be // exceeded, then an element is removed from the cache before the insertion is // done, based on evicting the least recently used value. -type LRU struct { +type LRU[K comparable, _ any] struct { lock sync.Mutex - entryMap map[interface{}]*list.Element + entryMap map[K]*list.Element entryList *list.List Size int } -func (c *LRU) Put(key, value interface{}) { +func (c *LRU[K, V]) Put(key K, value V) { c.lock.Lock() defer c.lock.Unlock() c.put(key, value) } -func (c *LRU) Get(key interface{}) (interface{}, bool) { +func (c *LRU[K, V]) Get(key K) (V, bool) { c.lock.Lock() defer c.lock.Unlock() return c.get(key) } -func (c *LRU) Evict(key interface{}) { +func (c *LRU[K, _]) Evict(key K) { c.lock.Lock() defer c.lock.Unlock() c.evict(key) } -func (c *LRU) Flush() { +func (c *LRU[_, _]) Flush() { c.lock.Lock() defer c.lock.Unlock() c.flush() } -func (c *LRU) init() { +func (c *LRU[K, _]) init() { if c.entryMap == nil { - c.entryMap = make(map[interface{}]*list.Element, minCacheSize) + c.entryMap = make(map[K]*list.Element, minCacheSize) } if c.entryList == nil { c.entryList = list.New() @@ -67,17 +69,17 @@ func (c *LRU) init() { } } -func (c *LRU) resize() { +func (c *LRU[K, V]) resize() { for c.entryList.Len() > c.Size { e := c.entryList.Front() c.entryList.Remove(e) - val := e.Value.(*entry) + val := e.Value.(*entry[K, V]) delete(c.entryMap, val.Key) } } -func (c *LRU) put(key, value interface{}) { +func (c *LRU[K, V]) put(key K, value V) { c.init() c.resize() @@ -86,12 +88,12 @@ func (c *LRU) put(key, value interface{}) { e = c.entryList.Front() c.entryList.MoveToBack(e) - val := e.Value.(*entry) + val := e.Value.(*entry[K, V]) delete(c.entryMap, val.Key) val.Key = key val.Value = value } else { - e = c.entryList.PushBack(&entry{ + e = c.entryList.PushBack(&entry[K, V]{ Key: key, Value: value, }) @@ -100,25 +102,25 @@ func (c *LRU) put(key, value interface{}) { } else { c.entryList.MoveToBack(e) - val := e.Value.(*entry) + val := e.Value.(*entry[K, V]) val.Value = value } } -func (c *LRU) get(key interface{}) (interface{}, bool) { +func (c *LRU[K, V]) get(key K) (V, bool) { c.init() c.resize() if e, ok := c.entryMap[key]; ok { c.entryList.MoveToBack(e) - val := e.Value.(*entry) + val := e.Value.(*entry[K, V]) return val.Value, true } - return struct{}{}, false + return utils.Zero[V](), false } -func (c *LRU) evict(key interface{}) { +func (c *LRU[K, _]) evict(key K) { c.init() c.resize() @@ -128,9 +130,9 @@ func (c *LRU) evict(key interface{}) { } } -func (c *LRU) flush() { +func (c *LRU[K, _]) flush() { c.init() - c.entryMap = make(map[interface{}]*list.Element, minCacheSize) + c.entryMap = make(map[K]*list.Element, minCacheSize) c.entryList = list.New() } diff --git a/cache/lru_cache_benchmark_test.go b/cache/lru_cache_benchmark_test.go index 20c45e62d0b8..4d4a8c7b3030 100644 --- a/cache/lru_cache_benchmark_test.go +++ b/cache/lru_cache_benchmark_test.go @@ -12,7 +12,7 @@ import ( func BenchmarkLRUCachePutSmall(b *testing.B) { smallLen := 5 - cache := &LRU{Size: smallLen} + cache := &LRU[ids.ID, int]{Size: smallLen} for n := 0; n < b.N; n++ { for i := 0; i < smallLen; i++ { var id ids.ID @@ -29,7 +29,7 @@ func BenchmarkLRUCachePutSmall(b *testing.B) { func BenchmarkLRUCachePutMedium(b *testing.B) { mediumLen := 250 - cache := &LRU{Size: mediumLen} + cache := &LRU[ids.ID, int]{Size: mediumLen} for n := 0; n < b.N; n++ { for i := 0; i < mediumLen; i++ { var id ids.ID @@ -46,7 +46,7 @@ func BenchmarkLRUCachePutMedium(b *testing.B) { func BenchmarkLRUCachePutLarge(b *testing.B) { largeLen := 10000 - cache := &LRU{Size: largeLen} + cache := &LRU[ids.ID, int]{Size: largeLen} for n := 0; n < b.N; n++ { for i := 0; i < largeLen; i++ { var id ids.ID diff --git a/cache/lru_cache_test.go b/cache/lru_cache_test.go index 0228c42a4b5d..b7ca773d1beb 100644 --- a/cache/lru_cache_test.go +++ b/cache/lru_cache_test.go @@ -10,19 +10,19 @@ import ( ) func TestLRU(t *testing.T) { - cache := &LRU{Size: 1} + cache := &LRU[ids.ID, int]{Size: 1} TestBasic(t, cache) } func TestLRUEviction(t *testing.T) { - cache := &LRU{Size: 2} + cache := &LRU[ids.ID, int]{Size: 2} TestEviction(t, cache) } func TestLRUResize(t *testing.T) { - cache := LRU{Size: 2} + cache := LRU[ids.ID, int]{Size: 2} id1 := ids.ID{1} id2 := ids.ID{2} diff --git a/cache/metercacher/cache.go b/cache/metercacher/cache.go index 7a8b7106651d..b7b367ee613d 100644 --- a/cache/metercacher/cache.go +++ b/cache/metercacher/cache.go @@ -10,32 +10,32 @@ import ( "github.com/ava-labs/avalanchego/utils/timer/mockable" ) -var _ cache.Cacher = (*Cache)(nil) +var _ cache.Cacher[struct{}, struct{}] = (*Cache[struct{}, struct{}])(nil) -type Cache struct { +type Cache[K comparable, V any] struct { metrics - cache.Cacher + cache.Cacher[K, V] clock mockable.Clock } -func New( +func New[K comparable, V any]( namespace string, registerer prometheus.Registerer, - cache cache.Cacher, -) (cache.Cacher, error) { - meterCache := &Cache{Cacher: cache} + cache cache.Cacher[K, V], +) (cache.Cacher[K, V], error) { + meterCache := &Cache[K, V]{Cacher: cache} return meterCache, meterCache.metrics.Initialize(namespace, registerer) } -func (c *Cache) Put(key, value interface{}) { +func (c *Cache[K, V]) Put(key K, value V) { start := c.clock.Time() c.Cacher.Put(key, value) end := c.clock.Time() c.put.Observe(float64(end.Sub(start))) } -func (c *Cache) Get(key interface{}) (interface{}, bool) { +func (c *Cache[K, V]) Get(key K) (V, bool) { start := c.clock.Time() value, has := c.Cacher.Get(key) end := c.clock.Time() diff --git a/cache/metercacher/cache_test.go b/cache/metercacher/cache_test.go index 64c547e16d96..46b64d673512 100644 --- a/cache/metercacher/cache_test.go +++ b/cache/metercacher/cache_test.go @@ -9,12 +9,13 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/ava-labs/avalanchego/cache" + "github.com/ava-labs/avalanchego/ids" ) func TestInterface(t *testing.T) { for _, test := range cache.CacherTests { - cache := &cache.LRU{Size: test.Size} - c, err := New("", prometheus.NewRegistry(), cache) + cache := &cache.LRU[ids.ID, int]{Size: test.Size} + c, err := New[ids.ID, int]("", prometheus.NewRegistry(), cache) if err != nil { t.Fatal(err) } diff --git a/cache/test_cacher.go b/cache/test_cacher.go index 32c8e7954c6d..5bd173c3fb1a 100644 --- a/cache/test_cacher.go +++ b/cache/test_cacher.go @@ -12,13 +12,13 @@ import ( // CacherTests is a list of all Cacher tests var CacherTests = []struct { Size int - Func func(t *testing.T, c Cacher) + Func func(t *testing.T, c Cacher[ids.ID, int]) }{ {Size: 1, Func: TestBasic}, {Size: 2, Func: TestEviction}, } -func TestBasic(t *testing.T, cache Cacher) { +func TestBasic(t *testing.T, cache Cacher[ids.ID, int]) { id1 := ids.ID{1} if _, found := cache.Get(id1); found { t.Fatalf("Retrieved value when none exists") @@ -60,7 +60,7 @@ func TestBasic(t *testing.T, cache Cacher) { } } -func TestEviction(t *testing.T, cache Cacher) { +func TestEviction(t *testing.T, cache Cacher[ids.ID, int]) { id1 := ids.ID{1} id2 := ids.ID{2} id3 := ids.ID{3} diff --git a/cache/unique_cache.go b/cache/unique_cache.go index a825a6e8f689..e9b2de00174b 100644 --- a/cache/unique_cache.go +++ b/cache/unique_cache.go @@ -8,33 +8,33 @@ import ( "sync" ) -var _ Deduplicator = (*EvictableLRU)(nil) +var _ Deduplicator[struct{}, Evictable[struct{}]] = (*EvictableLRU[struct{}, Evictable[struct{}]])(nil) // EvictableLRU is an LRU cache that notifies the objects when they are evicted. -type EvictableLRU struct { +type EvictableLRU[K comparable, _ Evictable[K]] struct { lock sync.Mutex - entryMap map[interface{}]*list.Element + entryMap map[K]*list.Element entryList *list.List Size int } -func (c *EvictableLRU) Deduplicate(value Evictable) Evictable { +func (c *EvictableLRU[_, V]) Deduplicate(value V) V { c.lock.Lock() defer c.lock.Unlock() return c.deduplicate(value) } -func (c *EvictableLRU) Flush() { +func (c *EvictableLRU[_, _]) Flush() { c.lock.Lock() defer c.lock.Unlock() c.flush() } -func (c *EvictableLRU) init() { +func (c *EvictableLRU[K, _]) init() { if c.entryMap == nil { - c.entryMap = make(map[interface{}]*list.Element) + c.entryMap = make(map[K]*list.Element) } if c.entryList == nil { c.entryList = list.New() @@ -44,18 +44,18 @@ func (c *EvictableLRU) init() { } } -func (c *EvictableLRU) resize() { +func (c *EvictableLRU[_, V]) resize() { for c.entryList.Len() > c.Size { e := c.entryList.Front() c.entryList.Remove(e) - val := e.Value.(Evictable) + val := e.Value.(V) delete(c.entryMap, val.Key()) val.Evict() } } -func (c *EvictableLRU) deduplicate(value Evictable) Evictable { +func (c *EvictableLRU[_, V]) deduplicate(value V) V { c.init() c.resize() @@ -65,7 +65,7 @@ func (c *EvictableLRU) deduplicate(value Evictable) Evictable { e = c.entryList.Front() c.entryList.MoveToBack(e) - val := e.Value.(Evictable) + val := e.Value.(V) delete(c.entryMap, val.Key()) val.Evict() @@ -77,13 +77,13 @@ func (c *EvictableLRU) deduplicate(value Evictable) Evictable { } else { c.entryList.MoveToBack(e) - val := e.Value.(Evictable) + val := e.Value.(V) value = val } return value } -func (c *EvictableLRU) flush() { +func (c *EvictableLRU[_, _]) flush() { c.init() size := c.Size diff --git a/cache/unique_cache_test.go b/cache/unique_cache_test.go index fc409d3e07d7..0b610b30348c 100644 --- a/cache/unique_cache_test.go +++ b/cache/unique_cache_test.go @@ -9,35 +9,35 @@ import ( "github.com/ava-labs/avalanchego/ids" ) -type evictable struct { - id ids.ID +type evictable[K comparable] struct { + id K evicted int } -func (e *evictable) Key() interface{} { +func (e *evictable[K]) Key() K { return e.id } -func (e *evictable) Evict() { +func (e *evictable[_]) Evict() { e.evicted++ } func TestEvictableLRU(t *testing.T) { - cache := EvictableLRU{} + cache := EvictableLRU[ids.ID, *evictable[ids.ID]]{} - expectedValue1 := &evictable{id: ids.ID{1}} - if returnedValue := cache.Deduplicate(expectedValue1).(*evictable); returnedValue != expectedValue1 { + expectedValue1 := &evictable[ids.ID]{id: ids.ID{1}} + if returnedValue := cache.Deduplicate(expectedValue1); returnedValue != expectedValue1 { t.Fatalf("Returned unknown value") } else if expectedValue1.evicted != 0 { t.Fatalf("Value was evicted unexpectedly") - } else if returnedValue := cache.Deduplicate(expectedValue1).(*evictable); returnedValue != expectedValue1 { + } else if returnedValue := cache.Deduplicate(expectedValue1); returnedValue != expectedValue1 { t.Fatalf("Returned unknown value") } else if expectedValue1.evicted != 0 { t.Fatalf("Value was evicted unexpectedly") } - expectedValue2 := &evictable{id: ids.ID{2}} - returnedValue := cache.Deduplicate(expectedValue2).(*evictable) + expectedValue2 := &evictable[ids.ID]{id: ids.ID{2}} + returnedValue := cache.Deduplicate(expectedValue2) switch { case returnedValue != expectedValue2: t.Fatalf("Returned unknown value") @@ -49,8 +49,8 @@ func TestEvictableLRU(t *testing.T) { cache.Size = 2 - expectedValue3 := &evictable{id: ids.ID{2}} - returnedValue = cache.Deduplicate(expectedValue3).(*evictable) + expectedValue3 := &evictable[ids.ID]{id: ids.ID{2}} + returnedValue = cache.Deduplicate(expectedValue3) switch { case returnedValue != expectedValue2: t.Fatalf("Returned unknown value") diff --git a/chains/manager.go b/chains/manager.go index d5ebd16affac..11cd0f1b4e42 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -139,7 +139,8 @@ type ChainParameters struct { type chain struct { Name string - Engine common.Engine + Context *snow.ConsensusContext + VM common.VM Handler handler.Handler Beacons validators.Set } @@ -373,7 +374,7 @@ func (m *manager) createChain(chainParams ChainParameters) { } // Notify those that registered to be notified when a new chain is created - m.notifyRegistrants(chain.Name, chain.Engine) + m.notifyRegistrants(chain.Name, chain.Context, chain.VM) // Allows messages to be routed to the new chain. If the handler hasn't been // started and a message is forwarded, then the message will block until the @@ -454,14 +455,9 @@ func (m *manager) buildChain(chainParams ChainParameters, sb Subnet) (*chain, er ConsensusAcceptor: m.ConsensusAcceptorGroup, Registerer: consensusMetrics, } - // We set the state to Initializing here because failing to set the state - // before it's first access would cause a panic. - ctx.SetState(snow.Initializing) if subnetConfig, ok := m.SubnetConfigs[chainParams.SubnetID]; ok { - if subnetConfig.ValidatorOnly { - ctx.SetValidatorOnly() - } + ctx.ValidatorOnly.Set(subnetConfig.ValidatorOnly) } // Get a factory for the vm we want to use on our chain @@ -692,7 +688,6 @@ func (m *manager) createAvalancheChain( msgChan, sb.afterBootstrapped(), m.ConsensusGossipFrequency, - p2p.EngineType_ENGINE_TYPE_AVALANCHE, m.ResourceTracker, validators.UnhandledSubnetConnector, // avalanche chains don't use subnet connector ) @@ -789,7 +784,8 @@ func (m *manager) createAvalancheChain( return &chain{ Name: chainAlias, - Engine: engine, + Context: ctx, + VM: vm, Handler: handler, }, nil } @@ -969,7 +965,6 @@ func (m *manager) createSnowmanChain( msgChan, sb.afterBootstrapped(), m.ConsensusGossipFrequency, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, m.ResourceTracker, subnetConnector, ) @@ -1082,7 +1077,8 @@ func (m *manager) createSnowmanChain( return &chain{ Name: chainAlias, - Engine: engine, + Context: ctx, + VM: vm, Handler: handler, }, nil } @@ -1095,7 +1091,7 @@ func (m *manager) IsBootstrapped(id ids.ID) bool { return false } - return chain.Context().GetState() == snow.NormalOp + return chain.Context().State.Get().State == snow.NormalOp } func (m *manager) subnetsNotBootstrapped() []ids.ID { @@ -1188,9 +1184,9 @@ func (m *manager) LookupVM(alias string) (ids.ID, error) { // Notify registrants [those who want to know about the creation of chains] // that the specified chain has been created -func (m *manager) notifyRegistrants(name string, engine common.Engine) { +func (m *manager) notifyRegistrants(name string, ctx *snow.ConsensusContext, vm common.VM) { for _, registrant := range m.registrants { - registrant.RegisterChain(name, engine) + registrant.RegisterChain(name, ctx, vm) } } diff --git a/chains/registrant.go b/chains/registrant.go index 1c9fc0a5ba88..1d4290fe0336 100644 --- a/chains/registrant.go +++ b/chains/registrant.go @@ -4,13 +4,14 @@ package chains import ( + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/engine/common" ) // Registrant can register the existence of a chain type Registrant interface { - // Called when the chain described by [engine] is created + // Called when a chain is created // This function is called before the chain starts processing messages - // [engine] should be an avalanche.Engine or snowman.Engine - RegisterChain(name string, engine common.Engine) + // [vm] should be a vertex.DAGVM or block.ChainVM + RegisterChain(chainName string, ctx *snow.ConsensusContext, vm common.VM) } diff --git a/config/config.go b/config/config.go index fcff610bddca..24d190d179c7 100644 --- a/config/config.go +++ b/config/config.go @@ -350,6 +350,9 @@ func getNetworkConfig(v *viper.Viper, stakingEnabled bool, halflife time.Duratio SendFailRateHalflife: halflife, }, + ProxyEnabled: v.GetBool(NetworkTCPProxyEnabledKey), + ProxyReadHeaderTimeout: v.GetDuration(NetworkTCPProxyReadTimeoutKey), + DialerConfig: dialer.Config{ ThrottleRps: v.GetUint32(OutboundConnectionThrottlingRpsKey), ConnectionTimeout: v.GetDuration(OutboundConnectionTimeoutKey), diff --git a/config/flags.go b/config/flags.go index e7b726becb58..e51e29c7a15e 100644 --- a/config/flags.go +++ b/config/flags.go @@ -152,6 +152,14 @@ func addNodeFlags(fs *flag.FlagSet) { fs.Uint(NetworkPeerReadBufferSizeKey, 8*units.KiB, "Size, in bytes, of the buffer that we read peer messages into (there is one buffer per peer)") fs.Uint(NetworkPeerWriteBufferSizeKey, 8*units.KiB, "Size, in bytes, of the buffer that we write peer messages into (there is one buffer per peer)") + fs.Bool(NetworkTCPProxyEnabledKey, false, "Require all P2P connections to be initiated with a TCP proxy header") + // The PROXY protocol specification recommends setting this value to be at + // least 3 seconds to cover a TCP retransmit. + // Ref: https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt + // Specifying a timeout of 0 will actually result in a timeout of 200ms, but + // a timeout of 0 should generally not be provided. + fs.Duration(NetworkTCPProxyReadTimeoutKey, 3*time.Second, "Maximum duration to wait for a TCP proxy header") + fs.String(NetworkTLSKeyLogFileKey, "", "TLS key log file path. Should only be specified for debugging") // Benchlist diff --git a/config/keys.go b/config/keys.go index e74ee2dbb818..94455cfce385 100644 --- a/config/keys.go +++ b/config/keys.go @@ -100,6 +100,8 @@ const ( NetworkRequireValidatorToConnectKey = "network-require-validator-to-connect" NetworkPeerReadBufferSizeKey = "network-peer-read-buffer-size" NetworkPeerWriteBufferSizeKey = "network-peer-write-buffer-size" + NetworkTCPProxyEnabledKey = "network-tcp-proxy-enabled" + NetworkTCPProxyReadTimeoutKey = "network-tcp-proxy-read-timeout" NetworkTLSKeyLogFileKey = "network-tls-key-log-file-unsafe" BenchlistFailThresholdKey = "benchlist-fail-threshold" BenchlistDurationKey = "benchlist-duration" diff --git a/database/leveldb/db.go b/database/leveldb/db.go index 675130b80d57..7ea9e997616f 100644 --- a/database/leveldb/db.go +++ b/database/leveldb/db.go @@ -77,7 +77,7 @@ type Database struct { // metrics is only initialized and used when [MetricUpdateFrequency] is >= 0 // in the config metrics metrics - closed utils.AtomicBool + closed utils.Atomic[bool] closeOnce sync.Once // closeCh is closed when Close() is called. closeCh chan struct{} @@ -351,7 +351,7 @@ func (db *Database) Compact(start []byte, limit []byte) error { } func (db *Database) Close() error { - db.closed.SetValue(true) + db.closed.Set(true) db.closeOnce.Do(func() { close(db.closeCh) }) @@ -360,7 +360,7 @@ func (db *Database) Close() error { } func (db *Database) HealthCheck(context.Context) (interface{}, error) { - if db.closed.GetValue() { + if db.closed.Get() { return nil, database.ErrClosed } return nil, nil @@ -447,7 +447,7 @@ type iter struct { func (it *iter) Next() bool { // Short-circuit and set an error if the underlying database has been closed. - if it.db.closed.GetValue() { + if it.db.closed.Get() { it.key = nil it.val = nil it.err = database.ErrClosed diff --git a/database/linkeddb/linkeddb.go b/database/linkeddb/linkeddb.go index 892134f45f8b..5412f6974c47 100644 --- a/database/linkeddb/linkeddb.go +++ b/database/linkeddb/linkeddb.go @@ -45,7 +45,7 @@ type linkedDB struct { headKeyIsSynced, headKeyExists, headKeyIsUpdated, updatedHeadKeyExists bool headKey, updatedHeadKey []byte // these variables provide caching for the nodes. - nodeCache cache.Cacher // key -> *node + nodeCache cache.Cacher[string, *node] // key -> *node updatedNodes map[string]*node // db is the underlying database that this list is stored in. @@ -64,7 +64,7 @@ type node struct { func New(db database.Database, cacheSize int) LinkedDB { return &linkedDB{ - nodeCache: &cache.LRU{Size: cacheSize}, + nodeCache: &cache.LRU[string, *node]{Size: cacheSize}, updatedNodes: make(map[string]*node), db: db, batch: db.NewBatch(), @@ -300,8 +300,7 @@ func (ldb *linkedDB) getNode(key []byte) (node, error) { defer ldb.cacheLock.Unlock() keyStr := string(key) - if nodeIntf, exists := ldb.nodeCache.Get(keyStr); exists { - n := nodeIntf.(*node) + if n, exists := ldb.nodeCache.Get(keyStr); exists { if n == nil { return node{}, database.ErrNotFound } @@ -310,9 +309,7 @@ func (ldb *linkedDB) getNode(key []byte) (node, error) { nodeBytes, err := ldb.db.Get(nodeKey(key)) if err == database.ErrNotFound { - // Passing [nil] without the pointer cast would result in a panic when - // performing the type assertion in the above cache check. - ldb.nodeCache.Put(keyStr, (*node)(nil)) + ldb.nodeCache.Put(keyStr, nil) return node{}, err } if err != nil { diff --git a/database/mockdb/db.go b/database/mockdb/db.go deleted file mode 100644 index 697d0977cec7..000000000000 --- a/database/mockdb/db.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package mockdb - -import ( - "context" - "errors" - - "github.com/ava-labs/avalanchego/database" -) - -var ( - errNoFunction = errors.New("user didn't specify what value(s) return") - - _ database.Database = (*Database)(nil) -) - -// Database is a mock database meant to be used in tests. -// You specify the database's return value(s) for a given method call by -// assign value to the corresponding member. -// For example, to specify what should happen when Has is called, -// assign a value to OnHas. -// If no value is assigned to the corresponding member, the method returns an error or nil -// If you -type Database struct { - // Executed when Has is called - OnHas func([]byte) (bool, error) - OnGet func([]byte) ([]byte, error) - OnPut func([]byte, []byte) error - OnDelete func([]byte) error - OnNewBatch func() database.Batch - OnNewIterator func() database.Iterator - OnNewIteratorWithStart func([]byte) database.Iterator - OnNewIteratorWithPrefix func([]byte) database.Iterator - OnNewIteratorWithStartAndPrefix func([]byte, []byte) database.Iterator - OnCompact func([]byte, []byte) error - OnClose func() error - OnHealthCheck func(context.Context) (interface{}, error) -} - -// New returns a new mock database -func New() *Database { - return &Database{} -} - -func (db *Database) Has(k []byte) (bool, error) { - if db.OnHas == nil { - return false, errNoFunction - } - return db.OnHas(k) -} - -func (db *Database) Get(k []byte) ([]byte, error) { - if db.OnGet == nil { - return nil, errNoFunction - } - return db.OnGet(k) -} - -func (db *Database) Put(k, v []byte) error { - if db.OnPut == nil { - return errNoFunction - } - return db.OnPut(k, v) -} - -func (db *Database) Delete(k []byte) error { - if db.OnDelete == nil { - return errNoFunction - } - return db.OnDelete(k) -} - -func (db *Database) NewBatch() database.Batch { - if db.OnNewBatch == nil { - return nil - } - return db.OnNewBatch() -} - -func (db *Database) NewIterator() database.Iterator { - if db.OnNewIterator == nil { - return nil - } - return db.OnNewIterator() -} - -func (db *Database) NewIteratorWithStart(start []byte) database.Iterator { - if db.OnNewIteratorWithStart == nil { - return nil - } - return db.OnNewIteratorWithStart(start) -} - -func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator { - if db.OnNewIteratorWithPrefix == nil { - return nil - } - return db.OnNewIteratorWithPrefix(prefix) -} - -func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator { - if db.OnNewIteratorWithStartAndPrefix == nil { - return nil - } - return db.OnNewIteratorWithStartAndPrefix(start, prefix) -} - -func (db *Database) Compact(start []byte, limit []byte) error { - if db.OnCompact == nil { - return errNoFunction - } - return db.OnCompact(start, limit) -} - -func (db *Database) Close() error { - if db.OnClose == nil { - return errNoFunction - } - return db.OnClose() -} - -func (db *Database) HealthCheck(ctx context.Context) (interface{}, error) { - if db.OnHealthCheck == nil { - return nil, errNoFunction - } - return db.OnHealthCheck(ctx) -} diff --git a/database/mockdb/db_test.go b/database/mockdb/db_test.go deleted file mode 100644 index 60c201b143e0..000000000000 --- a/database/mockdb/db_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package mockdb - -import ( - "bytes" - "context" - "errors" - "testing" -) - -var errTest = errors.New("non-nil error") - -// Assert that when no members are assigned values, every method returns nil/error -func TestDefaultError(t *testing.T) { - db := New() - - if err := db.Close(); err == nil { - t.Fatal("should have errored") - } - if _, err := db.Has([]byte{}); err == nil { - t.Fatal("should have errored") - } - if _, err := db.Get([]byte{}); err == nil { - t.Fatal("should have errored") - } - if err := db.Put([]byte{}, []byte{}); err == nil { - t.Fatal("should have errored") - } - if err := db.Delete([]byte{}); err == nil { - t.Fatal("should have errored") - } - if batch := db.NewBatch(); batch != nil { - t.Fatal("should have been nil") - } - if iterator := db.NewIterator(); iterator != nil { - t.Fatal("should have errored") - } - if iterator := db.NewIteratorWithPrefix([]byte{}); iterator != nil { - t.Fatal("should have errored") - } - if iterator := db.NewIteratorWithStart([]byte{}); iterator != nil { - t.Fatal("should have errored") - } - if iterator := db.NewIteratorWithStartAndPrefix([]byte{}, []byte{}); iterator != nil { - t.Fatal("should have errored") - } - if err := db.Compact([]byte{}, []byte{}); err == nil { - t.Fatal("should have errored") - } - if _, err := db.HealthCheck(context.Background()); err == nil { - t.Fatal("should have errored") - } -} - -// Assert that mocking works for Get -func TestGet(t *testing.T) { - db := New() - - // Mock Has() - db.OnHas = func(b []byte) (bool, error) { - if bytes.Equal(b, []byte{1, 2, 3}) { - return true, nil - } - return false, errTest - } - - if has, err := db.Has([]byte{1, 2, 3}); err != nil { - t.Fatal("should not have errored") - } else if !has { - t.Fatal("has should be true") - } - - if _, err := db.Has([]byte{1, 2}); err == nil { - t.Fatal("should have have errored") - } -} diff --git a/database/rpcdb/db_client.go b/database/rpcdb/db_client.go index 767e76f1ac88..1faf1eebdaf7 100644 --- a/database/rpcdb/db_client.go +++ b/database/rpcdb/db_client.go @@ -28,7 +28,7 @@ var ( type DatabaseClient struct { client rpcdbpb.DatabaseClient - closed utils.AtomicBool + closed utils.Atomic[bool] } // NewClient returns a database instance connected to a remote database instance @@ -127,7 +127,7 @@ func (db *DatabaseClient) Compact(start, limit []byte) error { // Close attempts to close the database func (db *DatabaseClient) Close() error { - db.closed.SetValue(true) + db.closed.Set(true) resp, err := db.client.Close(context.Background(), &rpcdbpb.CloseRequest{}) if err != nil { return err @@ -239,7 +239,7 @@ type iterator struct { // Next attempts to move the iterator to the next element and returns if this // succeeded func (it *iterator) Next() bool { - if it.db.closed.GetValue() { + if it.db.closed.Get() { it.data = nil it.errs.Add(database.ErrClosed) return false diff --git a/go.mod b/go.mod index cfe577429957..fe2e05afd05d 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/Microsoft/go-winio v0.5.2 github.com/NYTimes/gziphandler v1.1.1 github.com/ava-labs/avalanche-network-runner-sdk v0.3.0 - github.com/ava-labs/coreth v0.11.6-rc.0 + github.com/ava-labs/coreth v0.11.7-rc.0 github.com/ava-labs/ledger-avalanche/go v0.0.0-20230105152938-00a24d05a8c7 github.com/btcsuite/btcd/btcutil v1.1.3 github.com/decred/dcrd/dcrec/secp256k1/v3 v3.0.0-20200627015759-01fd2de07837 @@ -33,6 +33,7 @@ require ( github.com/nbutton23/zxcvbn-go v0.0.0-20180912185939-ae427f1e4c1d github.com/onsi/ginkgo/v2 v2.4.0 github.com/onsi/gomega v1.24.0 + github.com/pires/go-proxyproto v0.6.2 github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_model v0.2.0 github.com/rs/cors v1.7.0 diff --git a/go.sum b/go.sum index ec8254a78ee7..823944c9061a 100644 --- a/go.sum +++ b/go.sum @@ -57,8 +57,8 @@ github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/ava-labs/avalanche-network-runner-sdk v0.3.0 h1:TVi9JEdKNU/RevYZ9PyW4pULbEdS+KQDA9Ki2DUvuAs= github.com/ava-labs/avalanche-network-runner-sdk v0.3.0/go.mod h1:SgKJvtqvgo/Bl/c8fxEHCLaSxEbzimYfBopcfrajxQk= -github.com/ava-labs/coreth v0.11.6-rc.0 h1:P8g/vqVx7nZBUHhM95oq9bcsY37P1Y7NNVb7RPe0mW8= -github.com/ava-labs/coreth v0.11.6-rc.0/go.mod h1:xgjjJdl50zhHlWPP+3Ux5LxfvFcbSG60tGK6QUkFDhI= +github.com/ava-labs/coreth v0.11.7-rc.0 h1:C+6vtAqBz3KrGyuSeZSwYeFTNalCKxxLdClWaFGAUIY= +github.com/ava-labs/coreth v0.11.7-rc.0/go.mod h1:e7SuEq6g3+YWyNPiznJF6KnnAuc0HCXxiSshMNj52Sw= github.com/ava-labs/ledger-avalanche/go v0.0.0-20230105152938-00a24d05a8c7 h1:EdxD90j5sClfL5Ngpz2TlnbnkNYdFPDXa0jDOjam65c= github.com/ava-labs/ledger-avalanche/go v0.0.0-20230105152938-00a24d05a8c7/go.mod h1:XhiXSrh90sHUbkERzaxEftCmUz53eCijshDLZ4fByVM= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= @@ -370,6 +370,8 @@ github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3v github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml/v2 v2.0.1 h1:8e3L2cCQzLFi2CR4g7vGFuFxX7Jl1kKX8gW+iV0GUKU= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= +github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= +github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/indexer/indexer.go b/indexer/indexer.go index 7a2f2f0ba53b..891f0017294e 100644 --- a/indexer/indexer.go +++ b/indexer/indexer.go @@ -21,9 +21,9 @@ import ( "github.com/ava-labs/avalanchego/database/prefixdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" - "github.com/ava-labs/avalanchego/snow/engine/avalanche" + "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex" "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/snow/engine/snowman" + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/utils/json" @@ -145,22 +145,21 @@ type indexer struct { consensusAcceptorGroup snow.AcceptorGroup } -// Assumes [engine]'s context lock is not held -func (i *indexer) RegisterChain(name string, engine common.Engine) { +// Assumes [ctx.Lock] is not held +func (i *indexer) RegisterChain(chainName string, ctx *snow.ConsensusContext, vm common.VM) { i.lock.Lock() defer i.lock.Unlock() - ctx := engine.Context() if i.closed { i.log.Debug("not registering chain to indexer", zap.String("reason", "indexer is closed"), - zap.String("chainName", name), + zap.String("chainName", chainName), ) return } else if ctx.SubnetID != constants.PrimaryNetworkID { i.log.Debug("not registering chain to indexer", zap.String("reason", "not in the primary network"), - zap.String("chainName", name), + zap.String("chainName", chainName), ) return } @@ -177,7 +176,7 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { isIncomplete, err := i.isIncomplete(chainID) if err != nil { i.log.Error("couldn't get whether chain is incomplete", - zap.String("chainName", name), + zap.String("chainName", chainName), zap.Error(err), ) if err := i.close(); err != nil { @@ -192,7 +191,7 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { previouslyIndexed, err := i.previouslyIndexed(chainID) if err != nil { i.log.Error("couldn't get whether chain was previously indexed", - zap.String("chainName", name), + zap.String("chainName", chainName), zap.Error(err), ) if err := i.close(); err != nil { @@ -208,7 +207,7 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { // We indexed this chain in a previous run but not in this run. // This would create an incomplete index, which is not allowed, so exit. i.log.Fatal("running would cause index to become incomplete but incomplete indices are disabled", - zap.String("chainName", name), + zap.String("chainName", chainName), ) if err := i.close(); err != nil { i.log.Error("failed to close indexer", @@ -224,7 +223,7 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { return } i.log.Fatal("couldn't mark chain as incomplete", - zap.String("chainName", name), + zap.String("chainName", chainName), zap.Error(err), ) if err := i.close(); err != nil { @@ -237,7 +236,7 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { if !i.allowIncompleteIndex && isIncomplete && (previouslyIndexed || i.hasRunBefore) { i.log.Fatal("index is incomplete but incomplete indices are disabled. Shutting down", - zap.String("chainName", name), + zap.String("chainName", chainName), ) if err := i.close(); err != nil { i.log.Error("failed to close indexer", @@ -250,7 +249,7 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { // Mark that in this run, this chain was indexed if err := i.markPreviouslyIndexed(chainID); err != nil { i.log.Error("couldn't mark chain as indexed", - zap.String("chainName", name), + zap.String("chainName", chainName), zap.Error(err), ) if err := i.close(); err != nil { @@ -261,12 +260,13 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { return } - switch engine.(type) { - case snowman.Engine: - index, err := i.registerChainHelper(chainID, blockPrefix, name, "block", i.consensusAcceptorGroup) + switch vm.(type) { + case vertex.DAGVM: + vtxIndex, err := i.registerChainHelper(chainID, vtxPrefix, chainName, "vtx", i.consensusAcceptorGroup) if err != nil { - i.log.Fatal("failed to create block index", - zap.String("chainName", name), + i.log.Fatal("couldn't create index", + zap.String("chainName", chainName), + zap.String("endpoint", "vtx"), zap.Error(err), ) if err := i.close(); err != nil { @@ -276,12 +276,13 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { } return } - i.blockIndices[chainID] = index - case avalanche.Engine: - vtxIndex, err := i.registerChainHelper(chainID, vtxPrefix, name, "vtx", i.consensusAcceptorGroup) + i.vtxIndices[chainID] = vtxIndex + + txIndex, err := i.registerChainHelper(chainID, txPrefix, chainName, "tx", i.decisionAcceptorGroup) if err != nil { - i.log.Fatal("couldn't create vertex index", - zap.String("chainName", name), + i.log.Fatal("couldn't create index", + zap.String("chainName", chainName), + zap.String("endpoint", "tx"), zap.Error(err), ) if err := i.close(); err != nil { @@ -291,33 +292,33 @@ func (i *indexer) RegisterChain(name string, engine common.Engine) { } return } - i.vtxIndices[chainID] = vtxIndex - - txIndex, err := i.registerChainHelper(chainID, txPrefix, name, "tx", i.decisionAcceptorGroup) + i.txIndices[chainID] = txIndex + case block.ChainVM: + index, err := i.registerChainHelper(chainID, blockPrefix, chainName, "block", i.consensusAcceptorGroup) if err != nil { - i.log.Fatal("couldn't create tx index for", - zap.String("chainName", name), + i.log.Fatal("failed to create index", + zap.String("chainName", chainName), + zap.String("endpoint", "block"), zap.Error(err), ) if err := i.close(); err != nil { - i.log.Error("failed to close indexer:", + i.log.Error("failed to close indexer", zap.Error(err), ) } return } - i.txIndices[chainID] = txIndex + i.blockIndices[chainID] = index default: - engineType := fmt.Sprintf("%T", engine) - i.log.Error("got unexpected engine type", - zap.String("engineType", engineType), + vmType := fmt.Sprintf("%T", vm) + i.log.Error("got unexpected vm type", + zap.String("vmType", vmType), ) if err := i.close(); err != nil { i.log.Error("failed to close indexer", zap.Error(err), ) } - return } } diff --git a/indexer/indexer_test.go b/indexer/indexer_test.go index a6aa464888ec..2d8495b79250 100644 --- a/indexer/indexer_test.go +++ b/indexer/indexer_test.go @@ -19,16 +19,12 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" - "github.com/ava-labs/avalanchego/snow/consensus/avalanche" "github.com/ava-labs/avalanchego/snow/consensus/snowstorm" "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex" "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/snow/engine/snowman" + "github.com/ava-labs/avalanchego/snow/engine/snowman/block/mocks" "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/logging" - - aveng "github.com/ava-labs/avalanchego/snow/engine/avalanche" - smblockmocks "github.com/ava-labs/avalanchego/snow/engine/snowman/block/mocks" ) var ( @@ -165,12 +161,8 @@ func TestIndexer(t *testing.T) { require.False(previouslyIndexed) // Register this chain, creating a new index - chainVM := smblockmocks.NewMockChainVM(ctrl) - chainEngine := snowman.NewMockEngine(ctrl) - chainEngine.EXPECT().Context().AnyTimes().Return(chain1Ctx) - chainEngine.EXPECT().GetVM().AnyTimes().Return(chainVM) - - idxr.RegisterChain("chain1", chainEngine) + chainVM := mocks.NewMockChainVM(ctrl) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) isIncomplete, err = idxr.isIncomplete(chain1Ctx.ChainID) require.NoError(err) require.False(isIncomplete) @@ -255,7 +247,7 @@ func TestIndexer(t *testing.T) { require.False(isIncomplete) // Register the same chain as before - idxr.RegisterChain("chain1", chainEngine) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) blkIdx = idxr.blockIndices[chain1Ctx.ChainID] require.NotNil(blkIdx) container, err = blkIdx.GetLastAccepted() @@ -272,10 +264,7 @@ func TestIndexer(t *testing.T) { require.NoError(err) require.False(previouslyIndexed) dagVM := vertex.NewMockDAGVM(ctrl) - dagEngine := aveng.NewMockEngine(ctrl) - dagEngine.EXPECT().Context().AnyTimes().Return(chain2Ctx) - dagEngine.EXPECT().GetVM().AnyTimes().Return(dagVM) - idxr.RegisterChain("chain2", dagEngine) + idxr.RegisterChain("chain2", chain2Ctx, dagVM) require.NoError(err) server = config.APIServer.(*apiServerMock) require.EqualValues(3, server.timesCalled) // block index, vtx index, tx index @@ -290,21 +279,11 @@ func TestIndexer(t *testing.T) { vtxID, vtxBytes := ids.GenerateTestID(), utils.RandomBytes(32) expectedVtx := Container{ ID: vtxID, - Bytes: blkBytes, + Bytes: vtxBytes, Timestamp: now.UnixNano(), } - // Mocked VM knows about this block now - dagEngine.EXPECT().GetVtx(gomock.Any(), vtxID).Return( - &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - StatusV: choices.Accepted, - IDV: vtxID, - }, - BytesV: vtxBytes, - }, nil, - ).AnyTimes() - require.NoError(config.ConsensusAcceptorGroup.Accept(chain2Ctx, vtxID, blkBytes)) + require.NoError(config.ConsensusAcceptorGroup.Accept(chain2Ctx, vtxID, vtxBytes)) vtxIdx := idxr.vtxIndices[chain2Ctx.ChainID] require.NotNil(vtxIdx) @@ -406,8 +385,8 @@ func TestIndexer(t *testing.T) { require.NoError(err) idxr, ok = idxrIntf.(*indexer) require.True(ok) - idxr.RegisterChain("chain1", chainEngine) - idxr.RegisterChain("chain2", dagEngine) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) + idxr.RegisterChain("chain2", chain2Ctx, dagVM) // Verify state lastAcceptedTx, err = idxr.txIndices[chain2Ctx.ChainID].GetLastAccepted() @@ -454,9 +433,8 @@ func TestIncompleteIndex(t *testing.T) { previouslyIndexed, err := idxr.previouslyIndexed(chain1Ctx.ChainID) require.NoError(err) require.False(previouslyIndexed) - chainEngine := snowman.NewMockEngine(ctrl) - chainEngine.EXPECT().Context().AnyTimes().Return(chain1Ctx) - idxr.RegisterChain("chain1", chainEngine) + chainVM := mocks.NewMockChainVM(ctrl) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) isIncomplete, err = idxr.isIncomplete(chain1Ctx.ChainID) require.NoError(err) require.True(isIncomplete) @@ -475,7 +453,7 @@ func TestIncompleteIndex(t *testing.T) { // Register the chain again. Should die due to incomplete index. require.NoError(config.DB.(*versiondb.Database).Commit()) - idxr.RegisterChain("chain1", chainEngine) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) require.True(idxr.closed) // Close and re-open the indexer, this time with indexing enabled @@ -490,7 +468,7 @@ func TestIncompleteIndex(t *testing.T) { require.True(idxr.allowIncompleteIndex) // Register the chain again. Should be OK - idxr.RegisterChain("chain1", chainEngine) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) require.False(idxr.closed) // Close the indexer and re-open with indexing disabled and @@ -536,10 +514,7 @@ func TestIgnoreNonDefaultChains(t *testing.T) { chain1Ctx.SubnetID = ids.GenerateTestID() // RegisterChain should return without adding an index for this chain - chainVM := smblockmocks.NewMockChainVM(ctrl) - chainEngine := snowman.NewMockEngine(ctrl) - chainEngine.EXPECT().Context().AnyTimes().Return(chain1Ctx) - chainEngine.EXPECT().GetVM().AnyTimes().Return(chainVM) - idxr.RegisterChain("chain1", chainEngine) + chainVM := mocks.NewMockChainVM(ctrl) + idxr.RegisterChain("chain1", chain1Ctx, chainVM) require.Len(idxr.blockIndices, 0) } diff --git a/network/config.go b/network/config.go index d6e690eb344e..eaca1c2d7d3c 100644 --- a/network/config.go +++ b/network/config.go @@ -109,6 +109,9 @@ type Config struct { DelayConfig `json:"delayConfig"` ThrottlerConfig ThrottlerConfig `json:"throttlerConfig"` + ProxyEnabled bool `json:"proxyEnabled"` + ProxyReadHeaderTimeout time.Duration `json:"proxyReadHeaderTimeout"` + DialerConfig dialer.Config `json:"dialerConfig"` TLSConfig *tls.Config `json:"-"` diff --git a/network/network.go b/network/network.go index 0b4514af469b..08400a87f4b1 100644 --- a/network/network.go +++ b/network/network.go @@ -15,6 +15,8 @@ import ( gomath "math" + "github.com/pires/go-proxyproto" + "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" @@ -56,6 +58,8 @@ var ( errNotValidator = errors.New("node is not a validator") errNotTracked = errors.New("subnet is not tracked") errSubnetNotExist = errors.New("subnet does not exist") + errExpectedProxy = errors.New("expected proxy") + errExpectedTCPProtocol = errors.New("expected TCP protocol") ) // Network defines the functionality of the networking library. @@ -189,6 +193,28 @@ func NewNetwork( return nil, errMissingPrimaryValidators } + if config.ProxyEnabled { + // Wrap the listener to process the proxy header. + listener = &proxyproto.Listener{ + Listener: listener, + Policy: func(net.Addr) (proxyproto.Policy, error) { + // Do not perform any fuzzy matching, the header must be + // provided. + return proxyproto.REQUIRE, nil + }, + ValidateHeader: func(h *proxyproto.Header) error { + if !h.Command.IsProxy() { + return errExpectedProxy + } + if h.TransportProtocol != proxyproto.TCPv4 && h.TransportProtocol != proxyproto.TCPv6 { + return errExpectedTCPProtocol + } + return nil + }, + ReadHeaderTimeout: config.ProxyReadHeaderTimeout, + } + } + inboundMsgThrottler, err := throttling.NewInboundMsgThrottler( log, config.Namespace, @@ -705,34 +731,50 @@ func (n *network) Dispatch() error { continue } - // We pessimistically drop an incoming connection if the remote - // address is found in connectedIPs, myIPs, or peerAliasIPs. - // This protects our node from spending CPU cycles on TLS - // handshakes to upgrade connections from existing peers. - // Specifically, this can occur when one of our existing - // peers attempts to connect to one our IP aliases (that they - // aren't yet aware is an alias). - remoteAddr := conn.RemoteAddr().String() - ip, err := ips.ToIPPort(remoteAddr) - if err != nil { - errs.Add(fmt.Errorf("unable to convert remote address %s to IP: %w", remoteAddr, err)) - break - } + // Note: listener.Accept is rate limited outside of this package, so a + // peer can not just arbitrarily spin up goroutines here. + go func() { + // We pessimistically drop an incoming connection if the remote + // address is found in connectedIPs, myIPs, or peerAliasIPs. This + // protects our node from spending CPU cycles on TLS handshakes to + // upgrade connections from existing peers. Specifically, this can + // occur when one of our existing peers attempts to connect to one + // our IP aliases (that they aren't yet aware is an alias). + // + // Note: Calling [RemoteAddr] with the Proxy protocol enabled may + // block for up to ProxyReadHeaderTimeout. Therefore, we ensure to + // call this function inside the go-routine, rather than the main + // accept loop. + remoteAddr := conn.RemoteAddr().String() + ip, err := ips.ToIPPort(remoteAddr) + if err != nil { + n.peerConfig.Log.Error("failed to parse remote address", + zap.String("peerIP", remoteAddr), + zap.Error(err), + ) + _ = conn.Close() + return + } - if !n.inboundConnUpgradeThrottler.ShouldUpgrade(ip) { - n.peerConfig.Log.Debug("failed to upgrade connection", - zap.String("reason", "rate-limiting"), + if !n.inboundConnUpgradeThrottler.ShouldUpgrade(ip) { + n.peerConfig.Log.Debug("failed to upgrade connection", + zap.String("reason", "rate-limiting"), + zap.Stringer("peerIP", ip), + ) + n.metrics.inboundConnRateLimited.Inc() + _ = conn.Close() + return + } + n.metrics.inboundConnAllowed.Inc() + + n.peerConfig.Log.Verbo("starting to upgrade connection", + zap.String("direction", "inbound"), zap.Stringer("peerIP", ip), ) - n.metrics.inboundConnRateLimited.Inc() - _ = conn.Close() - continue - } - n.metrics.inboundConnAllowed.Inc() - go func() { if err := n.upgrade(conn, n.serverUpgrader); err != nil { - n.peerConfig.Log.Verbo("failed to upgrade inbound connection", + n.peerConfig.Log.Verbo("failed to upgrade connection", + zap.String("direction", "inbound"), zap.Error(err), ) } @@ -1067,6 +1109,11 @@ func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) { continue } + n.peerConfig.Log.Verbo("starting to upgrade connection", + zap.String("direction", "outbound"), + zap.Stringer("peerIP", ip.ip.IP), + ) + err = n.upgrade(conn, n.clientUpgrader) if err != nil { n.peerConfig.Log.Verbo( @@ -1135,9 +1182,9 @@ func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader) error { } n.peersLock.Lock() - defer n.peersLock.Unlock() - if n.closing { + n.peersLock.Unlock() + _ = tlsConn.Close() n.peerConfig.Log.Verbo( "dropping connection", @@ -1148,6 +1195,8 @@ func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader) error { } if _, connecting := n.connectingPeers.GetByID(nodeID); connecting { + n.peersLock.Unlock() + _ = tlsConn.Close() n.peerConfig.Log.Verbo( "dropping connection", @@ -1158,6 +1207,8 @@ func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader) error { } if _, connected := n.connectedPeers.GetByID(nodeID); connected { + n.peersLock.Unlock() + _ = tlsConn.Close() n.peerConfig.Log.Verbo( "dropping connection", @@ -1194,6 +1245,7 @@ func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader) error { ), ) n.connectingPeers.Add(peer) + n.peersLock.Unlock() return nil } diff --git a/network/peer/peer.go b/network/peer/peer.go index 49518321c69c..e3187e1e3e9c 100644 --- a/network/peer/peer.go +++ b/network/peer/peer.go @@ -137,14 +137,14 @@ type peer struct { // True if this peer has sent us a valid Version message and // is running a compatible version. // Only modified on the connection's reader routine. - gotVersion utils.AtomicBool + gotVersion utils.Atomic[bool] // True if the peer: // * Has sent us a Version message // * Has sent us a PeerList message // * Is running a compatible version // Only modified on the connection's reader routine. - finishedHandshake utils.AtomicBool + finishedHandshake utils.Atomic[bool] // onFinishHandshake is closed when the peer finishes the p2p handshake. onFinishHandshake chan struct{} @@ -226,7 +226,7 @@ func (p *peer) LastReceived() time.Time { } func (p *peer) Ready() bool { - return p.finishedHandshake.GetValue() + return p.finishedHandshake.Get() } func (p *peer) AwaitReady(ctx context.Context) error { @@ -638,7 +638,7 @@ func (p *peer) sendNetworkMessages() { return } - if p.finishedHandshake.GetValue() { + if p.finishedHandshake.Get() { if err := p.VersionCompatibility.Compatible(p.version); err != nil { p.Log.Debug("disconnecting from peer", zap.String("reason", "version not compatible"), @@ -689,7 +689,7 @@ func (p *peer) handle(msg message.InboundMessage) { msg.OnFinishedHandling() return } - if !p.finishedHandshake.GetValue() { + if !p.finishedHandshake.Get() { p.Log.Debug( "dropping message", zap.String("reason", "handshake isn't finished"), @@ -794,7 +794,7 @@ func (p *peer) observeUptime(subnetID ids.ID, uptime uint32) { } func (p *peer) handleVersion(msg *p2p.Version) { - if p.gotVersion.GetValue() { + if p.gotVersion.Get() { // TODO: this should never happen, should we close the connection here? p.Log.Verbo("dropping duplicated version message", zap.Stringer("nodeID", p.id), @@ -926,7 +926,7 @@ func (p *peer) handleVersion(msg *p2p.Version) { return } - p.gotVersion.SetValue(true) + p.gotVersion.Set(true) peerIPs, err := p.Network.Peers(p.id) if err != nil { @@ -957,13 +957,13 @@ func (p *peer) handleVersion(msg *p2p.Version) { } func (p *peer) handlePeerList(msg *p2p.PeerList) { - if !p.finishedHandshake.GetValue() { - if !p.gotVersion.GetValue() { + if !p.finishedHandshake.Get() { + if !p.gotVersion.Get() { return } p.Network.Connected(p.id) - p.finishedHandshake.SetValue(true) + p.finishedHandshake.Set(true) close(p.onFinishHandshake) } diff --git a/node/node.go b/node/node.go index 7aa9d20252f0..44ddd84457a8 100644 --- a/node/node.go +++ b/node/node.go @@ -163,10 +163,10 @@ type Node struct { shutdownOnce sync.Once // True if node is shutting down or is done shutting down - shuttingDown utils.AtomicBool + shuttingDown utils.Atomic[bool] // Sets the exit code - shuttingDownExitCode utils.AtomicInterface + shuttingDownExitCode utils.Atomic[int] // Incremented only once on initialization. // Decremented when node is done shutting down. @@ -285,7 +285,7 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error { // shutdown. timer := timer.NewTimer(func() { // If the timeout fires and we're already shutting down, nothing to do. - if !n.shuttingDown.GetValue() { + if !n.shuttingDown.Get() { n.Log.Warn("failed to connect to bootstrap nodes", zap.Stringer("beacons", n.beacons), zap.Duration("duration", n.Config.BootstrapBeaconConnectionTimeout), @@ -362,7 +362,7 @@ func (n *Node) Dispatch() error { // When [n].Shutdown() is called, [n.APIServer].Close() is called. // This causes [n.APIServer].Dispatch() to return an error. // If that happened, don't log/return an error here. - if !n.shuttingDown.GetValue() { + if !n.shuttingDown.Get() { n.Log.Fatal("API server dispatch failed", zap.Error(err), ) @@ -1338,10 +1338,10 @@ func (n *Node) Initialize( // Shutdown this node // May be called multiple times func (n *Node) Shutdown(exitCode int) { - if !n.shuttingDown.GetValue() { // only set the exit code once - n.shuttingDownExitCode.SetValue(exitCode) + if !n.shuttingDown.Get() { // only set the exit code once + n.shuttingDownExitCode.Set(exitCode) } - n.shuttingDown.SetValue(true) + n.shuttingDown.Set(true) n.shutdownOnce.Do(n.shutdown) } @@ -1425,8 +1425,5 @@ func (n *Node) shutdown() { } func (n *Node) ExitCode() int { - if exitCode, ok := n.shuttingDownExitCode.GetValue().(int); ok { - return exitCode - } - return 0 + return n.shuttingDownExitCode.Get() } diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 2e4b4d116979..ee79dcebe320 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -6,12 +6,10 @@ github.com/ava-labs/avalanchego/message=OutboundMsgBuilder=message/mock_outbound github.com/ava-labs/avalanchego/network/peer=GossipTracker=network/peer/mock_gossip_tracker.go github.com/ava-labs/avalanchego/snow/consensus/snowman=Block=snow/consensus/snowman/mock_block.go github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex=DAGVM=snow/engine/avalanche/vertex/mock_vm.go -github.com/ava-labs/avalanchego/snow/engine/avalanche=Engine=snow/engine/avalanche/mock_engine.go github.com/ava-labs/avalanchego/snow/engine/snowman/block=BuildBlockWithContextChainVM=snow/engine/snowman/block/mocks/build_block_with_context_vm.go github.com/ava-labs/avalanchego/snow/engine/snowman/block=ChainVM=snow/engine/snowman/block/mocks/chain_vm.go github.com/ava-labs/avalanchego/snow/engine/snowman/block=StateSyncableVM=snow/engine/snowman/block/mocks/state_syncable_vm.go github.com/ava-labs/avalanchego/snow/engine/snowman/block=WithVerifyContext=snow/engine/snowman/block/mocks/with_verify_context.go -github.com/ava-labs/avalanchego/snow/engine/snowman=Engine=snow/engine/snowman/mock_engine.go github.com/ava-labs/avalanchego/snow/networking/handler=Handler=snow/networking/handler/mock_handler.go github.com/ava-labs/avalanchego/snow/networking/timeout=Manager=snow/networking/timeout/mock_manager.go github.com/ava-labs/avalanchego/snow/networking/tracker=Targeter=snow/networking/tracker/mock_targeter.go diff --git a/snow/context.go b/snow/context.go index fc0f83d10d2f..5c57c0d865c2 100644 --- a/snow/context.go +++ b/snow/context.go @@ -75,55 +75,17 @@ type ConsensusContext struct { // accepted. ConsensusAcceptor Acceptor - // Non-zero iff this chain bootstrapped. - state utils.AtomicInterface + // State indicates the current state of this consensus instance. + State utils.Atomic[EngineState] // True iff this chain is executing transactions as part of bootstrapping. - executing utils.AtomicBool + Executing utils.Atomic[bool] // True iff this chain is currently state-syncing - stateSyncing utils.AtomicBool + StateSyncing utils.Atomic[bool] // Indicates this chain is available to only validators. - validatorOnly utils.AtomicBool -} - -func (ctx *ConsensusContext) SetState(newState State) { - ctx.state.SetValue(newState) -} - -func (ctx *ConsensusContext) GetState() State { - stateInf := ctx.state.GetValue() - return stateInf.(State) -} - -// IsExecuting returns true iff this chain is still executing transactions. -func (ctx *ConsensusContext) IsExecuting() bool { - return ctx.executing.GetValue() -} - -// Executing marks this chain as executing or not. -// Set to "true" if there's an ongoing transaction. -func (ctx *ConsensusContext) Executing(b bool) { - ctx.executing.SetValue(b) -} - -func (ctx *ConsensusContext) IsRunningStateSync() bool { - return ctx.stateSyncing.GetValue() -} - -func (ctx *ConsensusContext) RunningStateSync(b bool) { - ctx.stateSyncing.SetValue(b) -} - -// IsValidatorOnly returns true iff this chain is available only to validators -func (ctx *ConsensusContext) IsValidatorOnly() bool { - return ctx.validatorOnly.GetValue() -} - -// SetValidatorOnly marks this chain as available only to validators -func (ctx *ConsensusContext) SetValidatorOnly() { - ctx.validatorOnly.SetValue(true) + ValidatorOnly utils.Atomic[bool] } func DefaultContextTest() *Context { diff --git a/snow/engine/avalanche/bootstrap/bootstrapper.go b/snow/engine/avalanche/bootstrap/bootstrapper.go index f16aa2e0f74c..b55efe6792a8 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper.go @@ -14,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/avalanche" @@ -52,7 +53,7 @@ func New(ctx context.Context, config Config, onFinished func(ctx context.Context ChitsHandler: common.NewNoOpChitsHandler(config.Ctx.Log), AppHandler: common.NewNoOpAppHandler(config.Ctx.Log), - processedCache: &cache.LRU{Size: cacheSize}, + processedCache: &cache.LRU[ids.ID, struct{}]{Size: cacheSize}, Fetcher: common.Fetcher{OnFinished: onFinished}, executedStateTransitions: math.MaxInt32, } @@ -106,7 +107,7 @@ type bootstrapper struct { needToFetch set.Set[ids.ID] // Contains IDs of vertices that have recently been processed - processedCache *cache.LRU + processedCache *cache.LRU[ids.ID, struct{}] // number of state transitions executed executedStateTransitions int @@ -322,7 +323,10 @@ func (*bootstrapper) Notify(context.Context, common.Message) error { func (b *bootstrapper) Start(ctx context.Context, startReqID uint32) error { b.Ctx.Log.Info("starting bootstrap") - b.Ctx.SetState(snow.Bootstrapping) + b.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.Bootstrapping, + }) if err := b.VM.SetState(ctx, snow.Bootstrapping); err != nil { return fmt.Errorf("failed to notify VM that bootstrapping has started: %w", err) @@ -488,7 +492,7 @@ func (b *bootstrapper) process(ctx context.Context, vtxs ...avalanche.Vertex) er return err } if height%stripeDistance < stripeWidth { // See comment for stripeDistance - b.processedCache.Put(vtxID, nil) + b.processedCache.Put(vtxID, struct{}{}) } if height == prevHeight { vtxHeightSet.Add(vtxID) diff --git a/snow/engine/avalanche/bootstrap/bootstrapper_test.go b/snow/engine/avalanche/bootstrap/bootstrapper_test.go index 34919276164f..6ff105d7c444 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper_test.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper_test.go @@ -12,6 +12,7 @@ import ( "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/database/prefixdb" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/avalanche" @@ -146,7 +147,10 @@ func TestBootstrapperSingleFrontier(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -193,7 +197,7 @@ func TestBootstrapperSingleFrontier(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case vtx0.Status() != choices.Accepted: t.Fatalf("Vertex should be accepted") @@ -249,7 +253,10 @@ func TestBootstrapperByzantineResponses(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -343,7 +350,7 @@ func TestBootstrapperByzantineResponses(t *testing.T) { switch { case *requestID != oldReqID: t.Fatal("should not have issued new request") - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case vtx0.Status() != choices.Accepted: t.Fatalf("Vertex should be accepted") @@ -427,7 +434,10 @@ func TestBootstrapperTxDependencies(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -498,7 +508,7 @@ func TestBootstrapperTxDependencies(t *testing.T) { t.Fatal(err) } - if config.Ctx.GetState() != snow.NormalOp { + if config.Ctx.State.Get().State != snow.NormalOp { t.Fatalf("Should have finished bootstrapping") } if tx0.Status() != choices.Accepted { @@ -571,7 +581,10 @@ func TestBootstrapperMissingTxDependency(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -631,7 +644,7 @@ func TestBootstrapperMissingTxDependency(t *testing.T) { t.Fatal(err) } - if config.Ctx.GetState() != snow.NormalOp { + if config.Ctx.State.Get().State != snow.NormalOp { t.Fatalf("Bootstrapping should have finished") } if tx0.Status() != choices.Unknown { // never saw this tx @@ -692,7 +705,10 @@ func TestBootstrapperIncompleteAncestors(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -760,7 +776,7 @@ func TestBootstrapperIncompleteAncestors(t *testing.T) { switch { case err != nil: // Provide vtx1; should request vtx0 t.Fatal(err) - case bs.Context().GetState() == snow.NormalOp: + case bs.Context().State.Get().State == snow.NormalOp: t.Fatalf("should not have finished") case requested != vtxID0: t.Fatal("should hae requested vtx0") @@ -770,7 +786,7 @@ func TestBootstrapperIncompleteAncestors(t *testing.T) { switch { case err != nil: // Provide vtx0; can finish now t.Fatal(err) - case bs.Context().GetState() != snow.NormalOp: + case bs.Context().State.Get().State != snow.NormalOp: t.Fatal("should have finished") case vtx0.Status() != choices.Accepted: t.Fatal("should be accepted") @@ -812,7 +828,10 @@ func TestBootstrapperFinalized(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -891,7 +910,7 @@ func TestBootstrapperFinalized(t *testing.T) { switch { case err != nil: t.Fatal(err) - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case vtx0.Status() != choices.Accepted: t.Fatalf("Vertex should be accepted") @@ -943,7 +962,10 @@ func TestBootstrapperAcceptsAncestorsParents(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -1024,7 +1046,7 @@ func TestBootstrapperAcceptsAncestorsParents(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case vtx0.Status() != choices.Accepted: t.Fatalf("Vertex should be accepted") @@ -1110,7 +1132,10 @@ func TestRestartBootstrapping(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) return nil }, ) @@ -1272,7 +1297,7 @@ func TestRestartBootstrapping(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case vtx0.Status() != choices.Accepted: t.Fatalf("Vertex should be accepted") diff --git a/snow/engine/avalanche/mock_engine.go b/snow/engine/avalanche/mock_engine.go deleted file mode 100644 index 2f6d6695df89..000000000000 --- a/snow/engine/avalanche/mock_engine.go +++ /dev/null @@ -1,618 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ava-labs/avalanchego/snow/engine/avalanche (interfaces: Engine) - -// Package avalanche is a generated GoMock package. -package avalanche - -import ( - context "context" - reflect "reflect" - time "time" - - ids "github.com/ava-labs/avalanchego/ids" - snow "github.com/ava-labs/avalanchego/snow" - avalanche "github.com/ava-labs/avalanchego/snow/consensus/avalanche" - common "github.com/ava-labs/avalanchego/snow/engine/common" - version "github.com/ava-labs/avalanchego/version" - gomock "github.com/golang/mock/gomock" -) - -// MockEngine is a mock of Engine interface. -type MockEngine struct { - ctrl *gomock.Controller - recorder *MockEngineMockRecorder -} - -// MockEngineMockRecorder is the mock recorder for MockEngine. -type MockEngineMockRecorder struct { - mock *MockEngine -} - -// NewMockEngine creates a new mock instance. -func NewMockEngine(ctrl *gomock.Controller) *MockEngine { - mock := &MockEngine{ctrl: ctrl} - mock.recorder = &MockEngineMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEngine) EXPECT() *MockEngineMockRecorder { - return m.recorder -} - -// Accepted mocks base method. -func (m *MockEngine) Accepted(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Accepted", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Accepted indicates an expected call of Accepted. -func (mr *MockEngineMockRecorder) Accepted(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accepted", reflect.TypeOf((*MockEngine)(nil).Accepted), arg0, arg1, arg2, arg3) -} - -// AcceptedFrontier mocks base method. -func (m *MockEngine) AcceptedFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptedFrontier", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// AcceptedFrontier indicates an expected call of AcceptedFrontier. -func (mr *MockEngineMockRecorder) AcceptedFrontier(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptedFrontier", reflect.TypeOf((*MockEngine)(nil).AcceptedFrontier), arg0, arg1, arg2, arg3) -} - -// AcceptedStateSummary mocks base method. -func (m *MockEngine) AcceptedStateSummary(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptedStateSummary", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// AcceptedStateSummary indicates an expected call of AcceptedStateSummary. -func (mr *MockEngineMockRecorder) AcceptedStateSummary(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptedStateSummary", reflect.TypeOf((*MockEngine)(nil).AcceptedStateSummary), arg0, arg1, arg2, arg3) -} - -// Ancestors mocks base method. -func (m *MockEngine) Ancestors(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 [][]byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Ancestors", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Ancestors indicates an expected call of Ancestors. -func (mr *MockEngineMockRecorder) Ancestors(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ancestors", reflect.TypeOf((*MockEngine)(nil).Ancestors), arg0, arg1, arg2, arg3) -} - -// AppGossip mocks base method. -func (m *MockEngine) AppGossip(arg0 context.Context, arg1 ids.NodeID, arg2 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppGossip", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppGossip indicates an expected call of AppGossip. -func (mr *MockEngineMockRecorder) AppGossip(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppGossip", reflect.TypeOf((*MockEngine)(nil).AppGossip), arg0, arg1, arg2) -} - -// AppRequest mocks base method. -func (m *MockEngine) AppRequest(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 time.Time, arg4 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppRequest", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppRequest indicates an expected call of AppRequest. -func (mr *MockEngineMockRecorder) AppRequest(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppRequest", reflect.TypeOf((*MockEngine)(nil).AppRequest), arg0, arg1, arg2, arg3, arg4) -} - -// AppRequestFailed mocks base method. -func (m *MockEngine) AppRequestFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppRequestFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppRequestFailed indicates an expected call of AppRequestFailed. -func (mr *MockEngineMockRecorder) AppRequestFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppRequestFailed", reflect.TypeOf((*MockEngine)(nil).AppRequestFailed), arg0, arg1, arg2) -} - -// AppResponse mocks base method. -func (m *MockEngine) AppResponse(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppResponse", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppResponse indicates an expected call of AppResponse. -func (mr *MockEngineMockRecorder) AppResponse(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppResponse", reflect.TypeOf((*MockEngine)(nil).AppResponse), arg0, arg1, arg2, arg3) -} - -// Chits mocks base method. -func (m *MockEngine) Chits(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3, arg4 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Chits", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// Chits indicates an expected call of Chits. -func (mr *MockEngineMockRecorder) Chits(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chits", reflect.TypeOf((*MockEngine)(nil).Chits), arg0, arg1, arg2, arg3, arg4) -} - -// Connected mocks base method. -func (m *MockEngine) Connected(arg0 context.Context, arg1 ids.NodeID, arg2 *version.Application) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connected", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// Connected indicates an expected call of Connected. -func (mr *MockEngineMockRecorder) Connected(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connected", reflect.TypeOf((*MockEngine)(nil).Connected), arg0, arg1, arg2) -} - -// Context mocks base method. -func (m *MockEngine) Context() *snow.ConsensusContext { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(*snow.ConsensusContext) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockEngineMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEngine)(nil).Context)) -} - -// CrossChainAppRequest mocks base method. -func (m *MockEngine) CrossChainAppRequest(arg0 context.Context, arg1 ids.ID, arg2 uint32, arg3 time.Time, arg4 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppRequest", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// CrossChainAppRequest indicates an expected call of CrossChainAppRequest. -func (mr *MockEngineMockRecorder) CrossChainAppRequest(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppRequest", reflect.TypeOf((*MockEngine)(nil).CrossChainAppRequest), arg0, arg1, arg2, arg3, arg4) -} - -// CrossChainAppRequestFailed mocks base method. -func (m *MockEngine) CrossChainAppRequestFailed(arg0 context.Context, arg1 ids.ID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppRequestFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// CrossChainAppRequestFailed indicates an expected call of CrossChainAppRequestFailed. -func (mr *MockEngineMockRecorder) CrossChainAppRequestFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppRequestFailed", reflect.TypeOf((*MockEngine)(nil).CrossChainAppRequestFailed), arg0, arg1, arg2) -} - -// CrossChainAppResponse mocks base method. -func (m *MockEngine) CrossChainAppResponse(arg0 context.Context, arg1 ids.ID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppResponse", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// CrossChainAppResponse indicates an expected call of CrossChainAppResponse. -func (mr *MockEngineMockRecorder) CrossChainAppResponse(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppResponse", reflect.TypeOf((*MockEngine)(nil).CrossChainAppResponse), arg0, arg1, arg2, arg3) -} - -// Disconnected mocks base method. -func (m *MockEngine) Disconnected(arg0 context.Context, arg1 ids.NodeID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Disconnected", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Disconnected indicates an expected call of Disconnected. -func (mr *MockEngineMockRecorder) Disconnected(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnected", reflect.TypeOf((*MockEngine)(nil).Disconnected), arg0, arg1) -} - -// Get mocks base method. -func (m *MockEngine) Get(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Get indicates an expected call of Get. -func (mr *MockEngineMockRecorder) Get(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEngine)(nil).Get), arg0, arg1, arg2, arg3) -} - -// GetAccepted mocks base method. -func (m *MockEngine) GetAccepted(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAccepted", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAccepted indicates an expected call of GetAccepted. -func (mr *MockEngineMockRecorder) GetAccepted(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccepted", reflect.TypeOf((*MockEngine)(nil).GetAccepted), arg0, arg1, arg2, arg3) -} - -// GetAcceptedFailed mocks base method. -func (m *MockEngine) GetAcceptedFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedFailed indicates an expected call of GetAcceptedFailed. -func (mr *MockEngineMockRecorder) GetAcceptedFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedFailed", reflect.TypeOf((*MockEngine)(nil).GetAcceptedFailed), arg0, arg1, arg2) -} - -// GetAcceptedFrontier mocks base method. -func (m *MockEngine) GetAcceptedFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedFrontier", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedFrontier indicates an expected call of GetAcceptedFrontier. -func (mr *MockEngineMockRecorder) GetAcceptedFrontier(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedFrontier", reflect.TypeOf((*MockEngine)(nil).GetAcceptedFrontier), arg0, arg1, arg2) -} - -// GetAcceptedFrontierFailed mocks base method. -func (m *MockEngine) GetAcceptedFrontierFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedFrontierFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedFrontierFailed indicates an expected call of GetAcceptedFrontierFailed. -func (mr *MockEngineMockRecorder) GetAcceptedFrontierFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedFrontierFailed", reflect.TypeOf((*MockEngine)(nil).GetAcceptedFrontierFailed), arg0, arg1, arg2) -} - -// GetAcceptedStateSummary mocks base method. -func (m *MockEngine) GetAcceptedStateSummary(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []uint64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedStateSummary", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedStateSummary indicates an expected call of GetAcceptedStateSummary. -func (mr *MockEngineMockRecorder) GetAcceptedStateSummary(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedStateSummary", reflect.TypeOf((*MockEngine)(nil).GetAcceptedStateSummary), arg0, arg1, arg2, arg3) -} - -// GetAcceptedStateSummaryFailed mocks base method. -func (m *MockEngine) GetAcceptedStateSummaryFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedStateSummaryFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedStateSummaryFailed indicates an expected call of GetAcceptedStateSummaryFailed. -func (mr *MockEngineMockRecorder) GetAcceptedStateSummaryFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedStateSummaryFailed", reflect.TypeOf((*MockEngine)(nil).GetAcceptedStateSummaryFailed), arg0, arg1, arg2) -} - -// GetAncestors mocks base method. -func (m *MockEngine) GetAncestors(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAncestors", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAncestors indicates an expected call of GetAncestors. -func (mr *MockEngineMockRecorder) GetAncestors(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAncestors", reflect.TypeOf((*MockEngine)(nil).GetAncestors), arg0, arg1, arg2, arg3) -} - -// GetAncestorsFailed mocks base method. -func (m *MockEngine) GetAncestorsFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAncestorsFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAncestorsFailed indicates an expected call of GetAncestorsFailed. -func (mr *MockEngineMockRecorder) GetAncestorsFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAncestorsFailed", reflect.TypeOf((*MockEngine)(nil).GetAncestorsFailed), arg0, arg1, arg2) -} - -// GetFailed mocks base method. -func (m *MockEngine) GetFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetFailed indicates an expected call of GetFailed. -func (mr *MockEngineMockRecorder) GetFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFailed", reflect.TypeOf((*MockEngine)(nil).GetFailed), arg0, arg1, arg2) -} - -// GetStateSummaryFrontier mocks base method. -func (m *MockEngine) GetStateSummaryFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStateSummaryFrontier", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetStateSummaryFrontier indicates an expected call of GetStateSummaryFrontier. -func (mr *MockEngineMockRecorder) GetStateSummaryFrontier(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStateSummaryFrontier", reflect.TypeOf((*MockEngine)(nil).GetStateSummaryFrontier), arg0, arg1, arg2) -} - -// GetStateSummaryFrontierFailed mocks base method. -func (m *MockEngine) GetStateSummaryFrontierFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStateSummaryFrontierFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetStateSummaryFrontierFailed indicates an expected call of GetStateSummaryFrontierFailed. -func (mr *MockEngineMockRecorder) GetStateSummaryFrontierFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStateSummaryFrontierFailed", reflect.TypeOf((*MockEngine)(nil).GetStateSummaryFrontierFailed), arg0, arg1, arg2) -} - -// GetVM mocks base method. -func (m *MockEngine) GetVM() common.VM { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVM") - ret0, _ := ret[0].(common.VM) - return ret0 -} - -// GetVM indicates an expected call of GetVM. -func (mr *MockEngineMockRecorder) GetVM() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVM", reflect.TypeOf((*MockEngine)(nil).GetVM)) -} - -// GetVtx mocks base method. -func (m *MockEngine) GetVtx(arg0 context.Context, arg1 ids.ID) (avalanche.Vertex, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVtx", arg0, arg1) - ret0, _ := ret[0].(avalanche.Vertex) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetVtx indicates an expected call of GetVtx. -func (mr *MockEngineMockRecorder) GetVtx(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVtx", reflect.TypeOf((*MockEngine)(nil).GetVtx), arg0, arg1) -} - -// Gossip mocks base method. -func (m *MockEngine) Gossip(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Gossip", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Gossip indicates an expected call of Gossip. -func (mr *MockEngineMockRecorder) Gossip(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Gossip", reflect.TypeOf((*MockEngine)(nil).Gossip), arg0) -} - -// Halt mocks base method. -func (m *MockEngine) Halt(arg0 context.Context) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Halt", arg0) -} - -// Halt indicates an expected call of Halt. -func (mr *MockEngineMockRecorder) Halt(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Halt", reflect.TypeOf((*MockEngine)(nil).Halt), arg0) -} - -// HealthCheck mocks base method. -func (m *MockEngine) HealthCheck(arg0 context.Context) (interface{}, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HealthCheck", arg0) - ret0, _ := ret[0].(interface{}) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// HealthCheck indicates an expected call of HealthCheck. -func (mr *MockEngineMockRecorder) HealthCheck(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HealthCheck", reflect.TypeOf((*MockEngine)(nil).HealthCheck), arg0) -} - -// Notify mocks base method. -func (m *MockEngine) Notify(arg0 context.Context, arg1 common.Message) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Notify", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Notify indicates an expected call of Notify. -func (mr *MockEngineMockRecorder) Notify(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Notify", reflect.TypeOf((*MockEngine)(nil).Notify), arg0, arg1) -} - -// PullQuery mocks base method. -func (m *MockEngine) PullQuery(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PullQuery", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// PullQuery indicates an expected call of PullQuery. -func (mr *MockEngineMockRecorder) PullQuery(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PullQuery", reflect.TypeOf((*MockEngine)(nil).PullQuery), arg0, arg1, arg2, arg3) -} - -// PushQuery mocks base method. -func (m *MockEngine) PushQuery(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PushQuery", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// PushQuery indicates an expected call of PushQuery. -func (mr *MockEngineMockRecorder) PushQuery(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushQuery", reflect.TypeOf((*MockEngine)(nil).PushQuery), arg0, arg1, arg2, arg3) -} - -// Put mocks base method. -func (m *MockEngine) Put(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Put", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Put indicates an expected call of Put. -func (mr *MockEngineMockRecorder) Put(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEngine)(nil).Put), arg0, arg1, arg2, arg3) -} - -// QueryFailed mocks base method. -func (m *MockEngine) QueryFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// QueryFailed indicates an expected call of QueryFailed. -func (mr *MockEngineMockRecorder) QueryFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryFailed", reflect.TypeOf((*MockEngine)(nil).QueryFailed), arg0, arg1, arg2) -} - -// Shutdown mocks base method. -func (m *MockEngine) Shutdown(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Shutdown", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Shutdown indicates an expected call of Shutdown. -func (mr *MockEngineMockRecorder) Shutdown(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockEngine)(nil).Shutdown), arg0) -} - -// Start mocks base method. -func (m *MockEngine) Start(arg0 context.Context, arg1 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Start", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Start indicates an expected call of Start. -func (mr *MockEngineMockRecorder) Start(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockEngine)(nil).Start), arg0, arg1) -} - -// StateSummaryFrontier mocks base method. -func (m *MockEngine) StateSummaryFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StateSummaryFrontier", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// StateSummaryFrontier indicates an expected call of StateSummaryFrontier. -func (mr *MockEngineMockRecorder) StateSummaryFrontier(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateSummaryFrontier", reflect.TypeOf((*MockEngine)(nil).StateSummaryFrontier), arg0, arg1, arg2, arg3) -} - -// Timeout mocks base method. -func (m *MockEngine) Timeout(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Timeout", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Timeout indicates an expected call of Timeout. -func (mr *MockEngineMockRecorder) Timeout(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockEngine)(nil).Timeout), arg0) -} diff --git a/snow/engine/avalanche/state/prefixed_state.go b/snow/engine/avalanche/state/prefixed_state.go index 9cb87395ee3b..5b1511fdc22d 100644 --- a/snow/engine/avalanche/state/prefixed_state.go +++ b/snow/engine/avalanche/state/prefixed_state.go @@ -21,28 +21,29 @@ var uniqueEdgeID = ids.Empty.Prefix(edgeID) type prefixedState struct { state *state - vtx, status cache.Cacher - uniqueVtx cache.Deduplicator + vtx, status cache.Cacher[ids.ID, ids.ID] + uniqueVtx cache.Deduplicator[ids.ID, *uniqueVertex] } func newPrefixedState(state *state, idCacheSizes int) *prefixedState { return &prefixedState{ state: state, - vtx: &cache.LRU{Size: idCacheSizes}, - status: &cache.LRU{Size: idCacheSizes}, - uniqueVtx: &cache.EvictableLRU{Size: idCacheSizes}, + vtx: &cache.LRU[ids.ID, ids.ID]{Size: idCacheSizes}, + status: &cache.LRU[ids.ID, ids.ID]{Size: idCacheSizes}, + uniqueVtx: &cache.EvictableLRU[ids.ID, *uniqueVertex]{Size: idCacheSizes}, } } func (s *prefixedState) UniqueVertex(vtx *uniqueVertex) *uniqueVertex { - return s.uniqueVtx.Deduplicate(vtx).(*uniqueVertex) + return s.uniqueVtx.Deduplicate(vtx) } func (s *prefixedState) Vertex(id ids.ID) vertex.StatelessVertex { - var vID ids.ID - if cachedVtxIDIntf, found := s.vtx.Get(id); found { - vID = cachedVtxIDIntf.(ids.ID) - } else { + var ( + vID ids.ID + ok bool + ) + if vID, ok = s.vtx.Get(id); !ok { vID = id.Prefix(vtxID) s.vtx.Put(id, vID) } @@ -51,11 +52,12 @@ func (s *prefixedState) Vertex(id ids.ID) vertex.StatelessVertex { } func (s *prefixedState) SetVertex(vtx vertex.StatelessVertex) error { - rawVertexID := vtx.ID() - var vID ids.ID - if cachedVtxIDIntf, found := s.vtx.Get(rawVertexID); found { - vID = cachedVtxIDIntf.(ids.ID) - } else { + var ( + rawVertexID = vtx.ID() + vID ids.ID + ok bool + ) + if vID, ok = s.vtx.Get(rawVertexID); !ok { vID = rawVertexID.Prefix(vtxID) s.vtx.Put(rawVertexID, vID) } @@ -64,10 +66,11 @@ func (s *prefixedState) SetVertex(vtx vertex.StatelessVertex) error { } func (s *prefixedState) Status(id ids.ID) choices.Status { - var sID ids.ID - if cachedStatusIDIntf, found := s.status.Get(id); found { - sID = cachedStatusIDIntf.(ids.ID) - } else { + var ( + sID ids.ID + ok bool + ) + if sID, ok = s.status.Get(id); !ok { sID = id.Prefix(vtxStatusID) s.status.Put(id, sID) } @@ -76,10 +79,11 @@ func (s *prefixedState) Status(id ids.ID) choices.Status { } func (s *prefixedState) SetStatus(id ids.ID, status choices.Status) error { - var sID ids.ID - if cachedStatusIDIntf, found := s.status.Get(id); found { - sID = cachedStatusIDIntf.(ids.ID) - } else { + var ( + sID ids.ID + ok bool + ) + if sID, ok = s.status.Get(id); !ok { sID = id.Prefix(vtxStatusID) s.status.Put(id, sID) } diff --git a/snow/engine/avalanche/state/serializer.go b/snow/engine/avalanche/state/serializer.go index 0d4f3675df64..c7f81d79fb9c 100644 --- a/snow/engine/avalanche/state/serializer.go +++ b/snow/engine/avalanche/state/serializer.go @@ -53,7 +53,7 @@ type SerializerConfig struct { func NewSerializer(config SerializerConfig) vertex.Manager { versionDB := versiondb.New(config.DB) - dbCache := &cache.LRU{Size: dbCacheSize} + dbCache := &cache.LRU[ids.ID, any]{Size: dbCacheSize} s := Serializer{ SerializerConfig: config, versionDB: versionDB, diff --git a/snow/engine/avalanche/state/state.go b/snow/engine/avalanche/state/state.go index a808b3fea960..c1bac5b90bd7 100644 --- a/snow/engine/avalanche/state/state.go +++ b/snow/engine/avalanche/state/state.go @@ -20,7 +20,7 @@ type state struct { serializer *Serializer log logging.Logger - dbCache cache.Cacher + dbCache cache.Cacher[ids.ID, any] db database.Database } diff --git a/snow/engine/avalanche/state/unique_vertex.go b/snow/engine/avalanche/state/unique_vertex.go index d3a29b3c3918..ecefb0838332 100644 --- a/snow/engine/avalanche/state/unique_vertex.go +++ b/snow/engine/avalanche/state/unique_vertex.go @@ -22,8 +22,8 @@ import ( ) var ( - _ cache.Evictable = (*uniqueVertex)(nil) - _ avalanche.Vertex = (*uniqueVertex)(nil) + _ cache.Evictable[ids.ID] = (*uniqueVertex)(nil) + _ avalanche.Vertex = (*uniqueVertex)(nil) ) // uniqueVertex acts as a cache for vertices in the database. @@ -171,7 +171,7 @@ func (vtx *uniqueVertex) ID() ids.ID { return vtx.id } -func (vtx *uniqueVertex) Key() interface{} { +func (vtx *uniqueVertex) Key() ids.ID { return vtx.id } diff --git a/snow/engine/avalanche/transitive.go b/snow/engine/avalanche/transitive.go index 23d4b81bd9c5..c10a3b3381b9 100644 --- a/snow/engine/avalanche/transitive.go +++ b/snow/engine/avalanche/transitive.go @@ -11,6 +11,7 @@ import ( "go.uber.org/zap" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/consensus/avalanche" "github.com/ava-labs/avalanchego/snow/consensus/avalanche/poll" @@ -376,7 +377,10 @@ func (t *Transitive) Start(ctx context.Context, startReqID uint32) error { ) t.metrics.bootstrapFinished.Set(1) - t.Ctx.SetState(snow.NormalOp) + t.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_AVALANCHE, + State: snow.NormalOp, + }) if err := t.VM.SetState(ctx, snow.NormalOp); err != nil { return fmt.Errorf("failed to notify VM that consensus has started: %w", err) diff --git a/snow/engine/avalanche/vertex/mock_vm.go b/snow/engine/avalanche/vertex/mock_vm.go index 30af1c7b7b17..e790fd712b11 100644 --- a/snow/engine/avalanche/vertex/mock_vm.go +++ b/snow/engine/avalanche/vertex/mock_vm.go @@ -15,6 +15,7 @@ import ( manager "github.com/ava-labs/avalanchego/database/manager" ids "github.com/ava-labs/avalanchego/ids" snow "github.com/ava-labs/avalanchego/snow" + snowman "github.com/ava-labs/avalanchego/snow/consensus/snowman" snowstorm "github.com/ava-labs/avalanchego/snow/consensus/snowstorm" common "github.com/ava-labs/avalanchego/snow/engine/common" version "github.com/ava-labs/avalanchego/version" @@ -100,6 +101,21 @@ func (mr *MockDAGVMMockRecorder) AppResponse(arg0, arg1, arg2, arg3 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppResponse", reflect.TypeOf((*MockDAGVM)(nil).AppResponse), arg0, arg1, arg2, arg3) } +// BuildBlock mocks base method. +func (m *MockDAGVM) BuildBlock(arg0 context.Context) (snowman.Block, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BuildBlock", arg0) + ret0, _ := ret[0].(snowman.Block) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BuildBlock indicates an expected call of BuildBlock. +func (mr *MockDAGVMMockRecorder) BuildBlock(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildBlock", reflect.TypeOf((*MockDAGVM)(nil).BuildBlock), arg0) +} + // Connected mocks base method. func (m *MockDAGVM) Connected(arg0 context.Context, arg1 ids.NodeID, arg2 *version.Application) error { m.ctrl.T.Helper() @@ -200,6 +216,21 @@ func (mr *MockDAGVMMockRecorder) Disconnected(arg0, arg1 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnected", reflect.TypeOf((*MockDAGVM)(nil).Disconnected), arg0, arg1) } +// GetBlock mocks base method. +func (m *MockDAGVM) GetBlock(arg0 context.Context, arg1 ids.ID) (snowman.Block, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBlock", arg0, arg1) + ret0, _ := ret[0].(snowman.Block) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBlock indicates an expected call of GetBlock. +func (mr *MockDAGVMMockRecorder) GetBlock(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlock", reflect.TypeOf((*MockDAGVM)(nil).GetBlock), arg0, arg1) +} + // GetTx mocks base method. func (m *MockDAGVM) GetTx(arg0 context.Context, arg1 ids.ID) (snowstorm.Tx, error) { m.ctrl.T.Helper() @@ -244,6 +275,50 @@ func (mr *MockDAGVMMockRecorder) Initialize(arg0, arg1, arg2, arg3, arg4, arg5, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockDAGVM)(nil).Initialize), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } +// LastAccepted mocks base method. +func (m *MockDAGVM) LastAccepted(arg0 context.Context) (ids.ID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastAccepted", arg0) + ret0, _ := ret[0].(ids.ID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LastAccepted indicates an expected call of LastAccepted. +func (mr *MockDAGVMMockRecorder) LastAccepted(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastAccepted", reflect.TypeOf((*MockDAGVM)(nil).LastAccepted), arg0) +} + +// Linearize mocks base method. +func (m *MockDAGVM) Linearize(arg0 context.Context, arg1 ids.ID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Linearize", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Linearize indicates an expected call of Linearize. +func (mr *MockDAGVMMockRecorder) Linearize(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Linearize", reflect.TypeOf((*MockDAGVM)(nil).Linearize), arg0, arg1) +} + +// ParseBlock mocks base method. +func (m *MockDAGVM) ParseBlock(arg0 context.Context, arg1 []byte) (snowman.Block, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ParseBlock", arg0, arg1) + ret0, _ := ret[0].(snowman.Block) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ParseBlock indicates an expected call of ParseBlock. +func (mr *MockDAGVMMockRecorder) ParseBlock(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseBlock", reflect.TypeOf((*MockDAGVM)(nil).ParseBlock), arg0, arg1) +} + // ParseTx mocks base method. func (m *MockDAGVM) ParseTx(arg0 context.Context, arg1 []byte) (snowstorm.Tx, error) { m.ctrl.T.Helper() @@ -273,6 +348,20 @@ func (mr *MockDAGVMMockRecorder) PendingTxs(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PendingTxs", reflect.TypeOf((*MockDAGVM)(nil).PendingTxs), arg0) } +// SetPreference mocks base method. +func (m *MockDAGVM) SetPreference(arg0 context.Context, arg1 ids.ID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetPreference", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetPreference indicates an expected call of SetPreference. +func (mr *MockDAGVMMockRecorder) SetPreference(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPreference", reflect.TypeOf((*MockDAGVM)(nil).SetPreference), arg0, arg1) +} + // SetState mocks base method. func (m *MockDAGVM) SetState(arg0 context.Context, arg1 snow.State) error { m.ctrl.T.Helper() diff --git a/snow/engine/common/config.go b/snow/engine/common/config.go index 9ed65c3c86a6..e83f0a1404be 100644 --- a/snow/engine/common/config.go +++ b/snow/engine/common/config.go @@ -52,7 +52,7 @@ func (c *Config) Context() *snow.ConsensusContext { // IsBootstrapped returns true iff this chain is done bootstrapping func (c *Config) IsBootstrapped() bool { - return c.Ctx.GetState() == snow.NormalOp + return c.Ctx.State.Get().State == snow.NormalOp } // Shared among common.bootstrapper and snowman/avalanche bootstrapper diff --git a/snow/engine/common/queue/jobs.go b/snow/engine/common/queue/jobs.go index d5b0e96d50cb..70ea16cacc32 100644 --- a/snow/engine/common/queue/jobs.go +++ b/snow/engine/common/queue/jobs.go @@ -117,8 +117,8 @@ func (j *Jobs) ExecuteAll( restarted bool, acceptors ...snow.Acceptor, ) (int, error) { - chainCtx.Executing(true) - defer chainCtx.Executing(false) + chainCtx.Executing.Set(true) + defer chainCtx.Executing.Set(false) numExecuted := 0 numToExecute := j.state.numJobs diff --git a/snow/engine/common/queue/state.go b/snow/engine/common/queue/state.go index 5d7486348239..39bbd8e399a5 100644 --- a/snow/engine/common/queue/state.go +++ b/snow/engine/common/queue/state.go @@ -37,7 +37,7 @@ type state struct { parser Parser runnableJobIDs linkeddb.LinkedDB cachingEnabled bool - jobsCache cache.Cacher + jobsCache cache.Cacher[ids.ID, Job] jobsDB database.Database // Should be prefixed with the jobID that we are attempting to find the // dependencies of. This prefixdb.Database should then be wrapped in a @@ -45,7 +45,7 @@ type state struct { dependenciesDB database.Database // This is a cache that tracks LinkedDB iterators that have recently been // made. - dependentsCache cache.Cacher + dependentsCache cache.Cacher[ids.ID, linkeddb.LinkedDB] missingJobIDs linkeddb.LinkedDB // This tracks the summary values of this state. Currently, this only // contains the last known checkpoint of how many jobs are currently in the @@ -62,7 +62,13 @@ func newState( metricsRegisterer prometheus.Registerer, ) (*state, error) { jobsCacheMetricsNamespace := fmt.Sprintf("%s_jobs_cache", metricsNamespace) - jobsCache, err := metercacher.New(jobsCacheMetricsNamespace, metricsRegisterer, &cache.LRU{Size: jobsCacheSize}) + jobsCache, err := metercacher.New[ids.ID, Job]( + jobsCacheMetricsNamespace, + metricsRegisterer, + &cache.LRU[ids.ID, Job]{ + Size: jobsCacheSize, + }, + ) if err != nil { return nil, fmt.Errorf("couldn't create metered cache: %w", err) } @@ -79,7 +85,7 @@ func newState( jobsCache: jobsCache, jobsDB: jobs, dependenciesDB: prefixdb.New(dependenciesPrefix, db), - dependentsCache: &cache.LRU{Size: dependentsCacheSize}, + dependentsCache: &cache.LRU[ids.ID, linkeddb.LinkedDB]{Size: dependentsCacheSize}, missingJobIDs: linkeddb.NewDefault(prefixdb.New(missingJobIDsPrefix, db)), metadataDB: metadataDB, numJobs: numJobs, @@ -228,7 +234,7 @@ func (s *state) HasJob(id ids.ID) (bool, error) { func (s *state) GetJob(ctx context.Context, id ids.ID) (Job, error) { if s.cachingEnabled { if job, exists := s.jobsCache.Get(id); exists { - return job.(Job), nil + return job, nil } } jobBytes, err := s.jobsDB.Get(id[:]) @@ -313,8 +319,8 @@ func (s *state) MissingJobIDs() ([]ids.ID, error) { func (s *state) getDependentsDB(dependency ids.ID) linkeddb.LinkedDB { if s.cachingEnabled { - if dependentsDBIntf, ok := s.dependentsCache.Get(dependency); ok { - return dependentsDBIntf.(linkeddb.LinkedDB) + if dependentsDB, ok := s.dependentsCache.Get(dependency); ok { + return dependentsDB } } dependencyDB := prefixdb.New(dependency[:], s.dependenciesDB) diff --git a/snow/engine/snowman/bootstrap/bootstrapper.go b/snow/engine/snowman/bootstrap/bootstrapper.go index fb60c9fddb6a..3c846b3c476d 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper.go +++ b/snow/engine/snowman/bootstrap/bootstrapper.go @@ -14,6 +14,7 @@ import ( "go.uber.org/zap" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/snowman" @@ -121,7 +122,10 @@ func New(ctx context.Context, config Config, onFinished func(ctx context.Context func (b *bootstrapper) Start(ctx context.Context, startReqID uint32) error { b.Ctx.Log.Info("starting bootstrapper") - b.Ctx.SetState(snow.Bootstrapping) + b.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, + }) if err := b.VM.SetState(ctx, snow.Bootstrapping); err != nil { return fmt.Errorf("failed to notify VM that bootstrapping has started: %w", err) @@ -300,7 +304,7 @@ func (b *bootstrapper) Notify(_ context.Context, msg common.Message) error { return nil } - b.Ctx.RunningStateSync(false) + b.Ctx.StateSyncing.Set(false) return nil } diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 11ec6d890636..354716fd4191 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -16,6 +16,7 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/snowman" @@ -168,7 +169,10 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { // create bootstrapper dummyCallback := func(context.Context, uint32) error { - cfg.Ctx.SetState(snow.NormalOp) + cfg.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil } bs, err := New(context.Background(), cfg, dummyCallback) @@ -246,7 +250,10 @@ func TestBootstrapperSingleFrontier(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -287,7 +294,7 @@ func TestBootstrapperSingleFrontier(t *testing.T) { switch { case err != nil: // should finish t.Fatal(err) - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk1.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -350,7 +357,10 @@ func TestBootstrapperUnknownByzantineResponse(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -437,7 +447,7 @@ func TestBootstrapperUnknownByzantineResponse(t *testing.T) { switch { case err != nil: // respond with right block t.Fatal(err) - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -511,7 +521,10 @@ func TestBootstrapperPartialFetch(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -600,7 +613,7 @@ func TestBootstrapperPartialFetch(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -675,7 +688,10 @@ func TestBootstrapperEmptyResponse(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -783,7 +799,7 @@ func TestBootstrapperEmptyResponse(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -861,7 +877,10 @@ func TestBootstrapperAncestors(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -943,7 +962,7 @@ func TestBootstrapperAncestors(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -1004,7 +1023,10 @@ func TestBootstrapperFinalized(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -1077,7 +1099,7 @@ func TestBootstrapperFinalized(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -1214,7 +1236,10 @@ func TestRestartBootstrapping(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -1280,7 +1305,7 @@ func TestRestartBootstrapping(t *testing.T) { t.Fatal(err) } - if config.Ctx.GetState() == snow.NormalOp { + if config.Ctx.State.Get().State == snow.NormalOp { t.Fatal("Bootstrapping should not have finished with outstanding request for blk4") } @@ -1289,7 +1314,7 @@ func TestRestartBootstrapping(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Accepted: t.Fatalf("Block should be accepted") @@ -1354,7 +1379,10 @@ func TestBootstrapOldBlockAfterStateSync(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) @@ -1394,7 +1422,7 @@ func TestBootstrapOldBlockAfterStateSync(t *testing.T) { } switch { - case config.Ctx.GetState() != snow.NormalOp: + case config.Ctx.State.Get().State != snow.NormalOp: t.Fatalf("Bootstrapping should have finished") case blk0.Status() != choices.Processing: t.Fatalf("Block should be processing") @@ -1441,7 +1469,10 @@ func TestBootstrapContinueAfterHalt(t *testing.T) { context.Background(), config, func(context.Context, uint32) error { - config.Ctx.SetState(snow.NormalOp) + config.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) return nil }, ) diff --git a/snow/engine/snowman/mock_engine.go b/snow/engine/snowman/mock_engine.go deleted file mode 100644 index 3eabca1627fb..000000000000 --- a/snow/engine/snowman/mock_engine.go +++ /dev/null @@ -1,618 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ava-labs/avalanchego/snow/engine/snowman (interfaces: Engine) - -// Package snowman is a generated GoMock package. -package snowman - -import ( - context "context" - reflect "reflect" - time "time" - - ids "github.com/ava-labs/avalanchego/ids" - snow "github.com/ava-labs/avalanchego/snow" - snowman "github.com/ava-labs/avalanchego/snow/consensus/snowman" - common "github.com/ava-labs/avalanchego/snow/engine/common" - version "github.com/ava-labs/avalanchego/version" - gomock "github.com/golang/mock/gomock" -) - -// MockEngine is a mock of Engine interface. -type MockEngine struct { - ctrl *gomock.Controller - recorder *MockEngineMockRecorder -} - -// MockEngineMockRecorder is the mock recorder for MockEngine. -type MockEngineMockRecorder struct { - mock *MockEngine -} - -// NewMockEngine creates a new mock instance. -func NewMockEngine(ctrl *gomock.Controller) *MockEngine { - mock := &MockEngine{ctrl: ctrl} - mock.recorder = &MockEngineMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEngine) EXPECT() *MockEngineMockRecorder { - return m.recorder -} - -// Accepted mocks base method. -func (m *MockEngine) Accepted(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Accepted", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Accepted indicates an expected call of Accepted. -func (mr *MockEngineMockRecorder) Accepted(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accepted", reflect.TypeOf((*MockEngine)(nil).Accepted), arg0, arg1, arg2, arg3) -} - -// AcceptedFrontier mocks base method. -func (m *MockEngine) AcceptedFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptedFrontier", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// AcceptedFrontier indicates an expected call of AcceptedFrontier. -func (mr *MockEngineMockRecorder) AcceptedFrontier(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptedFrontier", reflect.TypeOf((*MockEngine)(nil).AcceptedFrontier), arg0, arg1, arg2, arg3) -} - -// AcceptedStateSummary mocks base method. -func (m *MockEngine) AcceptedStateSummary(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptedStateSummary", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// AcceptedStateSummary indicates an expected call of AcceptedStateSummary. -func (mr *MockEngineMockRecorder) AcceptedStateSummary(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptedStateSummary", reflect.TypeOf((*MockEngine)(nil).AcceptedStateSummary), arg0, arg1, arg2, arg3) -} - -// Ancestors mocks base method. -func (m *MockEngine) Ancestors(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 [][]byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Ancestors", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Ancestors indicates an expected call of Ancestors. -func (mr *MockEngineMockRecorder) Ancestors(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ancestors", reflect.TypeOf((*MockEngine)(nil).Ancestors), arg0, arg1, arg2, arg3) -} - -// AppGossip mocks base method. -func (m *MockEngine) AppGossip(arg0 context.Context, arg1 ids.NodeID, arg2 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppGossip", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppGossip indicates an expected call of AppGossip. -func (mr *MockEngineMockRecorder) AppGossip(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppGossip", reflect.TypeOf((*MockEngine)(nil).AppGossip), arg0, arg1, arg2) -} - -// AppRequest mocks base method. -func (m *MockEngine) AppRequest(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 time.Time, arg4 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppRequest", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppRequest indicates an expected call of AppRequest. -func (mr *MockEngineMockRecorder) AppRequest(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppRequest", reflect.TypeOf((*MockEngine)(nil).AppRequest), arg0, arg1, arg2, arg3, arg4) -} - -// AppRequestFailed mocks base method. -func (m *MockEngine) AppRequestFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppRequestFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppRequestFailed indicates an expected call of AppRequestFailed. -func (mr *MockEngineMockRecorder) AppRequestFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppRequestFailed", reflect.TypeOf((*MockEngine)(nil).AppRequestFailed), arg0, arg1, arg2) -} - -// AppResponse mocks base method. -func (m *MockEngine) AppResponse(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppResponse", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppResponse indicates an expected call of AppResponse. -func (mr *MockEngineMockRecorder) AppResponse(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppResponse", reflect.TypeOf((*MockEngine)(nil).AppResponse), arg0, arg1, arg2, arg3) -} - -// Chits mocks base method. -func (m *MockEngine) Chits(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3, arg4 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Chits", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// Chits indicates an expected call of Chits. -func (mr *MockEngineMockRecorder) Chits(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chits", reflect.TypeOf((*MockEngine)(nil).Chits), arg0, arg1, arg2, arg3, arg4) -} - -// Connected mocks base method. -func (m *MockEngine) Connected(arg0 context.Context, arg1 ids.NodeID, arg2 *version.Application) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connected", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// Connected indicates an expected call of Connected. -func (mr *MockEngineMockRecorder) Connected(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connected", reflect.TypeOf((*MockEngine)(nil).Connected), arg0, arg1, arg2) -} - -// Context mocks base method. -func (m *MockEngine) Context() *snow.ConsensusContext { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(*snow.ConsensusContext) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockEngineMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEngine)(nil).Context)) -} - -// CrossChainAppRequest mocks base method. -func (m *MockEngine) CrossChainAppRequest(arg0 context.Context, arg1 ids.ID, arg2 uint32, arg3 time.Time, arg4 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppRequest", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// CrossChainAppRequest indicates an expected call of CrossChainAppRequest. -func (mr *MockEngineMockRecorder) CrossChainAppRequest(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppRequest", reflect.TypeOf((*MockEngine)(nil).CrossChainAppRequest), arg0, arg1, arg2, arg3, arg4) -} - -// CrossChainAppRequestFailed mocks base method. -func (m *MockEngine) CrossChainAppRequestFailed(arg0 context.Context, arg1 ids.ID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppRequestFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// CrossChainAppRequestFailed indicates an expected call of CrossChainAppRequestFailed. -func (mr *MockEngineMockRecorder) CrossChainAppRequestFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppRequestFailed", reflect.TypeOf((*MockEngine)(nil).CrossChainAppRequestFailed), arg0, arg1, arg2) -} - -// CrossChainAppResponse mocks base method. -func (m *MockEngine) CrossChainAppResponse(arg0 context.Context, arg1 ids.ID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppResponse", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// CrossChainAppResponse indicates an expected call of CrossChainAppResponse. -func (mr *MockEngineMockRecorder) CrossChainAppResponse(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppResponse", reflect.TypeOf((*MockEngine)(nil).CrossChainAppResponse), arg0, arg1, arg2, arg3) -} - -// Disconnected mocks base method. -func (m *MockEngine) Disconnected(arg0 context.Context, arg1 ids.NodeID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Disconnected", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Disconnected indicates an expected call of Disconnected. -func (mr *MockEngineMockRecorder) Disconnected(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnected", reflect.TypeOf((*MockEngine)(nil).Disconnected), arg0, arg1) -} - -// Get mocks base method. -func (m *MockEngine) Get(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Get indicates an expected call of Get. -func (mr *MockEngineMockRecorder) Get(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEngine)(nil).Get), arg0, arg1, arg2, arg3) -} - -// GetAccepted mocks base method. -func (m *MockEngine) GetAccepted(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAccepted", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAccepted indicates an expected call of GetAccepted. -func (mr *MockEngineMockRecorder) GetAccepted(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccepted", reflect.TypeOf((*MockEngine)(nil).GetAccepted), arg0, arg1, arg2, arg3) -} - -// GetAcceptedFailed mocks base method. -func (m *MockEngine) GetAcceptedFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedFailed indicates an expected call of GetAcceptedFailed. -func (mr *MockEngineMockRecorder) GetAcceptedFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedFailed", reflect.TypeOf((*MockEngine)(nil).GetAcceptedFailed), arg0, arg1, arg2) -} - -// GetAcceptedFrontier mocks base method. -func (m *MockEngine) GetAcceptedFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedFrontier", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedFrontier indicates an expected call of GetAcceptedFrontier. -func (mr *MockEngineMockRecorder) GetAcceptedFrontier(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedFrontier", reflect.TypeOf((*MockEngine)(nil).GetAcceptedFrontier), arg0, arg1, arg2) -} - -// GetAcceptedFrontierFailed mocks base method. -func (m *MockEngine) GetAcceptedFrontierFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedFrontierFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedFrontierFailed indicates an expected call of GetAcceptedFrontierFailed. -func (mr *MockEngineMockRecorder) GetAcceptedFrontierFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedFrontierFailed", reflect.TypeOf((*MockEngine)(nil).GetAcceptedFrontierFailed), arg0, arg1, arg2) -} - -// GetAcceptedStateSummary mocks base method. -func (m *MockEngine) GetAcceptedStateSummary(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []uint64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedStateSummary", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedStateSummary indicates an expected call of GetAcceptedStateSummary. -func (mr *MockEngineMockRecorder) GetAcceptedStateSummary(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedStateSummary", reflect.TypeOf((*MockEngine)(nil).GetAcceptedStateSummary), arg0, arg1, arg2, arg3) -} - -// GetAcceptedStateSummaryFailed mocks base method. -func (m *MockEngine) GetAcceptedStateSummaryFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAcceptedStateSummaryFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAcceptedStateSummaryFailed indicates an expected call of GetAcceptedStateSummaryFailed. -func (mr *MockEngineMockRecorder) GetAcceptedStateSummaryFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAcceptedStateSummaryFailed", reflect.TypeOf((*MockEngine)(nil).GetAcceptedStateSummaryFailed), arg0, arg1, arg2) -} - -// GetAncestors mocks base method. -func (m *MockEngine) GetAncestors(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAncestors", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAncestors indicates an expected call of GetAncestors. -func (mr *MockEngineMockRecorder) GetAncestors(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAncestors", reflect.TypeOf((*MockEngine)(nil).GetAncestors), arg0, arg1, arg2, arg3) -} - -// GetAncestorsFailed mocks base method. -func (m *MockEngine) GetAncestorsFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAncestorsFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetAncestorsFailed indicates an expected call of GetAncestorsFailed. -func (mr *MockEngineMockRecorder) GetAncestorsFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAncestorsFailed", reflect.TypeOf((*MockEngine)(nil).GetAncestorsFailed), arg0, arg1, arg2) -} - -// GetBlock mocks base method. -func (m *MockEngine) GetBlock(arg0 context.Context, arg1 ids.ID) (snowman.Block, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetBlock", arg0, arg1) - ret0, _ := ret[0].(snowman.Block) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetBlock indicates an expected call of GetBlock. -func (mr *MockEngineMockRecorder) GetBlock(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlock", reflect.TypeOf((*MockEngine)(nil).GetBlock), arg0, arg1) -} - -// GetFailed mocks base method. -func (m *MockEngine) GetFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetFailed indicates an expected call of GetFailed. -func (mr *MockEngineMockRecorder) GetFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFailed", reflect.TypeOf((*MockEngine)(nil).GetFailed), arg0, arg1, arg2) -} - -// GetStateSummaryFrontier mocks base method. -func (m *MockEngine) GetStateSummaryFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStateSummaryFrontier", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetStateSummaryFrontier indicates an expected call of GetStateSummaryFrontier. -func (mr *MockEngineMockRecorder) GetStateSummaryFrontier(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStateSummaryFrontier", reflect.TypeOf((*MockEngine)(nil).GetStateSummaryFrontier), arg0, arg1, arg2) -} - -// GetStateSummaryFrontierFailed mocks base method. -func (m *MockEngine) GetStateSummaryFrontierFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStateSummaryFrontierFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetStateSummaryFrontierFailed indicates an expected call of GetStateSummaryFrontierFailed. -func (mr *MockEngineMockRecorder) GetStateSummaryFrontierFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStateSummaryFrontierFailed", reflect.TypeOf((*MockEngine)(nil).GetStateSummaryFrontierFailed), arg0, arg1, arg2) -} - -// GetVM mocks base method. -func (m *MockEngine) GetVM() common.VM { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVM") - ret0, _ := ret[0].(common.VM) - return ret0 -} - -// GetVM indicates an expected call of GetVM. -func (mr *MockEngineMockRecorder) GetVM() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVM", reflect.TypeOf((*MockEngine)(nil).GetVM)) -} - -// Gossip mocks base method. -func (m *MockEngine) Gossip(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Gossip", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Gossip indicates an expected call of Gossip. -func (mr *MockEngineMockRecorder) Gossip(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Gossip", reflect.TypeOf((*MockEngine)(nil).Gossip), arg0) -} - -// Halt mocks base method. -func (m *MockEngine) Halt(arg0 context.Context) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Halt", arg0) -} - -// Halt indicates an expected call of Halt. -func (mr *MockEngineMockRecorder) Halt(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Halt", reflect.TypeOf((*MockEngine)(nil).Halt), arg0) -} - -// HealthCheck mocks base method. -func (m *MockEngine) HealthCheck(arg0 context.Context) (interface{}, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HealthCheck", arg0) - ret0, _ := ret[0].(interface{}) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// HealthCheck indicates an expected call of HealthCheck. -func (mr *MockEngineMockRecorder) HealthCheck(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HealthCheck", reflect.TypeOf((*MockEngine)(nil).HealthCheck), arg0) -} - -// Notify mocks base method. -func (m *MockEngine) Notify(arg0 context.Context, arg1 common.Message) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Notify", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Notify indicates an expected call of Notify. -func (mr *MockEngineMockRecorder) Notify(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Notify", reflect.TypeOf((*MockEngine)(nil).Notify), arg0, arg1) -} - -// PullQuery mocks base method. -func (m *MockEngine) PullQuery(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PullQuery", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// PullQuery indicates an expected call of PullQuery. -func (mr *MockEngineMockRecorder) PullQuery(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PullQuery", reflect.TypeOf((*MockEngine)(nil).PullQuery), arg0, arg1, arg2, arg3) -} - -// PushQuery mocks base method. -func (m *MockEngine) PushQuery(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PushQuery", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// PushQuery indicates an expected call of PushQuery. -func (mr *MockEngineMockRecorder) PushQuery(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushQuery", reflect.TypeOf((*MockEngine)(nil).PushQuery), arg0, arg1, arg2, arg3) -} - -// Put mocks base method. -func (m *MockEngine) Put(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Put", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Put indicates an expected call of Put. -func (mr *MockEngineMockRecorder) Put(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEngine)(nil).Put), arg0, arg1, arg2, arg3) -} - -// QueryFailed mocks base method. -func (m *MockEngine) QueryFailed(arg0 context.Context, arg1 ids.NodeID, arg2 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryFailed", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// QueryFailed indicates an expected call of QueryFailed. -func (mr *MockEngineMockRecorder) QueryFailed(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryFailed", reflect.TypeOf((*MockEngine)(nil).QueryFailed), arg0, arg1, arg2) -} - -// Shutdown mocks base method. -func (m *MockEngine) Shutdown(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Shutdown", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Shutdown indicates an expected call of Shutdown. -func (mr *MockEngineMockRecorder) Shutdown(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockEngine)(nil).Shutdown), arg0) -} - -// Start mocks base method. -func (m *MockEngine) Start(arg0 context.Context, arg1 uint32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Start", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Start indicates an expected call of Start. -func (mr *MockEngineMockRecorder) Start(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockEngine)(nil).Start), arg0, arg1) -} - -// StateSummaryFrontier mocks base method. -func (m *MockEngine) StateSummaryFrontier(arg0 context.Context, arg1 ids.NodeID, arg2 uint32, arg3 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StateSummaryFrontier", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// StateSummaryFrontier indicates an expected call of StateSummaryFrontier. -func (mr *MockEngineMockRecorder) StateSummaryFrontier(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateSummaryFrontier", reflect.TypeOf((*MockEngine)(nil).StateSummaryFrontier), arg0, arg1, arg2, arg3) -} - -// Timeout mocks base method. -func (m *MockEngine) Timeout(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Timeout", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Timeout indicates an expected call of Timeout. -func (mr *MockEngineMockRecorder) Timeout(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockEngine)(nil).Timeout), arg0) -} diff --git a/snow/engine/snowman/syncer/state_syncer.go b/snow/engine/snowman/syncer/state_syncer.go index e8e231712fa4..10d234fa072d 100644 --- a/snow/engine/snowman/syncer/state_syncer.go +++ b/snow/engine/snowman/syncer/state_syncer.go @@ -14,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/engine/snowman/block" @@ -322,13 +323,13 @@ func (ss *stateSyncer) AcceptedStateSummary(ctx context.Context, nodeID ids.Node case block.StateSyncStatic: // Summary was accepted and VM is state syncing. // Engine will wait for notification of state sync done. - ss.Ctx.RunningStateSync(true) + ss.Ctx.StateSyncing.Set(true) return nil case block.StateSyncDynamic: // Summary was accepted and VM is state syncing. // Engine will continue into bootstrapping and the VM will sync in the // background. - ss.Ctx.RunningStateSync(true) + ss.Ctx.StateSyncing.Set(true) return ss.onDoneStateSyncing(ctx, ss.requestID) default: ss.Ctx.Log.Warn("unhandled state summary mode, proceeding to bootstrap", @@ -384,7 +385,10 @@ func (ss *stateSyncer) GetAcceptedStateSummaryFailed(ctx context.Context, nodeID func (ss *stateSyncer) Start(ctx context.Context, startReqID uint32) error { ss.Ctx.Log.Info("starting state sync") - ss.Ctx.SetState(snow.StateSyncing) + ss.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.StateSyncing, + }) if err := ss.VM.SetState(ctx, snow.StateSyncing); err != nil { return fmt.Errorf("failed to notify VM that state syncing has started: %w", err) } @@ -542,7 +546,7 @@ func (ss *stateSyncer) Notify(ctx context.Context, msg common.Message) error { return nil } - ss.Ctx.RunningStateSync(false) + ss.Ctx.StateSyncing.Set(false) return ss.onDoneStateSyncing(ctx, ss.requestID) } diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index 81e883b8cd5a..174abb9723ce 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -13,6 +13,7 @@ import ( "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/cache/metercacher" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/snowman" @@ -65,7 +66,7 @@ type Transitive struct { // A block is put into this cache if it was not able to be issued. A block // fails to be issued if verification on the block or one of its ancestors // occurs. - nonVerifiedCache cache.Cacher + nonVerifiedCache cache.Cacher[ids.ID, snowman.Block] // acceptedFrontiers of the other validators of this chain acceptedFrontiers tracker.Accepted @@ -85,10 +86,10 @@ type Transitive struct { func newTransitive(config Config) (*Transitive, error) { config.Ctx.Log.Info("initializing consensus engine") - nonVerifiedCache, err := metercacher.New( + nonVerifiedCache, err := metercacher.New[ids.ID, snowman.Block]( "non_verified_cache", config.Ctx.Registerer, - &cache.LRU{Size: nonVerifiedCacheSize}, + &cache.LRU[ids.ID, snowman.Block]{Size: nonVerifiedCacheSize}, ) if err != nil { return nil, err @@ -382,7 +383,7 @@ func (t *Transitive) Notify(ctx context.Context, msg common.Message) error { t.pendingBuildBlocks++ return t.buildBlocks(ctx) case common.StateSyncDone: - t.Ctx.RunningStateSync(false) + t.Ctx.StateSyncing.Set(false) return nil default: t.Ctx.Log.Warn("received an unexpected message from the VM", @@ -446,7 +447,10 @@ func (t *Transitive) Start(ctx context.Context, startReqID uint32) error { ) t.metrics.bootstrapFinished.Set(1) - t.Ctx.SetState(snow.NormalOp) + t.Ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, + }) if err := t.VM.SetState(ctx, snow.NormalOp); err != nil { return fmt.Errorf("failed to notify VM that consensus is starting: %w", err) @@ -479,7 +483,7 @@ func (t *Transitive) GetBlock(ctx context.Context, blkID ids.ID) (snowman.Block, return blk, nil } if blk, ok := t.nonVerifiedCache.Get(blkID); ok { - return blk.(snowman.Block), nil + return blk, nil } return t.VM.GetBlock(ctx, blkID) @@ -487,7 +491,7 @@ func (t *Transitive) GetBlock(ctx context.Context, blkID ids.ID) (snowman.Block, func (t *Transitive) sendChits(ctx context.Context, nodeID ids.NodeID, requestID uint32) { lastAccepted := t.Consensus.LastAccepted() - if t.Ctx.IsRunningStateSync() { + if t.Ctx.StateSyncing.Get() { t.Sender.SendChits(ctx, nodeID, requestID, []ids.ID{lastAccepted}, []ids.ID{lastAccepted}) } else { t.Sender.SendChits(ctx, nodeID, requestID, []ids.ID{t.Consensus.Preference()}, []ids.ID{lastAccepted}) diff --git a/snow/networking/handler/handler.go b/snow/networking/handler/handler.go index f8b25d23c61c..530417766fd2 100644 --- a/snow/networking/handler/handler.go +++ b/snow/networking/handler/handler.go @@ -78,10 +78,9 @@ type handler struct { preemptTimeouts chan struct{} gossipFrequency time.Duration - defaultEngine p2p.EngineType - stateSyncer common.StateSyncer - bootstrapper common.BootstrapableEngine - engine common.Engine + stateSyncer common.StateSyncer + bootstrapper common.BootstrapableEngine + engine common.Engine // onStopped is called in a goroutine when this handler finishes shutting // down. If it is nil then it is skipped. onStopped func() @@ -116,7 +115,6 @@ func New( msgFromVMChan <-chan common.Message, preemptTimeouts chan struct{}, gossipFrequency time.Duration, - defaultEngine p2p.EngineType, resourceTracker tracker.ResourceTracker, subnetConnector validators.SubnetConnector, ) (Handler, error) { @@ -126,7 +124,6 @@ func New( msgFromVMChan: msgFromVMChan, preemptTimeouts: preemptTimeouts, gossipFrequency: gossipFrequency, - defaultEngine: defaultEngine, asyncMessagePool: worker.NewPool(threadPoolSize), timeouts: make(chan struct{}, 1), closingChan: make(chan struct{}), @@ -158,7 +155,7 @@ func (h *handler) Context() *snow.ConsensusContext { } func (h *handler) IsValidator(nodeID ids.NodeID) bool { - return !h.ctx.IsValidatorOnly() || + return !h.ctx.ValidatorOnly.Get() || nodeID == h.ctx.NodeID || h.validators.Contains(nodeID) } @@ -448,7 +445,7 @@ func (h *handler) handleSyncMsg(ctx context.Context, msg message.InboundMessage) h.ctx.Log.Debug("finished handling sync message", zap.Stringer("messageOp", op), ) - if processingTime > syncProcessingTimeWarnLimit && h.ctx.GetState() == snow.NormalOp { + if processingTime > syncProcessingTimeWarnLimit && h.ctx.State.Get().State == snow.NormalOp { h.ctx.Log.Warn("handling sync message took longer than expected", zap.Duration("processingTime", processingTime), zap.Stringer("nodeID", nodeID), @@ -489,7 +486,6 @@ func (h *handler) handleSyncMsg(ctx context.Context, msg message.InboundMessage) zap.Stringer("messageOp", message.GetAcceptedStateSummaryOp), zap.Uint32("requestID", msg.RequestId), zap.String("field", "Heights"), - zap.Error(err), ) return engine.GetAcceptedStateSummaryFailed(ctx, nodeID, msg.RequestId) } @@ -835,7 +831,7 @@ func (h *handler) handleChanMsg(msg message.InboundMessage) error { } func (h *handler) getEngine() (common.Engine, error) { - state := h.ctx.GetState() + state := h.ctx.State.Get().State switch state { case snow.StateSyncing: return h.stateSyncer, nil diff --git a/snow/networking/handler/handler_test.go b/snow/networking/handler/handler_test.go index bf733fbeeaf7..72d4f319299c 100644 --- a/snow/networking/handler/handler_test.go +++ b/snow/networking/handler/handler_test.go @@ -52,7 +52,6 @@ func TestHandlerDropsTimedOutMessages(t *testing.T) { nil, nil, time.Second, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -80,7 +79,10 @@ func TestHandlerDropsTimedOutMessages(t *testing.T) { return nil } handler.SetBootstrapper(bootstrapper) - ctx.SetState(snow.Bootstrapping) // assumed bootstrapping is ongoing + ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, // assumed bootstrap is ongoing + }) pastTime := time.Now() handler.clock.Set(pastTime) @@ -134,7 +136,6 @@ func TestHandlerClosesOnError(t *testing.T) { nil, nil, time.Second, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -172,7 +173,10 @@ func TestHandlerClosesOnError(t *testing.T) { // assume bootstrapping is ongoing so that InboundGetAcceptedFrontier // should normally be handled - ctx.SetState(snow.Bootstrapping) + ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, + }) bootstrapper.StartF = func(context.Context, uint32) error { return nil @@ -214,7 +218,6 @@ func TestHandlerDropsGossipDuringBootstrapping(t *testing.T) { nil, nil, 1, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -240,7 +243,10 @@ func TestHandlerDropsGossipDuringBootstrapping(t *testing.T) { return nil } handler.SetBootstrapper(bootstrapper) - ctx.SetState(snow.Bootstrapping) // assumed bootstrapping is ongoing + ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, // assumed bootstrap is ongoing + }) bootstrapper.StartF = func(context.Context, uint32) error { return nil @@ -284,7 +290,6 @@ func TestHandlerDispatchInternal(t *testing.T) { msgFromVMChan, nil, time.Second, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -311,7 +316,10 @@ func TestHandlerDispatchInternal(t *testing.T) { return nil } handler.SetConsensus(engine) - ctx.SetState(snow.NormalOp) // assumed bootstrapping is done + ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, // assumed bootstrap is done + }) bootstrapper.StartF = func(context.Context, uint32) error { return nil @@ -353,7 +361,6 @@ func TestHandlerSubnetConnector(t *testing.T) { nil, nil, time.Second, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, connector, ) @@ -376,7 +383,10 @@ func TestHandlerSubnetConnector(t *testing.T) { return ctx } handler.SetConsensus(engine) - ctx.SetState(snow.NormalOp) // assumed bootstrapping is done + ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.NormalOp, // assumed bootstrap is done + }) bootstrapper.StartF = func(context.Context, uint32) error { return nil diff --git a/snow/networking/router/chain_router.go b/snow/networking/router/chain_router.go index 6890daaffaea..52e87586553d 100644 --- a/snow/networking/router/chain_router.go +++ b/snow/networking/router/chain_router.go @@ -259,7 +259,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes // before the overflow may not be handled properly. if notRequested := message.UnrequestedOps.Contains(op); notRequested || (op == message.PutOp && requestID == constants.GossipMsgRequestID) { - if chainCtx.IsExecuting() { + if chainCtx.Executing.Get() { cr.log.Debug("dropping message and skipping queue", zap.String("reason", "the chain is currently executing"), zap.Stringer("messageOp", op), @@ -290,7 +290,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes return } - if chainCtx.IsExecuting() { + if chainCtx.Executing.Get() { cr.log.Debug("dropping message and skipping queue", zap.String("reason", "the chain is currently executing"), zap.Stringer("messageOp", op), diff --git a/snow/networking/router/chain_router_test.go b/snow/networking/router/chain_router_test.go index 1a5c392be2db..3f3b3fbdb0f8 100644 --- a/snow/networking/router/chain_router_test.go +++ b/snow/networking/router/chain_router_test.go @@ -89,7 +89,6 @@ func TestShutdown(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -133,7 +132,10 @@ func TestShutdown(t *testing.T) { } engine.HaltF = func(context.Context) {} handler.SetConsensus(engine) - ctx.SetState(snow.NormalOp) // assumed bootstrap is done + ctx.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, // assumed bootstrapping is done + }) chainRouter.AddChain(context.Background(), handler) @@ -212,7 +214,6 @@ func TestShutdownTimesOut(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -255,7 +256,10 @@ func TestShutdownTimesOut(t *testing.T) { return nil } handler.SetConsensus(engine) - ctx.SetState(snow.NormalOp) // assumed bootstrapping is done + ctx.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, // assumed bootstrapping is done + }) chainRouter.AddChain(context.Background(), handler) @@ -353,7 +357,6 @@ func TestRouterTimeout(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -423,7 +426,10 @@ func TestRouterTimeout(t *testing.T) { return nil } handler.SetBootstrapper(bootstrapper) - ctx.SetState(snow.Bootstrapping) // assumed bootstrapping is ongoing + ctx.State.Set(snow.EngineState{ + Type: engineType, + State: snow.Bootstrapping, // assumed bootstrapping is ongoing + }) chainRouter.AddChain(context.Background(), handler) @@ -670,7 +676,6 @@ func TestRouterClearTimeouts(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -696,7 +701,10 @@ func TestRouterClearTimeouts(t *testing.T) { return ctx } handler.SetConsensus(engine) - ctx.SetState(snow.NormalOp) // assumed bootstrapping is done + ctx.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, // assumed bootstrapping is done + }) chainRouter.AddChain(context.Background(), handler) @@ -926,7 +934,7 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { wg := sync.WaitGroup{} ctx := snow.DefaultConsensusContextTest() - ctx.SetValidatorOnly() + ctx.ValidatorOnly.Set(true) vdrs := validators.NewSet() vID := ids.GenerateTestNodeID() err = vdrs.Add(vID, nil, ids.Empty, 1) @@ -944,7 +952,6 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -968,7 +975,10 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { return nil } handler.SetBootstrapper(bootstrapper) - ctx.SetState(snow.Bootstrapping) // assumed bootstrapping is ongoing + ctx.State.Set(snow.EngineState{ + Type: engineType, + State: snow.Bootstrapping, // assumed bootstrapping is ongoing + }) engine := &common.EngineTest{T: t} engine.ContextF = func() *snow.ConsensusContext { @@ -1089,7 +1099,7 @@ func TestRouterCrossChainMessages(t *testing.T) { requester.ChainID = ids.GenerateTestID() requester.Registerer = prometheus.NewRegistry() requester.Metrics = metrics.NewOptionalGatherer() - requester.Executing(false) + requester.Executing.Set(false) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -1105,7 +1115,6 @@ func TestRouterCrossChainMessages(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -1115,7 +1124,7 @@ func TestRouterCrossChainMessages(t *testing.T) { responder.ChainID = ids.GenerateTestID() responder.Registerer = prometheus.NewRegistry() responder.Metrics = metrics.NewOptionalGatherer() - responder.Executing(false) + responder.Executing.Set(false) responderHandler, err := handler.New( responder, @@ -1123,15 +1132,20 @@ func TestRouterCrossChainMessages(t *testing.T) { nil, nil, time.Second, - engineType, resourceTracker, validators.UnhandledSubnetConnector, ) require.NoError(t, err) // assumed bootstrapping is done - responder.SetState(snow.NormalOp) - requester.SetState(snow.NormalOp) + responder.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, + }) + requester.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, + }) // router tracks two chains - one will send a message to the other chainRouter.AddChain(context.Background(), requesterHandler) @@ -1230,8 +1244,11 @@ func TestConnectedSubnet(t *testing.T) { platform.SubnetID = constants.PrimaryNetworkID platform.Registerer = prometheus.NewRegistry() platform.Metrics = metrics.NewOptionalGatherer() - platform.Executing(false) - platform.SetState(snow.NormalOp) + platform.Executing.Set(false) + platform.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, + }) myConnectedMsg := message.InternalConnected(myNodeID, version.CurrentApp) mySubnetConnectedMsg0 := message.InternalConnectedSubnet(myNodeID, subnetID0) diff --git a/snow/networking/sender/sender.go b/snow/networking/sender/sender.go index f569e7e9e25f..55e541a01eab 100644 --- a/snow/networking/sender/sender.go +++ b/snow/networking/sender/sender.go @@ -149,7 +149,7 @@ func (s *sender) SendGetStateSummaryFrontier(ctx context.Context, nodeIDs set.Se outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -212,7 +212,7 @@ func (s *sender) SendStateSummaryFrontier(ctx context.Context, nodeID ids.NodeID outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -289,7 +289,7 @@ func (s *sender) SendGetAcceptedStateSummary(ctx context.Context, nodeIDs set.Se outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -352,7 +352,7 @@ func (s *sender) SendAcceptedStateSummary(ctx context.Context, nodeID ids.NodeID outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -424,7 +424,7 @@ func (s *sender) SendGetAcceptedFrontier(ctx context.Context, nodeIDs set.Set[id outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -489,7 +489,7 @@ func (s *sender) SendAcceptedFrontier(ctx context.Context, nodeID ids.NodeID, re outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -563,7 +563,7 @@ func (s *sender) SendGetAccepted(ctx context.Context, nodeIDs set.Set[ids.NodeID outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -623,7 +623,7 @@ func (s *sender) SendAccepted(ctx context.Context, nodeID ids.NodeID, requestID outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -703,7 +703,7 @@ func (s *sender) SendGetAncestors(ctx context.Context, nodeID ids.NodeID, reques outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -743,7 +743,7 @@ func (s *sender) SendAncestors(_ context.Context, nodeID ids.NodeID, requestID u outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -817,7 +817,7 @@ func (s *sender) SendGet(ctx context.Context, nodeID ids.NodeID, requestID uint3 outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -869,7 +869,7 @@ func (s *sender) SendPut(_ context.Context, nodeID ids.NodeID, requestID uint32, outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -976,7 +976,7 @@ func (s *sender) SendPushQuery(ctx context.Context, nodeIDs set.Set[ids.NodeID], outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -1102,7 +1102,7 @@ func (s *sender) SendPullQuery(ctx context.Context, nodeIDs set.Set[ids.NodeID], outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -1177,7 +1177,7 @@ func (s *sender) SendChits(ctx context.Context, nodeID ids.NodeID, requestID uin outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -1318,7 +1318,7 @@ func (s *sender) SendAppRequest(ctx context.Context, nodeIDs set.Set[ids.NodeID] outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) } else { s.ctx.Log.Error("failed to build message", @@ -1399,7 +1399,7 @@ func (s *sender) SendAppResponse(ctx context.Context, nodeID ids.NodeID, request outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { s.ctx.Log.Debug("failed to send message", @@ -1437,7 +1437,7 @@ func (s *sender) SendAppGossipSpecific(_ context.Context, nodeIDs set.Set[ids.No outMsg, nodeIDs, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), ) if sentTo.Len() == 0 { for nodeID := range nodeIDs { @@ -1480,7 +1480,7 @@ func (s *sender) SendAppGossip(_ context.Context, appGossipBytes []byte) error { sentTo := s.sender.Gossip( outMsg, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), validatorSize, nonValidatorSize, peerSize, @@ -1521,7 +1521,7 @@ func (s *sender) SendGossip(_ context.Context, container []byte) { sentTo := s.sender.Gossip( outMsg, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), int(s.gossipConfig.AcceptedFrontierValidatorSize), int(s.gossipConfig.AcceptedFrontierNonValidatorSize), int(s.gossipConfig.AcceptedFrontierPeerSize), @@ -1541,7 +1541,7 @@ func (s *sender) SendGossip(_ context.Context, container []byte) { // Accept is called after every consensus decision func (s *sender) Accept(ctx *snow.ConsensusContext, _ ids.ID, container []byte) error { - if ctx.GetState() != snow.NormalOp { + if ctx.State.Get().State != snow.NormalOp { // don't gossip during bootstrapping return nil } @@ -1566,7 +1566,7 @@ func (s *sender) Accept(ctx *snow.ConsensusContext, _ ids.ID, container []byte) sentTo := s.sender.Gossip( outMsg, s.ctx.SubnetID, - s.ctx.IsValidatorOnly(), + s.ctx.ValidatorOnly.Get(), int(s.gossipConfig.OnAcceptValidatorSize), int(s.gossipConfig.OnAcceptNonValidatorSize), int(s.gossipConfig.OnAcceptPeerSize), diff --git a/snow/networking/sender/sender_test.go b/snow/networking/sender/sender_test.go index 84556e51fa01..640da6524763 100644 --- a/snow/networking/sender/sender_test.go +++ b/snow/networking/sender/sender_test.go @@ -118,7 +118,6 @@ func TestTimeout(t *testing.T) { nil, nil, time.Hour, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -141,7 +140,10 @@ func TestTimeout(t *testing.T) { return nil } handler.SetBootstrapper(bootstrapper) - ctx2.SetState(snow.Bootstrapping) // assumed bootstrap is ongoing + ctx2.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, // assumed bootstrap is ongoing + }) chainRouter.AddChain(context.Background(), handler) @@ -377,7 +379,6 @@ func TestReliableMessages(t *testing.T) { nil, nil, 1, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -410,7 +411,10 @@ func TestReliableMessages(t *testing.T) { } bootstrapper.CantGossip = false handler.SetBootstrapper(bootstrapper) - ctx2.SetState(snow.Bootstrapping) // assumed bootstrap is ongoing + ctx2.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, // assumed bootstrap is ongoing + }) chainRouter.AddChain(context.Background(), handler) @@ -503,7 +507,6 @@ func TestReliableMessagesToMyself(t *testing.T) { nil, nil, time.Second, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, resourceTracker, validators.UnhandledSubnetConnector, ) @@ -535,7 +538,10 @@ func TestReliableMessagesToMyself(t *testing.T) { return nil } handler.SetBootstrapper(bootstrapper) - ctx2.SetState(snow.Bootstrapping) // assumed bootstrap is ongoing + ctx2.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Bootstrapping, // assumed bootstrap is ongoing + }) chainRouter.AddChain(context.Background(), handler) @@ -625,7 +631,7 @@ func TestSender_Bootstrap_Requests(t *testing.T) { failedNodeID: struct{}{}, }, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(set.Set[ids.NodeID]{ successNodeID: struct{}{}, }) @@ -672,7 +678,7 @@ func TestSender_Bootstrap_Requests(t *testing.T) { failedNodeID: struct{}{}, }, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(set.Set[ids.NodeID]{ successNodeID: struct{}{}, }) @@ -716,7 +722,7 @@ func TestSender_Bootstrap_Requests(t *testing.T) { failedNodeID: struct{}{}, }, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(set.Set[ids.NodeID]{ successNodeID: struct{}{}, }) @@ -761,7 +767,7 @@ func TestSender_Bootstrap_Requests(t *testing.T) { failedNodeID: struct{}{}, }, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(set.Set[ids.NodeID]{ successNodeID: struct{}{}, }) @@ -899,7 +905,7 @@ func TestSender_Bootstrap_Responses(t *testing.T) { gomock.Any(), // Outbound message set.Set[ids.NodeID]{destinationNodeID: struct{}{}}, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(nil) }, sendF: func(_ *require.Assertions, sender common.Sender, nodeID ids.NodeID) { @@ -929,7 +935,7 @@ func TestSender_Bootstrap_Responses(t *testing.T) { gomock.Any(), // Outbound message set.Set[ids.NodeID]{destinationNodeID: struct{}{}}, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(nil) }, sendF: func(_ *require.Assertions, sender common.Sender, nodeID ids.NodeID) { @@ -961,7 +967,7 @@ func TestSender_Bootstrap_Responses(t *testing.T) { gomock.Any(), // Outbound message set.Set[ids.NodeID]{destinationNodeID: struct{}{}}, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(nil) }, sendF: func(_ *require.Assertions, sender common.Sender, nodeID ids.NodeID) { @@ -993,7 +999,7 @@ func TestSender_Bootstrap_Responses(t *testing.T) { gomock.Any(), // Outbound message set.Set[ids.NodeID]{destinationNodeID: struct{}{}}, // Node IDs subnetID, // Subnet ID - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(nil) }, sendF: func(_ *require.Assertions, sender common.Sender, nodeID ids.NodeID) { @@ -1121,7 +1127,7 @@ func TestSender_Single_Request(t *testing.T) { gomock.Any(), // Outbound message set.Set[ids.NodeID]{destinationNodeID: struct{}{}}, // Node IDs subnetID, - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(sentTo) }, sendF: func(_ *require.Assertions, sender common.Sender, nodeID ids.NodeID) { @@ -1160,7 +1166,7 @@ func TestSender_Single_Request(t *testing.T) { gomock.Any(), // Outbound message set.Set[ids.NodeID]{destinationNodeID: struct{}{}}, // Node IDs subnetID, - snowCtx.IsValidatorOnly(), + snowCtx.ValidatorOnly.Get(), ).Return(sentTo) }, sendF: func(_ *require.Assertions, sender common.Sender, nodeID ids.NodeID) { diff --git a/snow/state.go b/snow/state.go index 845819d6ebc9..ce799212335f 100644 --- a/snow/state.go +++ b/snow/state.go @@ -3,7 +3,11 @@ package snow -import "errors" +import ( + "errors" + + "github.com/ava-labs/avalanchego/proto/pb/p2p" +) const ( Initializing = iota @@ -30,3 +34,8 @@ func (st State) String() string { return "Unknown state" } } + +type EngineState struct { + Type p2p.EngineType + State State +} diff --git a/snow/uptime/locked_calculator.go b/snow/uptime/locked_calculator.go index e567eb35d896..2e5248bdfa95 100644 --- a/snow/uptime/locked_calculator.go +++ b/snow/uptime/locked_calculator.go @@ -21,12 +21,12 @@ var ( type LockedCalculator interface { Calculator - SetCalculator(isBootstrapped *utils.AtomicBool, lock sync.Locker, newC Calculator) + SetCalculator(isBootstrapped *utils.Atomic[bool], lock sync.Locker, newC Calculator) } type lockedCalculator struct { lock sync.RWMutex - isBootstrapped *utils.AtomicBool + isBootstrapped *utils.Atomic[bool] calculatorLock sync.Locker c Calculator } @@ -39,7 +39,7 @@ func (c *lockedCalculator) CalculateUptime(nodeID ids.NodeID, subnetID ids.ID) ( c.lock.RLock() defer c.lock.RUnlock() - if c.isBootstrapped == nil || !c.isBootstrapped.GetValue() { + if c.isBootstrapped == nil || !c.isBootstrapped.Get() { return 0, time.Time{}, errNotReady } @@ -53,7 +53,7 @@ func (c *lockedCalculator) CalculateUptimePercent(nodeID ids.NodeID, subnetID id c.lock.RLock() defer c.lock.RUnlock() - if c.isBootstrapped == nil || !c.isBootstrapped.GetValue() { + if c.isBootstrapped == nil || !c.isBootstrapped.Get() { return 0, errNotReady } @@ -67,7 +67,7 @@ func (c *lockedCalculator) CalculateUptimePercentFrom(nodeID ids.NodeID, subnetI c.lock.RLock() defer c.lock.RUnlock() - if c.isBootstrapped == nil || !c.isBootstrapped.GetValue() { + if c.isBootstrapped == nil || !c.isBootstrapped.Get() { return 0, errNotReady } @@ -77,7 +77,7 @@ func (c *lockedCalculator) CalculateUptimePercentFrom(nodeID ids.NodeID, subnetI return c.c.CalculateUptimePercentFrom(nodeID, subnetID, startTime) } -func (c *lockedCalculator) SetCalculator(isBootstrapped *utils.AtomicBool, lock sync.Locker, newC Calculator) { +func (c *lockedCalculator) SetCalculator(isBootstrapped *utils.Atomic[bool], lock sync.Locker, newC Calculator) { c.lock.Lock() defer c.lock.Unlock() diff --git a/snow/uptime/locked_calculator_test.go b/snow/uptime/locked_calculator_test.go index 4f88cc908216..945c23316e8e 100644 --- a/snow/uptime/locked_calculator_test.go +++ b/snow/uptime/locked_calculator_test.go @@ -36,7 +36,7 @@ func TestLockedCalculator(t *testing.T) { _, err = lc.CalculateUptimePercentFrom(nodeID, subnetID, time.Now()) require.ErrorIs(err, errNotReady) - var isBootstrapped utils.AtomicBool + var isBootstrapped utils.Atomic[bool] mockCalc := NewMockCalculator(ctrl) // Should still error because ctx is not bootstrapped @@ -50,7 +50,7 @@ func TestLockedCalculator(t *testing.T) { _, err = lc.CalculateUptimePercentFrom(nodeID, subnetID, time.Now()) require.EqualValues(errNotReady, err) - isBootstrapped.SetValue(true) + isBootstrapped.Set(true) // Should return the value from the mocked inner calculator mockCalc.EXPECT().CalculateUptime(gomock.Any(), gomock.Any()).AnyTimes().Return(time.Duration(0), time.Time{}, errTest) diff --git a/snow/validators/manager.go b/snow/validators/manager.go index c9bdf18e910f..91391aaa2088 100644 --- a/snow/validators/manager.go +++ b/snow/validators/manager.go @@ -135,7 +135,7 @@ func RemoveWeight(m Manager, subnetID ids.ID, nodeID ids.NodeID, weight uint64) return vdrs.RemoveWeight(nodeID, weight) } -// AddWeight is a helper that fetches the validator set of [subnetID] from [m] +// Contains is a helper that fetches the validator set of [subnetID] from [m] // and returns if the validator set contains [nodeID]. If [m] does not contain a // validator set for [subnetID], false is returned. func Contains(m Manager, subnetID ids.ID, nodeID ids.NodeID) bool { diff --git a/utils/atomic.go b/utils/atomic.go new file mode 100644 index 000000000000..05db3281fbaa --- /dev/null +++ b/utils/atomic.go @@ -0,0 +1,27 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package utils + +import ( + "sync" +) + +type Atomic[T any] struct { + lock sync.RWMutex + value T +} + +func (a *Atomic[T]) Get() T { + a.lock.RLock() + defer a.lock.RUnlock() + + return a.value +} + +func (a *Atomic[T]) Set(value T) { + a.lock.Lock() + defer a.lock.Unlock() + + a.value = value +} diff --git a/utils/atomic_bool.go b/utils/atomic_bool.go deleted file mode 100644 index c01008ac1bde..000000000000 --- a/utils/atomic_bool.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package utils - -import "sync/atomic" - -type AtomicBool struct { - value uint32 -} - -func (a *AtomicBool) GetValue() bool { - return atomic.LoadUint32(&a.value) != 0 -} - -func (a *AtomicBool) SetValue(b bool) { - var value uint32 - if b { - value = 1 - } - atomic.StoreUint32(&a.value, value) -} diff --git a/utils/atomic_interface.go b/utils/atomic_interface.go deleted file mode 100644 index d3c239aad02c..000000000000 --- a/utils/atomic_interface.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package utils - -import ( - "sync" -) - -type AtomicInterface struct { - value interface{} - lock sync.RWMutex -} - -func NewAtomicInterface(v interface{}) *AtomicInterface { - mutexInterface := AtomicInterface{} - mutexInterface.SetValue(v) - return &mutexInterface -} - -func (a *AtomicInterface) GetValue() interface{} { - a.lock.RLock() - defer a.lock.RUnlock() - return a.value -} - -func (a *AtomicInterface) SetValue(v interface{}) { - a.lock.Lock() - defer a.lock.Unlock() - a.value = v -} diff --git a/utils/atomic_interface_test.go b/utils/atomic_interface_test.go deleted file mode 100644 index 2897e4dde977..000000000000 --- a/utils/atomic_interface_test.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package utils - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestAtomicInterface(t *testing.T) { - iface := NewAtomicInterface(nil) - require.Nil(t, iface.GetValue()) - iface.SetValue(nil) - require.Nil(t, iface.GetValue()) - val, ok := iface.GetValue().([]byte) - require.False(t, ok) - require.Nil(t, val) - iface.SetValue([]byte("test")) - require.Equal(t, []byte("test"), iface.GetValue().([]byte)) -} diff --git a/utils/atomic_test.go b/utils/atomic_test.go new file mode 100644 index 000000000000..34a946676bd4 --- /dev/null +++ b/utils/atomic_test.go @@ -0,0 +1,26 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAtomic(t *testing.T) { + require := require.New(t) + + var a Atomic[bool] + require.Zero(a.Get()) + + a.Set(false) + require.False(a.Get()) + + a.Set(true) + require.True(a.Get()) + + a.Set(false) + require.False(a.Get()) +} diff --git a/utils/crypto/secp256k1r.go b/utils/crypto/secp256k1r.go index 85ca27d04af9..5dcd9171751a 100644 --- a/utils/crypto/secp256k1r.go +++ b/utils/crypto/secp256k1r.go @@ -56,7 +56,9 @@ var ( _ PrivateKey = (*PrivateKeySECP256K1R)(nil) ) -type FactorySECP256K1R struct{ Cache cache.LRU } +type FactorySECP256K1R struct { + Cache cache.LRU[ids.ID, *PublicKeySECP256K1R] +} func (*FactorySECP256K1R) NewPrivateKey() (PrivateKey, error) { k, err := secp256k1.GeneratePrivateKey() @@ -91,7 +93,7 @@ func (f *FactorySECP256K1R) RecoverHashPublicKey(hash, sig []byte) (PublicKey, e copy(cacheBytes[len(hash):], sig) id := hashing.ComputeHash256Array(cacheBytes) if cachedPublicKey, ok := f.Cache.Get(id); ok { - return cachedPublicKey.(*PublicKeySECP256K1R), nil + return cachedPublicKey, nil } if err := verifySECP256K1RSignatureFormat(sig); err != nil { diff --git a/utils/crypto/secp256k1r_test.go b/utils/crypto/secp256k1r_test.go index 0c4fc4357d90..67f95146c6e4 100644 --- a/utils/crypto/secp256k1r_test.go +++ b/utils/crypto/secp256k1r_test.go @@ -12,6 +12,7 @@ import ( secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v3" "github.com/ava-labs/avalanchego/cache" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/hashing" ) @@ -38,7 +39,7 @@ func TestRecover(t *testing.T) { func TestCachedRecover(t *testing.T) { require := require.New(t) - f := FactorySECP256K1R{Cache: cache.LRU{Size: 1}} + f := FactorySECP256K1R{Cache: cache.LRU[ids.ID, *PublicKeySECP256K1R]{Size: 1}} key, err := f.NewPrivateKey() require.NoError(err) diff --git a/utils/hashing/consistent/ring.go b/utils/hashing/consistent/ring.go index 32c53a143934..4fae79dc12cc 100644 --- a/utils/hashing/consistent/ring.go +++ b/utils/hashing/consistent/ring.go @@ -12,8 +12,8 @@ import ( ) var ( - _ Ring = (*hashRing)(nil) - _ btree.Item = (*ringItem)(nil) + _ Ring = (*hashRing)(nil) + _ btree.LessFunc[ringItem] = ringItem.Less errEmptyRing = errors.New("ring doesn't have any members") ) @@ -154,7 +154,7 @@ type hashRing struct { virtualNodes int lock sync.RWMutex - ring *btree.BTree + ring *btree.BTreeG[ringItem] } // RingConfig configures settings for a Ring. @@ -172,7 +172,7 @@ func NewHashRing(config RingConfig) Ring { return &hashRing{ hasher: config.Hasher, virtualNodes: config.VirtualNodes, - ring: btree.New(config.Degree), + ring: btree.NewG(config.Degree, ringItem.Less), } } @@ -200,8 +200,7 @@ func (h *hashRing) get(key Hashable) (Hashable, error) { hash: hash, value: key, }, - func(itemIntf btree.Item) bool { - item := itemIntf.(ringItem) + func(item ringItem) bool { if hash < item.hash { result = item.value return false @@ -213,7 +212,8 @@ func (h *hashRing) get(key Hashable) (Hashable, error) { // If found nothing ascending the tree, we need to wrap around the ring to // the left-most (min) node. if result == nil { - result = h.ring.Min().(ringItem).value + min, _ := h.ring.Min() + result = min.value } return result, nil } @@ -260,9 +260,7 @@ func (h *hashRing) remove(key Hashable) bool { item := ringItem{ hash: virtualNodeHash, } - if h.ring.Delete(item) != nil { - removed = true - } + _, removed = h.ring.Delete(item) } return removed } @@ -278,6 +276,6 @@ type ringItem struct { value Hashable } -func (r ringItem) Less(than btree.Item) bool { - return r.hash < than.(ringItem).hash +func (r ringItem) Less(than ringItem) bool { + return r.hash < than.hash } diff --git a/vms/avm/blocks/parser.go b/vms/avm/blocks/parser.go index 2f67bda6bfb7..3d654742da40 100644 --- a/vms/avm/blocks/parser.go +++ b/vms/avm/blocks/parser.go @@ -5,8 +5,11 @@ package blocks import ( "fmt" + "reflect" "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/vms/avm/fxs" "github.com/ava-labs/avalanchego/vms/avm/txs" @@ -49,6 +52,29 @@ func NewParser(fxs []fxs.Fx) (Parser, error) { }, errs.Err } +func NewCustomParser( + typeToFxIndex map[reflect.Type]int, + clock *mockable.Clock, + log logging.Logger, + fxs []fxs.Fx, +) (Parser, error) { + p, err := txs.NewCustomParser(typeToFxIndex, clock, log, fxs) + if err != nil { + return nil, err + } + c := p.CodecRegistry() + gc := p.GenesisCodecRegistry() + + errs := wrappers.Errs{} + errs.Add( + c.RegisterType(&StandardBlock{}), + gc.RegisterType(&StandardBlock{}), + ) + return &parser{ + Parser: p, + }, errs.Err +} + func (p *parser) ParseBlock(bytes []byte) (Block, error) { return parse(p.Codec(), bytes) } diff --git a/vms/avm/index_test.go b/vms/avm/index_test.go index 784d6e79be5a..d9493c74e4a1 100644 --- a/vms/avm/index_test.go +++ b/vms/avm/index_test.go @@ -76,9 +76,7 @@ func TestIndexTransaction_Ordered(t *testing.T) { utxo := buildPlatformUTXO(utxoID, txAssetID, addr) // save utxo to state - if err := vm.state.PutUTXO(utxo); err != nil { - t.Fatal("Error saving utxo", err) - } + vm.state.AddUTXO(utxo) // issue transaction if _, err := vm.IssueTx(tx.Bytes()); err != nil { @@ -168,9 +166,7 @@ func TestIndexTransaction_MultipleTransactions(t *testing.T) { utxo := buildPlatformUTXO(utxoID, txAssetID, addr) // save utxo to state - if err := vm.state.PutUTXO(utxo); err != nil { - t.Fatal("Error saving utxo", err) - } + vm.state.AddUTXO(utxo) // issue transaction if _, err := vm.IssueTx(tx.Bytes()); err != nil { @@ -264,9 +260,7 @@ func TestIndexTransaction_MultipleAddresses(t *testing.T) { utxo := buildPlatformUTXO(utxoID, txAssetID, addr) // save utxo to state - if err := vm.state.PutUTXO(utxo); err != nil { - t.Fatal("Error saving utxo", err) - } + vm.state.AddUTXO(utxo) var inputUTXOs []*avax.UTXO //nolint:prealloc for _, utxoID := range tx.Unsigned.InputUTXOs() { @@ -325,9 +319,7 @@ func TestIndexTransaction_UnorderedWrites(t *testing.T) { utxo := buildPlatformUTXO(utxoID, txAssetID, addr) // save utxo to state - if err := vm.state.PutUTXO(utxo); err != nil { - t.Fatal("Error saving utxo", err) - } + vm.state.AddUTXO(utxo) // issue transaction if _, err := vm.IssueTx(tx.Bytes()); err != nil { diff --git a/vms/avm/service_test.go b/vms/avm/service_test.go index b36ade7fb2c9..515de9ddaf9c 100644 --- a/vms/avm/service_test.go +++ b/vms/avm/service_test.go @@ -304,8 +304,8 @@ func TestServiceGetBalanceStrict(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(twoOfTwoUTXO) - require.NoError(t, err) + vm.state.AddUTXO(twoOfTwoUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs := &GetBalanceArgs{ @@ -349,8 +349,8 @@ func TestServiceGetBalanceStrict(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(oneOfTwoUTXO) - require.NoError(t, err) + vm.state.AddUTXO(oneOfTwoUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs = &GetBalanceArgs{ @@ -396,8 +396,8 @@ func TestServiceGetBalanceStrict(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(futureUTXO) - require.NoError(t, err) + vm.state.AddUTXO(futureUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs = &GetBalanceArgs{ @@ -500,8 +500,8 @@ func TestServiceGetAllBalances(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(twoOfTwoUTXO) - require.NoError(t, err) + vm.state.AddUTXO(twoOfTwoUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs := &GetAllBalancesArgs{ @@ -542,8 +542,8 @@ func TestServiceGetAllBalances(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(oneOfTwoUTXO) - require.NoError(t, err) + vm.state.AddUTXO(oneOfTwoUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs = &GetAllBalancesArgs{ @@ -587,8 +587,8 @@ func TestServiceGetAllBalances(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(futureUTXO) - require.NoError(t, err) + vm.state.AddUTXO(futureUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs = &GetAllBalancesArgs{ @@ -630,8 +630,8 @@ func TestServiceGetAllBalances(t *testing.T) { }, } // Insert the UTXO - err = vm.state.PutUTXO(otherAssetUTXO) - require.NoError(t, err) + vm.state.AddUTXO(otherAssetUTXO) + require.NoError(t, vm.state.Commit()) // Check the balance with IncludePartial set to true balanceArgs = &GetAllBalancesArgs{ @@ -1829,10 +1829,9 @@ func TestServiceGetUTXOs(t *testing.T) { }, }, } - if err := vm.state.PutUTXO(utxo); err != nil { - t.Fatal(err) - } + vm.state.AddUTXO(utxo) } + require.NoError(t, vm.state.Commit()) sm := m.NewSharedMemory(constants.PlatformChainID) diff --git a/vms/avm/state_test.go b/vms/avm/state_test.go index 6179db9d37f8..077bcb69180c 100644 --- a/vms/avm/state_test.go +++ b/vms/avm/state_test.go @@ -8,6 +8,8 @@ import ( "math" "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/engine/common" @@ -75,25 +77,21 @@ func TestSetsAndGets(t *testing.T) { t.Fatal(err) } - if err := state.PutUTXO(utxo); err != nil { - t.Fatal(err) - } - if err := state.PutTx(ids.Empty, tx); err != nil { - t.Fatal(err) - } - if err := state.PutStatus(ids.Empty, choices.Accepted); err != nil { - t.Fatal(err) - } + txID := tx.ID() + + state.AddUTXO(utxo) + state.AddTx(tx) + state.AddStatus(txID, choices.Accepted) resultUTXO, err := state.GetUTXO(utxoID) if err != nil { t.Fatal(err) } - resultTx, err := state.GetTx(ids.Empty) + resultTx, err := state.GetTx(txID) if err != nil { t.Fatal(err) } - resultStatus, err := state.GetStatus(ids.Empty) + resultStatus, err := state.GetStatus(txID) if err != nil { t.Fatal(err) } @@ -142,12 +140,8 @@ func TestFundingNoAddresses(t *testing.T) { Out: &avax.TestVerifiable{}, } - if err := state.PutUTXO(utxo); err != nil { - t.Fatal(err) - } - if err := state.DeleteUTXO(utxo.InputID()); err != nil { - t.Fatal(err) - } + state.AddUTXO(utxo) + state.DeleteUTXO(utxo.InputID()) } func TestFundingAddresses(t *testing.T) { @@ -185,27 +179,18 @@ func TestFundingAddresses(t *testing.T) { }, } - if err := state.PutUTXO(utxo); err != nil { - t.Fatal(err) - } + state.AddUTXO(utxo) + require.NoError(t, state.Commit()) + utxos, err := state.UTXOIDs([]byte{0}, ids.Empty, math.MaxInt32) - if err != nil { - t.Fatal(err) - } - if len(utxos) != 1 { - t.Fatalf("Should have returned 1 utxoIDs") - } - if utxoID := utxos[0]; utxoID != utxo.InputID() { - t.Fatalf("Returned wrong utxoID") - } - if err := state.DeleteUTXO(utxo.InputID()); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.Len(t, utxos, 1) + require.Equal(t, utxo.InputID(), utxos[0]) + + state.DeleteUTXO(utxo.InputID()) + require.NoError(t, state.Commit()) + utxos, err = state.UTXOIDs([]byte{0}, ids.Empty, math.MaxInt32) - if err != nil { - t.Fatal(err) - } - if len(utxos) != 0 { - t.Fatalf("Should have returned 0 utxoIDs") - } + require.NoError(t, err) + require.Empty(t, utxos) } diff --git a/vms/avm/states/diff.go b/vms/avm/states/diff.go new file mode 100644 index 000000000000..17031bdf2c6f --- /dev/null +++ b/vms/avm/states/diff.go @@ -0,0 +1,168 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package states + +import ( + "errors" + "fmt" + "time" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/vms/avm/blocks" + "github.com/ava-labs/avalanchego/vms/avm/txs" + "github.com/ava-labs/avalanchego/vms/components/avax" +) + +var ( + _ Diff = (*diff)(nil) + + ErrMissingParentState = errors.New("missing parent state") +) + +type Diff interface { + Chain + + Apply(Chain) +} + +type diff struct { + parentID ids.ID + stateVersions Versions + + // map of modified UTXOID -> *UTXO if the UTXO is nil, it has been removed + modifiedUTXOs map[ids.ID]*avax.UTXO + addedTxs map[ids.ID]*txs.Tx // map of txID -> tx + addedBlockIDs map[uint64]ids.ID // map of height -> blockID + addedBlocks map[ids.ID]blocks.Block // map of blockID -> block + + lastAccepted ids.ID + timestamp time.Time +} + +func NewDiff( + parentID ids.ID, + stateVersions Versions, +) (Diff, error) { + parentState, ok := stateVersions.GetState(parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, parentID) + } + return &diff{ + parentID: parentID, + stateVersions: stateVersions, + modifiedUTXOs: make(map[ids.ID]*avax.UTXO), + addedTxs: make(map[ids.ID]*txs.Tx), + addedBlockIDs: make(map[uint64]ids.ID), + addedBlocks: make(map[ids.ID]blocks.Block), + lastAccepted: parentState.GetLastAccepted(), + timestamp: parentState.GetTimestamp(), + }, nil +} + +func (d *diff) GetUTXO(utxoID ids.ID) (*avax.UTXO, error) { + if utxo, modified := d.modifiedUTXOs[utxoID]; modified { + if utxo == nil { + return nil, database.ErrNotFound + } + return utxo, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + return parentState.GetUTXO(utxoID) +} + +func (d *diff) AddUTXO(utxo *avax.UTXO) { + d.modifiedUTXOs[utxo.InputID()] = utxo +} + +func (d *diff) DeleteUTXO(utxoID ids.ID) { + d.modifiedUTXOs[utxoID] = nil +} + +func (d *diff) GetTx(txID ids.ID) (*txs.Tx, error) { + if tx, exists := d.addedTxs[txID]; exists { + return tx, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + return parentState.GetTx(txID) +} + +func (d *diff) AddTx(tx *txs.Tx) { + d.addedTxs[tx.ID()] = tx +} + +func (d *diff) GetBlockID(height uint64) (ids.ID, error) { + if blkID, exists := d.addedBlockIDs[height]; exists { + return blkID, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return ids.Empty, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + return parentState.GetBlockID(height) +} + +func (d *diff) GetBlock(blkID ids.ID) (blocks.Block, error) { + if blk, exists := d.addedBlocks[blkID]; exists { + return blk, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + return parentState.GetBlock(blkID) +} + +func (d *diff) AddBlock(blk blocks.Block) { + blkID := blk.ID() + d.addedBlockIDs[blk.Height()] = blkID + d.addedBlocks[blkID] = blk +} + +func (d *diff) GetLastAccepted() ids.ID { + return d.lastAccepted +} + +func (d *diff) SetLastAccepted(lastAccepted ids.ID) { + d.lastAccepted = lastAccepted +} + +func (d *diff) GetTimestamp() time.Time { + return d.timestamp +} + +func (d *diff) SetTimestamp(t time.Time) { + d.timestamp = t +} + +func (d *diff) Apply(state Chain) { + for utxoID, utxo := range d.modifiedUTXOs { + if utxo != nil { + state.AddUTXO(utxo) + } else { + state.DeleteUTXO(utxoID) + } + } + + for _, tx := range d.addedTxs { + state.AddTx(tx) + } + + for _, blk := range d.addedBlocks { + state.AddBlock(blk) + } + + state.SetLastAccepted(d.lastAccepted) + state.SetTimestamp(d.timestamp) +} diff --git a/vms/avm/states/state.go b/vms/avm/states/state.go index 8064765c47cd..5cf5feedd887 100644 --- a/vms/avm/states/state.go +++ b/vms/avm/states/state.go @@ -4,60 +4,566 @@ package states import ( + "fmt" + "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/avalanchego/cache" + "github.com/ava-labs/avalanchego/cache/metercacher" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/prefixdb" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/choices" + "github.com/ava-labs/avalanchego/utils/wrappers" + "github.com/ava-labs/avalanchego/vms/avm/blocks" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" ) +const ( + statusCacheSize = 8192 + txCacheSize = 8192 + blockIDCacheSize = 8192 + blockCacheSize = 2048 +) + var ( utxoPrefix = []byte("utxo") statusPrefix = []byte("status") - singletonPrefix = []byte("singleton") txPrefix = []byte("tx") + blockIDPrefix = []byte("blockID") + blockPrefix = []byte("block") + singletonPrefix = []byte("singleton") + + isInitializedKey = []byte{0x00} + timestampKey = []byte{0x01} + lastAcceptedKey = []byte{0x02} _ State = (*state)(nil) ) +type Chain interface { + GetUTXO(utxoID ids.ID) (*avax.UTXO, error) + AddUTXO(utxo *avax.UTXO) + DeleteUTXO(utxoID ids.ID) + + GetTx(txID ids.ID) (*txs.Tx, error) + AddTx(tx *txs.Tx) + + GetBlockID(height uint64) (ids.ID, error) + GetBlock(blkID ids.ID) (blocks.Block, error) + AddBlock(block blocks.Block) + + GetLastAccepted() ids.ID + SetLastAccepted(blkID ids.ID) + + GetTimestamp() time.Time + SetTimestamp(t time.Time) +} + // State persistently maintains a set of UTXOs, transaction, statuses, and // singletons. type State interface { - avax.UTXOState - avax.StatusState - avax.SingletonState - TxState + Chain + avax.UTXOReader + + IsInitialized() (bool, error) + SetInitialized() error + + // InitializeChainState is called after the VM has been linearized. Calling + // [GetLastAccepted] or [GetTimestamp] before calling this function will + // return uninitialized data. + // + // Invariant: After the chain is linearized, this function is expected to be + // called during startup. + InitializeChainState(stopVertexID ids.ID, genesisTimestamp time.Time) error + + // TODO: deprecate statuses. We should only persist accepted state + // Status returns a status from storage. + GetStatus(id ids.ID) (choices.Status, error) + // AddStatus saves a status in storage. + AddStatus(id ids.ID, status choices.Status) + + // Discard uncommitted changes to the database. + Abort() + + // Commit changes to the base database. + Commit() error + + // Returns a batch of unwritten changes that, when written, will commit all + // pending changes to the base database. + CommitBatch() (database.Batch, error) + + Close() error } +/* + * VMDB + * |- utxos + * | '-- utxoDB + * |- statuses + * | '-- statusDB + * |-. txs + * | '-- txID -> tx bytes + * |-. blockIDs + * | '-- height -> blockID + * |-. blocks + * | '-- blockID -> block bytes + * '-. singletons + * |-- initializedKey -> nil + * |-- timestampKey -> timestamp + * '-- lastAcceptedKey -> lastAccepted + */ type state struct { - avax.UTXOState - avax.StatusState - avax.SingletonState - TxState + parser blocks.Parser + db *versiondb.Database + + modifiedUTXOs map[ids.ID]*avax.UTXO // map of modified UTXOID -> *UTXO if the UTXO is nil, it has been removed + utxoDB database.Database + utxoState avax.UTXOState + + addedStatuses map[ids.ID]choices.Status + statusCache cache.Cacher[ids.ID, *choices.Status] // cache of id -> choices.Status. If the entry is nil, it is not in the database + statusDB database.Database + + addedTxs map[ids.ID]*txs.Tx // map of txID -> *txs.Tx + txCache cache.Cacher[ids.ID, *txs.Tx] // cache of txID -> *txs.Tx. If the entry is nil, it is not in the database + txDB database.Database + + addedBlockIDs map[uint64]ids.ID // map of height -> blockID + blockIDCache cache.Cacher[uint64, ids.ID] // cache of height -> blockID. If the entry is ids.Empty, it is not in the database + blockIDDB database.Database + + addedBlocks map[ids.ID]blocks.Block // map of blockID -> Block + blockCache cache.Cacher[ids.ID, blocks.Block] // cache of blockID -> Block. If the entry is nil, it is not in the database + blockDB database.Database + + // [lastAccepted] is the most recently accepted block. + lastAccepted, persistedLastAccepted ids.ID + timestamp, persistedTimestamp time.Time + singletonDB database.Database } -func New(db database.Database, parser txs.Parser, metrics prometheus.Registerer) (State, error) { +func New( + db *versiondb.Database, + parser blocks.Parser, + metrics prometheus.Registerer, +) (State, error) { utxoDB := prefixdb.New(utxoPrefix, db) statusDB := prefixdb.New(statusPrefix, db) - singletonDB := prefixdb.New(singletonPrefix, db) txDB := prefixdb.New(txPrefix, db) + blockIDDB := prefixdb.New(blockIDPrefix, db) + blockDB := prefixdb.New(blockPrefix, db) + singletonDB := prefixdb.New(singletonPrefix, db) - utxoState, err := avax.NewMeteredUTXOState(utxoDB, parser.Codec(), metrics) + statusCache, err := metercacher.New[ids.ID, *choices.Status]( + "status_cache", + metrics, + &cache.LRU[ids.ID, *choices.Status]{Size: statusCacheSize}, + ) + if err != nil { + return nil, err + } + + txCache, err := metercacher.New[ids.ID, *txs.Tx]( + "tx_cache", + metrics, + &cache.LRU[ids.ID, *txs.Tx]{Size: txCacheSize}, + ) if err != nil { return nil, err } - statusState, err := avax.NewMeteredStatusState(statusDB, metrics) + blockIDCache, err := metercacher.New[uint64, ids.ID]( + "block_id_cache", + metrics, + &cache.LRU[uint64, ids.ID]{Size: blockIDCacheSize}, + ) if err != nil { return nil, err } - txState, err := NewTxState(txDB, parser, metrics) + blockCache, err := metercacher.New[ids.ID, blocks.Block]( + "block_cache", + metrics, + &cache.LRU[ids.ID, blocks.Block]{Size: blockCacheSize}, + ) + if err != nil { + return nil, err + } + + utxoState, err := avax.NewMeteredUTXOState(utxoDB, parser.Codec(), metrics) return &state{ - UTXOState: utxoState, - StatusState: statusState, - SingletonState: avax.NewSingletonState(singletonDB), - TxState: txState, + parser: parser, + db: db, + + modifiedUTXOs: make(map[ids.ID]*avax.UTXO), + utxoDB: utxoDB, + utxoState: utxoState, + + addedStatuses: make(map[ids.ID]choices.Status), + statusCache: statusCache, + statusDB: statusDB, + + addedTxs: make(map[ids.ID]*txs.Tx), + txCache: txCache, + txDB: txDB, + + addedBlockIDs: make(map[uint64]ids.ID), + blockIDCache: blockIDCache, + blockIDDB: blockIDDB, + + addedBlocks: make(map[ids.ID]blocks.Block), + blockCache: blockCache, + blockDB: blockDB, + + singletonDB: singletonDB, }, err } + +func (s *state) GetUTXO(utxoID ids.ID) (*avax.UTXO, error) { + if utxo, exists := s.modifiedUTXOs[utxoID]; exists { + if utxo == nil { + return nil, database.ErrNotFound + } + return utxo, nil + } + return s.utxoState.GetUTXO(utxoID) +} + +func (s *state) UTXOIDs(addr []byte, start ids.ID, limit int) ([]ids.ID, error) { + return s.utxoState.UTXOIDs(addr, start, limit) +} + +func (s *state) AddUTXO(utxo *avax.UTXO) { + s.modifiedUTXOs[utxo.InputID()] = utxo +} + +func (s *state) DeleteUTXO(utxoID ids.ID) { + s.modifiedUTXOs[utxoID] = nil +} + +func (s *state) GetTx(txID ids.ID) (*txs.Tx, error) { + if tx, exists := s.addedTxs[txID]; exists { + return tx, nil + } + if tx, exists := s.txCache.Get(txID); exists { + if tx == nil { + return nil, database.ErrNotFound + } + return tx, nil + } + + txBytes, err := s.txDB.Get(txID[:]) + if err == database.ErrNotFound { + s.txCache.Put(txID, nil) + return nil, database.ErrNotFound + } + if err != nil { + return nil, err + } + + // The key was in the database + tx, err := s.parser.ParseGenesisTx(txBytes) + if err != nil { + return nil, err + } + + s.txCache.Put(txID, tx) + return tx, nil +} + +func (s *state) AddTx(tx *txs.Tx) { + s.addedTxs[tx.ID()] = tx +} + +func (s *state) GetBlockID(height uint64) (ids.ID, error) { + if blkID, exists := s.addedBlockIDs[height]; exists { + return blkID, nil + } + if blkID, cached := s.blockIDCache.Get(height); cached { + if blkID == ids.Empty { + return ids.Empty, database.ErrNotFound + } + + return blkID, nil + } + + heightKey := database.PackUInt64(height) + + blkID, err := database.GetID(s.blockIDDB, heightKey) + if err == database.ErrNotFound { + s.blockIDCache.Put(height, ids.Empty) + return ids.Empty, database.ErrNotFound + } + if err != nil { + return ids.Empty, err + } + + s.blockIDCache.Put(height, blkID) + return blkID, nil +} + +func (s *state) GetBlock(blkID ids.ID) (blocks.Block, error) { + if blk, exists := s.addedBlocks[blkID]; exists { + return blk, nil + } + if blk, cached := s.blockCache.Get(blkID); cached { + if blk == nil { + return nil, database.ErrNotFound + } + + return blk, nil + } + + blkBytes, err := s.blockDB.Get(blkID[:]) + if err == database.ErrNotFound { + s.blockCache.Put(blkID, nil) + return nil, database.ErrNotFound + } + if err != nil { + return nil, err + } + + blk, err := s.parser.ParseBlock(blkBytes) + if err != nil { + return nil, err + } + + s.blockCache.Put(blkID, blk) + return blk, nil +} + +func (s *state) AddBlock(block blocks.Block) { + blkID := block.ID() + s.addedBlockIDs[block.Height()] = blkID + s.addedBlocks[blkID] = block +} + +func (s *state) InitializeChainState(stopVertexID ids.ID, genesisTimestamp time.Time) error { + lastAccepted, err := database.GetID(s.singletonDB, lastAcceptedKey) + if err == database.ErrNotFound { + return s.initializeChainState(stopVertexID, genesisTimestamp) + } else if err != nil { + return err + } + s.lastAccepted = lastAccepted + s.persistedLastAccepted = lastAccepted + s.timestamp, err = database.GetTimestamp(s.singletonDB, timestampKey) + s.persistedTimestamp = s.timestamp + return err +} + +func (s *state) initializeChainState(stopVertexID ids.ID, genesisTimestamp time.Time) error { + genesis, err := blocks.NewStandardBlock( + stopVertexID, + 0, + genesisTimestamp, + nil, + s.parser.Codec(), + ) + if err != nil { + return err + } + + s.SetLastAccepted(genesis.ID()) + s.SetTimestamp(genesis.Timestamp()) + s.AddBlock(genesis) + return s.Commit() +} + +func (s *state) IsInitialized() (bool, error) { + return s.singletonDB.Has(isInitializedKey) +} + +func (s *state) SetInitialized() error { + return s.singletonDB.Put(isInitializedKey, nil) +} + +func (s *state) GetLastAccepted() ids.ID { + return s.lastAccepted +} + +func (s *state) SetLastAccepted(lastAccepted ids.ID) { + s.lastAccepted = lastAccepted +} + +func (s *state) GetTimestamp() time.Time { + return s.timestamp +} + +func (s *state) SetTimestamp(t time.Time) { + s.timestamp = t +} + +// TODO: remove status support +func (s *state) GetStatus(id ids.ID) (choices.Status, error) { + if status, exists := s.addedStatuses[id]; exists { + return status, nil + } + if status, found := s.statusCache.Get(id); found { + if status == nil { + return choices.Unknown, database.ErrNotFound + } + return *status, nil + } + + val, err := database.GetUInt32(s.statusDB, id[:]) + if err == database.ErrNotFound { + s.statusCache.Put(id, nil) + return choices.Unknown, database.ErrNotFound + } + if err != nil { + return choices.Unknown, err + } + + status := choices.Status(val) + if err := status.Valid(); err != nil { + return choices.Unknown, err + } + + s.statusCache.Put(id, &status) + return status, nil +} + +// TODO: remove status support +func (s *state) AddStatus(id ids.ID, status choices.Status) { + s.addedStatuses[id] = status +} + +func (s *state) Commit() error { + defer s.Abort() + batch, err := s.CommitBatch() + if err != nil { + return err + } + return batch.Write() +} + +func (s *state) Abort() { + s.db.Abort() +} + +func (s *state) CommitBatch() (database.Batch, error) { + if err := s.write(); err != nil { + return nil, err + } + return s.db.CommitBatch() +} + +func (s *state) Close() error { + errs := wrappers.Errs{} + errs.Add( + s.utxoDB.Close(), + s.statusDB.Close(), + s.txDB.Close(), + s.blockIDDB.Close(), + s.blockDB.Close(), + s.singletonDB.Close(), + s.db.Close(), + ) + return errs.Err +} + +func (s *state) write() error { + errs := wrappers.Errs{} + errs.Add( + s.writeUTXOs(), + s.writeTxs(), + s.writeBlockIDs(), + s.writeBlocks(), + s.writeMetadata(), + s.writeStatuses(), + ) + return errs.Err +} + +func (s *state) writeUTXOs() error { + for utxoID, utxo := range s.modifiedUTXOs { + delete(s.modifiedUTXOs, utxoID) + + if utxo != nil { + if err := s.utxoState.PutUTXO(utxo); err != nil { + return fmt.Errorf("failed to add utxo: %w", err) + } + } else { + if err := s.utxoState.DeleteUTXO(utxoID); err != nil { + return fmt.Errorf("failed to remove utxo: %w", err) + } + } + } + return nil +} + +func (s *state) writeTxs() error { + for txID, tx := range s.addedTxs { + txID := txID + txBytes := tx.Bytes() + + delete(s.addedTxs, txID) + s.txCache.Put(txID, tx) + if err := s.txDB.Put(txID[:], txBytes); err != nil { + return fmt.Errorf("failed to add tx: %w", err) + } + } + return nil +} + +func (s *state) writeBlockIDs() error { + for height, blkID := range s.addedBlockIDs { + heightKey := database.PackUInt64(height) + + delete(s.addedBlockIDs, height) + s.blockIDCache.Put(height, blkID) + if err := database.PutID(s.blockIDDB, heightKey, blkID); err != nil { + return fmt.Errorf("failed to add blockID: %w", err) + } + } + return nil +} + +func (s *state) writeBlocks() error { + for blkID, blk := range s.addedBlocks { + blkID := blkID + blkBytes := blk.Bytes() + + delete(s.addedBlocks, blkID) + s.blockCache.Put(blkID, blk) + if err := s.blockDB.Put(blkID[:], blkBytes); err != nil { + return fmt.Errorf("failed to add block: %w", err) + } + } + return nil +} + +func (s *state) writeMetadata() error { + if !s.persistedTimestamp.Equal(s.timestamp) { + if err := database.PutTimestamp(s.singletonDB, timestampKey, s.timestamp); err != nil { + return fmt.Errorf("failed to write timestamp: %w", err) + } + s.persistedTimestamp = s.timestamp + } + if s.persistedLastAccepted != s.lastAccepted { + if err := database.PutID(s.singletonDB, lastAcceptedKey, s.lastAccepted); err != nil { + return fmt.Errorf("failed to write last accepted: %w", err) + } + s.persistedLastAccepted = s.lastAccepted + } + return nil +} + +func (s *state) writeStatuses() error { + for id, status := range s.addedStatuses { + id := id + status := status + + delete(s.addedStatuses, id) + s.statusCache.Put(id, &status) + if err := database.PutUInt32(s.statusDB, id[:], uint32(status)); err != nil { + return fmt.Errorf("failed to add status: %w", err) + } + } + return nil +} diff --git a/vms/avm/states/state_test.go b/vms/avm/states/state_test.go new file mode 100644 index 000000000000..beea732d889e --- /dev/null +++ b/vms/avm/states/state_test.go @@ -0,0 +1,314 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package states + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/version" + "github.com/ava-labs/avalanchego/vms/avm/blocks" + "github.com/ava-labs/avalanchego/vms/avm/fxs" + "github.com/ava-labs/avalanchego/vms/avm/txs" + "github.com/ava-labs/avalanchego/vms/components/avax" + "github.com/ava-labs/avalanchego/vms/secp256k1fx" +) + +var ( + parser blocks.Parser + populatedUTXO *avax.UTXO + populatedUTXOID ids.ID + populatedTx *txs.Tx + populatedTxID ids.ID + populatedBlk blocks.Block + populatedBlkHeight uint64 + populatedBlkID ids.ID +) + +func init() { + var err error + parser, err = blocks.NewParser([]fxs.Fx{ + &secp256k1fx.Fx{}, + }) + if err != nil { + panic(err) + } + + populatedUTXO = &avax.UTXO{ + UTXOID: avax.UTXOID{ + TxID: ids.GenerateTestID(), + }, + Asset: avax.Asset{ + ID: ids.GenerateTestID(), + }, + Out: &secp256k1fx.TransferOutput{ + Amt: 1, + }, + } + populatedUTXOID = populatedUTXO.InputID() + + populatedTx = &txs.Tx{Unsigned: &txs.BaseTx{BaseTx: avax.BaseTx{ + BlockchainID: ids.GenerateTestID(), + }}} + err = parser.InitializeTx(populatedTx) + if err != nil { + panic(err) + } + populatedTxID = populatedTx.ID() + + populatedBlk, err = blocks.NewStandardBlock( + ids.GenerateTestID(), + 1, + time.Now(), + []*txs.Tx{ + { + Unsigned: &txs.BaseTx{BaseTx: avax.BaseTx{ + BlockchainID: ids.GenerateTestID(), + }}, + }, + }, + parser.Codec(), + ) + if err != nil { + panic(err) + } + populatedBlkHeight = populatedBlk.Height() + populatedBlkID = populatedBlk.ID() +} + +type versions struct { + chains map[ids.ID]Chain +} + +func (v *versions) GetState(blkID ids.ID) (Chain, bool) { + c, ok := v.chains[blkID] + return c, ok +} + +func TestState(t *testing.T) { + db := memdb.New() + vdb := versiondb.New(db) + s, err := New(vdb, parser, prometheus.NewRegistry()) + require.NoError(t, err) + + s.AddUTXO(populatedUTXO) + s.AddTx(populatedTx) + s.AddBlock(populatedBlk) + require.NoError(t, s.Commit()) + + s, err = New(vdb, parser, prometheus.NewRegistry()) + require.NoError(t, err) + + ChainUTXOTest(t, s) + ChainTxTest(t, s) + ChainBlockTest(t, s) +} + +func TestDiff(t *testing.T) { + db := memdb.New() + vdb := versiondb.New(db) + s, err := New(vdb, parser, prometheus.NewRegistry()) + require.NoError(t, err) + + s.AddUTXO(populatedUTXO) + s.AddTx(populatedTx) + s.AddBlock(populatedBlk) + require.NoError(t, s.Commit()) + + parentID := ids.GenerateTestID() + d, err := NewDiff(parentID, &versions{ + chains: map[ids.ID]Chain{ + parentID: s, + }, + }) + require.NoError(t, err) + + ChainUTXOTest(t, d) + ChainTxTest(t, d) + ChainBlockTest(t, d) +} + +func ChainUTXOTest(t *testing.T, c Chain) { + require := require.New(t) + + fetchedUTXO, err := c.GetUTXO(populatedUTXOID) + require.NoError(err) + + // Compare IDs because [fetchedUTXO] isn't initialized + require.Equal(populatedUTXO.InputID(), fetchedUTXO.InputID()) + + utxo := &avax.UTXO{ + UTXOID: avax.UTXOID{ + TxID: ids.GenerateTestID(), + }, + Asset: avax.Asset{ + ID: ids.GenerateTestID(), + }, + Out: &secp256k1fx.TransferOutput{ + Amt: 1, + }, + } + utxoID := utxo.InputID() + + _, err = c.GetUTXO(utxoID) + require.ErrorIs(err, database.ErrNotFound) + + c.AddUTXO(utxo) + + fetchedUTXO, err = c.GetUTXO(utxoID) + require.NoError(err) + require.Equal(utxo, fetchedUTXO) + + c.DeleteUTXO(utxoID) + + _, err = c.GetUTXO(utxoID) + require.ErrorIs(err, database.ErrNotFound) +} + +func ChainTxTest(t *testing.T, c Chain) { + require := require.New(t) + + fetchedTx, err := c.GetTx(populatedTxID) + require.NoError(err) + + // Compare IDs because [fetchedTx] differs between nil and empty fields + require.Equal(populatedTx.ID(), fetchedTx.ID()) + + // Pull again for the cached path + fetchedTx, err = c.GetTx(populatedTxID) + require.NoError(err) + require.Equal(populatedTx.ID(), fetchedTx.ID()) + + tx := &txs.Tx{Unsigned: &txs.BaseTx{BaseTx: avax.BaseTx{ + BlockchainID: ids.GenerateTestID(), + }}} + require.NoError(parser.InitializeTx(tx)) + txID := tx.ID() + + _, err = c.GetTx(txID) + require.ErrorIs(err, database.ErrNotFound) + + // Pull again for the cached path + _, err = c.GetTx(txID) + require.ErrorIs(err, database.ErrNotFound) + + c.AddTx(tx) + + fetchedTx, err = c.GetTx(txID) + require.NoError(err) + require.Equal(tx, fetchedTx) +} + +func ChainBlockTest(t *testing.T, c Chain) { + require := require.New(t) + + fetchedBlkID, err := c.GetBlockID(populatedBlkHeight) + require.NoError(err) + require.Equal(populatedBlkID, fetchedBlkID) + + fetchedBlk, err := c.GetBlock(populatedBlkID) + require.NoError(err) + require.Equal(populatedBlk.ID(), fetchedBlk.ID()) + + // Pull again for the cached path + fetchedBlkID, err = c.GetBlockID(populatedBlkHeight) + require.NoError(err) + require.Equal(populatedBlkID, fetchedBlkID) + + fetchedBlk, err = c.GetBlock(populatedBlkID) + require.NoError(err) + require.Equal(populatedBlk.ID(), fetchedBlk.ID()) + + blk, err := blocks.NewStandardBlock( + ids.GenerateTestID(), + 10, + time.Now(), + []*txs.Tx{ + { + Unsigned: &txs.BaseTx{BaseTx: avax.BaseTx{ + BlockchainID: ids.GenerateTestID(), + }}, + }, + }, + parser.Codec(), + ) + if err != nil { + panic(err) + } + blkID := blk.ID() + blkHeight := blk.Height() + + _, err = c.GetBlockID(blkHeight) + require.ErrorIs(err, database.ErrNotFound) + + _, err = c.GetBlock(blkID) + require.ErrorIs(err, database.ErrNotFound) + + // Pull again for the cached path + _, err = c.GetBlockID(blkHeight) + require.ErrorIs(err, database.ErrNotFound) + + _, err = c.GetBlock(blkID) + require.ErrorIs(err, database.ErrNotFound) + + c.AddBlock(blk) + + fetchedBlkID, err = c.GetBlockID(blkHeight) + require.NoError(err) + require.Equal(blkID, fetchedBlkID) + + fetchedBlk, err = c.GetBlock(blkID) + require.NoError(err) + require.Equal(blk, fetchedBlk) +} + +func TestInitializeChainState(t *testing.T) { + require := require.New(t) + + db := memdb.New() + vdb := versiondb.New(db) + s, err := New(vdb, parser, prometheus.NewRegistry()) + require.NoError(err) + + stopVertexID := ids.GenerateTestID() + genesisTimestamp := version.XChainMigrationDefaultTime + err = s.InitializeChainState(stopVertexID, genesisTimestamp) + require.NoError(err) + + lastAcceptedID := s.GetLastAccepted() + genesis, err := s.GetBlock(lastAcceptedID) + require.NoError(err) + require.Equal(stopVertexID, genesis.Parent()) + require.Equal(genesisTimestamp.UnixNano(), genesis.Timestamp().UnixNano()) + + childBlock, err := blocks.NewStandardBlock( + genesis.ID(), + genesis.Height()+1, + genesisTimestamp, + nil, + parser.Codec(), + ) + require.NoError(err) + + s.AddBlock(childBlock) + s.SetLastAccepted(childBlock.ID()) + err = s.Commit() + require.NoError(err) + + err = s.InitializeChainState(stopVertexID, genesisTimestamp) + require.NoError(err) + + lastAcceptedID = s.GetLastAccepted() + lastAccepted, err := s.GetBlock(lastAcceptedID) + require.NoError(err) + require.Equal(genesis.ID(), lastAccepted.Parent()) +} diff --git a/vms/avm/states/tx_state.go b/vms/avm/states/tx_state.go deleted file mode 100644 index 5b830221133d..000000000000 --- a/vms/avm/states/tx_state.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package states - -import ( - "github.com/prometheus/client_golang/prometheus" - - "github.com/ava-labs/avalanchego/cache" - "github.com/ava-labs/avalanchego/cache/metercacher" - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/vms/avm/txs" -) - -const txCacheSize = 8192 - -var _ TxState = (*txState)(nil) - -// TxState is a thin wrapper around a database to provide, caching, -// serialization, and de-serialization of transactions. -type TxState interface { - // Tx attempts to load a transaction from storage. - GetTx(txID ids.ID) (*txs.Tx, error) - - // PutTx saves the provided transaction to storage. - PutTx(txID ids.ID, tx *txs.Tx) error - - // DeleteTx removes the provided transaction from storage. - DeleteTx(txID ids.ID) error -} - -type txState struct { - parser txs.Parser - - // Caches TxID -> *Tx. If the *Tx is nil, that means the tx is not in - // storage. - txCache cache.Cacher - txDB database.Database -} - -func NewTxState(db database.Database, parser txs.Parser, metrics prometheus.Registerer) (TxState, error) { - cache, err := metercacher.New( - "tx_cache", - metrics, - &cache.LRU{Size: txCacheSize}, - ) - return &txState{ - parser: parser, - - txCache: cache, - txDB: db, - }, err -} - -func (s *txState) GetTx(txID ids.ID) (*txs.Tx, error) { - if txIntf, found := s.txCache.Get(txID); found { - if txIntf == nil { - return nil, database.ErrNotFound - } - return txIntf.(*txs.Tx), nil - } - - txBytes, err := s.txDB.Get(txID[:]) - if err == database.ErrNotFound { - s.txCache.Put(txID, nil) - return nil, database.ErrNotFound - } - if err != nil { - return nil, err - } - - // The key was in the database - tx, err := s.parser.ParseGenesisTx(txBytes) - if err != nil { - return nil, err - } - - s.txCache.Put(txID, tx) - return tx, nil -} - -func (s *txState) PutTx(txID ids.ID, tx *txs.Tx) error { - s.txCache.Put(txID, tx) - return s.txDB.Put(txID[:], tx.Bytes()) -} - -func (s *txState) DeleteTx(txID ids.ID) error { - s.txCache.Put(txID, nil) - return s.txDB.Delete(txID[:]) -} diff --git a/vms/avm/states/tx_state_test.go b/vms/avm/states/tx_state_test.go deleted file mode 100644 index d6d072690830..000000000000 --- a/vms/avm/states/tx_state_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package states - -import ( - "testing" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/database/memdb" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils/crypto" - "github.com/ava-labs/avalanchego/utils/units" - "github.com/ava-labs/avalanchego/vms/avm/fxs" - "github.com/ava-labs/avalanchego/vms/avm/txs" - "github.com/ava-labs/avalanchego/vms/components/avax" - "github.com/ava-labs/avalanchego/vms/nftfx" - "github.com/ava-labs/avalanchego/vms/propertyfx" - "github.com/ava-labs/avalanchego/vms/secp256k1fx" -) - -var ( - networkID uint32 = 10 - chainID = ids.ID{5, 4, 3, 2, 1} - assetID = ids.ID{1, 2, 3} - keys = crypto.BuildTestKeys() -) - -func TestTxState(t *testing.T) { - require := require.New(t) - - db := memdb.New() - parser, err := txs.NewParser([]fxs.Fx{ - &secp256k1fx.Fx{}, - &nftfx.Fx{}, - &propertyfx.Fx{}, - }) - require.NoError(err) - - stateIntf, err := NewTxState(db, parser, prometheus.NewRegistry()) - require.NoError(err) - - s := stateIntf.(*txState) - - _, err = s.GetTx(ids.Empty) - require.Equal(database.ErrNotFound, err) - - tx := &txs.Tx{ - Unsigned: &txs.BaseTx{ - BaseTx: avax.BaseTx{ - NetworkID: networkID, - BlockchainID: chainID, - Ins: []*avax.TransferableInput{{ - UTXOID: avax.UTXOID{ - TxID: ids.Empty, - OutputIndex: 0, - }, - Asset: avax.Asset{ID: assetID}, - In: &secp256k1fx.TransferInput{ - Amt: 20 * units.KiloAvax, - Input: secp256k1fx.Input{ - SigIndices: []uint32{ - 0, - }, - }, - }, - }}, - }, - }, - } - - err = tx.SignSECP256K1Fx(parser.Codec(), [][]*crypto.PrivateKeySECP256K1R{{keys[0]}}) - require.NoError(err) - - err = s.PutTx(ids.Empty, tx) - require.NoError(err) - - loadedTx, err := s.GetTx(ids.Empty) - require.NoError(err) - require.Equal(tx.ID(), loadedTx.ID()) - - s.txCache.Flush() - - loadedTx, err = s.GetTx(ids.Empty) - require.NoError(err) - require.Equal(tx.ID(), loadedTx.ID()) - - err = s.DeleteTx(ids.Empty) - require.NoError(err) - - _, err = s.GetTx(ids.Empty) - require.Equal(database.ErrNotFound, err) - - s.txCache.Flush() - - _, err = s.GetTx(ids.Empty) - require.Equal(database.ErrNotFound, err) -} diff --git a/vms/avm/states/versions.go b/vms/avm/states/versions.go new file mode 100644 index 000000000000..19461b7b7b17 --- /dev/null +++ b/vms/avm/states/versions.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package states + +import ( + "github.com/ava-labs/avalanchego/ids" +) + +type Versions interface { + // GetState returns the state of the chain after [blkID] has been accepted. + // If the state is not known, `false` will be returned. + GetState(blkID ids.ID) (Chain, bool) +} diff --git a/vms/avm/unique_tx.go b/vms/avm/unique_tx.go index 9ea6db54fb3f..cc44b7ddcf04 100644 --- a/vms/avm/unique_tx.go +++ b/vms/avm/unique_tx.go @@ -27,8 +27,8 @@ var ( ) var ( - _ snowstorm.Tx = (*UniqueTx)(nil) - _ cache.Evictable = (*UniqueTx)(nil) + _ snowstorm.Tx = (*UniqueTx)(nil) + _ cache.Evictable[ids.ID] = (*UniqueTx)(nil) ) // UniqueTx provides a de-duplication service for txs. This only provides a @@ -104,13 +104,12 @@ func (tx *UniqueTx) Evict() { tx.deps = nil } -func (tx *UniqueTx) setStatus(status choices.Status) error { +func (tx *UniqueTx) setStatus(status choices.Status) { tx.refresh() - if tx.status == status { - return nil + if tx.status != status { + tx.status = status + tx.vm.state.AddStatus(tx.ID(), status) } - tx.status = status - return tx.vm.state.PutStatus(tx.ID(), status) } // ID returns the wrapped txID @@ -118,7 +117,7 @@ func (tx *UniqueTx) ID() ids.ID { return tx.txID } -func (tx *UniqueTx) Key() interface{} { +func (tx *UniqueTx) Key() ids.ID { return tx.txID } @@ -129,7 +128,6 @@ func (tx *UniqueTx) Accept(context.Context) error { } txID := tx.ID() - defer tx.vm.db.Abort() // Fetch the input UTXOs inputUTXOIDs := tx.InputUTXOs() @@ -162,26 +160,20 @@ func (tx *UniqueTx) Accept(context.Context) error { continue } utxoID := utxo.InputID() - if err := tx.vm.state.DeleteUTXO(utxoID); err != nil { - return fmt.Errorf("couldn't delete UTXO %s: %w", utxoID, err) - } + tx.vm.state.DeleteUTXO(utxoID) } // Add new utxos for _, utxo := range outputUTXOs { - if err := tx.vm.state.PutUTXO(utxo); err != nil { - return fmt.Errorf("couldn't put UTXO %s: %w", utxo.InputID(), err) - } + tx.vm.state.AddUTXO(utxo) } + tx.setStatus(choices.Accepted) - if err := tx.setStatus(choices.Accepted); err != nil { - return fmt.Errorf("couldn't set status of tx %s: %w", txID, err) - } - - commitBatch, err := tx.vm.db.CommitBatch() + commitBatch, err := tx.vm.state.CommitBatch() if err != nil { return fmt.Errorf("couldn't create commitBatch while processing tx %s: %w", txID, err) } + defer tx.vm.state.Abort() err = tx.Tx.Unsigned.Visit(&executeTx{ tx: tx.Tx, batch: commitBatch, @@ -201,22 +193,14 @@ func (tx *UniqueTx) Accept(context.Context) error { // Reject is called when the transaction was finalized as rejected by consensus func (tx *UniqueTx) Reject(context.Context) error { - defer tx.vm.db.Abort() - - if err := tx.setStatus(choices.Rejected); err != nil { - tx.vm.ctx.Log.Error("failed to reject tx", - zap.Stringer("txID", tx.txID), - zap.Error(err), - ) - return err - } + tx.setStatus(choices.Rejected) txID := tx.ID() tx.vm.ctx.Log.Debug("rejecting tx", zap.Stringer("txID", txID), ) - if err := tx.vm.db.Commit(); err != nil { + if err := tx.vm.state.Commit(); err != nil { tx.vm.ctx.Log.Error("failed to commit reject", zap.Stringer("txID", tx.txID), zap.Error(err), @@ -227,7 +211,6 @@ func (tx *UniqueTx) Reject(context.Context) error { tx.vm.walletService.decided(txID) tx.deps = nil // Needed to prevent a memory leak - return nil } diff --git a/vms/avm/vm.go b/vms/avm/vm.go index ad7723cf7160..b2411ddb6197 100644 --- a/vms/avm/vm.go +++ b/vms/avm/vm.go @@ -36,7 +36,9 @@ import ( "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/timer" "github.com/ava-labs/avalanchego/utils/timer/mockable" + "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/version" + "github.com/ava-labs/avalanchego/vms/avm/blocks" "github.com/ava-labs/avalanchego/vms/avm/states" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -81,7 +83,7 @@ type VM struct { // Used to check local time clock mockable.Clock - parser txs.Parser + parser blocks.Parser pubsub *pubsub.Server @@ -95,7 +97,7 @@ type VM struct { feeAssetID ids.ID // Asset ID --> Bit set with fx IDs the asset supports - assetToFxCache *cache.LRU + assetToFxCache *cache.LRU[ids.ID, set.Bits64] // Transaction issuing timer *timer.Timer @@ -113,7 +115,7 @@ type VM struct { addressTxsIndexer index.AddressTxsIndexer - uniqueTxs cache.Deduplicator + uniqueTxs cache.Deduplicator[ids.ID, *UniqueTx] } func (*VM) Connected(context.Context, ids.NodeID, *version.Application) error { @@ -173,7 +175,7 @@ func (vm *VM) Initialize( vm.toEngine = toEngine vm.baseDB = db vm.db = versiondb.New(db) - vm.assetToFxCache = &cache.LRU{Size: assetToFxCacheSize} + vm.assetToFxCache = &cache.LRU[ids.ID, set.Bits64]{Size: assetToFxCacheSize} vm.pubsub = pubsub.New(ctx.Log) @@ -195,7 +197,7 @@ func (vm *VM) Initialize( } vm.typeToFxIndex = map[reflect.Type]int{} - vm.parser, err = txs.NewCustomParser( + vm.parser, err = blocks.NewCustomParser( vm.typeToFxIndex, &vm.clock, ctx.Log, @@ -227,7 +229,7 @@ func (vm *VM) Initialize( go ctx.Log.RecoverAndPanic(vm.timer.Dispatch) vm.batchTimeout = batchTimeout - vm.uniqueTxs = &cache.EvictableLRU{ + vm.uniqueTxs = &cache.EvictableLRU[ids.ID, *UniqueTx]{ Size: txDeduplicatorSize, } vm.walletService.vm = vm @@ -248,7 +250,7 @@ func (vm *VM) Initialize( return fmt.Errorf("failed to initialize disabled indexer: %w", err) } } - return vm.db.Commit() + return vm.state.Commit() } // onBootstrapStarted is called by the consensus engine when it starts bootstrapping this chain @@ -293,7 +295,12 @@ func (vm *VM) Shutdown(context.Context) error { vm.timer.Stop() vm.ctx.Lock.Lock() - return vm.baseDB.Close() + errs := wrappers.Errs{} + errs.Add( + vm.state.Close(), + vm.baseDB.Close(), + ) + return errs.Err } func (*VM) Version(context.Context) (string, error) { @@ -373,8 +380,9 @@ func (*VM) LastAccepted(context.Context) (ids.ID, error) { ****************************************************************************** */ -func (*VM) Linearize(context.Context, ids.ID) error { - return errUnimplemented +func (vm *VM) Linearize(_ context.Context, stopVertexID ids.ID) error { + time := version.GetXChainMigrationTime(vm.ctx.NetworkID) + return vm.state.InitializeChainState(stopVertexID, time) } func (vm *VM) PendingTxs(context.Context) []snowstorm.Tx { @@ -478,10 +486,10 @@ func (vm *VM) initGenesis(genesisBytes []byte) error { return errGenesisAssetMustHaveState } - tx := txs.Tx{ + tx := &txs.Tx{ Unsigned: &genesisTx.CreateAssetTx, } - if err := vm.parser.InitializeGenesisTx(&tx); err != nil { + if err := vm.parser.InitializeGenesisTx(tx); err != nil { return err } @@ -491,9 +499,7 @@ func (vm *VM) initGenesis(genesisBytes []byte) error { } if !stateInitialized { - if err := vm.initState(tx); err != nil { - return err - } + vm.initState(tx) } if index == 0 { vm.ctx.Log.Info("fee asset is established", @@ -511,23 +517,16 @@ func (vm *VM) initGenesis(genesisBytes []byte) error { return nil } -func (vm *VM) initState(tx txs.Tx) error { +func (vm *VM) initState(tx *txs.Tx) { txID := tx.ID() vm.ctx.Log.Info("initializing genesis asset", zap.Stringer("txID", txID), ) - if err := vm.state.PutTx(txID, &tx); err != nil { - return err - } - if err := vm.state.PutStatus(txID, choices.Accepted); err != nil { - return err - } + vm.state.AddTx(tx) + vm.state.AddStatus(txID, choices.Accepted) for _, utxo := range tx.UTXOs() { - if err := vm.state.PutUTXO(utxo); err != nil { - return err - } + vm.state.AddUTXO(utxo) } - return nil } func (vm *VM) parseTx(bytes []byte) (*UniqueTx, error) { @@ -548,13 +547,9 @@ func (vm *VM) parseTx(bytes []byte) (*UniqueTx, error) { } if tx.Status() == choices.Unknown { - if err := vm.state.PutTx(tx.ID(), tx.Tx); err != nil { - return nil, err - } - if err := tx.setStatus(choices.Processing); err != nil { - return nil, err - } - return tx, vm.db.Commit() + vm.state.AddTx(tx.Tx) + tx.setStatus(choices.Processing) + return tx, vm.state.Commit() } return tx, nil @@ -607,9 +602,8 @@ func (vm *VM) getFx(val interface{}) (int, error) { func (vm *VM) verifyFxUsage(fxID int, assetID ids.ID) bool { // Check cache to see whether this asset supports this fx - fxIDsIntf, assetInCache := vm.assetToFxCache.Get(assetID) - if assetInCache { - return fxIDsIntf.(set.Bits64).Contains(uint(fxID)) + if fxIDs, ok := vm.assetToFxCache.Get(assetID); ok { + return fxIDs.Contains(uint(fxID)) } // Caches doesn't say whether this asset support this fx. // Get the tx that created the asset and check. @@ -1121,5 +1115,5 @@ func (*VM) AppGossip(context.Context, ids.NodeID, []byte) error { // UniqueTx de-duplicates the transaction. func (vm *VM) DeduplicateTx(tx *UniqueTx) *UniqueTx { - return vm.uniqueTxs.Deduplicate(tx).(*UniqueTx) + return vm.uniqueTxs.Deduplicate(tx) } diff --git a/vms/avm/vm_benchmark_test.go b/vms/avm/vm_benchmark_test.go index 3fb1fa7b8255..cf1170f940e5 100644 --- a/vms/avm/vm_benchmark_test.go +++ b/vms/avm/vm_benchmark_test.go @@ -9,6 +9,8 @@ import ( "math/rand" "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -96,31 +98,20 @@ func GetAllUTXOsBenchmark(b *testing.B, utxoCount int) { }, } - if err := vm.state.PutUTXO(utxo); err != nil { - b.Fatal(err) - } + vm.state.AddUTXO(utxo) } + require.NoError(b, vm.state.Commit()) addrsSet := set.Set[ids.ShortID]{} addrsSet.Add(addr) - var ( - err error - notPaginatedUTXOs []*avax.UTXO - ) - b.ResetTimer() for i := 0; i < b.N; i++ { // Fetch all UTXOs older version - notPaginatedUTXOs, err = avax.GetAllUTXOs(vm.state, addrsSet) - if err != nil { - b.Fatal(err) - } - - if len(notPaginatedUTXOs) != utxoCount { - b.Fatalf("Wrong number of utxos. Expected (%d) returned (%d)", utxoCount, len(notPaginatedUTXOs)) - } + notPaginatedUTXOs, err := avax.GetAllUTXOs(vm.state, addrsSet) + require.NoError(b, err) + require.Len(b, notPaginatedUTXOs, utxoCount) } } diff --git a/vms/avm/vm_test.go b/vms/avm/vm_test.go index e79e26e1f498..0ffdc034c345 100644 --- a/vms/avm/vm_test.go +++ b/vms/avm/vm_test.go @@ -19,9 +19,11 @@ import ( "github.com/ava-labs/avalanchego/api/keystore" "github.com/ava-labs/avalanchego/chains/atomic" + "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/manager" - "github.com/ava-labs/avalanchego/database/mockdb" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/database/prefixdb" + "github.com/ava-labs/avalanchego/database/versiondb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/engine/common" @@ -61,7 +63,6 @@ var ( otherAssetName = "OTHER" errMissing = errors.New("missing") - errTest = errors.New("non-nil error") ) func init() { @@ -1114,24 +1115,22 @@ func TestTxCached(t *testing.T) { _, err := vm.ParseTx(context.Background(), txBytes) require.NoError(t, err) - db := mockdb.New() - called := new(bool) - db.OnGet = func([]byte) ([]byte, error) { - *called = true - return nil, errTest - } - registerer := prometheus.NewRegistry() err = vm.metrics.Initialize("", registerer) require.NoError(t, err) - vm.state, err = states.New(prefixdb.New([]byte("tx"), db), vm.parser, registerer) + db := memdb.New() + vdb := versiondb.New(db) + vm.state, err = states.New(vdb, vm.parser, registerer) require.NoError(t, err) _, err = vm.ParseTx(context.Background(), txBytes) require.NoError(t, err) - require.False(t, *called, "shouldn't have called the DB") + + count, err := database.Count(vdb) + require.NoError(t, err) + require.Zero(t, count) } func TestTxNotCached(t *testing.T) { @@ -1150,30 +1149,25 @@ func TestTxNotCached(t *testing.T) { _, err := vm.ParseTx(context.Background(), txBytes) require.NoError(t, err) - db := mockdb.New() - called := new(bool) - db.OnGet = func([]byte) ([]byte, error) { - *called = true - return nil, errTest - } - db.OnPut = func([]byte, []byte) error { - return nil - } - registerer := prometheus.NewRegistry() require.NoError(t, err) err = vm.metrics.Initialize("", registerer) require.NoError(t, err) - vm.state, err = states.New(db, vm.parser, registerer) + db := memdb.New() + vdb := versiondb.New(db) + vm.state, err = states.New(vdb, vm.parser, registerer) require.NoError(t, err) vm.uniqueTxs.Flush() _, err = vm.ParseTx(context.Background(), txBytes) require.NoError(t, err) - require.True(t, *called, "should have called the DB") + + count, err := database.Count(vdb) + require.NoError(t, err) + require.NotZero(t, count) } func TestTxVerifyAfterIssueTx(t *testing.T) { diff --git a/vms/components/avax/singleton_state.go b/vms/components/avax/singleton_state.go deleted file mode 100644 index 62d069c95d63..000000000000 --- a/vms/components/avax/singleton_state.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package avax - -import ( - "github.com/ava-labs/avalanchego/database" -) - -const ( - IsInitializedKey byte = iota -) - -var ( - isInitializedKey = []byte{IsInitializedKey} - _ SingletonState = (*singletonState)(nil) -) - -// SingletonState is a thin wrapper around a database to provide, caching, -// serialization, and de-serialization of singletons. -type SingletonState interface { - IsInitialized() (bool, error) - SetInitialized() error -} - -type singletonState struct { - singletonDB database.Database -} - -func NewSingletonState(db database.Database) SingletonState { - return &singletonState{ - singletonDB: db, - } -} - -func (s *singletonState) IsInitialized() (bool, error) { - return s.singletonDB.Has(isInitializedKey) -} - -func (s *singletonState) SetInitialized() error { - return s.singletonDB.Put(isInitializedKey, nil) -} diff --git a/vms/components/avax/singleton_state_test.go b/vms/components/avax/singleton_state_test.go deleted file mode 100644 index 5b8e97dccd88..000000000000 --- a/vms/components/avax/singleton_state_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package avax - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/database/memdb" -) - -func TestSingletonState(t *testing.T) { - require := require.New(t) - - db := memdb.New() - s := NewSingletonState(db) - - isInitialized, err := s.IsInitialized() - require.NoError(err) - require.False(isInitialized) - - err = s.SetInitialized() - require.NoError(err) - - isInitialized, err = s.IsInitialized() - require.NoError(err) - require.True(isInitialized) -} diff --git a/vms/components/avax/status_state.go b/vms/components/avax/status_state.go deleted file mode 100644 index acbc8474a26b..000000000000 --- a/vms/components/avax/status_state.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package avax - -import ( - "github.com/prometheus/client_golang/prometheus" - - "github.com/ava-labs/avalanchego/cache" - "github.com/ava-labs/avalanchego/cache/metercacher" - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/choices" -) - -const ( - statusCacheSize = 8192 -) - -// StatusState is a thin wrapper around a database to provide, caching, -// serialization, and de-serialization for statuses. -type StatusState interface { - // Status returns a status from storage. - GetStatus(id ids.ID) (choices.Status, error) - - // PutStatus saves a status in storage. - PutStatus(id ids.ID, status choices.Status) error - - // DeleteStatus removes a status from storage. - DeleteStatus(id ids.ID) error -} - -type statusState struct { - // ID -> Status of thing with that ID, or nil if StatusState doesn't have - // that status. - statusCache cache.Cacher - statusDB database.Database -} - -func NewStatusState(db database.Database) StatusState { - return &statusState{ - statusCache: &cache.LRU{Size: statusCacheSize}, - statusDB: db, - } -} - -func NewMeteredStatusState(db database.Database, metrics prometheus.Registerer) (StatusState, error) { - cache, err := metercacher.New( - "status_cache", - metrics, - &cache.LRU{Size: statusCacheSize}, - ) - return &statusState{ - statusCache: cache, - statusDB: db, - }, err -} - -func (s *statusState) GetStatus(id ids.ID) (choices.Status, error) { - if statusIntf, found := s.statusCache.Get(id); found { - if statusIntf == nil { - return choices.Unknown, database.ErrNotFound - } - return statusIntf.(choices.Status), nil - } - - val, err := database.GetUInt32(s.statusDB, id[:]) - if err == database.ErrNotFound { - s.statusCache.Put(id, nil) - return choices.Unknown, database.ErrNotFound - } - if err != nil { - return choices.Unknown, err - } - - status := choices.Status(val) - if err := status.Valid(); err != nil { - return choices.Unknown, err - } - - s.statusCache.Put(id, status) - return status, nil -} - -func (s *statusState) PutStatus(id ids.ID, status choices.Status) error { - s.statusCache.Put(id, status) - return database.PutUInt32(s.statusDB, id[:], uint32(status)) -} - -func (s *statusState) DeleteStatus(id ids.ID) error { - s.statusCache.Put(id, nil) - return s.statusDB.Delete(id[:]) -} diff --git a/vms/components/avax/status_state_test.go b/vms/components/avax/status_state_test.go deleted file mode 100644 index 955cb8372b64..000000000000 --- a/vms/components/avax/status_state_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2019-2022, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package avax - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/database/memdb" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/choices" -) - -func TestStatusState(t *testing.T) { - require := require.New(t) - id0 := ids.GenerateTestID() - - db := memdb.New() - s := NewStatusState(db) - - _, err := s.GetStatus(id0) - require.Equal(database.ErrNotFound, err) - - _, err = s.GetStatus(id0) - require.Equal(database.ErrNotFound, err) - - err = s.PutStatus(id0, choices.Accepted) - require.NoError(err) - - status, err := s.GetStatus(id0) - require.NoError(err) - require.Equal(choices.Accepted, status) - - err = s.DeleteStatus(id0) - require.NoError(err) - - _, err = s.GetStatus(id0) - require.Equal(database.ErrNotFound, err) - - err = s.PutStatus(id0, choices.Accepted) - require.NoError(err) - - s = NewStatusState(db) - - status, err = s.GetStatus(id0) - require.NoError(err) - require.Equal(choices.Accepted, status) -} diff --git a/vms/components/avax/utxo_state.go b/vms/components/avax/utxo_state.go index f8a7aed1cdce..560d02530b43 100644 --- a/vms/components/avax/utxo_state.go +++ b/vms/components/avax/utxo_state.go @@ -63,39 +63,39 @@ type utxoState struct { codec codec.Manager // UTXO ID -> *UTXO. If the *UTXO is nil the UTXO doesn't exist - utxoCache cache.Cacher + utxoCache cache.Cacher[ids.ID, *UTXO] utxoDB database.Database indexDB database.Database - indexCache cache.Cacher + indexCache cache.Cacher[string, linkeddb.LinkedDB] } func NewUTXOState(db database.Database, codec codec.Manager) UTXOState { return &utxoState{ codec: codec, - utxoCache: &cache.LRU{Size: utxoCacheSize}, + utxoCache: &cache.LRU[ids.ID, *UTXO]{Size: utxoCacheSize}, utxoDB: prefixdb.New(utxoPrefix, db), indexDB: prefixdb.New(indexPrefix, db), - indexCache: &cache.LRU{Size: indexCacheSize}, + indexCache: &cache.LRU[string, linkeddb.LinkedDB]{Size: indexCacheSize}, } } func NewMeteredUTXOState(db database.Database, codec codec.Manager, metrics prometheus.Registerer) (UTXOState, error) { - utxoCache, err := metercacher.New( + utxoCache, err := metercacher.New[ids.ID, *UTXO]( "utxo_cache", metrics, - &cache.LRU{Size: utxoCacheSize}, + &cache.LRU[ids.ID, *UTXO]{Size: utxoCacheSize}, ) if err != nil { return nil, err } - indexCache, err := metercacher.New( + indexCache, err := metercacher.New[string, linkeddb.LinkedDB]( "index_cache", metrics, - &cache.LRU{ + &cache.LRU[string, linkeddb.LinkedDB]{ Size: indexCacheSize, }, ) @@ -111,11 +111,11 @@ func NewMeteredUTXOState(db database.Database, codec codec.Manager, metrics prom } func (s *utxoState) GetUTXO(utxoID ids.ID) (*UTXO, error) { - if utxoIntf, found := s.utxoCache.Get(utxoID); found { - if utxoIntf == nil { + if utxo, found := s.utxoCache.Get(utxoID); found { + if utxo == nil { return nil, database.ErrNotFound } - return utxoIntf.(*UTXO), nil + return utxo, nil } bytes, err := s.utxoDB.Get(utxoID[:]) @@ -214,7 +214,7 @@ func (s *utxoState) UTXOIDs(addr []byte, start ids.ID, limit int) ([]ids.ID, err func (s *utxoState) getIndexDB(addr []byte) linkeddb.LinkedDB { addrStr := string(addr) if indexList, exists := s.indexCache.Get(addrStr); exists { - return indexList.(linkeddb.LinkedDB) + return indexList } indexDB := prefixdb.NewNested(addr, s.indexDB) diff --git a/vms/components/chain/state.go b/vms/components/chain/state.go index 7a24ca803964..9fecc567a812 100644 --- a/vms/components/chain/state.go +++ b/vms/components/chain/state.go @@ -43,17 +43,14 @@ type State struct { // therefore currently in consensus. verifiedBlocks map[ids.ID]*BlockWrapper // decidedBlocks is an LRU cache of decided blocks. - // Every value in [decidedBlocks] is a (*BlockWrapper) - decidedBlocks cache.Cacher + decidedBlocks cache.Cacher[ids.ID, *BlockWrapper] // unverifiedBlocks is an LRU cache of blocks with status processing // that have not yet passed verification. - // Every value in [unverifiedBlocks] is a (*BlockWrapper) - unverifiedBlocks cache.Cacher + unverifiedBlocks cache.Cacher[ids.ID, *BlockWrapper] // missingBlocks is an LRU cache of missing blocks - // Every value in [missingBlocks] is an empty struct. - missingBlocks cache.Cacher + missingBlocks cache.Cacher[ids.ID, struct{}] // string([byte repr. of block]) --> the block's ID - bytesToIDCache cache.Cacher + bytesToIDCache cache.Cacher[string, ids.ID] lastAcceptedBlock *BlockWrapper } @@ -141,10 +138,10 @@ func (s *State) initialize(config *Config) { func NewState(config *Config) *State { c := &State{ verifiedBlocks: make(map[ids.ID]*BlockWrapper), - decidedBlocks: &cache.LRU{Size: config.DecidedCacheSize}, - missingBlocks: &cache.LRU{Size: config.MissingCacheSize}, - unverifiedBlocks: &cache.LRU{Size: config.UnverifiedCacheSize}, - bytesToIDCache: &cache.LRU{Size: config.BytesToIDCacheSize}, + decidedBlocks: &cache.LRU[ids.ID, *BlockWrapper]{Size: config.DecidedCacheSize}, + missingBlocks: &cache.LRU[ids.ID, struct{}]{Size: config.MissingCacheSize}, + unverifiedBlocks: &cache.LRU[ids.ID, *BlockWrapper]{Size: config.UnverifiedCacheSize}, + bytesToIDCache: &cache.LRU[string, ids.ID]{Size: config.BytesToIDCacheSize}, } c.initialize(config) return c @@ -154,34 +151,34 @@ func NewMeteredState( registerer prometheus.Registerer, config *Config, ) (*State, error) { - decidedCache, err := metercacher.New( + decidedCache, err := metercacher.New[ids.ID, *BlockWrapper]( "decided_cache", registerer, - &cache.LRU{Size: config.DecidedCacheSize}, + &cache.LRU[ids.ID, *BlockWrapper]{Size: config.DecidedCacheSize}, ) if err != nil { return nil, err } - missingCache, err := metercacher.New( + missingCache, err := metercacher.New[ids.ID, struct{}]( "missing_cache", registerer, - &cache.LRU{Size: config.MissingCacheSize}, + &cache.LRU[ids.ID, struct{}]{Size: config.MissingCacheSize}, ) if err != nil { return nil, err } - unverifiedCache, err := metercacher.New( + unverifiedCache, err := metercacher.New[ids.ID, *BlockWrapper]( "unverified_cache", registerer, - &cache.LRU{Size: config.UnverifiedCacheSize}, + &cache.LRU[ids.ID, *BlockWrapper]{Size: config.UnverifiedCacheSize}, ) if err != nil { return nil, err } - bytesToIDCache, err := metercacher.New( + bytesToIDCache, err := metercacher.New[string, ids.ID]( "bytes_to_id_cache", registerer, - &cache.LRU{Size: config.BytesToIDCacheSize}, + &cache.LRU[string, ids.ID]{Size: config.BytesToIDCacheSize}, ) if err != nil { return nil, err @@ -265,11 +262,11 @@ func (s *State) getCachedBlock(blkID ids.ID) (snowman.Block, bool) { } if blk, ok := s.decidedBlocks.Get(blkID); ok { - return blk.(snowman.Block), true + return blk, true } if blk, ok := s.unverifiedBlocks.Get(blkID); ok { - return blk.(snowman.Block), true + return blk, true } return nil, false @@ -289,11 +286,10 @@ func (s *State) GetBlockInternal(ctx context.Context, blkID ids.ID) (snowman.Blo // appropriate caching layer if successful. func (s *State) ParseBlock(ctx context.Context, b []byte) (snowman.Block, error) { // See if we've cached this block's ID by its byte repr. - blkIDIntf, blkIDCached := s.bytesToIDCache.Get(string(b)) + cachedBlkID, blkIDCached := s.bytesToIDCache.Get(string(b)) if blkIDCached { - blkID := blkIDIntf.(ids.ID) // See if we have this block cached - if cachedBlk, ok := s.getCachedBlock(blkID); ok { + if cachedBlk, ok := s.getCachedBlock(cachedBlkID); ok { return cachedBlk, nil } } @@ -334,14 +330,13 @@ func (s *State) BatchedParseBlock(ctx context.Context, blksBytes [][]byte) ([]sn unparsedBlksBytes := make([][]byte, 0, len(blksBytes)) for i, blkBytes := range blksBytes { // See if we've cached this block's ID by its byte repr. - blkIDIntf, blkIDCached := s.bytesToIDCache.Get(string(blkBytes)) + blkID, blkIDCached := s.bytesToIDCache.Get(string(blkBytes)) idWasCached[i] = blkIDCached if !blkIDCached { unparsedBlksBytes = append(unparsedBlksBytes, blkBytes) continue } - blkID := blkIDIntf.(ids.ID) // See if we have this block cached if cachedBlk, ok := s.getCachedBlock(blkID); ok { blks[i] = cachedBlk diff --git a/vms/components/state/state.go b/vms/components/state/state.go index efccc0eac7c0..22c3afb3e9c9 100644 --- a/vms/components/state/state.go +++ b/vms/components/state/state.go @@ -90,7 +90,7 @@ type state struct { // Keys: Type ID // Values: Cache that stores uniqueIDs for values that were put with that type ID // (Saves us from having to re-compute uniqueIDs) - uniqueIDCaches map[uint64]*cache.LRU + uniqueIDCaches map[uint64]*cache.LRU[ids.ID, ids.ID] } func (s *state) RegisterType( @@ -211,11 +211,11 @@ func (s *state) GetTime(db database.Database, key ids.ID) (time.Time, error) { func (s *state) uniqueID(id ids.ID, typeID uint64) ids.ID { uIDCache, cacheExists := s.uniqueIDCaches[typeID] if cacheExists { - if uID, uIDExists := uIDCache.Get(id); uIDExists { // Get the uniqueID associated with [typeID] and [ID] - return uID.(ids.ID) + if uID, exists := uIDCache.Get(id); exists { // Get the uniqueID associated with [typeID] and [ID] + return uID } } else { - s.uniqueIDCaches[typeID] = &cache.LRU{Size: cacheSize} + s.uniqueIDCaches[typeID] = &cache.LRU[ids.ID, ids.ID]{Size: cacheSize} } uID := id.Prefix(typeID) s.uniqueIDCaches[typeID].Put(id, uID) @@ -227,7 +227,7 @@ func NewState() (State, error) { state := &state{ marshallers: make(map[uint64]func(interface{}) ([]byte, error)), unmarshallers: make(map[uint64]func([]byte) (interface{}, error)), - uniqueIDCaches: make(map[uint64]*cache.LRU), + uniqueIDCaches: make(map[uint64]*cache.LRU[ids.ID, ids.ID]), } // Register ID, Status and time.Time so they can be put/get without client code diff --git a/vms/platformvm/blocks/builder/builder.go b/vms/platformvm/blocks/builder/builder.go index 7c55f5c0d497..fec8180de0e5 100644 --- a/vms/platformvm/blocks/builder/builder.go +++ b/vms/platformvm/blocks/builder/builder.go @@ -273,7 +273,7 @@ func (b *builder) setNextBuildBlockTime() { ctx.Lock.Lock() defer ctx.Lock.Unlock() - if !b.txExecutorBackend.Bootstrapped.GetValue() { + if !b.txExecutorBackend.Bootstrapped.Get() { ctx.Log.Verbo("skipping block timer reset", zap.String("reason", "not bootstrapped"), ) diff --git a/vms/platformvm/blocks/builder/builder_test.go b/vms/platformvm/blocks/builder/builder_test.go index 5b14dcf50367..df6762a51c79 100644 --- a/vms/platformvm/blocks/builder/builder_test.go +++ b/vms/platformvm/blocks/builder/builder_test.go @@ -106,7 +106,7 @@ func TestPreviouslyDroppedTxsCanBeReAddedToMempool(t *testing.T) { func TestNoErrorOnUnexpectedSetPreferenceDuringBootstrapping(t *testing.T) { env := newEnvironment(t) env.ctx.Lock.Lock() - env.isBootstrapped.SetValue(false) + env.isBootstrapped.Set(false) env.ctx.Log = logging.NoWarn{} defer func() { require.NoError(t, shutdownEnvironment(env)) diff --git a/vms/platformvm/blocks/builder/helpers_test.go b/vms/platformvm/blocks/builder/helpers_test.go index 69c0cedda071..642e2e892feb 100644 --- a/vms/platformvm/blocks/builder/helpers_test.go +++ b/vms/platformvm/blocks/builder/helpers_test.go @@ -93,7 +93,7 @@ type environment struct { mempool mempool.Mempool sender *common.SenderTest - isBootstrapped *utils.AtomicBool + isBootstrapped *utils.Atomic[bool] config *config.Config clk *mockable.Clock baseDB *versiondb.Database @@ -110,11 +110,11 @@ type environment struct { func newEnvironment(t *testing.T) *environment { res := &environment{ - isBootstrapped: &utils.AtomicBool{}, + isBootstrapped: &utils.Atomic[bool]{}, config: defaultConfig(), clk: defaultClock(), } - res.isBootstrapped.SetValue(true) + res.isBootstrapped.Set(true) baseDBManager := manager.NewMemDB(version.Semantic1_0_0) res.baseDB = versiondb.New(baseDBManager.Current().Database) @@ -123,7 +123,7 @@ func newEnvironment(t *testing.T) *environment { res.ctx.Lock.Lock() defer res.ctx.Lock.Unlock() - res.fx = defaultFx(res.clk, res.ctx.Log, res.isBootstrapped.GetValue()) + res.fx = defaultFx(res.clk, res.ctx.Log, res.isBootstrapped.Get()) rewardsCalc := reward.NewCalculator(res.config.RewardConfig) res.state = defaultState(res.config, res.ctx, res.baseDB, rewardsCalc) @@ -436,7 +436,7 @@ func buildGenesisTest(ctx *snow.Context) []byte { } func shutdownEnvironment(env *environment) error { - if env.isBootstrapped.GetValue() { + if env.isBootstrapped.Get() { primaryValidatorSet, exist := env.config.Validators.Get(constants.PrimaryNetworkID) if !exist { return errMissingPrimaryValidators diff --git a/vms/platformvm/blocks/builder/network.go b/vms/platformvm/blocks/builder/network.go index 6f29eba83fc0..d831b1f6720b 100644 --- a/vms/platformvm/blocks/builder/network.go +++ b/vms/platformvm/blocks/builder/network.go @@ -41,7 +41,7 @@ type network struct { // gossip related attributes appSender common.AppSender - recentTxs *cache.LRU + recentTxs *cache.LRU[ids.ID, struct{}] } func NewNetwork( @@ -53,7 +53,7 @@ func NewNetwork( ctx: ctx, blkBuilder: blkBuilder, appSender: appSender, - recentTxs: &cache.LRU{Size: recentCacheSize}, + recentTxs: &cache.LRU[ids.ID, struct{}]{Size: recentCacheSize}, } } @@ -153,7 +153,7 @@ func (n *network) GossipTx(tx *txs.Tx) error { if _, has := n.recentTxs.Get(txID); has { return nil } - n.recentTxs.Put(txID, nil) + n.recentTxs.Put(txID, struct{}{}) n.ctx.Log.Debug("gossiping tx", zap.Stringer("txID", txID), diff --git a/vms/platformvm/blocks/executor/acceptor.go b/vms/platformvm/blocks/executor/acceptor.go index 34f7700877a3..23740d3c86da 100644 --- a/vms/platformvm/blocks/executor/acceptor.go +++ b/vms/platformvm/blocks/executor/acceptor.go @@ -26,7 +26,7 @@ type acceptor struct { *backend metrics metrics.Metrics recentlyAccepted window.Window[ids.ID] - bootstrapped *utils.AtomicBool + bootstrapped *utils.Atomic[bool] } func (a *acceptor) BanffAbortBlock(b *blocks.BanffAbortBlock) error { @@ -180,7 +180,7 @@ func (a *acceptor) abortBlock(b blocks.Block) error { return fmt.Errorf("%w: %s", state.ErrMissingParentState, parentID) } - if a.bootstrapped.GetValue() { + if a.bootstrapped.Get() { if parentState.initiallyPreferCommit { a.metrics.MarkOptionVoteLost() } else { @@ -198,7 +198,7 @@ func (a *acceptor) commitBlock(b blocks.Block) error { return fmt.Errorf("%w: %s", state.ErrMissingParentState, parentID) } - if a.bootstrapped.GetValue() { + if a.bootstrapped.Get() { if parentState.initiallyPreferCommit { a.metrics.MarkOptionVoteWon() } else { diff --git a/vms/platformvm/blocks/executor/acceptor_test.go b/vms/platformvm/blocks/executor/acceptor_test.go index 80a2695db049..5fdabb9e2878 100644 --- a/vms/platformvm/blocks/executor/acceptor_test.go +++ b/vms/platformvm/blocks/executor/acceptor_test.go @@ -284,7 +284,7 @@ func TestAcceptorVisitCommitBlock(t *testing.T) { MaxSize: 1, TTL: time.Hour, }), - bootstrapped: &utils.AtomicBool{}, + bootstrapped: &utils.Atomic[bool]{}, } blk, err := blocks.NewApricotCommitBlock(parentID, 1 /*height*/) @@ -374,7 +374,7 @@ func TestAcceptorVisitAbortBlock(t *testing.T) { MaxSize: 1, TTL: time.Hour, }), - bootstrapped: &utils.AtomicBool{}, + bootstrapped: &utils.Atomic[bool]{}, } blk, err := blocks.NewApricotAbortBlock(parentID, 1 /*height*/) diff --git a/vms/platformvm/blocks/executor/helpers_test.go b/vms/platformvm/blocks/executor/helpers_test.go index 38d24508978d..5d7dc0c5fc97 100644 --- a/vms/platformvm/blocks/executor/helpers_test.go +++ b/vms/platformvm/blocks/executor/helpers_test.go @@ -111,7 +111,7 @@ type environment struct { mempool mempool.Mempool sender *common.SenderTest - isBootstrapped *utils.AtomicBool + isBootstrapped *utils.Atomic[bool] config *config.Config clk *mockable.Clock baseDB *versiondb.Database @@ -132,16 +132,16 @@ func (*environment) ResetBlockTimer() { func newEnvironment(t *testing.T, ctrl *gomock.Controller) *environment { res := &environment{ - isBootstrapped: &utils.AtomicBool{}, + isBootstrapped: &utils.Atomic[bool]{}, config: defaultConfig(), clk: defaultClock(), } - res.isBootstrapped.SetValue(true) + res.isBootstrapped.Set(true) baseDBManager := db_manager.NewMemDB(version.Semantic1_0_0) res.baseDB = versiondb.New(baseDBManager.Current().Database) res.ctx = defaultCtx(res.baseDB) - res.fx = defaultFx(res.clk, res.ctx.Log, res.isBootstrapped.GetValue()) + res.fx = defaultFx(res.clk, res.ctx.Log, res.isBootstrapped.Get()) rewardsCalc := reward.NewCalculator(res.config.RewardConfig) res.atomicUTXOs = avax.NewAtomicUTXOManager(res.ctx.SharedMemory, txs.Codec) @@ -472,7 +472,7 @@ func shutdownEnvironment(t *environment) error { return nil } - if t.isBootstrapped.GetValue() { + if t.isBootstrapped.Get() { primaryValidatorSet, exist := t.config.Validators.Get(constants.PrimaryNetworkID) if !exist { return errMissingPrimaryValidators diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index fb64e27d3b80..54cf39f19a44 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" "github.com/ava-labs/avalanchego/api" + "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils" @@ -28,6 +29,7 @@ import ( "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/keystore" + "github.com/ava-labs/avalanchego/vms/platformvm/fx" "github.com/ava-labs/avalanchego/vms/platformvm/reward" "github.com/ava-labs/avalanchego/vms/platformvm/signer" "github.com/ava-labs/avalanchego/vms/platformvm/stakeable" @@ -51,6 +53,10 @@ const ( // Minimum amount of delay to allow a transaction to be issued through the // API minAddStakerDelay = 2 * executor.SyncBound + + // Note: Staker attributes cache should be large enough so that no evictions + // happen when the API loops through all stakers. + stakerAttributesCacheSize = 100_000 ) var ( @@ -74,8 +80,18 @@ var ( // Service defines the API calls that can be made to the platform chain type Service struct { - vm *VM - addrManager avax.AddressManager + vm *VM + addrManager avax.AddressManager + stakerAttributesCache *cache.LRU[ids.ID, *stakerAttributes] +} + +// All attributes are optional and may not be filled for each stakerTx. +type stakerAttributes struct { + shares uint32 + rewardsOwner fx.Owner + validationRewardsOwner fx.Owner + delegationRewardsOwner fx.Owner + proofOfPossession *signer.ProofOfPossession } type GetHeightResponse struct { @@ -673,6 +689,48 @@ type GetCurrentValidatorsReply struct { Validators []interface{} `json:"validators"` } +func (s *Service) loadStakerTxAttributes(txID ids.ID) (*stakerAttributes, error) { + // Lookup tx from the cache first. + attr, found := s.stakerAttributesCache.Get(txID) + if found { + return attr, nil + } + + // Tx not available in cache; pull it from disk and populate the cache. + tx, _, err := s.vm.state.GetTx(txID) + if err != nil { + return nil, err + } + + switch stakerTx := tx.Unsigned.(type) { + case txs.ValidatorTx: + var pop *signer.ProofOfPossession + if staker, ok := stakerTx.(*txs.AddPermissionlessValidatorTx); ok { + if s, ok := staker.Signer.(*signer.ProofOfPossession); ok { + pop = s + } + } + + attr = &stakerAttributes{ + shares: stakerTx.Shares(), + validationRewardsOwner: stakerTx.ValidationRewardsOwner(), + delegationRewardsOwner: stakerTx.DelegationRewardsOwner(), + proofOfPossession: pop, + } + + case txs.DelegatorTx: + attr = &stakerAttributes{ + rewardsOwner: stakerTx.RewardsOwner(), + } + + default: + return nil, fmt.Errorf("unexpected staker tx type %T", tx.Unsigned) + } + + s.stakerAttributesCache.Put(txID, attr) + return attr, nil +} + // GetCurrentValidators returns current validators and delegators func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidatorsArgs, reply *GetCurrentValidatorsReply) error { s.vm.ctx.Log.Debug("Platform: GetCurrentValidators called") @@ -685,29 +743,48 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato // Create set of nodeIDs nodeIDs := set.Set[ids.NodeID]{} nodeIDs.Add(args.NodeIDs...) - includeAllNodes := nodeIDs.Len() == 0 - - currentStakerIterator, err := s.vm.state.GetCurrentStakerIterator() - if err != nil { - return err - } - defer currentStakerIterator.Release() - // TODO: do not iterate over all stakers when nodeIDs given. Use currentValidators.ValidatorSet for iteration - for currentStakerIterator.Next() { // Iterates in order of increasing stop time - currentStaker := currentStakerIterator.Value() - if args.SubnetID != currentStaker.SubnetID { - continue + numNodeIDs := nodeIDs.Len() + targetStakers := make([]*state.Staker, 0, numNodeIDs) + if numNodeIDs == 0 { // Include all nodes + currentStakerIterator, err := s.vm.state.GetCurrentStakerIterator() + if err != nil { + return err } - if !includeAllNodes && !nodeIDs.Contains(currentStaker.NodeID) { - continue + for currentStakerIterator.Next() { + staker := currentStakerIterator.Value() + if args.SubnetID != staker.SubnetID { + continue + } + targetStakers = append(targetStakers, staker) } + currentStakerIterator.Release() + } else { + for nodeID := range nodeIDs { + staker, err := s.vm.state.GetCurrentValidator(args.SubnetID, nodeID) + switch err { + case nil: + case database.ErrNotFound: + // nothing to do, continue + continue + default: + return err + } + targetStakers = append(targetStakers, staker) - tx, _, err := s.vm.state.GetTx(currentStaker.TxID) - if err != nil { - return err + delegatorsIt, err := s.vm.state.GetCurrentDelegatorIterator(args.SubnetID, nodeID) + if err != nil { + return err + } + for delegatorsIt.Next() { + staker := delegatorsIt.Value() + targetStakers = append(targetStakers, staker) + } + delegatorsIt.Release() } + } + for _, currentStaker := range targetStakers { nodeID := currentStaker.NodeID weight := json.Uint64(currentStaker.Weight) apiStaker := platformapi.Staker{ @@ -717,11 +794,16 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato StakeAmount: &weight, NodeID: nodeID, } - potentialReward := json.Uint64(currentStaker.PotentialReward) - switch staker := tx.Unsigned.(type) { - case txs.ValidatorTx: - shares := staker.Shares() + + switch currentStaker.Priority { + case txs.PrimaryNetworkValidatorCurrentPriority, txs.SubnetPermissionlessValidatorCurrentPriority: + attr, err := s.loadStakerTxAttributes(currentStaker.TxID) + if err != nil { + return err + } + + shares := attr.shares delegationFee := json.Float32(100 * float32(shares) / float32(reward.PercentDenominator)) uptime, err := s.getAPIUptime(currentStaker) @@ -734,14 +816,14 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato validationRewardOwner *platformapi.Owner delegationRewardOwner *platformapi.Owner ) - validationOwner, ok := staker.ValidationRewardsOwner().(*secp256k1fx.OutputOwners) + validationOwner, ok := attr.validationRewardsOwner.(*secp256k1fx.OutputOwners) if ok { validationRewardOwner, err = s.getAPIOwner(validationOwner) if err != nil { return err } } - delegationOwner, ok := staker.DelegationRewardsOwner().(*secp256k1fx.OutputOwners) + delegationOwner, ok := attr.delegationRewardsOwner.(*secp256k1fx.OutputOwners) if ok { delegationRewardOwner, err = s.getAPIOwner(delegationOwner) if err != nil { @@ -758,19 +840,18 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato ValidationRewardOwner: validationRewardOwner, DelegationRewardOwner: delegationRewardOwner, DelegationFee: delegationFee, + Signer: attr.proofOfPossession, } + reply.Validators = append(reply.Validators, vdr) - if staker, ok := staker.(*txs.AddPermissionlessValidatorTx); ok { - if signer, ok := staker.Signer.(*signer.ProofOfPossession); ok { - vdr.Signer = signer - } + case txs.PrimaryNetworkDelegatorCurrentPriority, txs.SubnetPermissionlessDelegatorCurrentPriority: + attr, err := s.loadStakerTxAttributes(currentStaker.TxID) + if err != nil { + return err } - reply.Validators = append(reply.Validators, vdr) - - case txs.DelegatorTx: var rewardOwner *platformapi.Owner - owner, ok := staker.RewardsOwner().(*secp256k1fx.OutputOwners) + owner, ok := attr.rewardsOwner.(*secp256k1fx.OutputOwners) if ok { rewardOwner, err = s.getAPIOwner(owner) if err != nil { @@ -784,7 +865,8 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato PotentialReward: &potentialReward, } vdrToDelegators[delegator.NodeID] = append(vdrToDelegators[delegator.NodeID], delegator) - case *txs.AddSubnetValidatorTx: + + case txs.SubnetPermissionedValidatorCurrentPriority: uptime, err := s.getAPIUptime(currentStaker) if err != nil { return err @@ -795,8 +877,9 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato Connected: connected, Uptime: uptime, }) + default: - return fmt.Errorf("expected validator but got %T", tx.Unsigned) + return fmt.Errorf("unexpected staker priority %d", currentStaker.Priority) } } @@ -841,28 +924,48 @@ func (s *Service) GetPendingValidators(_ *http.Request, args *GetPendingValidato // Create set of nodeIDs nodeIDs := set.Set[ids.NodeID]{} nodeIDs.Add(args.NodeIDs...) - includeAllNodes := nodeIDs.Len() == 0 - pendingStakerIterator, err := s.vm.state.GetPendingStakerIterator() - if err != nil { - return err - } - defer pendingStakerIterator.Release() - - for pendingStakerIterator.Next() { // Iterates in order of increasing start time - pendingStaker := pendingStakerIterator.Value() - if args.SubnetID != pendingStaker.SubnetID { - continue + numNodeIDs := nodeIDs.Len() + targetStakers := make([]*state.Staker, 0, numNodeIDs) + if numNodeIDs == 0 { // Include all nodes + pendingStakerIterator, err := s.vm.state.GetPendingStakerIterator() + if err != nil { + return err } - if !includeAllNodes && !nodeIDs.Contains(pendingStaker.NodeID) { - continue + for pendingStakerIterator.Next() { // Iterates in order of increasing stop time + staker := pendingStakerIterator.Value() + if args.SubnetID != staker.SubnetID { + continue + } + targetStakers = append(targetStakers, staker) } + pendingStakerIterator.Release() + } else { + for nodeID := range nodeIDs { + staker, err := s.vm.state.GetPendingValidator(args.SubnetID, nodeID) + switch err { + case nil: + case database.ErrNotFound: + // nothing to do, continue + continue + default: + return err + } + targetStakers = append(targetStakers, staker) - tx, _, err := s.vm.state.GetTx(pendingStaker.TxID) - if err != nil { - return err + delegatorsIt, err := s.vm.state.GetPendingDelegatorIterator(args.SubnetID, nodeID) + if err != nil { + return err + } + for delegatorsIt.Next() { + staker := delegatorsIt.Value() + targetStakers = append(targetStakers, staker) + } + delegatorsIt.Release() } + } + for _, pendingStaker := range targetStakers { nodeID := pendingStaker.NodeID weight := json.Uint64(pendingStaker.Weight) apiStaker := platformapi.Staker{ @@ -873,9 +976,14 @@ func (s *Service) GetPendingValidators(_ *http.Request, args *GetPendingValidato StakeAmount: &weight, } - switch staker := tx.Unsigned.(type) { - case txs.ValidatorTx: - shares := staker.Shares() + switch pendingStaker.Priority { + case txs.PrimaryNetworkValidatorPendingPriority, txs.SubnetPermissionlessValidatorPendingPriority: + attr, err := s.loadStakerTxAttributes(pendingStaker.TxID) + if err != nil { + return err + } + + shares := attr.shares delegationFee := json.Float32(100 * float32(shares) / float32(reward.PercentDenominator)) connected := s.vm.uptimeManager.IsConnected(nodeID, args.SubnetID) @@ -883,27 +991,22 @@ func (s *Service) GetPendingValidators(_ *http.Request, args *GetPendingValidato Staker: apiStaker, DelegationFee: delegationFee, Connected: connected, + Signer: attr.proofOfPossession, } - - if staker, ok := staker.(*txs.AddPermissionlessValidatorTx); ok { - if signer, ok := staker.Signer.(*signer.ProofOfPossession); ok { - vdr.Signer = signer - } - } - reply.Validators = append(reply.Validators, vdr) - case txs.DelegatorTx: + case txs.PrimaryNetworkDelegatorApricotPendingPriority, txs.PrimaryNetworkDelegatorBanffPendingPriority, txs.SubnetPermissionlessDelegatorPendingPriority: reply.Delegators = append(reply.Delegators, apiStaker) - case *txs.AddSubnetValidatorTx: + case txs.SubnetPermissionedValidatorPendingPriority: connected := s.vm.uptimeManager.IsConnected(nodeID, args.SubnetID) reply.Validators = append(reply.Validators, platformapi.PermissionedValidator{ Staker: apiStaker, Connected: connected, }) + default: - return fmt.Errorf("expected validator but got %T", tx.Unsigned) + return fmt.Errorf("unexpected staker priority %d", pendingStaker.Priority) } } return nil @@ -2355,7 +2458,9 @@ func (s *Service) getAPIUptime(staker *state.Staker) (*json.Float32, error) { if err != nil { return nil, err } - uptime := json.Float32(rawUptime) + // Transform this to a percentage (0-100) to make it consistent + // with observedUptime in info.peers API + uptime := json.Float32(rawUptime * 100) return &uptime, nil } diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 09356cae4866..a455b61644c4 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -17,6 +17,7 @@ import ( "github.com/ava-labs/avalanchego/api" "github.com/ava-labs/avalanchego/api/keystore" + "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/chains/atomic" "github.com/ava-labs/avalanchego/database/manager" "github.com/ava-labs/avalanchego/database/prefixdb" @@ -77,6 +78,9 @@ func defaultService(t *testing.T) (*Service, *mutableSharedMemory) { return &Service{ vm: vm, addrManager: avax.NewAddressManager(vm.ctx), + stakerAttributesCache: &cache.LRU[ids.ID, *stakerAttributes]{ + Size: stakerAttributesCacheSize, + }, }, mutableSharedMemory } diff --git a/vms/platformvm/state/diff.go b/vms/platformvm/state/diff.go index 236bdf6a35a5..7b9bfe8bbef8 100644 --- a/vms/platformvm/state/diff.go +++ b/vms/platformvm/state/diff.go @@ -47,10 +47,8 @@ type diff struct { addedChains map[ids.ID][]*txs.Tx cachedChains map[ids.ID][]*txs.Tx - // map of txID -> []*UTXO addedRewardUTXOs map[ids.ID][]*avax.UTXO - // map of txID -> {*txs.Tx, Status} addedTxs map[ids.ID]*txAndStatus // map of modified UTXOID -> *UTXO if the UTXO is nil, it has been removed @@ -439,12 +437,11 @@ func (d *diff) Apply(baseState State) { } for _, subnetValidatorDiffs := range d.currentStakerDiffs.validatorDiffs { for _, validatorDiff := range subnetValidatorDiffs { - if validatorDiff.validatorModified { - if validatorDiff.validatorDeleted { - baseState.DeleteCurrentValidator(validatorDiff.validator) - } else { - baseState.PutCurrentValidator(validatorDiff.validator) - } + if validatorDiff.validatorAdded { + baseState.PutCurrentValidator(validatorDiff.validator) + } + if validatorDiff.validatorDeleted { + baseState.DeleteCurrentValidator(validatorDiff.validator) } addedDelegatorIterator := NewTreeIterator(validatorDiff.addedDelegators) @@ -460,12 +457,11 @@ func (d *diff) Apply(baseState State) { } for _, subnetValidatorDiffs := range d.pendingStakerDiffs.validatorDiffs { for _, validatorDiff := range subnetValidatorDiffs { - if validatorDiff.validatorModified { - if validatorDiff.validatorDeleted { - baseState.DeletePendingValidator(validatorDiff.validator) - } else { - baseState.PutPendingValidator(validatorDiff.validator) - } + if validatorDiff.validatorAdded { + baseState.PutPendingValidator(validatorDiff.validator) + } + if validatorDiff.validatorDeleted { + baseState.DeletePendingValidator(validatorDiff.validator) } addedDelegatorIterator := NewTreeIterator(validatorDiff.addedDelegators) diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index e4a74c9881c2..9d026668cb7a 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -109,6 +109,7 @@ func TestDiffCurrentValidator(t *testing.T) { d.DeleteCurrentValidator(currentValidator) // Make sure the deletion worked + state.EXPECT().GetCurrentValidator(currentValidator.SubnetID, currentValidator.NodeID).Return(nil, database.ErrNotFound).Times(1) _, err = d.GetCurrentValidator(currentValidator.SubnetID, currentValidator.NodeID) require.ErrorIs(err, database.ErrNotFound) } @@ -146,6 +147,7 @@ func TestDiffPendingValidator(t *testing.T) { d.DeletePendingValidator(pendingValidator) // Make sure the deletion worked + state.EXPECT().GetPendingValidator(pendingValidator.SubnetID, pendingValidator.NodeID).Return(nil, database.ErrNotFound).Times(1) _, err = d.GetPendingValidator(pendingValidator.SubnetID, pendingValidator.NodeID) require.ErrorIs(err, database.ErrNotFound) } diff --git a/vms/platformvm/state/staker.go b/vms/platformvm/state/staker.go index a03b5226fe83..32e91c6d8aa5 100644 --- a/vms/platformvm/state/staker.go +++ b/vms/platformvm/state/staker.go @@ -14,7 +14,7 @@ import ( "github.com/ava-labs/avalanchego/vms/platformvm/txs" ) -var _ btree.Item = (*Staker)(nil) +var _ btree.LessFunc[*Staker] = (*Staker).Less // StakerIterator defines an interface for iterating over a set of stakers. type StakerIterator interface { @@ -33,6 +33,7 @@ type StakerIterator interface { // Staker contains all information required to represent a validator or // delegator in the current and pending validator sets. +// Invariant: Staker's size is bounded to prevent OOM DoS attacks. type Staker struct { TxID ids.ID NodeID ids.NodeID @@ -63,11 +64,7 @@ type Staker struct { // lesser one. // 3. If the priorities are also the same, the one with the lesser txID is // lesser. -// -// Invariant: [thanIntf] is a *Staker. -func (s *Staker) Less(thanIntf btree.Item) bool { - than := thanIntf.(*Staker) - +func (s *Staker) Less(than *Staker) bool { if s.NextTime.Before(than.NextTime) { return true } diff --git a/vms/platformvm/state/stakers.go b/vms/platformvm/state/stakers.go index b3deae019741..043bf452fca0 100644 --- a/vms/platformvm/state/stakers.go +++ b/vms/platformvm/state/stakers.go @@ -90,20 +90,20 @@ type PendingStakers interface { type baseStakers struct { // subnetID --> nodeID --> current state for the validator of the subnet validators map[ids.ID]map[ids.NodeID]*baseStaker - stakers *btree.BTree + stakers *btree.BTreeG[*Staker] // subnetID --> nodeID --> diff for that validator since the last db write validatorDiffs map[ids.ID]map[ids.NodeID]*diffValidator } type baseStaker struct { validator *Staker - delegators *btree.BTree + delegators *btree.BTreeG[*Staker] } func newBaseStakers() *baseStakers { return &baseStakers{ validators: make(map[ids.ID]map[ids.NodeID]*baseStaker), - stakers: btree.New(defaultTreeDegree), + stakers: btree.NewG(defaultTreeDegree, (*Staker).Less), validatorDiffs: make(map[ids.ID]map[ids.NodeID]*diffValidator), } } @@ -128,8 +128,7 @@ func (v *baseStakers) PutValidator(staker *Staker) { validator.validator = staker validatorDiff := v.getOrCreateValidatorDiff(staker.SubnetID, staker.NodeID) - validatorDiff.validatorModified = true - validatorDiff.validatorDeleted = false + validatorDiff.validatorAdded = true validatorDiff.validator = staker v.stakers.ReplaceOrInsert(staker) @@ -141,7 +140,6 @@ func (v *baseStakers) DeleteValidator(staker *Staker) { v.pruneValidator(staker.SubnetID, staker.NodeID) validatorDiff := v.getOrCreateValidatorDiff(staker.SubnetID, staker.NodeID) - validatorDiff.validatorModified = true validatorDiff.validatorDeleted = true validatorDiff.validator = staker @@ -163,13 +161,13 @@ func (v *baseStakers) GetDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) S func (v *baseStakers) PutDelegator(staker *Staker) { validator := v.getOrCreateValidator(staker.SubnetID, staker.NodeID) if validator.delegators == nil { - validator.delegators = btree.New(defaultTreeDegree) + validator.delegators = btree.NewG(defaultTreeDegree, (*Staker).Less) } validator.delegators.ReplaceOrInsert(staker) validatorDiff := v.getOrCreateValidatorDiff(staker.SubnetID, staker.NodeID) if validatorDiff.addedDelegators == nil { - validatorDiff.addedDelegators = btree.New(defaultTreeDegree) + validatorDiff.addedDelegators = btree.NewG(defaultTreeDegree, (*Staker).Less) } validatorDiff.addedDelegators.ReplaceOrInsert(staker) @@ -244,17 +242,18 @@ func (v *baseStakers) getOrCreateValidatorDiff(subnetID ids.ID, nodeID ids.NodeI type diffStakers struct { // subnetID --> nodeID --> diff for that validator validatorDiffs map[ids.ID]map[ids.NodeID]*diffValidator - addedStakers *btree.BTree + addedStakers *btree.BTreeG[*Staker] deletedStakers map[ids.ID]*Staker } type diffValidator struct { - validatorModified bool - // [validatorDeleted] implies [validatorModified] + // Invariant: [validatorAdded] and [validatorDeleted] will not be set at the + // same time. + validatorAdded bool validatorDeleted bool validator *Staker - addedDelegators *btree.BTree + addedDelegators *btree.BTreeG[*Staker] deletedDelegators map[ids.ID]*Staker } @@ -266,6 +265,8 @@ type diffValidator struct { // 2. If the validator was removed in this diff, [nil, true] will be returned. // 3. If the validator was not modified by this diff, [nil, false] will be // returned. +// +// Invariant: Assumes that the validator will never be removed and then added. func (s *diffStakers) GetValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, bool) { subnetValidatorDiffs, ok := s.validatorDiffs[subnetID] if !ok { @@ -277,38 +278,41 @@ func (s *diffStakers) GetValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, return nil, false } - if !validatorDiff.validatorModified { - return nil, false - } - - if validatorDiff.validatorDeleted { + switch { + case validatorDiff.validatorAdded: + return validatorDiff.validator, true + case validatorDiff.validatorDeleted: return nil, true + default: + return nil, false } - return validatorDiff.validator, true } func (s *diffStakers) PutValidator(staker *Staker) { validatorDiff := s.getOrCreateDiff(staker.SubnetID, staker.NodeID) - validatorDiff.validatorModified = true - validatorDiff.validatorDeleted = false + validatorDiff.validatorAdded = true validatorDiff.validator = staker if s.addedStakers == nil { - s.addedStakers = btree.New(defaultTreeDegree) + s.addedStakers = btree.NewG(defaultTreeDegree, (*Staker).Less) } s.addedStakers.ReplaceOrInsert(staker) } func (s *diffStakers) DeleteValidator(staker *Staker) { validatorDiff := s.getOrCreateDiff(staker.SubnetID, staker.NodeID) - validatorDiff.validatorModified = true - validatorDiff.validatorDeleted = true - validatorDiff.validator = staker - - if s.deletedStakers == nil { - s.deletedStakers = make(map[ids.ID]*Staker) + if validatorDiff.validatorAdded { + validatorDiff.validatorAdded = false + s.addedStakers.Delete(validatorDiff.validator) + validatorDiff.validator = nil + } else { + validatorDiff.validatorDeleted = true + validatorDiff.validator = staker + if s.deletedStakers == nil { + s.deletedStakers = make(map[ids.ID]*Staker) + } + s.deletedStakers[staker.TxID] = staker } - s.deletedStakers[staker.TxID] = staker } func (s *diffStakers) GetDelegatorIterator( @@ -339,12 +343,12 @@ func (s *diffStakers) GetDelegatorIterator( func (s *diffStakers) PutDelegator(staker *Staker) { validatorDiff := s.getOrCreateDiff(staker.SubnetID, staker.NodeID) if validatorDiff.addedDelegators == nil { - validatorDiff.addedDelegators = btree.New(defaultTreeDegree) + validatorDiff.addedDelegators = btree.NewG(defaultTreeDegree, (*Staker).Less) } validatorDiff.addedDelegators.ReplaceOrInsert(staker) if s.addedStakers == nil { - s.addedStakers = btree.New(defaultTreeDegree) + s.addedStakers = btree.NewG(defaultTreeDegree, (*Staker).Less) } s.addedStakers.ReplaceOrInsert(staker) } diff --git a/vms/platformvm/state/stakers_test.go b/vms/platformvm/state/stakers_test.go index d022d9cd1e75..d5650971eedd 100644 --- a/vms/platformvm/state/stakers_test.go +++ b/vms/platformvm/state/stakers_test.go @@ -166,14 +166,30 @@ func TestDiffStakersValidator(t *testing.T) { v.DeleteValidator(staker) - returnedStaker, ok = v.GetValidator(staker.SubnetID, staker.NodeID) - require.True(ok) - require.Nil(returnedStaker) + _, ok = v.GetValidator(staker.SubnetID, staker.NodeID) + require.False(ok) stakerIterator = v.GetStakerIterator(EmptyIterator) assertIteratorsEqual(t, NewSliceIterator(delegator), stakerIterator) } +func TestDiffStakersDeleteValidator(t *testing.T) { + require := require.New(t) + staker := newTestStaker() + delegator := newTestStaker() + + v := diffStakers{} + + _, ok := v.GetValidator(ids.GenerateTestID(), delegator.NodeID) + require.False(ok) + + v.DeleteValidator(staker) + + returnedStaker, ok := v.GetValidator(staker.SubnetID, staker.NodeID) + require.True(ok) + require.Nil(returnedStaker) +} + func TestDiffStakersDelegator(t *testing.T) { staker := newTestStaker() delegator := newTestStaker() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index a595259521f6..95970ac998f2 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -139,8 +139,8 @@ type State interface { // Commit changes to the base database. Commit() error - // Returns a batch of unwritten changes that, when written, will be commit - // all pending changes to the base database. + // Returns a batch of unwritten changes that, when written, will commit all + // pending changes to the base database. CommitBatch() (database.Batch, error) Close() error @@ -228,8 +228,10 @@ type state struct { currentHeight uint64 addedBlocks map[ids.ID]stateBlk // map of blockID -> Block - blockCache cache.Cacher // cache of blockID -> Block, if the entry is nil, it is not in the database - blockDB database.Database + // cache of blockID -> Block + // If the block isn't known, nil is cached. + blockCache cache.Cacher[ids.ID, *stateBlk] + blockDB database.Database validatorsDB database.Database currentValidatorsDB database.Database @@ -251,18 +253,18 @@ type state struct { pendingSubnetDelegatorBaseDB database.Database pendingSubnetDelegatorList linkeddb.LinkedDB - validatorWeightDiffsCache cache.Cacher // cache of heightWithSubnet -> map[ids.NodeID]*ValidatorWeightDiff + validatorWeightDiffsCache cache.Cacher[string, map[ids.NodeID]*ValidatorWeightDiff] // cache of heightWithSubnet -> map[ids.NodeID]*ValidatorWeightDiff validatorWeightDiffsDB database.Database - validatorPublicKeyDiffsCache cache.Cacher // cache of height -> map[ids.NodeID]*bls.PublicKey + validatorPublicKeyDiffsCache cache.Cacher[uint64, map[ids.NodeID]*bls.PublicKey] // cache of height -> map[ids.NodeID]*bls.PublicKey validatorPublicKeyDiffsDB database.Database - addedTxs map[ids.ID]*txAndStatus // map of txID -> {*txs.Tx, Status} - txCache cache.Cacher // cache of txID -> {*txs.Tx, Status} if the entry is nil, it is not in the database + addedTxs map[ids.ID]*txAndStatus // map of txID -> {*txs.Tx, Status} + txCache cache.Cacher[ids.ID, *txAndStatus] // txID -> {*txs.Tx, Status}. If the entry is nil, it isn't in the database txDB database.Database - addedRewardUTXOs map[ids.ID][]*avax.UTXO // map of txID -> []*UTXO - rewardUTXOsCache cache.Cacher // cache of txID -> []*UTXO + addedRewardUTXOs map[ids.ID][]*avax.UTXO // map of txID -> []*UTXO + rewardUTXOsCache cache.Cacher[ids.ID, []*avax.UTXO] // txID -> []*UTXO rewardUTXODB database.Database modifiedUTXOs map[ids.ID]*avax.UTXO // map of modified UTXOID -> *UTXO if the UTXO is nil, it has been removed @@ -274,17 +276,17 @@ type state struct { subnetBaseDB database.Database subnetDB linkeddb.LinkedDB - transformedSubnets map[ids.ID]*txs.Tx // map of subnetID -> transformSubnetTx - transformedSubnetCache cache.Cacher // cache of subnetID -> transformSubnetTx if the entry is nil, it is not in the database + transformedSubnets map[ids.ID]*txs.Tx // map of subnetID -> transformSubnetTx + transformedSubnetCache cache.Cacher[ids.ID, *txs.Tx] // cache of subnetID -> transformSubnetTx if the entry is nil, it is not in the database transformedSubnetDB database.Database - modifiedSupplies map[ids.ID]uint64 // map of subnetID -> current supply - supplyCache cache.Cacher // cache of subnetID -> current supply if the entry is nil, it is not in the database + modifiedSupplies map[ids.ID]uint64 // map of subnetID -> current supply + supplyCache cache.Cacher[ids.ID, *uint64] // cache of subnetID -> current supply if the entry is nil, it is not in the database supplyDB database.Database - addedChains map[ids.ID][]*txs.Tx // maps subnetID -> the newly added chains to the subnet - chainCache cache.Cacher // cache of subnetID -> the chains after all local modifications []*txs.Tx - chainDBCache cache.Cacher // cache of subnetID -> linkedDB + addedChains map[ids.ID][]*txs.Tx // maps subnetID -> the newly added chains to the subnet + chainCache cache.Cacher[ids.ID, []*txs.Tx] // cache of subnetID -> the chains after all local modifications []*txs.Tx + chainDBCache cache.Cacher[ids.ID, linkeddb.LinkedDB] // cache of subnetID -> linkedDB chainDB database.Database // The persisted fields represent the current database value @@ -370,10 +372,10 @@ func new( metricsReg prometheus.Registerer, rewards reward.Calculator, ) (*state, error) { - blockCache, err := metercacher.New( + blockCache, err := metercacher.New[ids.ID, *stateBlk]( "block_cache", metricsReg, - &cache.LRU{Size: blockCacheSize}, + &cache.LRU[ids.ID, *stateBlk]{Size: blockCacheSize}, ) if err != nil { return nil, err @@ -396,39 +398,39 @@ func new( pendingSubnetDelegatorBaseDB := prefixdb.New(subnetDelegatorPrefix, pendingValidatorsDB) validatorWeightDiffsDB := prefixdb.New(validatorWeightDiffsPrefix, validatorsDB) - validatorWeightDiffsCache, err := metercacher.New( + validatorWeightDiffsCache, err := metercacher.New[string, map[ids.NodeID]*ValidatorWeightDiff]( "validator_weight_diffs_cache", metricsReg, - &cache.LRU{Size: validatorDiffsCacheSize}, + &cache.LRU[string, map[ids.NodeID]*ValidatorWeightDiff]{Size: validatorDiffsCacheSize}, ) if err != nil { return nil, err } validatorPublicKeyDiffsDB := prefixdb.New(validatorPublicKeyDiffsPrefix, validatorsDB) - validatorPublicKeyDiffsCache, err := metercacher.New( + validatorPublicKeyDiffsCache, err := metercacher.New[uint64, map[ids.NodeID]*bls.PublicKey]( "validator_pub_key_diffs_cache", metricsReg, - &cache.LRU{Size: validatorDiffsCacheSize}, + &cache.LRU[uint64, map[ids.NodeID]*bls.PublicKey]{Size: validatorDiffsCacheSize}, ) if err != nil { return nil, err } - txCache, err := metercacher.New( + txCache, err := metercacher.New[ids.ID, *txAndStatus]( "tx_cache", metricsReg, - &cache.LRU{Size: txCacheSize}, + &cache.LRU[ids.ID, *txAndStatus]{Size: txCacheSize}, ) if err != nil { return nil, err } rewardUTXODB := prefixdb.New(rewardUTXOsPrefix, baseDB) - rewardUTXOsCache, err := metercacher.New( + rewardUTXOsCache, err := metercacher.New[ids.ID, []*avax.UTXO]( "reward_utxos_cache", metricsReg, - &cache.LRU{Size: rewardUTXOsCacheSize}, + &cache.LRU[ids.ID, []*avax.UTXO]{Size: rewardUTXOsCacheSize}, ) if err != nil { return nil, err @@ -442,37 +444,37 @@ func new( subnetBaseDB := prefixdb.New(subnetPrefix, baseDB) - transformedSubnetCache, err := metercacher.New( + transformedSubnetCache, err := metercacher.New[ids.ID, *txs.Tx]( "transformed_subnet_cache", metricsReg, - &cache.LRU{Size: chainCacheSize}, + &cache.LRU[ids.ID, *txs.Tx]{Size: chainCacheSize}, ) if err != nil { return nil, err } - supplyCache, err := metercacher.New( + supplyCache, err := metercacher.New[ids.ID, *uint64]( "supply_cache", metricsReg, - &cache.LRU{Size: chainCacheSize}, + &cache.LRU[ids.ID, *uint64]{Size: chainCacheSize}, ) if err != nil { return nil, err } - chainCache, err := metercacher.New( + chainCache, err := metercacher.New[ids.ID, []*txs.Tx]( "chain_cache", metricsReg, - &cache.LRU{Size: chainCacheSize}, + &cache.LRU[ids.ID, []*txs.Tx]{Size: chainCacheSize}, ) if err != nil { return nil, err } - chainDBCache, err := metercacher.New( + chainDBCache, err := metercacher.New[ids.ID, linkeddb.LinkedDB]( "chain_db_cache", metricsReg, - &cache.LRU{Size: chainDBCacheSize}, + &cache.LRU[ids.ID, linkeddb.LinkedDB]{Size: chainDBCacheSize}, ) if err != nil { return nil, err @@ -656,11 +658,11 @@ func (s *state) GetSubnetTransformation(subnetID ids.ID) (*txs.Tx, error) { return tx, nil } - if txIntf, cached := s.transformedSubnetCache.Get(subnetID); cached { - if txIntf == nil { + if tx, cached := s.transformedSubnetCache.Get(subnetID); cached { + if tx == nil { return nil, database.ErrNotFound } - return txIntf.(*txs.Tx), nil + return tx, nil } transformSubnetTxID, err := database.GetID(s.transformedSubnetDB, subnetID[:]) @@ -686,8 +688,8 @@ func (s *state) AddSubnetTransformation(transformSubnetTxIntf *txs.Tx) { } func (s *state) GetChains(subnetID ids.ID) ([]*txs.Tx, error) { - if chainsIntf, cached := s.chainCache.Get(subnetID); cached { - return chainsIntf.([]*txs.Tx), nil + if chains, cached := s.chainCache.Get(subnetID); cached { + return chains, nil } chainDB := s.getChainDB(subnetID) chainDBIt := chainDB.NewIterator() @@ -718,16 +720,15 @@ func (s *state) AddChain(createChainTxIntf *txs.Tx) { createChainTx := createChainTxIntf.Unsigned.(*txs.CreateChainTx) subnetID := createChainTx.SubnetID s.addedChains[subnetID] = append(s.addedChains[subnetID], createChainTxIntf) - if chainsIntf, cached := s.chainCache.Get(subnetID); cached { - chains := chainsIntf.([]*txs.Tx) + if chains, cached := s.chainCache.Get(subnetID); cached { chains = append(chains, createChainTxIntf) s.chainCache.Put(subnetID, chains) } } func (s *state) getChainDB(subnetID ids.ID) linkeddb.LinkedDB { - if chainDBIntf, cached := s.chainDBCache.Get(subnetID); cached { - return chainDBIntf.(linkeddb.LinkedDB) + if chainDB, cached := s.chainDBCache.Get(subnetID); cached { + return chainDB } rawChainDB := prefixdb.New(subnetID[:], s.chainDB) chainDB := linkeddb.NewDefault(rawChainDB) @@ -739,11 +740,10 @@ func (s *state) GetTx(txID ids.ID) (*txs.Tx, status.Status, error) { if tx, exists := s.addedTxs[txID]; exists { return tx.tx, tx.status, nil } - if txIntf, cached := s.txCache.Get(txID); cached { - if txIntf == nil { + if tx, cached := s.txCache.Get(txID); cached { + if tx == nil { return nil, status.Unknown, database.ErrNotFound } - tx := txIntf.(*txAndStatus) return tx.tx, tx.status, nil } txBytes, err := s.txDB.Get(txID[:]) @@ -785,7 +785,7 @@ func (s *state) GetRewardUTXOs(txID ids.ID) ([]*avax.UTXO, error) { return utxos, nil } if utxos, exists := s.rewardUTXOsCache.Get(txID); exists { - return utxos.([]*avax.UTXO), nil + return utxos, nil } rawTxDB := prefixdb.New(txID[:], s.rewardUTXODB) @@ -869,12 +869,12 @@ func (s *state) GetCurrentSupply(subnetID ids.ID) (uint64, error) { return supply, nil } - supplyIntf, ok := s.supplyCache.Get(subnetID) + cachedSupply, ok := s.supplyCache.Get(subnetID) if ok { - if supplyIntf == nil { + if cachedSupply == nil { return 0, database.ErrNotFound } - return supplyIntf.(uint64), nil + return *cachedSupply, nil } supply, err := database.GetUInt64(s.supplyDB, subnetID[:]) @@ -886,7 +886,7 @@ func (s *state) GetCurrentSupply(subnetID ids.ID) (uint64, error) { return 0, err } - s.supplyCache.Put(subnetID, supply) + s.supplyCache.Put(subnetID, &supply) return supply, nil } @@ -929,8 +929,8 @@ func (s *state) GetValidatorWeightDiffs(height uint64, subnetID ids.ID) (map[ids } prefixStr := string(prefixBytes) - if weightDiffsIntf, ok := s.validatorWeightDiffsCache.Get(prefixStr); ok { - return weightDiffsIntf.(map[ids.NodeID]*ValidatorWeightDiff), nil + if weightDiffs, ok := s.validatorWeightDiffsCache.Get(prefixStr); ok { + return weightDiffs, nil } rawDiffDB := prefixdb.New(prefixBytes, s.validatorWeightDiffsDB) @@ -959,8 +959,8 @@ func (s *state) GetValidatorWeightDiffs(height uint64, subnetID ids.ID) (map[ids } func (s *state) GetValidatorPublicKeyDiffs(height uint64) (map[ids.NodeID]*bls.PublicKey, error) { - if publicKeyDiffsIntf, ok := s.validatorPublicKeyDiffsCache.Get(height); ok { - return publicKeyDiffsIntf.(map[ids.NodeID]*bls.PublicKey), nil + if publicKeyDiffs, ok := s.validatorPublicKeyDiffsCache.Get(height); ok { + return publicKeyDiffs, nil } heightBytes := database.PackUInt64(height) @@ -1228,7 +1228,7 @@ func (s *state) loadCurrentValidators() error { validator := s.currentStakers.getOrCreateValidator(staker.SubnetID, staker.NodeID) if validator.delegators == nil { - validator.delegators = btree.New(defaultTreeDegree) + validator.delegators = btree.NewG(defaultTreeDegree, (*Staker).Less) } validator.delegators.ReplaceOrInsert(staker) @@ -1314,7 +1314,7 @@ func (s *state) loadPendingValidators() error { validator := s.pendingStakers.getOrCreateValidator(staker.SubnetID, staker.NodeID) if validator.delegators == nil { - validator.delegators = btree.New(defaultTreeDegree) + validator.delegators = btree.NewG(defaultTreeDegree, (*Staker).Less) } validator.delegators.ReplaceOrInsert(staker) @@ -1513,7 +1513,7 @@ func (s *state) writeBlocks() error { } delete(s.addedBlocks, blkID) - s.blockCache.Put(blkID, stateBlk) + s.blockCache.Put(blkID, &stBlk) if err := s.blockDB.Put(blkID[:], blockBytes); err != nil { return fmt.Errorf("failed to write block %s: %w", blkID, err) } @@ -1522,15 +1522,13 @@ func (s *state) writeBlocks() error { } func (s *state) GetStatelessBlock(blockID ids.ID) (blocks.Block, choices.Status, error) { - if blk, exists := s.addedBlocks[blockID]; exists { + if blk, ok := s.addedBlocks[blockID]; ok { return blk.Blk, blk.Status, nil } - if blkIntf, cached := s.blockCache.Get(blockID); cached { - if blkIntf == nil { - return nil, choices.Processing, database.ErrNotFound // status does not matter here + if blkState, ok := s.blockCache.Get(blockID); ok { + if blkState == nil { + return nil, choices.Processing, database.ErrNotFound } - - blkState := blkIntf.(stateBlk) return blkState.Blk, blkState.Status, nil } @@ -1553,7 +1551,7 @@ func (s *state) GetStatelessBlock(blockID ids.ID) (blocks.Block, choices.Status, return nil, choices.Processing, err } - s.blockCache.Put(blockID, blkState) + s.blockCache.Put(blockID, &blkState) return blkState.Blk, blkState.Status, nil } @@ -1592,58 +1590,55 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error // Copy [nodeID] so it doesn't get overwritten next iteration. nodeID := nodeID - var ( - weightDiff = &ValidatorWeightDiff{} - isNewValidator bool - ) - if validatorDiff.validatorModified { - // This validator is being added or removed. + weightDiff := &ValidatorWeightDiff{ + Decrease: validatorDiff.validatorDeleted, + } + switch { + case validatorDiff.validatorAdded: staker := validatorDiff.validator - - weightDiff.Decrease = validatorDiff.validatorDeleted weightDiff.Amount = staker.Weight - if validatorDiff.validatorDeleted { - // Invariant: Only the Primary Network contains non-nil - // public keys. - if staker.PublicKey != nil { - // Record the public key of the validator being removed. - pkDiffs[nodeID] = staker.PublicKey - - pkBytes := bls.PublicKeyToBytes(staker.PublicKey) - if err := pkDiffDB.Put(nodeID[:], pkBytes); err != nil { - return err - } - } + // The validator is being added. + vdr := &uptimeAndReward{ + txID: staker.TxID, + lastUpdated: staker.StartTime, - if err := validatorDB.Delete(staker.TxID[:]); err != nil { - return fmt.Errorf("failed to delete current staker: %w", err) - } + UpDuration: 0, + LastUpdated: uint64(staker.StartTime.Unix()), + PotentialReward: staker.PotentialReward, + } - s.validatorUptimes.DeleteUptime(nodeID, subnetID) - } else { - // The validator is being added. - vdr := &uptimeAndReward{ - txID: staker.TxID, - lastUpdated: staker.StartTime, - - UpDuration: 0, - LastUpdated: uint64(staker.StartTime.Unix()), - PotentialReward: staker.PotentialReward, - } + vdrBytes, err := blocks.GenesisCodec.Marshal(blocks.Version, vdr) + if err != nil { + return fmt.Errorf("failed to serialize current validator: %w", err) + } - vdrBytes, err := blocks.GenesisCodec.Marshal(blocks.Version, vdr) - if err != nil { - return fmt.Errorf("failed to serialize current validator: %w", err) - } + if err = validatorDB.Put(staker.TxID[:], vdrBytes); err != nil { + return fmt.Errorf("failed to write current validator to list: %w", err) + } + + s.validatorUptimes.LoadUptime(nodeID, subnetID, vdr) + case validatorDiff.validatorDeleted: + staker := validatorDiff.validator + weightDiff.Amount = staker.Weight + + // Invariant: Only the Primary Network contains non-nil + // public keys. + if staker.PublicKey != nil { + // Record the public key of the validator being removed. + pkDiffs[nodeID] = staker.PublicKey - if err = validatorDB.Put(staker.TxID[:], vdrBytes); err != nil { - return fmt.Errorf("failed to write current validator to list: %w", err) + pkBytes := bls.PublicKeyToBytes(staker.PublicKey) + if err := pkDiffDB.Put(nodeID[:], pkBytes); err != nil { + return err } + } - s.validatorUptimes.LoadUptime(nodeID, subnetID, vdr) - isNewValidator = true + if err := validatorDB.Delete(staker.TxID[:]); err != nil { + return fmt.Errorf("failed to delete current staker: %w", err) } + + s.validatorUptimes.DeleteUptime(nodeID, subnetID) } err := writeCurrentDelegatorDiff( @@ -1683,7 +1678,7 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error if weightDiff.Decrease { err = validators.RemoveWeight(s.cfg.Validators, subnetID, nodeID, weightDiff.Amount) } else { - if isNewValidator { + if validatorDiff.validatorAdded { staker := validatorDiff.validator err = validators.Add( s.cfg.Validators, @@ -1781,17 +1776,16 @@ func writePendingDiff( pendingDelegatorList linkeddb.LinkedDB, validatorDiff *diffValidator, ) error { - if validatorDiff.validatorModified { - staker := validatorDiff.validator - - var err error - if validatorDiff.validatorDeleted { - err = pendingValidatorList.Delete(staker.TxID[:]) - } else { - err = pendingValidatorList.Put(staker.TxID[:], nil) + if validatorDiff.validatorAdded { + err := pendingValidatorList.Put(validatorDiff.validator.TxID[:], nil) + if err != nil { + return fmt.Errorf("failed to add pending validator: %w", err) } + } + if validatorDiff.validatorDeleted { + err := pendingValidatorList.Delete(validatorDiff.validator.TxID[:]) if err != nil { - return fmt.Errorf("failed to update pending validator: %w", err) + return fmt.Errorf("failed to delete pending validator: %w", err) } } @@ -1903,8 +1897,9 @@ func (s *state) writeTransformedSubnets() error { func (s *state) writeSubnetSupplies() error { for subnetID, supply := range s.modifiedSupplies { + supply := supply delete(s.modifiedSupplies, subnetID) - s.supplyCache.Put(subnetID, supply) + s.supplyCache.Put(subnetID, &supply) if err := database.PutUInt64(s.supplyDB, subnetID[:], supply); err != nil { return fmt.Errorf("failed to write subnet supply: %w", err) } diff --git a/vms/platformvm/state/tree_iterator.go b/vms/platformvm/state/tree_iterator.go index b700bc894f08..b138dea98111 100644 --- a/vms/platformvm/state/tree_iterator.go +++ b/vms/platformvm/state/tree_iterator.go @@ -23,7 +23,7 @@ type treeIterator struct { // NewTreeIterator returns a new iterator of the stakers in [tree] in ascending // order. Note that it isn't safe to modify [tree] while iterating over it. -func NewTreeIterator(tree *btree.BTree) StakerIterator { +func NewTreeIterator(tree *btree.BTreeG[*Staker]) StakerIterator { if tree == nil { return EmptyIterator } @@ -34,9 +34,9 @@ func NewTreeIterator(tree *btree.BTree) StakerIterator { it.wg.Add(1) go func() { defer it.wg.Done() - tree.Ascend(func(i btree.Item) bool { + tree.Ascend(func(i *Staker) bool { select { - case it.next <- i.(*Staker): + case it.next <- i: return true case <-it.release: return false diff --git a/vms/platformvm/state/tree_iterator_test.go b/vms/platformvm/state/tree_iterator_test.go index e1af4fe100f2..e57a6761ba15 100644 --- a/vms/platformvm/state/tree_iterator_test.go +++ b/vms/platformvm/state/tree_iterator_test.go @@ -31,7 +31,7 @@ func TestTreeIterator(t *testing.T) { }, } - tree := btree.New(defaultTreeDegree) + tree := btree.NewG(defaultTreeDegree, (*Staker).Less) for _, staker := range stakers { require.Nil(tree.ReplaceOrInsert(staker)) } @@ -68,7 +68,7 @@ func TestTreeIteratorEarlyRelease(t *testing.T) { }, } - tree := btree.New(defaultTreeDegree) + tree := btree.NewG(defaultTreeDegree, (*Staker).Less) for _, staker := range stakers { require.Nil(tree.ReplaceOrInsert(staker)) } diff --git a/vms/platformvm/state/versions.go b/vms/platformvm/state/versions.go index 3668da30fe99..dd3cd23bf575 100644 --- a/vms/platformvm/state/versions.go +++ b/vms/platformvm/state/versions.go @@ -8,5 +8,7 @@ import ( ) type Versions interface { + // GetState returns the state of the chain after [blkID] has been accepted. + // If the state is not known, `false` will be returned. GetState(blkID ids.ID) (Chain, bool) } diff --git a/vms/platformvm/txs/executor/backend.go b/vms/platformvm/txs/executor/backend.go index 4f7ac74cd9c8..596012918b20 100644 --- a/vms/platformvm/txs/executor/backend.go +++ b/vms/platformvm/txs/executor/backend.go @@ -22,5 +22,5 @@ type Backend struct { FlowChecker utxo.Verifier Uptimes uptime.Manager Rewards reward.Calculator - Bootstrapped *utils.AtomicBool + Bootstrapped *utils.Atomic[bool] } diff --git a/vms/platformvm/txs/executor/helpers_test.go b/vms/platformvm/txs/executor/helpers_test.go index edbc017a8f46..5451558c4cf2 100644 --- a/vms/platformvm/txs/executor/helpers_test.go +++ b/vms/platformvm/txs/executor/helpers_test.go @@ -84,7 +84,7 @@ type mutableSharedMemory struct { } type environment struct { - isBootstrapped *utils.AtomicBool + isBootstrapped *utils.Atomic[bool] config *config.Config clk *mockable.Clock baseDB *versiondb.Database @@ -113,8 +113,8 @@ func (e *environment) SetState(blkID ids.ID, chainState state.Chain) { } func newEnvironment(postBanff bool) *environment { - var isBootstrapped utils.AtomicBool - isBootstrapped.SetValue(true) + var isBootstrapped utils.Atomic[bool] + isBootstrapped.Set(true) config := defaultConfig(postBanff) clk := defaultClock(postBanff) @@ -123,7 +123,7 @@ func newEnvironment(postBanff bool) *environment { baseDB := versiondb.New(baseDBManager.Current().Database) ctx, msm := defaultCtx(baseDB) - fx := defaultFx(&clk, ctx.Log, isBootstrapped.GetValue()) + fx := defaultFx(&clk, ctx.Log, isBootstrapped.Get()) rewards := reward.NewCalculator(config.RewardConfig) baseState := defaultState(&config, ctx, baseDB, rewards) @@ -426,7 +426,7 @@ func buildGenesisTest(ctx *snow.Context) []byte { } func shutdownEnvironment(env *environment) error { - if env.isBootstrapped.GetValue() { + if env.isBootstrapped.Get() { primaryValidatorSet, exist := env.config.Validators.Get(constants.PrimaryNetworkID) if !exist { return errMissingPrimaryValidators diff --git a/vms/platformvm/txs/executor/staker_tx_verification.go b/vms/platformvm/txs/executor/staker_tx_verification.go index 635a3ad6ef4e..2a9f12934cb8 100644 --- a/vms/platformvm/txs/executor/staker_tx_verification.go +++ b/vms/platformvm/txs/executor/staker_tx_verification.go @@ -84,7 +84,7 @@ func verifyAddValidatorTx( copy(outs, tx.Outs) copy(outs[len(tx.Outs):], tx.StakeOuts) - if !backend.Bootstrapped.GetValue() { + if !backend.Bootstrapped.Get() { return outs, nil } @@ -163,7 +163,7 @@ func verifyAddSubnetValidatorTx( return errStakeTooLong } - if !backend.Bootstrapped.GetValue() { + if !backend.Bootstrapped.Get() { return nil } @@ -279,7 +279,7 @@ func removeSubnetValidatorValidation( return nil, false, errRemovePermissionlessValidator } - if !backend.Bootstrapped.GetValue() { + if !backend.Bootstrapped.Get() { // Not bootstrapped yet -- don't need to do full verification. return vdr, isCurrentValidator, nil } @@ -342,7 +342,7 @@ func verifyAddDelegatorTx( copy(outs, tx.Outs) copy(outs[len(tx.Outs):], tx.StakeOuts) - if !backend.Bootstrapped.GetValue() { + if !backend.Bootstrapped.Get() { return outs, nil } @@ -427,7 +427,7 @@ func verifyAddPermissionlessValidatorTx( return err } - if !backend.Bootstrapped.GetValue() { + if !backend.Bootstrapped.Get() { return nil } @@ -606,7 +606,7 @@ func verifyAddPermissionlessDelegatorTx( return err } - if !backend.Bootstrapped.GetValue() { + if !backend.Bootstrapped.Get() { return nil } diff --git a/vms/platformvm/txs/executor/staker_tx_verification_test.go b/vms/platformvm/txs/executor/staker_tx_verification_test.go index bf813a6dcad8..350d0cea28bd 100644 --- a/vms/platformvm/txs/executor/staker_tx_verification_test.go +++ b/vms/platformvm/txs/executor/staker_tx_verification_test.go @@ -118,7 +118,7 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { backendF: func(*gomock.Controller) *Backend { return &Backend{ Ctx: snow.DefaultContextTest(), - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, } }, stateF: func(ctrl *gomock.Controller) state.Chain { @@ -135,8 +135,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "start time too early", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -158,8 +158,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "weight too low", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -184,8 +184,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "weight too high", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -210,8 +210,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "insufficient delegation fee", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -237,8 +237,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "duration too short", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -267,8 +267,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "duration too long", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -297,8 +297,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "wrong assetID", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -329,8 +329,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "duplicate validator", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -355,8 +355,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "validator not subset of primary network validator", backendF: func(*gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) return &Backend{ Ctx: snow.DefaultContextTest(), Bootstrapped: bootstrapped, @@ -387,8 +387,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "flow check fails", backendF: func(ctrl *gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) flowChecker := utxo.NewMockVerifier(ctrl) flowChecker.EXPECT().VerifySpend( @@ -433,8 +433,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "starts too far in the future", backendF: func(ctrl *gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) flowChecker := utxo.NewMockVerifier(ctrl) flowChecker.EXPECT().VerifySpend( @@ -483,8 +483,8 @@ func TestVerifyAddPermissionlessValidatorTx(t *testing.T) { { name: "success", backendF: func(ctrl *gomock.Controller) *Backend { - bootstrapped := &utils.AtomicBool{} - bootstrapped.SetValue(true) + bootstrapped := &utils.Atomic[bool]{} + bootstrapped.Set(true) flowChecker := utxo.NewMockVerifier(ctrl) flowChecker.EXPECT().VerifySpend( diff --git a/vms/platformvm/txs/executor/standard_tx_executor.go b/vms/platformvm/txs/executor/standard_tx_executor.go index a78dd79577db..064c23c13764 100644 --- a/vms/platformvm/txs/executor/standard_tx_executor.go +++ b/vms/platformvm/txs/executor/standard_tx_executor.go @@ -136,7 +136,7 @@ func (e *StandardTxExecutor) ImportTx(tx *txs.ImportTx) error { utxoIDs[i] = utxoID[:] } - if e.Bootstrapped.GetValue() { + if e.Bootstrapped.Get() { if err := verify.SameSubnet(context.TODO(), e.Ctx, tx.SourceChain); err != nil { return err } @@ -204,7 +204,7 @@ func (e *StandardTxExecutor) ExportTx(tx *txs.ExportTx) error { copy(outs, tx.Outs) copy(outs[len(tx.Outs):], tx.ExportedOutputs) - if e.Bootstrapped.GetValue() { + if e.Bootstrapped.Get() { if err := verify.SameSubnet(context.TODO(), e.Ctx, tx.DestinationChain); err != nil { return err } diff --git a/vms/platformvm/txs/executor/standard_tx_executor_test.go b/vms/platformvm/txs/executor/standard_tx_executor_test.go index 68d9694b0c05..b85a9b31c925 100644 --- a/vms/platformvm/txs/executor/standard_tx_executor_test.go +++ b/vms/platformvm/txs/executor/standard_tx_executor_test.go @@ -1103,7 +1103,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1111,7 +1111,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: false, @@ -1128,7 +1128,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1136,7 +1136,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1153,7 +1153,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1161,7 +1161,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1182,7 +1182,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1190,7 +1190,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1209,7 +1209,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1217,7 +1217,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1235,7 +1235,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1243,7 +1243,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1268,7 +1268,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1276,7 +1276,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1304,7 +1304,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1312,7 +1312,7 @@ func TestStandardExecutorRemoveSubnetValidatorTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, shouldErr: true, @@ -1468,7 +1468,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1476,7 +1476,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, err: txs.ErrNilTx, @@ -1492,7 +1492,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Config: &config.Config{ BanffTime: env.banffTime, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1500,7 +1500,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, err: errMaxStakeDurationTooLarge, @@ -1518,7 +1518,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { BanffTime: env.banffTime, MaxStakeDuration: math.MaxInt64, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1526,7 +1526,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, err: errWrongNumberOfCredentials, @@ -1554,7 +1554,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { BanffTime: env.banffTime, MaxStakeDuration: math.MaxInt64, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1562,7 +1562,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, err: errFlowCheckFailed, @@ -1595,7 +1595,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { BanffTime: env.banffTime, MaxStakeDuration: math.MaxInt64, }, - Bootstrapped: &utils.AtomicBool{}, + Bootstrapped: &utils.Atomic[bool]{}, Fx: env.fx, FlowChecker: env.flowChecker, Ctx: &snow.Context{}, @@ -1603,7 +1603,7 @@ func TestStandardExecutorTransformSubnetTx(t *testing.T) { Tx: env.tx, State: env.state, } - e.Bootstrapped.SetValue(true) + e.Bootstrapped.Set(true) return env.unsignedTx, e }, err: nil, diff --git a/vms/platformvm/txs/executor/state_changes.go b/vms/platformvm/txs/executor/state_changes.go index d45f19388ea7..f87f37d0c809 100644 --- a/vms/platformvm/txs/executor/state_changes.go +++ b/vms/platformvm/txs/executor/state_changes.go @@ -119,6 +119,16 @@ func AdvanceTimeTo( // Add to the staker set any pending stakers whose start time is at or // before the new timestamp + + // Note: we process pending stakers ready to be promoted to current ones and + // then we process current stakers to be demoted out of stakers set. It is + // guaranteed that no promoted stakers would be demoted immediately. A + // failure of this invariant would cause a staker to be added to + // StateChanges and be persisted among current stakers even if it already + // expired. The following invariants ensure this does not happens: + // Invariant: minimum stake duration is > 0, so staker.StartTime != staker.EndTime. + // Invariant: [newChainTime] does not skip stakers set change times. + for pendingStakerIterator.Next() { stakerToRemove := pendingStakerIterator.Value() if stakerToRemove.StartTime.After(newChainTime) { @@ -130,12 +140,6 @@ func AdvanceTimeTo( stakerToAdd.Priority = txs.PendingToCurrentPriorities[stakerToRemove.Priority] if stakerToRemove.Priority == txs.SubnetPermissionedValidatorPendingPriority { - // Invariant: [txTimestamp] <= [nextStakerChangeTime]. - // Invariant: minimum stake duration is > 0. - // - // Both of the above invariants ensure the staker we are adding here - // should never be attempted to be removed in the following loop. - changes.currentValidatorsToAdd = append(changes.currentValidatorsToAdd, &stakerToAdd) changes.pendingValidatorsToRemove = append(changes.pendingValidatorsToRemove, stakerToRemove) continue diff --git a/vms/platformvm/txs/mempool/mempool.go b/vms/platformvm/txs/mempool/mempool.go index 22a1971fc195..8434e1c424c6 100644 --- a/vms/platformvm/txs/mempool/mempool.go +++ b/vms/platformvm/txs/mempool/mempool.go @@ -91,7 +91,7 @@ type mempool struct { // Key: Tx ID // Value: String repr. of the verification error - droppedTxIDs *cache.LRU + droppedTxIDs *cache.LRU[ids.ID, string] consumedUTXOs set.Set[ids.ID] @@ -136,7 +136,7 @@ func NewMempool( bytesAvailable: maxMempoolSize, unissuedDecisionTxs: unissuedDecisionTxs, unissuedStakerTxs: unissuedStakerTxs, - droppedTxIDs: &cache.LRU{Size: droppedTxIDsCacheSize}, + droppedTxIDs: &cache.LRU[ids.ID, string]{Size: droppedTxIDsCacheSize}, consumedUTXOs: set.NewSet[ids.ID](initialConsumedUTXOsSize), dropIncoming: false, // enable tx adding by default blkTimer: blkTimer, @@ -284,7 +284,7 @@ func (m *mempool) GetDropReason(txID ids.ID) (string, bool) { if !exist { return "", false } - return reason.(string), true + return reason, true } func (m *mempool) register(tx *txs.Tx) { diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index cc308bb6fce3..7b9a82ea1339 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -66,7 +66,6 @@ var ( _ validators.State = (*VM)(nil) _ validators.SubnetConnector = (*VM)(nil) - errWrongCacheType = errors.New("unexpectedly cached type") errMissingValidatorSet = errors.New("missing validator set") errMissingValidator = errors.New("missing validator") ) @@ -93,12 +92,12 @@ type VM struct { codecRegistry codec.Registry // Bootstrapped remembers if this chain has finished bootstrapping or not - bootstrapped utils.AtomicBool + bootstrapped utils.Atomic[bool] // Maps caches for each subnet that is currently tracked. // Key: Subnet ID // Value: cache mapping height -> validator set map - validatorSetCaches map[ids.ID]cache.Cacher + validatorSetCaches map[ids.ID]cache.Cacher[uint64, map[ids.NodeID]*validators.GetValidatorOutput] // sliding window of blocks that were recently accepted recentlyAccepted window.Window[ids.ID] @@ -144,7 +143,7 @@ func (vm *VM) Initialize( return err } - vm.validatorSetCaches = make(map[ids.ID]cache.Cacher) + vm.validatorSetCaches = make(map[ids.ID]cache.Cacher[uint64, map[ids.NodeID]*validators.GetValidatorOutput]) vm.recentlyAccepted = window.New[ids.ID]( window.Config{ Clock: &vm.clock, @@ -275,16 +274,16 @@ func (vm *VM) createSubnet(subnetID ids.ID) error { // onBootstrapStarted marks this VM as bootstrapping func (vm *VM) onBootstrapStarted() error { - vm.bootstrapped.SetValue(false) + vm.bootstrapped.Set(false) return vm.fx.Bootstrapping() } // onNormalOperationsStarted marks this VM as bootstrapped func (vm *VM) onNormalOperationsStarted() error { - if vm.bootstrapped.GetValue() { + if vm.bootstrapped.Get() { return nil } - vm.bootstrapped.SetValue(true) + vm.bootstrapped.Set(true) if err := vm.fx.Bootstrapped(); err != nil { return err @@ -336,7 +335,7 @@ func (vm *VM) Shutdown(context.Context) error { vm.Builder.Shutdown() - if vm.bootstrapped.GetValue() { + if vm.bootstrapped.Get() { primaryVdrIDs, exists := vm.getValidatorIDs(constants.PrimaryNetworkID) if !exists { return errMissingValidatorSet @@ -425,6 +424,9 @@ func (vm *VM) CreateHandlers(context.Context) (map[string]*common.HTTPHandler, e &Service{ vm: vm, addrManager: avax.NewAddressManager(vm.ctx), + stakerAttributesCache: &cache.LRU[ids.ID, *stakerAttributes]{ + Size: stakerAttributesCacheSize, + }, }, "platform", ); err != nil { @@ -477,18 +479,14 @@ func (vm *VM) Disconnected(_ context.Context, nodeID ids.NodeID) error { func (vm *VM) GetValidatorSet(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { validatorSetsCache, exists := vm.validatorSetCaches[subnetID] if !exists { - validatorSetsCache = &cache.LRU{Size: validatorSetsCacheSize} + validatorSetsCache = &cache.LRU[uint64, map[ids.NodeID]*validators.GetValidatorOutput]{Size: validatorSetsCacheSize} // Only cache tracked subnets if subnetID == constants.PrimaryNetworkID || vm.TrackedSubnets.Contains(subnetID) { vm.validatorSetCaches[subnetID] = validatorSetsCache } } - if validatorSetIntf, ok := validatorSetsCache.Get(height); ok { - validatorSet, ok := validatorSetIntf.(map[ids.NodeID]*validators.GetValidatorOutput) - if !ok { - return nil, errWrongCacheType - } + if validatorSet, ok := validatorSetsCache.Get(height); ok { vm.metrics.IncValidatorSetsCached() return validatorSet, nil } diff --git a/vms/platformvm/vm_regression_test.go b/vms/platformvm/vm_regression_test.go index ac4c5a8279e8..88c60a430d10 100644 --- a/vms/platformvm/vm_regression_test.go +++ b/vms/platformvm/vm_regression_test.go @@ -1252,3 +1252,239 @@ func TestAddDelegatorTxAddBeforeRemove(t *testing.T) { // total stake weight would go over the limit. require.Error(vm.Builder.AddUnverifiedTx(addSecondDelegatorTx)) } + +func TestRemovePermissionedValidatorDuringPendingToCurrentTransitionNotTracked(t *testing.T) { + require := require.New(t) + + validatorStartTime := banffForkTime.Add(txexecutor.SyncBound).Add(1 * time.Second) + validatorEndTime := validatorStartTime.Add(360 * 24 * time.Hour) + + vm, _, _ := defaultVM() + + vm.ctx.Lock.Lock() + defer func() { + err := vm.Shutdown(context.Background()) + require.NoError(err) + + vm.ctx.Lock.Unlock() + }() + + key, err := testKeyFactory.NewPrivateKey() + require.NoError(err) + + id := key.PublicKey().Address() + changeAddr := keys[0].PublicKey().Address() + + addValidatorTx, err := vm.txBuilder.NewAddValidatorTx( + defaultMaxValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + ids.NodeID(id), + id, + reward.PercentDenominator, + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(addValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + addValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(addValidatorBlock.Verify(context.Background())) + require.NoError(addValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + createSubnetTx, err := vm.txBuilder.NewCreateSubnetTx( + 1, + []ids.ShortID{changeAddr}, + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(createSubnetTx) + require.NoError(err) + + // trigger block creation for the subnet tx + createSubnetBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(createSubnetBlock.Verify(context.Background())) + require.NoError(createSubnetBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + addSubnetValidatorTx, err := vm.txBuilder.NewAddSubnetValidatorTx( + defaultMaxValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + ids.NodeID(id), + createSubnetTx.ID(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(addSubnetValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + addSubnetValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(addSubnetValidatorBlock.Verify(context.Background())) + require.NoError(addSubnetValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + emptyValidatorSet, err := vm.GetValidatorSet( + context.Background(), + addSubnetValidatorBlock.Height(), + createSubnetTx.ID(), + ) + require.NoError(err) + require.Empty(emptyValidatorSet) + + removeSubnetValidatorTx, err := vm.txBuilder.NewRemoveSubnetValidatorTx( + ids.NodeID(id), + createSubnetTx.ID(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + // Set the clock so that the validator will be moved from the pending + // validator set into the current validator set. + vm.clock.Set(validatorStartTime) + + err = vm.Builder.AddUnverifiedTx(removeSubnetValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + removeSubnetValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(removeSubnetValidatorBlock.Verify(context.Background())) + require.NoError(removeSubnetValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + emptyValidatorSet, err = vm.GetValidatorSet( + context.Background(), + addSubnetValidatorBlock.Height(), + createSubnetTx.ID(), + ) + require.NoError(err) + require.Empty(emptyValidatorSet) +} + +func TestRemovePermissionedValidatorDuringPendingToCurrentTransitionTracked(t *testing.T) { + require := require.New(t) + + validatorStartTime := banffForkTime.Add(txexecutor.SyncBound).Add(1 * time.Second) + validatorEndTime := validatorStartTime.Add(360 * 24 * time.Hour) + + vm, _, _ := defaultVM() + + vm.ctx.Lock.Lock() + defer func() { + err := vm.Shutdown(context.Background()) + require.NoError(err) + + vm.ctx.Lock.Unlock() + }() + + key, err := testKeyFactory.NewPrivateKey() + require.NoError(err) + + id := key.PublicKey().Address() + changeAddr := keys[0].PublicKey().Address() + + addValidatorTx, err := vm.txBuilder.NewAddValidatorTx( + defaultMaxValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + ids.NodeID(id), + id, + reward.PercentDenominator, + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(addValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + addValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(addValidatorBlock.Verify(context.Background())) + require.NoError(addValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + createSubnetTx, err := vm.txBuilder.NewCreateSubnetTx( + 1, + []ids.ShortID{changeAddr}, + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(createSubnetTx) + require.NoError(err) + + // trigger block creation for the subnet tx + createSubnetBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(createSubnetBlock.Verify(context.Background())) + require.NoError(createSubnetBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + vm.TrackedSubnets.Add(createSubnetTx.ID()) + subnetValidators := validators.NewSet() + err = vm.state.ValidatorSet(createSubnetTx.ID(), subnetValidators) + require.NoError(err) + + added := vm.Validators.Add(createSubnetTx.ID(), subnetValidators) + require.True(added) + + addSubnetValidatorTx, err := vm.txBuilder.NewAddSubnetValidatorTx( + defaultMaxValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + ids.NodeID(id), + createSubnetTx.ID(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(addSubnetValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + addSubnetValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(addSubnetValidatorBlock.Verify(context.Background())) + require.NoError(addSubnetValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + removeSubnetValidatorTx, err := vm.txBuilder.NewRemoveSubnetValidatorTx( + ids.NodeID(id), + createSubnetTx.ID(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + changeAddr, + ) + require.NoError(err) + + // Set the clock so that the validator will be moved from the pending + // validator set into the current validator set. + vm.clock.Set(validatorStartTime) + + err = vm.Builder.AddUnverifiedTx(removeSubnetValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + removeSubnetValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(removeSubnetValidatorBlock.Verify(context.Background())) + require.NoError(removeSubnetValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) +} diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index 5ccb38a2275f..75a908077b05 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -1650,7 +1650,6 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { consensusCtx := snow.DefaultConsensusContextTest() consensusCtx.Context = ctx - consensusCtx.SetState(snow.Initializing) ctx.Lock.Lock() msgChan := make(chan common.Message, 1) @@ -1838,7 +1837,6 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { msgChan, nil, time.Hour, - p2p.EngineType_ENGINE_TYPE_SNOWMAN, cpuTracker, vm, ) @@ -2828,3 +2826,106 @@ func copySubnetValidator(vdr *validators.Validator) *validators.Validator { newVdr.PublicKey = nil return &newVdr } + +func TestRemovePermissionedValidatorDuringAddPending(t *testing.T) { + require := require.New(t) + + validatorStartTime := banffForkTime.Add(txexecutor.SyncBound).Add(1 * time.Second) + validatorEndTime := validatorStartTime.Add(360 * 24 * time.Hour) + + vm, _, _ := defaultVM() + + vm.ctx.Lock.Lock() + defer func() { + err := vm.Shutdown(context.Background()) + require.NoError(err) + + vm.ctx.Lock.Unlock() + }() + + keyIntf, err := testKeyFactory.NewPrivateKey() + require.NoError(err) + key := keyIntf.(*crypto.PrivateKeySECP256K1R) + + id := key.PublicKey().Address() + + addValidatorTx, err := vm.txBuilder.NewAddValidatorTx( + defaultMaxValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + ids.NodeID(id), + id, + reward.PercentDenominator, + []*crypto.PrivateKeySECP256K1R{keys[0]}, + keys[0].Address(), + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(addValidatorTx) + require.NoError(err) + + // trigger block creation for the validator tx + addValidatorBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(addValidatorBlock.Verify(context.Background())) + require.NoError(addValidatorBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + createSubnetTx, err := vm.txBuilder.NewCreateSubnetTx( + 1, + []ids.ShortID{id}, + []*crypto.PrivateKeySECP256K1R{keys[0]}, + keys[0].Address(), + ) + require.NoError(err) + + err = vm.Builder.AddUnverifiedTx(createSubnetTx) + require.NoError(err) + + // trigger block creation for the subnet tx + createSubnetBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(createSubnetBlock.Verify(context.Background())) + require.NoError(createSubnetBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + addSubnetValidatorTx, err := vm.txBuilder.NewAddSubnetValidatorTx( + defaultMaxValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + ids.NodeID(id), + createSubnetTx.ID(), + []*crypto.PrivateKeySECP256K1R{key, keys[1]}, + keys[1].Address(), + ) + require.NoError(err) + + removeSubnetValidatorTx, err := vm.txBuilder.NewRemoveSubnetValidatorTx( + ids.NodeID(id), + createSubnetTx.ID(), + []*crypto.PrivateKeySECP256K1R{key, keys[2]}, + keys[2].Address(), + ) + require.NoError(err) + + statelessBlock, err := blocks.NewBanffStandardBlock( + vm.state.GetTimestamp(), + createSubnetBlock.ID(), + createSubnetBlock.Height()+1, + []*txs.Tx{ + addSubnetValidatorTx, + removeSubnetValidatorTx, + }, + ) + require.NoError(err) + + blockBytes := statelessBlock.Bytes() + block, err := vm.ParseBlock(context.Background(), blockBytes) + require.NoError(err) + require.NoError(block.Verify(context.Background())) + require.NoError(block.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + _, err = vm.state.GetPendingValidator(createSubnetTx.ID(), ids.NodeID(id)) + require.ErrorIs(err, database.ErrNotFound) +} diff --git a/vms/proposervm/indexer/height_indexer.go b/vms/proposervm/indexer/height_indexer.go index 4ad35caa2f48..cbbf868e868c 100644 --- a/vms/proposervm/indexer/height_indexer.go +++ b/vms/proposervm/indexer/height_indexer.go @@ -63,18 +63,18 @@ type heightIndexer struct { server BlockServer log logging.Logger - jobDone utils.AtomicBool + jobDone utils.Atomic[bool] state state.State commitFrequency int } func (hi *heightIndexer) IsRepaired() bool { - return hi.jobDone.GetValue() + return hi.jobDone.Get() } func (hi *heightIndexer) MarkRepaired(repaired bool) { - hi.jobDone.SetValue(repaired) + hi.jobDone.Set(repaired) } // RepairHeightIndex ensures the height -> proBlkID height block index is well formed. diff --git a/vms/proposervm/state/block_height_index.go b/vms/proposervm/state/block_height_index.go index 7dfac36c40be..ee849ef931b1 100644 --- a/vms/proposervm/state/block_height_index.go +++ b/vms/proposervm/state/block_height_index.go @@ -59,7 +59,7 @@ type heightIndex struct { versiondb.Commitable // Caches block height -> proposerVMBlockID. - heightsCache cache.Cacher + heightsCache cache.Cacher[uint64, ids.ID] heightDB database.Database metadataDB database.Database @@ -69,16 +69,15 @@ func NewHeightIndex(db database.Database, commitable versiondb.Commitable) Heigh return &heightIndex{ Commitable: commitable, - heightsCache: &cache.LRU{Size: cacheSize}, + heightsCache: &cache.LRU[uint64, ids.ID]{Size: cacheSize}, heightDB: prefixdb.New(heightPrefix, db), metadataDB: prefixdb.New(metadataPrefix, db), } } func (hi *heightIndex) GetBlockIDAtHeight(height uint64) (ids.ID, error) { - if blkIDIntf, found := hi.heightsCache.Get(height); found { - res, _ := blkIDIntf.(ids.ID) - return res, nil + if blkID, found := hi.heightsCache.Get(height); found { + return blkID, nil } key := database.PackUInt64(height) diff --git a/vms/proposervm/state/block_state.go b/vms/proposervm/state/block_state.go index f98ae0dfb237..b0ea8d31bbb9 100644 --- a/vms/proposervm/state/block_state.go +++ b/vms/proposervm/state/block_state.go @@ -33,7 +33,7 @@ type BlockState interface { type blockState struct { // Caches BlockID -> Block. If the Block is nil, that means the block is not // in storage. - blkCache cache.Cacher + blkCache cache.Cacher[ids.ID, *blockWrapper] db database.Database } @@ -47,16 +47,16 @@ type blockWrapper struct { func NewBlockState(db database.Database) BlockState { return &blockState{ - blkCache: &cache.LRU{Size: blockCacheSize}, + blkCache: &cache.LRU[ids.ID, *blockWrapper]{Size: blockCacheSize}, db: db, } } func NewMeteredBlockState(db database.Database, namespace string, metrics prometheus.Registerer) (BlockState, error) { - blkCache, err := metercacher.New( + blkCache, err := metercacher.New[ids.ID, *blockWrapper]( fmt.Sprintf("%s_block_cache", namespace), metrics, - &cache.LRU{Size: blockCacheSize}, + &cache.LRU[ids.ID, *blockWrapper]{Size: blockCacheSize}, ) return &blockState{ @@ -66,12 +66,8 @@ func NewMeteredBlockState(db database.Database, namespace string, metrics promet } func (s *blockState) GetBlock(blkID ids.ID) (block.Block, choices.Status, error) { - if blkIntf, found := s.blkCache.Get(blkID); found { - if blkIntf == nil { - return nil, choices.Unknown, database.ErrNotFound - } - blk, ok := blkIntf.(*blockWrapper) - if !ok { + if blk, found := s.blkCache.Get(blkID); found { + if blk == nil { return nil, choices.Unknown, database.ErrNotFound } return blk.block, blk.Status, nil diff --git a/vms/proposervm/vm.go b/vms/proposervm/vm.go index d12421735a8f..e522ba04be2e 100644 --- a/vms/proposervm/vm.go +++ b/vms/proposervm/vm.go @@ -92,7 +92,7 @@ type VM struct { // Only contains post-fork blocks near the tip so that the cache doesn't get // filled with random blocks every time this node parses blocks while // processing a GetAncestors message from a bootstrapping node. - innerBlkCache cache.Cacher + innerBlkCache cache.Cacher[ids.ID, snowman.Block] preferred ids.ID consensusState snow.State context context.Context @@ -171,10 +171,10 @@ func (vm *VM) Initialize( vm.State = state.New(vm.db) vm.Windower = proposer.New(chainCtx.ValidatorState, chainCtx.SubnetID, chainCtx.ChainID) vm.Tree = tree.New() - innerBlkCache, err := metercacher.New( + innerBlkCache, err := metercacher.New[ids.ID, snowman.Block]( "inner_block_cache", registerer, - &cache.LRU{Size: innerBlkCacheSize}, + &cache.LRU[ids.ID, snowman.Block]{Size: innerBlkCacheSize}, ) if err != nil { return err @@ -789,8 +789,8 @@ func (vm *VM) optimalPChainHeight(ctx context.Context, minPChainHeight uint64) ( // the inner block happens to be cached, then the inner block will not be // parsed. func (vm *VM) parseInnerBlock(ctx context.Context, outerBlkID ids.ID, innerBlkBytes []byte) (snowman.Block, error) { - if innerBlkIntf, ok := vm.innerBlkCache.Get(outerBlkID); ok { - return innerBlkIntf.(snowman.Block), nil + if innerBlk, ok := vm.innerBlkCache.Get(outerBlkID); ok { + return innerBlk, nil } innerBlk, err := vm.ChainVM.ParseBlock(ctx, innerBlkBytes) diff --git a/vms/proposervm/vm_test.go b/vms/proposervm/vm_test.go index c49dbdda4cdf..fe0baccb598b 100644 --- a/vms/proposervm/vm_test.go +++ b/vms/proposervm/vm_test.go @@ -2502,9 +2502,8 @@ func TestVMInnerBlkCacheDeduplicationRegression(t *testing.T) { bBlock.(*postForkBlock).innerBlk.Status(), ) - xBlockIntf, ok := proVM.innerBlkCache.Get(bBlock.ID()) + cachedXBlock, ok := proVM.innerBlkCache.Get(bBlock.ID()) require.True(ok) - cachedXBlock := xBlockIntf.(snowman.Block) require.Equal( choices.Accepted, cachedXBlock.Status(), diff --git a/vms/secp256k1fx/fx.go b/vms/secp256k1fx/fx.go index 6d27cd76fc85..eadd28169bac 100644 --- a/vms/secp256k1fx/fx.go +++ b/vms/secp256k1fx/fx.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/ava-labs/avalanchego/cache" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -52,7 +53,7 @@ func (fx *Fx) Initialize(vmIntf interface{}) error { log.Debug("initializing secp256k1 fx") fx.SECPFactory = crypto.FactorySECP256K1R{ - Cache: cache.LRU{Size: defaultCacheSize}, + Cache: cache.LRU[ids.ID, *crypto.PublicKeySECP256K1R]{Size: defaultCacheSize}, } c := fx.VM.CodecRegistry() errs := wrappers.Errs{}