diff --git a/RELEASES.md b/RELEASES.md index 1902cc7977d4..c9e529b6ef86 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,25 @@ # Release Notes +## [v1.9.11](https://github.com/ava-labs/avalanchego/releases/tag/v1.9.11) + +This version is backwards compatible to [v1.9.0](https://github.com/ava-labs/avalanchego/releases/tag/v1.9.0). It is optional, but encouraged. The supported plugin version is `24`. + +### Plugins + +- Removed error from `logging.NoLog#Write` +- Added logging to the static VM factory usage +- Fixed incorrect error being returned from `subprocess.Bootstrap` + +### Ledger + +- Added ledger tx parsing support + +### MerkleDB + +- Added explicit consistency guarantees when committing multiple `merkledb.trieView`s to disk at once +- Removed reliance on premature root calculations for `merkledb.trieView` validity tracking +- Updated `x/merkledb/README.md` + ## [v1.9.10](https://github.com/ava-labs/avalanchego/releases/tag/v1.9.10) This version is backwards compatible to [v1.9.0](https://github.com/ava-labs/avalanchego/releases/tag/v1.9.0). It is optional, but encouraged. The supported plugin version is `24`. diff --git a/config/config.go b/config/config.go index e7f891f7a16e..2b9ea7b8139e 100644 --- a/config/config.go +++ b/config/config.go @@ -47,7 +47,6 @@ import ( "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/storage" "github.com/ava-labs/avalanchego/utils/timer" - "github.com/ava-labs/avalanchego/vms" "github.com/ava-labs/avalanchego/vms/platformvm/reward" "github.com/ava-labs/avalanchego/vms/proposervm" ) @@ -925,21 +924,21 @@ func getChainAliases(v *viper.Viper) (map[ids.ID][]string, error) { return getAliases(v, "chain aliases", ChainAliasesContentKey, ChainAliasesFileKey) } -func getVMManager(v *viper.Viper) (vms.Manager, error) { +func getVMAliaser(v *viper.Viper) (ids.Aliaser, error) { vmAliases, err := getVMAliases(v) if err != nil { return nil, err } - manager := vms.NewManager() + aliser := ids.NewAliaser() for vmID, aliases := range vmAliases { for _, alias := range aliases { - if err := manager.Alias(vmID, alias); err != nil { + if err := aliser.Alias(vmID, alias); err != nil { return nil, err } } } - return manager, nil + return aliser, nil } // getPathFromDirKey reads flag value from viper instance and then checks the folder existence @@ -1401,7 +1400,7 @@ func GetNodeConfig(v *viper.Viper) (node.Config, error) { } // VM Aliases - nodeConfig.VMManager, err = getVMManager(v) + nodeConfig.VMAliaser, err = getVMAliaser(v) if err != nil { return node.Config{}, err } diff --git a/node/config.go b/node/config.go index a9e52bc74e8c..9040c4c4559d 100644 --- a/node/config.go +++ b/node/config.go @@ -24,7 +24,6 @@ import ( "github.com/ava-labs/avalanchego/utils/profiler" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/timer" - "github.com/ava-labs/avalanchego/vms" ) type IPCConfig struct { @@ -186,7 +185,7 @@ type Config struct { ChainConfigs map[string]chains.ChainConfig `json:"-"` ChainAliases map[ids.ID][]string `json:"chainAliases"` - VMManager vms.Manager `json:"-"` + VMAliaser ids.Aliaser `json:"-"` // Halflife to use for the processing requests tracker. // Larger halflife --> usage metrics change more slowly. diff --git a/node/node.go b/node/node.go index bb433619e057..166f1e109525 100644 --- a/node/node.go +++ b/node/node.go @@ -70,6 +70,7 @@ import ( "github.com/ava-labs/avalanchego/utils/timer" "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/version" + "github.com/ava-labs/avalanchego/vms" "github.com/ava-labs/avalanchego/vms/avm" "github.com/ava-labs/avalanchego/vms/nftfx" "github.com/ava-labs/avalanchego/vms/platformvm" @@ -94,8 +95,9 @@ var ( // Node is an instance of an Avalanche node. type Node struct { - Log logging.Logger - LogFactory logging.Factory + Log logging.Logger + VMFactoryLog logging.Logger + LogFactory logging.Factory // This node's unique ID used when communicating with other nodes // (in consensus, for example) @@ -176,6 +178,8 @@ type Node struct { MetricsRegisterer *prometheus.Registry MetricsGatherer metrics.MultiGatherer + VMManager vms.Manager + // VM endpoint registry VMRegistry registry.VMRegistry @@ -627,7 +631,7 @@ func (n *Node) addDefaultVMAliases() error { for vmID, aliases := range vmAliases { for _, alias := range aliases { - if err := n.Config.VMManager.Alias(vmID, alias); err != nil { + if err := n.Config.VMAliaser.Alias(vmID, alias); err != nil { return err } } @@ -695,7 +699,7 @@ func (n *Node) initChainManager(avaxAssetID ids.ID) error { StakingBLSKey: n.Config.StakingSigningKey, Log: n.Log, LogFactory: n.LogFactory, - VMManager: n.Config.VMManager, + VMManager: n.VMManager, DecisionAcceptorGroup: n.DecisionAcceptorGroup, ConsensusAcceptorGroup: n.ConsensusAcceptorGroup, DBManager: n.DBManager, @@ -755,9 +759,10 @@ func (n *Node) initVMs() error { } vmRegisterer := registry.NewVMRegisterer(registry.VMRegistererConfig{ - APIServer: n.APIServer, - Log: n.Log, - VMManager: n.Config.VMManager, + APIServer: n.APIServer, + Log: n.Log, + VMFactoryLog: n.VMFactoryLog, + VMManager: n.VMManager, }) // Register the VMs that Avalanche supports @@ -801,9 +806,9 @@ func (n *Node) initVMs() error { }, }), vmRegisterer.Register(context.TODO(), constants.EVMID, &coreth.Factory{}), - n.Config.VMManager.RegisterFactory(context.TODO(), secp256k1fx.ID, &secp256k1fx.Factory{}), - n.Config.VMManager.RegisterFactory(context.TODO(), nftfx.ID, &nftfx.Factory{}), - n.Config.VMManager.RegisterFactory(context.TODO(), propertyfx.ID, &propertyfx.Factory{}), + n.VMManager.RegisterFactory(context.TODO(), secp256k1fx.ID, &secp256k1fx.Factory{}), + n.VMManager.RegisterFactory(context.TODO(), nftfx.ID, &nftfx.Factory{}), + n.VMManager.RegisterFactory(context.TODO(), propertyfx.ID, &propertyfx.Factory{}), ) if errs.Errored() { return errs.Err @@ -816,7 +821,7 @@ func (n *Node) initVMs() error { n.VMRegistry = registry.NewVMRegistry(registry.VMRegistryConfig{ VMGetter: registry.NewVMGetter(registry.VMGetterConfig{ FileReader: filesystem.NewReader(), - Manager: n.Config.VMManager, + Manager: n.VMManager, PluginDirectory: n.Config.PluginDir, CPUTracker: n.resourceManager, RuntimeTracker: n.runtimeManager, @@ -920,7 +925,7 @@ func (n *Node) initAdminAPI() error { ProfileDir: n.Config.ProfilerConfig.Dir, LogFactory: n.LogFactory, NodeConfig: n.Config, - VMManager: n.Config.VMManager, + VMManager: n.VMManager, VMRegistry: n.VMRegistry, }, ) @@ -978,11 +983,11 @@ func (n *Node) initInfoAPI() error { AddPrimaryNetworkDelegatorFee: n.Config.AddPrimaryNetworkDelegatorFee, AddSubnetValidatorFee: n.Config.AddSubnetValidatorFee, AddSubnetDelegatorFee: n.Config.AddSubnetDelegatorFee, - VMManager: n.Config.VMManager, + VMManager: n.VMManager, }, n.Log, n.chainManager, - n.Config.VMManager, + n.VMManager, n.Config.NetworkConfig.MyIPPort, n.Net, primaryValidators, @@ -1235,12 +1240,19 @@ func (n *Node) Initialize( zap.Reflect("config", n.Config), ) + var err error + n.VMFactoryLog, err = logFactory.Make("vm-factory") + if err != nil { + return fmt.Errorf("problem creating vm logger: %w", err) + } + + n.VMManager = vms.NewManager(n.VMFactoryLog, config.VMAliaser) + if err := n.initBeacons(); err != nil { // Configure the beacons return fmt.Errorf("problem initializing node beacons: %w", err) } // Set up tracer - var err error n.tracer, err = trace.New(n.Config.TraceConfig) if err != nil { return fmt.Errorf("couldn't initialize tracer: %w", err) diff --git a/snow/consensus/snowball/tree.go b/snow/consensus/snowball/tree.go index 2ec00b82327b..834fff6c3540 100644 --- a/snow/consensus/snowball/tree.go +++ b/snow/consensus/snowball/tree.go @@ -179,7 +179,7 @@ func (u *unaryNode) DecidedPrefix() int { return u.decidedPrefix } -//nolint:gofmt,gofmpt,gofumpt,goimports // this comment is formatted as intended +//nolint:gofmt,gofumpt,goimports // this comment is formatted as intended // // This is by far the most complicated function in this algorithm. // The intuition is that this instance represents a series of consecutive unary diff --git a/utils/crypto/keychain/keychain.go b/utils/crypto/keychain/keychain.go index 28ee4ad0e738..3306bf6b129a 100644 --- a/utils/crypto/keychain/keychain.go +++ b/utils/crypto/keychain/keychain.go @@ -25,6 +25,7 @@ var ( // to sign a hash type Signer interface { SignHash([]byte) ([]byte, error) + Sign([]byte) ([]byte, error) Address() ids.ShortID } @@ -120,6 +121,7 @@ func (l *ledgerKeychain) Get(addr ids.ShortID) (Signer, bool) { }, true } +// expects to receive a hash of the unsigned tx bytes func (l *ledgerSigner) SignHash(b []byte) ([]byte, error) { // Sign using the address with index l.idx on the ledger device. The number // of returned signatures should be the same length as the provided indices. @@ -139,6 +141,26 @@ func (l *ledgerSigner) SignHash(b []byte) ([]byte, error) { return sigs[0], err } +// expects to receive the unsigned tx bytes +func (l *ledgerSigner) Sign(b []byte) ([]byte, error) { + // Sign using the address with index l.idx on the ledger device. The number + // of returned signatures should be the same length as the provided indices. + sigs, err := l.ledger.Sign(b, []uint32{l.idx}) + if err != nil { + return nil, err + } + + if sigsLen := len(sigs); sigsLen != 1 { + return nil, fmt.Errorf( + "%w. expected 1, got %d", + ErrInvalidNumSignatures, + sigsLen, + ) + } + + return sigs[0], err +} + func (l *ledgerSigner) Address() ids.ShortID { return l.addr } diff --git a/utils/crypto/keychain/ledger.go b/utils/crypto/keychain/ledger.go index c2d025d0c350..237a2ba42f1a 100644 --- a/utils/crypto/keychain/ledger.go +++ b/utils/crypto/keychain/ledger.go @@ -14,6 +14,6 @@ type Ledger interface { Address(displayHRP string, addressIndex uint32) (ids.ShortID, error) Addresses(addressIndices []uint32) ([]ids.ShortID, error) SignHash(hash []byte, addressIndices []uint32) ([][]byte, error) - // TODO: add SignTransaction + Sign(unsignedTxBytes []byte, addressIndices []uint32) ([][]byte, error) Disconnect() error } diff --git a/utils/crypto/keychain/mock_ledger.go b/utils/crypto/keychain/mock_ledger.go index fb37d6399863..0191e257cdd6 100644 --- a/utils/crypto/keychain/mock_ledger.go +++ b/utils/crypto/keychain/mock_ledger.go @@ -82,6 +82,21 @@ func (mr *MockLedgerMockRecorder) Disconnect() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockLedger)(nil).Disconnect)) } +// Sign mocks base method. +func (m *MockLedger) Sign(arg0 []byte, arg1 []uint32) ([][]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sign", arg0, arg1) + ret0, _ := ret[0].([][]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sign indicates an expected call of Sign. +func (mr *MockLedgerMockRecorder) Sign(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*MockLedger)(nil).Sign), arg0, arg1) +} + // SignHash mocks base method. func (m *MockLedger) SignHash(arg0 []byte, arg1 []uint32) ([][]byte, error) { m.ctrl.T.Helper() diff --git a/utils/crypto/ledger/ledger.go b/utils/crypto/ledger/ledger.go index 8ff6f04361de..c31cb7f7929e 100644 --- a/utils/crypto/ledger/ledger.go +++ b/utils/crypto/ledger/ledger.go @@ -10,10 +10,15 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto/keychain" + "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/version" ) -const rootPath = "m/44'/9000'/0'" +const ( + rootPath = "m/44'/9000'/0'" + ledgerBufferLimit = 8192 + ledgerPathSize = 9 +) var _ keychain.Ledger = (*Ledger)(nil) @@ -79,6 +84,35 @@ func (l *Ledger) SignHash(hash []byte, addressIndices []uint32) ([][]byte, error return responses, nil } +func (l *Ledger) Sign(txBytes []byte, addressIndices []uint32) ([][]byte, error) { + // will pass to the ledger addressIndices both as signing paths and change paths + numSigningPaths := len(addressIndices) + numChangePaths := len(addressIndices) + if len(txBytes)+(numSigningPaths+numChangePaths)*ledgerPathSize > ledgerBufferLimit { + // There is a limit on the tx length that can be parsed by the ledger + // app. When the tx that is being signed is too large, we sign with hash + // instead. + // + // Ref: https://github.com/ava-labs/avalanche-wallet-sdk/blob/9a71f05e424e06b94eaccf21fd32d7983ed1b040/src/Wallet/Ledger/provider/ZondaxProvider.ts#L68 + unsignedHash := hashing.ComputeHash256(txBytes) + return l.SignHash(unsignedHash, addressIndices) + } + strIndices := convertToSigningPaths(addressIndices) + response, err := l.device.Sign(rootPath, strIndices, txBytes, strIndices) + if err != nil { + return nil, fmt.Errorf("%w: unable to sign transaction", err) + } + responses := make([][]byte, len(strIndices)) + for i, index := range strIndices { + sig, ok := response.Signature[index] + if !ok { + return nil, fmt.Errorf("missing signature %s", index) + } + responses[i] = sig + } + return responses, nil +} + func (l *Ledger) Version() (*version.Semantic, error) { resp, err := l.device.GetVersion() if err != nil { diff --git a/utils/logging/test_log.go b/utils/logging/test_log.go index ea08c53263b3..43c90456e300 100644 --- a/utils/logging/test_log.go +++ b/utils/logging/test_log.go @@ -4,7 +4,6 @@ package logging import ( - "errors" "io" "go.uber.org/zap" @@ -14,15 +13,13 @@ var ( // Discard is a mock WriterCloser that drops all writes and close requests Discard io.WriteCloser = discard{} - errNoLoggerWrite = errors.New("NoLogger can't write") - _ Logger = NoLog{} ) type NoLog struct{} -func (NoLog) Write([]byte) (int, error) { - return 0, errNoLoggerWrite +func (NoLog) Write(b []byte) (int, error) { + return len(b), nil } func (NoLog) Fatal(string, ...zap.Field) {} diff --git a/version/compatibility.json b/version/compatibility.json index 9024236ac044..6509b43ce4bb 100644 --- a/version/compatibility.json +++ b/version/compatibility.json @@ -1,6 +1,7 @@ { "24": [ - "v1.9.10" + "v1.9.10", + "v1.9.11" ], "23": [ "v1.9.9" diff --git a/version/constants.go b/version/constants.go index b2378b3270ba..a031cbaeae93 100644 --- a/version/constants.go +++ b/version/constants.go @@ -21,7 +21,7 @@ var ( Current = &Semantic{ Major: 1, Minor: 9, - Patch: 10, + Patch: 11, } CurrentApp = &Application{ Major: Current.Major, diff --git a/vms/manager.go b/vms/manager.go index 1c536e82fc93..28d0fe8e8c10 100644 --- a/vms/manager.go +++ b/vms/manager.go @@ -62,6 +62,8 @@ type manager struct { // alias of the VM. That is, [vmID].String() is an alias for [vmID]. ids.Aliaser + log logging.Logger + lock sync.RWMutex // Key: A VM's ID @@ -74,9 +76,10 @@ type manager struct { } // NewManager returns an instance of a VM manager -func NewManager() Manager { +func NewManager(log logging.Logger, aliaser ids.Aliaser) Manager { return &manager{ - Aliaser: ids.NewAliaser(), + Aliaser: aliaser, + log: log, factories: make(map[ids.ID]Factory), versions: make(map[ids.ID]string), } @@ -105,8 +108,7 @@ func (m *manager) RegisterFactory(ctx context.Context, vmID ids.ID, factory Fact m.factories[vmID] = factory - // TODO: Pass in a VM specific logger - vm, err := factory.New(logging.NoLog{}) + vm, err := factory.New(m.log) if err != nil { return err } diff --git a/vms/registry/vm_registerer.go b/vms/registry/vm_registerer.go index 2897c20bd376..4487d4620581 100644 --- a/vms/registry/vm_registerer.go +++ b/vms/registry/vm_registerer.go @@ -36,9 +36,10 @@ type registerer interface { // VMRegistererConfig configures settings for VMRegisterer. type VMRegistererConfig struct { - APIServer server.Server - Log logging.Logger - VMManager vms.Manager + APIServer server.Server + Log logging.Logger + VMFactoryLog logging.Logger + VMManager vms.Manager } type vmRegisterer struct { @@ -89,8 +90,7 @@ func (r *vmRegisterer) createStaticHandlers( vmID ids.ID, factory vms.Factory, ) (map[string]*common.HTTPHandler, error) { - // TODO: Pass in a VM specific logger - vm, err := factory.New(logging.NoLog{}) + vm, err := factory.New(r.config.VMFactoryLog) if err != nil { return nil, err } diff --git a/vms/registry/vm_registerer_test.go b/vms/registry/vm_registerer_test.go index db8180bfcd8f..8347e498afcc 100644 --- a/vms/registry/vm_registerer_test.go +++ b/vms/registry/vm_registerer_test.go @@ -433,9 +433,10 @@ func initRegistererTest(t *testing.T) *vmRegistererTestResources { mockLog := logging.NewMockLogger(ctrl) registerer := NewVMRegisterer(VMRegistererConfig{ - APIServer: mockServer, - Log: mockLog, - VMManager: mockManager, + APIServer: mockServer, + Log: mockLog, + VMFactoryLog: logging.NoLog{}, + VMManager: mockManager, }) mockLog.EXPECT().Error(gomock.Any(), gomock.Any()).AnyTimes() diff --git a/vms/rpcchainvm/runtime/subprocess/runtime.go b/vms/rpcchainvm/runtime/subprocess/runtime.go index d606426d10d3..c3ff52060545 100644 --- a/vms/rpcchainvm/runtime/subprocess/runtime.go +++ b/vms/rpcchainvm/runtime/subprocess/runtime.go @@ -137,7 +137,7 @@ func Bootstrap( if intitializer.err != nil { stopper.Stop(ctx) - return nil, nil, fmt.Errorf("%w: %v", runtime.ErrHandshakeFailed, err) + return nil, nil, fmt.Errorf("%w: %v", runtime.ErrHandshakeFailed, intitializer.err) } log.Info("plugin handshake succeeded", diff --git a/wallet/chain/p/signer_visitor.go b/wallet/chain/p/signer_visitor.go index cd07a0544c4b..29a2ca527e3c 100644 --- a/wallet/chain/p/signer_visitor.go +++ b/wallet/chain/p/signer_visitor.go @@ -14,7 +14,6 @@ import ( "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/keychain" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" - "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/verify" "github.com/ava-labs/avalanchego/vms/platformvm/stakeable" @@ -266,7 +265,6 @@ func sign(tx *txs.Tx, txSigners [][]keychain.Signer) error { if err != nil { return fmt.Errorf("couldn't marshal unsigned tx: %w", err) } - unsignedHash := hashing.ComputeHash256(unsignedBytes) if expectedLen := len(txSigners); expectedLen != len(tx.Creds) { tx.Creds = make([]verify.Verifiable, expectedLen) @@ -309,7 +307,7 @@ func sign(tx *txs.Tx, txSigners [][]keychain.Signer) error { continue } - sig, err := signer.SignHash(unsignedHash) + sig, err := signer.Sign(unsignedBytes) if err != nil { return fmt.Errorf("problem signing tx: %w", err) } diff --git a/wallet/chain/x/signer.go b/wallet/chain/x/signer.go index 7c3bdbb62a95..13932c0c82f2 100644 --- a/wallet/chain/x/signer.go +++ b/wallet/chain/x/signer.go @@ -13,7 +13,6 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto/keychain" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" - "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/vms/avm/fxs" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -262,7 +261,6 @@ func sign(tx *txs.Tx, creds []verify.Verifiable, txSigners [][]keychain.Signer) if err != nil { return fmt.Errorf("couldn't marshal unsigned tx: %w", err) } - unsignedHash := hashing.ComputeHash256(unsignedBytes) if expectedLen := len(txSigners); expectedLen != len(tx.Creds) { tx.Creds = make([]*fxs.FxCredential, expectedLen) @@ -318,7 +316,7 @@ func sign(tx *txs.Tx, creds []verify.Verifiable, txSigners [][]keychain.Signer) continue } - sig, err := signer.SignHash(unsignedHash) + sig, err := signer.Sign(unsignedBytes) if err != nil { return fmt.Errorf("problem signing tx: %w", err) } diff --git a/x/merkledb/README.md b/x/merkledb/README.md index d99f5842e729..6682a6d2921e 100644 --- a/x/merkledb/README.md +++ b/x/merkledb/README.md @@ -2,15 +2,10 @@ ## TODOs -- [ ] Simplify trieview rootID tracking to only track the direct parent's rootID. - [ ] Improve invariants around trieview commitment. Either: - - [ ] Guarantee atomicity of internal parent view commitments. - - [ ] Remove internal parent view commitments. - [ ] Consider allowing a child view to commit into a parent view without committing to the base DB. - [ ] Allow concurrent reads into the trieview. - [ ] Remove special casing around the root node from the physical structure of the hashed tree. -- [ ] Remove the implied prefix from the `dbNode`'s `child` -- [ ] Fix intermediate node eviction panic when encountering errors - [ ] Analyze performance impact of needing to skip intermediate nodes when generating range and change proofs - [ ] Consider moving nodes with values to a separate db prefix - [ ] Replace naive concurrent hashing with a more optimized implementation @@ -48,6 +43,10 @@ To reduce the depth of nodes in the trie, a `Merkle Node` utilizes path compress A `Merkle Node` holds the IDs of its children, its value, as well as any path extension. This simplifies some logic and allows all of the data about a node to be loaded in a single database read. This trades off a small amount of storage efficiency (some fields may be `nil` but are still stored for every node). +### Validity + +A `trieView` is built atop another trie, and that trie could change at any point. If it does, all descendants of the trie will be marked invalid before the edit of the trie occurs. If an operation is performed on an invalid trie, an ErrInvalid error will be returned instead of the expected result. When a view is committed, all of its sibling views (the views that share the same parent) are marked invalid and any child views of the view have their parent updated to exclude any committed views between them and the db. + ### Locking A `trieView` is built atop another trie, which may be the underlying `Database` or another `trieView`. @@ -55,7 +54,8 @@ It's important to guarantee atomicity/consistency of trie operations. That is, if a view method is executing, the views/database underneath the view shouldn't be changing. To prevent this, we need to use locking. -`trieView` has a `Mutex` named `lock` that's held when its methods are executing. +`trieView` has a `Mutex` named `lock` that's held when most of its methods are executing. +It also has a `Mutex` named `invalidationLock` that is held during methods that change the view's validity or tracking of child views' validity. Trie methods also grab the write `lock` for all views that its built atop, and a read lock for the underlying `Database`. The exception is `Commit`, which grabs a write lock for the `Database`. This is the only `trieView` method that modifies the underlying `Database`. @@ -65,7 +65,7 @@ To prevent deadlocks, `trieView` and `Database` never lock a view that is built That is, locking is always done from a view down to the underlying `Database`, never the other way around. In some of `Database`'s methods, we create a `trieView` and call unexported methods on it without locking it. We do so because the exported counterpart of the method read locks the `Database`, which is already locked. -This pattern is safe because the `Database` is locked, so no data under the view is changing, and nobody else has a reference to the view, so there can't be any concurrent access. +This pattern is safe because the `Database` is locked, so no data under the view is changing, and nobody else has a reference to the view, so there can't be any concurrent access. Additionally, any function that takes the `invalidationLock` should avoid taking the `trieView.lock` as this will likely trigger a deadlock as well. `Database` has a `RWMutex` named `lock`. Its read operations don't store data in a map, so a read lock suffices for read operations. `trieView`'s `Commit` method explicitly grabs this lock. diff --git a/x/merkledb/db.go b/x/merkledb/db.go index bd634ed5bfa8..28384cb00752 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -98,6 +98,14 @@ type Database struct { // The root of this trie. root *node + + // Valid children of this trie. + childViews []*trieView +} + +func (*Database) calculateIDs(context.Context) error { + // no-op as the db is always up to date with all ids + return nil } func newDatabase( @@ -113,6 +121,7 @@ func newDatabase( history: newTrieHistory(config.HistoryLength), tracer: config.Tracer, valueCache: cache.LRU[string, Maybe[[]byte]]{Size: config.ValueCacheSize}, + childViews: make([]*trieView, 0, defaultPreallocationSize), } // Note: trieDB.OnEviction is responsible for writing intermediary nodes to @@ -172,7 +181,7 @@ func (db *Database) rebuild(ctx context.Context) error { it := db.nodeDB.NewIterator() defer it.Release() - currentView, err := db.newView(ctx) + currentView, err := db.newUntrackedView(ctx) if err != nil { return err } @@ -183,10 +192,10 @@ func (db *Database) rebuild(ctx context.Context) error { ) for it.Next() { if currentViewSize >= viewSizeLimit { - if err := currentView.Commit(ctx); err != nil { + if err := currentView.commitToDB(ctx, nil); err != nil { return err } - currentView, err = db.newView(ctx) + currentView, err = db.newUntrackedView(ctx) if err != nil { return err } @@ -214,7 +223,7 @@ func (db *Database) rebuild(ctx context.Context) error { if err := it.Error(); err != nil { return err } - if err := currentView.Commit(ctx); err != nil { + if err := currentView.commitToDB(ctx, nil); err != nil { return err } return db.nodeDB.Compact(nil, nil) @@ -238,7 +247,7 @@ func (db *Database) CommitChangeProof(ctx context.Context, proof *ChangeProof) e if err != nil { return err } - return view.commit(ctx) + return view.commitToDB(ctx, nil) } // Commits the key/value pairs within the [proof] to the db. @@ -251,7 +260,7 @@ func (db *Database) CommitRangeProof(ctx context.Context, start []byte, proof *R if err != nil { return err } - return view.commit(ctx) + return view.commitToDB(ctx, nil) } func (db *Database) Compact(start []byte, limit []byte) error { @@ -389,7 +398,7 @@ func (db *Database) GetProof(ctx context.Context, key []byte) (*Proof, error) { // Returns a proof of the existence/non-existence of [key] in this trie. // Assumes [db.lock] is read locked. func (db *Database) getProof(ctx context.Context, key []byte) (*Proof, error) { - view, err := db.newView(ctx) + view, err := db.newUntrackedView(ctx) if err != nil { return nil, err } @@ -555,24 +564,33 @@ func (db *Database) NewView(ctx context.Context) (TrieView, error) { return db.NewPreallocatedView(ctx, defaultPreallocationSize) } +// Returns a new view that isn't tracked in [db.childViews]. +// For internal use only, namely in methods that create short-lived views. // Assumes [db.lock] is read locked. -func (db *Database) newView(ctx context.Context) (*trieView, error) { +func (db *Database) newUntrackedView(ctx context.Context) (*trieView, error) { return db.newPreallocatedView(ctx, defaultPreallocationSize) } -// Same as NewView except that the view will be preallocated to hold at least [estimatedSize] -// value changes at a time. If more changes are made, additional memory will be allocated. +// Returns a new view preallocated to hold at least [estimatedSize] value changes at a time. +// If more changes are made, additional memory will be allocated. +// The returned view is added to [db.childViews]. // Assumes [db.lock] isn't held. func (db *Database) NewPreallocatedView(ctx context.Context, estimatedSize int) (TrieView, error) { - db.lock.RLock() - defer db.lock.RUnlock() + db.lock.Lock() + defer db.lock.Unlock() - return newTrieView(ctx, db, nil, nil, estimatedSize) + newView, err := db.newPreallocatedView(ctx, estimatedSize) + if err != nil { + return nil, err + } + db.childViews = append(db.childViews, newView) + return newView, nil } // Assumes [db.lock] is read locked. +// Assumes that this view is temporary and doesn't require validity tracking func (db *Database) newPreallocatedView(ctx context.Context, estimatedSize int) (*trieView, error) { - return newTrieView(ctx, db, nil, nil, estimatedSize) + return newTrieView(ctx, db, db, nil, estimatedSize) } func (db *Database) Has(k []byte) (bool, error) { @@ -598,7 +616,7 @@ func (db *Database) Insert(ctx context.Context, k, v []byte) error { db.lock.Lock() defer db.lock.Unlock() - view, err := db.newView(ctx) + view, err := db.newUntrackedView(ctx) if err != nil { return err } @@ -606,7 +624,7 @@ func (db *Database) Insert(ctx context.Context, k, v []byte) error { if err := view.insert(ctx, k, v); err != nil { return err } - return view.commit(ctx) + return view.commitToDB(ctx, nil) } func (db *Database) NewBatch() database.Batch { @@ -684,7 +702,7 @@ func (db *Database) Remove(ctx context.Context, key []byte) error { db.lock.Lock() defer db.lock.Unlock() - view, err := db.newView(ctx) + view, err := db.newUntrackedView(ctx) if err != nil { return err } @@ -692,7 +710,7 @@ func (db *Database) Remove(ctx context.Context, key []byte) error { if err = view.remove(ctx, key); err != nil { return err } - return view.commit(ctx) + return view.commitToDB(ctx, nil) } // Assumes [db.lock] is held. @@ -701,12 +719,18 @@ func (db *Database) commitBatch(ops []database.BatchOp) error { if err != nil { return err } - return view.commit(context.Background()) + return view.commitToDB(context.Background(), nil) } -// Applies unwritten changes into the db. // Assumes [db.lock] is held. -func (db *Database) commitChanges(ctx context.Context, changes *changeSummary) error { +func (db *Database) commitChanges(ctx context.Context, trieToCommit *trieView) error { + if trieToCommit == nil { + return nil + } + if trieToCommit.isInvalid() { + return ErrInvalid + } + changes := trieToCommit.changes _, span := db.tracer.Start(ctx, "MerkleDB.commitChanges", oteltrace.WithAttributes( attribute.Int("nodesChanged", len(changes.nodes)), attribute.Int("valuesChanged", len(changes.values)), @@ -717,6 +741,16 @@ func (db *Database) commitChanges(ctx context.Context, changes *changeSummary) e return database.ErrClosed } + db.invalidateChildrenExcept(trieToCommit) + + // move any child views of the committed trie onto the db + for _, childView := range trieToCommit.childViews { + // It's safe to manipulate [childView.parentTrie] because we hold + // [db.lock] so all calls to [childView.lockStack] are blocking. + childView.parentTrie = db + db.childViews = append(db.childViews, childView) + } + if len(changes.nodes) == 0 { return nil } @@ -725,7 +759,6 @@ func (db *Database) commitChanges(ctx context.Context, changes *changeSummary) e if !ok { return errNoNewRoot } - changes.rootID = rootChange.after.id // commit any outstanding cache evicted nodes. // Note that we do this here because below we may Abort @@ -796,6 +829,30 @@ func (db *Database) commitChanges(ctx context.Context, changes *changeSummary) e return nil } +// Applies unwritten changes into the db. +// Assumes [db.lock] is held. +func (db *Database) commitToDB(ctx context.Context, trieToCommit *trieView) error { + return db.commitChanges(ctx, trieToCommit) +} + +// invalidate and remove any child views that aren't the exception +// Assumes [db.lock] is held. +func (db *Database) invalidateChildrenExcept(exception *trieView) { + isTrackedView := false + + for _, childView := range db.childViews { + if childView != exception { + childView.invalidate() + } else { + isTrackedView = true + } + } + db.childViews = make([]*trieView, 0, defaultPreallocationSize) + if isTrackedView { + db.childViews = append(db.childViews, exception) + } +} + func (db *Database) initializeRootIfNeeded(_ context.Context) (ids.ID, error) { // ensure that root exists nodeBytes, err := db.nodeDB.Get(rootKey) @@ -846,14 +903,14 @@ func (db *Database) getHistoricalViewForRangeProof( // looking for the trie's current root id, so return the trie unmodified if currentRootID == rootID { - return newTrieView(ctx, db, nil, nil, 100) + return newTrieView(ctx, db, db, nil, 100) } changeHistory, err := db.history.getChangesToGetToRoot(rootID, start, end) if err != nil { return nil, err } - return newTrieView(ctx, db, nil, changeHistory, len(changeHistory.nodes)) + return newTrieView(ctx, db, db, changeHistory, len(changeHistory.nodes)) } // Returns all of the keys in range [start, end] that aren't in [keySet]. diff --git a/x/merkledb/db_test.go b/x/merkledb/db_test.go index cd4826f2545f..383484ec29cf 100644 --- a/x/merkledb/db_test.go +++ b/x/merkledb/db_test.go @@ -28,16 +28,7 @@ func newNoopTracer() trace.Tracer { func Test_MerkleDB_DB_Interface(t *testing.T) { for _, test := range database.Tests { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) test(t, db) } @@ -47,16 +38,7 @@ func Benchmark_MerkleDB_DBInterface(b *testing.B) { for _, size := range database.BenchmarkSizes { keys, values := database.SetupBenchmark(b, size[0], size[1], size[2]) for _, bench := range database.Benchmarks { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(b, err) bench(b, db, "merkledb", keys, values) } @@ -87,7 +69,7 @@ func Test_MerkleDB_DB_Load_Root_From_DB(t *testing.T) { k := []byte(strconv.Itoa(i)) require.NoError(view.Insert(context.Background(), k, hashing.ComputeHash256(k))) } - require.NoError(view.Commit(context.Background())) + require.NoError(view.commitToDB(context.Background(), nil)) root, err := db.GetMerkleRoot(context.Background()) require.NoError(err) @@ -138,7 +120,7 @@ func Test_MerkleDB_DB_Rebuild(t *testing.T) { k := []byte(strconv.Itoa(i)) require.NoError(view.Insert(context.Background(), k, hashing.ComputeHash256(k))) } - require.NoError(view.Commit(context.Background())) + require.NoError(view.CommitToDB(context.Background())) root, err := db.GetMerkleRoot(context.Background()) require.NoError(err) @@ -221,17 +203,32 @@ func Test_MerkleDB_Value_Cache(t *testing.T) { require.ErrorIs(t, err, database.ErrNotFound) } +func Test_MerkleDB_Invalidate_Siblings_On_Commit(t *testing.T) { + dbTrie, err := getBasicDB() + require.NoError(t, err) + require.NotNil(t, dbTrie) + + viewToCommit, err := dbTrie.NewView(context.Background()) + require.NoError(t, err) + + sibling1, err := dbTrie.NewView(context.Background()) + require.NoError(t, err) + sibling2, err := dbTrie.NewView(context.Background()) + require.NoError(t, err) + + require.False(t, sibling1.(*trieView).isInvalid()) + require.False(t, sibling2.(*trieView).isInvalid()) + + require.NoError(t, viewToCommit.Insert(context.Background(), []byte{0}, []byte{0})) + require.NoError(t, viewToCommit.CommitToDB(context.Background())) + + require.True(t, sibling1.(*trieView).isInvalid()) + require.True(t, sibling2.(*trieView).isInvalid()) + require.False(t, viewToCommit.(*trieView).isInvalid()) +} + func Test_MerkleDB_Commit_Proof_To_Empty_Trie(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) batch := db.NewBatch() err = batch.Put([]byte("key1"), []byte("1")) @@ -246,16 +243,7 @@ func Test_MerkleDB_Commit_Proof_To_Empty_Trie(t *testing.T) { proof, err := db.GetRangeProof(context.Background(), []byte("key1"), []byte("key3"), 10) require.NoError(t, err) - freshDB, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + freshDB, err := getBasicDB() require.NoError(t, err) err = freshDB.CommitRangeProof(context.Background(), []byte("key1"), proof) @@ -273,16 +261,7 @@ func Test_MerkleDB_Commit_Proof_To_Empty_Trie(t *testing.T) { } func Test_MerkleDB_Commit_Proof_To_Filled_Trie(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) batch := db.NewBatch() err = batch.Put([]byte("key1"), []byte("1")) @@ -297,16 +276,7 @@ func Test_MerkleDB_Commit_Proof_To_Filled_Trie(t *testing.T) { proof, err := db.GetRangeProof(context.Background(), []byte("key1"), []byte("key3"), 10) require.NoError(t, err) - freshDB, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + freshDB, err := getBasicDB() require.NoError(t, err) batch = freshDB.NewBatch() err = batch.Put([]byte("key1"), []byte("3")) @@ -334,17 +304,31 @@ func Test_MerkleDB_Commit_Proof_To_Filled_Trie(t *testing.T) { require.Equal(t, oldRoot, freshRoot) } +func Test_MerkleDB_GetValues(t *testing.T) { + db, err := getBasicDB() + require.NoError(t, err) + + writeBasicBatch(t, db) + keys := [][]byte{{0}, {1}, {2}, {10}} + values, errors := db.GetValues(context.Background(), keys) + require.Len(t, values, len(keys)) + require.Len(t, errors, len(keys)) + + // first 3 have values + // last was not found + require.NoError(t, errors[0]) + require.NoError(t, errors[1]) + require.NoError(t, errors[2]) + require.ErrorIs(t, errors[3], database.ErrNotFound) + + require.Equal(t, []byte{0}, values[0]) + require.Equal(t, []byte{1}, values[1]) + require.Equal(t, []byte{2}, values[2]) + require.Nil(t, values[3]) +} + func Test_MerkleDB_InsertNil(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) batch := db.NewBatch() err = batch.Put([]byte("key0"), nil) @@ -362,16 +346,7 @@ func Test_MerkleDB_InsertNil(t *testing.T) { } func Test_MerkleDB_InsertAndRetrieve(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) // value hasn't been inserted so shouldn't exist @@ -389,16 +364,7 @@ func Test_MerkleDB_InsertAndRetrieve(t *testing.T) { } func Test_MerkleDB_HealthCheck(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) val, err := db.HealthCheck(context.Background()) require.NoError(t, err) @@ -406,16 +372,7 @@ func Test_MerkleDB_HealthCheck(t *testing.T) { } func Test_MerkleDB_Overwrite(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) err = db.Put([]byte("key"), []byte("value0")) @@ -434,16 +391,7 @@ func Test_MerkleDB_Overwrite(t *testing.T) { } func Test_MerkleDB_Delete(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) err = db.Put([]byte("key"), []byte("value0")) @@ -462,22 +410,194 @@ func Test_MerkleDB_Delete(t *testing.T) { } func Test_MerkleDB_DeleteMissingKey(t *testing.T) { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) err = db.Delete([]byte("key")) require.NoError(t, err) } +// Test that untracked views aren't persisted to [db.childViews]. +func TestDatabaseNewUntrackedView(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a new untracked view. + view, err := db.newUntrackedView(context.Background()) + require.NoError(err) + require.Empty(db.childViews) + + // Write to the untracked view. + err = view.Insert(context.Background(), []byte{1}, []byte{1}) + require.NoError(err) + + // Commit the view + err = view.CommitToDB(context.Background()) + require.NoError(err) + + // The untracked view should not be tracked by the parent database. + require.Empty(db.childViews) +} + +// Test that tracked views are persisted to [db.childViews]. +func TestDatabaseNewPreallocatedViewTracked(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a new tracked view. + view, err := db.NewPreallocatedView(context.Background(), 10) + require.NoError(err) + require.Len(db.childViews, 1) + + // Write to the view. + err = view.Insert(context.Background(), []byte{1}, []byte{1}) + require.NoError(err) + + // Commit the view + err = view.CommitToDB(context.Background()) + require.NoError(err) + + // The untracked view should be tracked by the parent database. + require.Contains(db.childViews, view) + require.Len(db.childViews, 1) +} + +func TestDatabaseCommitChanges(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + dbRoot := db.getMerkleRoot() + + // Committing a nil view should be a no-op. + err = db.commitChanges(context.Background(), nil) + require.NoError(err) + require.Equal(dbRoot, db.getMerkleRoot()) // Root didn't change + + // Committing an invalid view should fail. + invalidView := &trieView{ + invalidated: true, + } + err = db.commitChanges(context.Background(), invalidView) + require.ErrorIs(err, ErrInvalid) + + // Add key-value pairs to the database + err = db.Put([]byte{1}, []byte{1}) + require.NoError(err) + err = db.Put([]byte{2}, []byte{2}) + require.NoError(err) + + // Make a view and inser/delete a key-value pair. + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + err = view1.Insert(context.Background(), []byte{3}, []byte{3}) + require.NoError(err) + err = view1.Remove(context.Background(), []byte{1}) + require.NoError(err) + view1Root, err := view1.getMerkleRoot(context.Background()) + require.NoError(err) + + // Make a second view + view2Intf, err := db.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + // Make a view atop a view + view3Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + // view3 + // | + // view1 view2 + // \ / + // db + + // Commit view1 + err = db.commitChanges(context.Background(), view1) + require.NoError(err) + + // Make sure the key-value pairs are correct. + _, err = db.Get([]byte{1}) + require.ErrorIs(err, database.ErrNotFound) + value, err := db.Get([]byte{2}) + require.NoError(err) + require.Equal([]byte{2}, value) + value, err = db.Get([]byte{3}) + require.NoError(err) + require.Equal([]byte{3}, value) + + // Make sure the root is right + require.Equal(view1Root, db.getMerkleRoot()) + + // Make sure view2 is invalid and view1 and view3 is valid. + require.False(view1.invalidated) + require.True(view2.invalidated) + require.False(view3.invalidated) + + // Make sure view2 isn't tracked by the database. + require.NotContains(db.childViews, view2) + + // Make sure view1 and view3 is tracked by the database. + require.Contains(db.childViews, view1) + require.Contains(db.childViews, view3) + + // Make sure view3 is now a child of db. + require.Equal(db, view3.parentTrie) +} + +func TestDatabaseInvalidateChildrenExcept(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create children + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + view2Intf, err := db.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + view3Intf, err := db.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + db.invalidateChildrenExcept(view1) + + // Make sure view1 is valid and view2 and view3 are invalid. + require.False(view1.invalidated) + require.True(view2.invalidated) + require.True(view3.invalidated) + require.Contains(db.childViews, view1) + require.Len(db.childViews, 1) + + db.invalidateChildrenExcept(nil) + + // Make sure all views are invalid. + require.True(view1.invalidated) + require.True(view2.invalidated) + require.True(view3.invalidated) + require.Empty(db.childViews) + + // Calling with an untracked view doesn't add the untracked view + db.invalidateChildrenExcept(view1) + require.Empty(db.childViews) +} + func Test_MerkleDB_Random_Insert_Ordering(t *testing.T) { totalState := 1000 var ( @@ -526,16 +646,7 @@ func Test_MerkleDB_Random_Insert_Ordering(t *testing.T) { } ops = append(ops, &testOperation{key: key, value: value}) } - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(t, err) result, err := applyOperations(db, ops) require.NoError(t, err) @@ -583,16 +694,7 @@ func Test_MerkleDB_RandomCases(t *testing.T) { require := require.New(t) for i := 150; i < 500; i += 10 { - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(err) r := rand.New(rand.NewSource(int64(i))) // #nosec G404 runRandDBTest(require, db, r, generate(require, r, i, .01)) @@ -602,16 +704,7 @@ func Test_MerkleDB_RandomCases(t *testing.T) { func Test_MerkleDB_RandomCases_InitialValues(t *testing.T) { require := require.New(t) - db, err := New( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - HistoryLength: 300, - ValueCacheSize: minCacheSize, - NodeCacheSize: minCacheSize, - }, - ) + db, err := getBasicDB() require.NoError(err) r := rand.New(rand.NewSource(int64(0))) // #nosec G404 runRandDBTest(require, db, r, generateInitialValues(require, r, 2000, 2500, 0.0)) diff --git a/x/merkledb/proof.go b/x/merkledb/proof.go index 81847a67725b..df9a6310678d 100644 --- a/x/merkledb/proof.go +++ b/x/merkledb/proof.go @@ -428,7 +428,7 @@ func (proof *ChangeProof) Verify( defer db.lock.RUnlock() // Don't need to lock [view] because nobody else has a reference to it. - view, err := db.newView(ctx) + view, err := db.newUntrackedView(ctx) if err != nil { return err } @@ -610,7 +610,7 @@ func valueOrHashMatches(value Maybe[[]byte], valueOrHash Maybe[[]byte]) bool { // Assumes [t]'s view stack is locked. func addPathInfo( ctx context.Context, - t TrieView, + t *trieView, proofPath []ProofNode, startPath path, endPath path, @@ -662,7 +662,7 @@ func addPathInfo( return nil } -func getEmptyTrieView(ctx context.Context) (TrieView, error) { +func getEmptyTrieView(ctx context.Context) (*trieView, error) { tracer, err := trace.New(trace.Config{Enabled: false}) if err != nil { return nil, err @@ -681,5 +681,5 @@ func getEmptyTrieView(ctx context.Context) (TrieView, error) { return nil, err } - return db.NewView(ctx) + return db.newUntrackedView(ctx) } diff --git a/x/merkledb/trie.go b/x/merkledb/trie.go index c6c2dff157dd..5e6522a76ddc 100644 --- a/x/merkledb/trie.go +++ b/x/merkledb/trie.go @@ -74,6 +74,16 @@ type Trie interface { // Insert a key/value pair into the Trie Insert(ctx context.Context, key, value []byte) error + + // ensures that all changed nodes have their new ids calculated + calculateIDs(ctx context.Context) error + + // commits changes in the trieToCommit into the current trie + commitChanges(ctx context.Context, trieToCommit *trieView) error + + // commits changes in the trieToCommit into the current trie + // then commits the combined changes down the stack until all changes in the stack commit to the database + commitToDB(ctx context.Context, trieToCommit *trieView) error } // Invariant: unexported methods (except lockStack) are only called when the @@ -84,16 +94,5 @@ type TrieView interface { // Commit the changes from this Trie into the database. // Any views that this Trie is built on will also be committed, starting at // the oldest. - Commit(ctx context.Context) error - - // Insert key/value into the trie and get back the node associated with the - // key. - // Updates nodes in the trie, whereas Trie.Insert records the key/value - // without updating any trie nodes. - insertIntoTrie(ctx context.Context, keyPath path, value Maybe[[]byte]) (*node, error) - - // Remove the key's value from the trie. - // Updates nodes in the trie, whereas Trie.Remove records the key without - // updating any trie nodes. - removeFromTrie(ctx context.Context, keyPath path) error + CommitToDB(ctx context.Context) error } diff --git a/x/merkledb/trie_test.go b/x/merkledb/trie_test.go index 6a46d3e2eeb9..4750744fded7 100644 --- a/x/merkledb/trie_test.go +++ b/x/merkledb/trie_test.go @@ -19,7 +19,7 @@ import ( func getNodeValue(t ReadOnlyTrie, key string) ([]byte, error) { if asTrieView, ok := t.(*trieView); ok { - if err := asTrieView.CalculateIDs(context.Background()); err != nil { + if err := asTrieView.calculateIDs(context.Background()); err != nil { return nil, err } path := newPath([]byte(key)) @@ -57,17 +57,7 @@ func getNodeValue(t ReadOnlyTrie, key string) ([]byte, error) { func TestTrieViewGetPathTo(t *testing.T) { require := require.New(t) - db, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + db, err := getBasicDB() require.NoError(err) trieIntf, err := db.NewView(context.Background()) @@ -86,7 +76,7 @@ func TestTrieViewGetPathTo(t *testing.T) { key1 := []byte{0} err = trie.Insert(context.Background(), key1, []byte("value")) require.NoError(err) - err = trie.CalculateIDs(context.Background()) + err = trie.calculateIDs(context.Background()) require.NoError(err) path, err = trie.getPathTo(context.Background(), newPath(key1)) @@ -101,7 +91,7 @@ func TestTrieViewGetPathTo(t *testing.T) { key2 := []byte{0, 1} err = trie.Insert(context.Background(), key2, []byte("value")) require.NoError(err) - err = trie.CalculateIDs(context.Background()) + err = trie.calculateIDs(context.Background()) require.NoError(err) path, err = trie.getPathTo(context.Background(), newPath(key2)) @@ -115,7 +105,7 @@ func TestTrieViewGetPathTo(t *testing.T) { key3 := []byte{255} err = trie.Insert(context.Background(), key3, []byte("value")) require.NoError(err) - err = trie.CalculateIDs(context.Background()) + err = trie.calculateIDs(context.Background()) require.NoError(err) path, err = trie.getPathTo(context.Background(), newPath(key3)) @@ -149,131 +139,72 @@ func TestTrieViewGetPathTo(t *testing.T) { require.Equal(trie.root, path[0]) } -func Test_Trie_Partial_Commit_Leaves_Valid_Tries(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) +func Test_Trie_ViewOnCommitedView(t *testing.T) { + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) - trie2, err := dbTrie.NewView(context.Background()) + committedTrie, err := dbTrie.NewView(context.Background()) require.NoError(t, err) - err = trie2.Insert(context.Background(), []byte("key"), []byte("value")) + err = committedTrie.Insert(context.Background(), []byte{0}, []byte{0}) require.NoError(t, err) - trie3, err := trie2.NewView(context.Background()) - require.NoError(t, err) - err = trie3.Insert(context.Background(), []byte("key1"), []byte("value1")) - require.NoError(t, err) - - trie4, err := trie3.NewView(context.Background()) - require.NoError(t, err) - err = trie4.Insert(context.Background(), []byte("key2"), []byte("value2")) - require.NoError(t, err) + require.NoError(t, committedTrie.CommitToDB(context.Background())) - trie5, err := trie4.NewView(context.Background()) - require.NoError(t, err) - err = trie5.Insert(context.Background(), []byte("key3"), []byte("value3")) + newView, err := committedTrie.NewView(context.Background()) require.NoError(t, err) - err = trie3.Commit(context.Background()) + err = newView.Insert(context.Background(), []byte{1}, []byte{1}) require.NoError(t, err) + require.NoError(t, newView.CommitToDB(context.Background())) - root, err := trie3.GetMerkleRoot(context.Background()) + val0, err := dbTrie.GetValue(context.Background(), []byte{0}) require.NoError(t, err) - - dbRoot, err := dbTrie.GetMerkleRoot(context.Background()) + require.Equal(t, []byte{0}, val0) + val1, err := dbTrie.GetValue(context.Background(), []byte{1}) require.NoError(t, err) - - require.Equal(t, root, dbRoot) + require.Equal(t, []byte{1}, val1) } -func Test_Trie_Collapse_After_Commit(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) +func Test_Trie_Partial_Commit_Leaves_Valid_Tries(t *testing.T) { + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) - trie1 := Trie(dbTrie) - trie2, err := trie1.NewView(context.Background()) + trie2, err := dbTrie.NewView(context.Background()) require.NoError(t, err) err = trie2.Insert(context.Background(), []byte("key"), []byte("value")) require.NoError(t, err) - trie2Root, err := trie2.GetMerkleRoot(context.Background()) - require.NoError(t, err) trie3, err := trie2.NewView(context.Background()) require.NoError(t, err) err = trie3.Insert(context.Background(), []byte("key1"), []byte("value1")) require.NoError(t, err) - trie3Root, err := trie3.GetMerkleRoot(context.Background()) - require.NoError(t, err) trie4, err := trie3.NewView(context.Background()) require.NoError(t, err) err = trie4.Insert(context.Background(), []byte("key2"), []byte("value2")) require.NoError(t, err) - trie4Root, err := trie4.GetMerkleRoot(context.Background()) + + trie5, err := trie4.NewView(context.Background()) + require.NoError(t, err) + err = trie5.Insert(context.Background(), []byte("key3"), []byte("value3")) require.NoError(t, err) - err = trie4.Commit(context.Background()) + err = trie3.CommitToDB(context.Background()) require.NoError(t, err) - root, err := trie4.GetMerkleRoot(context.Background()) + root, err := trie3.GetMerkleRoot(context.Background()) require.NoError(t, err) dbRoot, err := dbTrie.GetMerkleRoot(context.Background()) require.NoError(t, err) require.Equal(t, root, dbRoot) - - // ensure each root is in the history - require.Contains(t, dbTrie.history.lastChanges, trie2Root) - require.Contains(t, dbTrie.history.lastChanges, trie3Root) - require.Contains(t, dbTrie.history.lastChanges, trie4Root) - - // ensure that they are in the correct order - _, _ = dbTrie.history.history.DeleteMin() // First one is root; ignore - got, ok := dbTrie.history.history.DeleteMin() - require.True(t, ok) - require.Equal(t, trie2Root, got.rootID) - got, ok = dbTrie.history.history.DeleteMin() - require.True(t, ok) - require.Equal(t, trie3Root, got.rootID) - got, ok = dbTrie.history.history.DeleteMin() - require.True(t, ok) - require.Equal(t, trie4Root, got.rootID) } func Test_Trie_WriteToDB(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie, err := dbTrie.NewView(context.Background()) @@ -292,7 +223,7 @@ func Test_Trie_WriteToDB(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("value"), value) - err = trie.Commit(context.Background()) + err = trie.CommitToDB(context.Background()) require.NoError(t, err) p := newPath([]byte("key")) rawBytes, err := dbTrie.nodeDB.Get(p.Bytes()) @@ -303,17 +234,7 @@ func Test_Trie_WriteToDB(t *testing.T) { } func Test_Trie_InsertAndRetrieve(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -333,17 +254,7 @@ func Test_Trie_InsertAndRetrieve(t *testing.T) { } func Test_Trie_Overwrite(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -364,17 +275,7 @@ func Test_Trie_Overwrite(t *testing.T) { } func Test_Trie_Delete(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -395,17 +296,7 @@ func Test_Trie_Delete(t *testing.T) { } func Test_Trie_DeleteMissingKey(t *testing.T) { - trie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + trie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, trie) @@ -414,17 +305,7 @@ func Test_Trie_DeleteMissingKey(t *testing.T) { } func Test_Trie_ExpandOnKeyPath(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -464,17 +345,7 @@ func Test_Trie_ExpandOnKeyPath(t *testing.T) { } func Test_Trie_CompressedPaths(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -514,17 +385,7 @@ func Test_Trie_CompressedPaths(t *testing.T) { } func Test_Trie_SplitBranch(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -545,17 +406,7 @@ func Test_Trie_SplitBranch(t *testing.T) { } func Test_Trie_HashCountOnBranch(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -572,17 +423,7 @@ func Test_Trie_HashCountOnBranch(t *testing.T) { } func Test_Trie_HashCountOnDelete(t *testing.T) { - trie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + trie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, trie) @@ -608,7 +449,7 @@ func Test_Trie_HashCountOnDelete(t *testing.T) { require.NoError(t, err) err = view.Remove(context.Background(), []byte("key")) require.NoError(t, err) - err = view.Commit(context.Background()) + err = view.CommitToDB(context.Background()) require.NoError(t, err) // the root is the only updated node so only one new hash @@ -616,17 +457,7 @@ func Test_Trie_HashCountOnDelete(t *testing.T) { } func Test_Trie_NoExistingResidual(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie := Trie(dbTrie) @@ -657,18 +488,109 @@ func Test_Trie_NoExistingResidual(t *testing.T) { require.Equal(t, []byte("4"), value) } +func Test_Trie_CommitChanges(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + err = view1.Insert(context.Background(), []byte{1}, []byte{1}) + require.NoError(err) + + // view1 + // | + // db + + // Case: Committing to an invalid view + view1.invalidated = true + err = view1.commitChanges(context.Background(), &trieView{}) + require.ErrorIs(err, ErrInvalid) + view1.invalidated = false // Reset + + // Case: Committing a nil view is a no-op + oldRoot, err := view1.getMerkleRoot(context.Background()) + require.NoError(err) + err = view1.commitChanges(context.Background(), nil) + require.NoError(err) + newRoot, err := view1.getMerkleRoot(context.Background()) + require.NoError(err) + require.Equal(oldRoot, newRoot) + + // Case: Committing a view with the wrong parent. + err = view1.commitChanges(context.Background(), &trieView{}) + require.ErrorIs(err, ErrViewIsNotAChild) + + // Case: Committing a view which is invalid + err = view1.commitChanges(context.Background(), &trieView{ + parentTrie: view1, + invalidated: true, + }) + require.ErrorIs(err, ErrInvalid) + + // Make more views atop the existing one + view2Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + err = view2.Insert(context.Background(), []byte{2}, []byte{2}) + require.NoError(err) + err = view2.Remove(context.Background(), []byte{1}) + require.NoError(err) + + view2Root, err := view2.getMerkleRoot(context.Background()) + require.NoError(err) + + // view1 has 1 --> 1 + // view2 has 2 --> 2 + + view3Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + view4Intf, err := view2.NewView(context.Background()) + require.NoError(err) + view4, ok := view4Intf.(*trieView) + require.True(ok) + + // view4 + // | + // view2 view3 + // | / + // view1 + // | + // db + + // Commit view2 to view1 + err = view1.commitChanges(context.Background(), view2) + require.NoError(err) + + // All siblings of view2 should be invalidated + require.True(view3.invalidated) + + // Children of view2 are now children of view1 + require.Equal(view1, view4.parentTrie) + require.Contains(view1.childViews, view4) + + // Value changes from view2 are reflected in view1 + newView1Root, err := view1.getMerkleRoot(context.Background()) + require.NoError(err) + require.Equal(view2Root, newView1Root) + _, err = view1.GetValue(context.Background(), []byte{1}) + require.ErrorIs(err, database.ErrNotFound) + got, err := view1.GetValue(context.Background(), []byte{2}) + require.NoError(err) + require.Equal([]byte{2}, got) +} + func Test_Trie_BatchApply(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie, err := dbTrie.NewView(context.Background()) @@ -697,17 +619,7 @@ func Test_Trie_BatchApply(t *testing.T) { } func Test_Trie_ChainDeletion(t *testing.T) { - trie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + trie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, trie) newTrie, err := trie.NewView(context.Background()) @@ -721,7 +633,7 @@ func Test_Trie_ChainDeletion(t *testing.T) { require.NoError(t, err) err = newTrie.Insert(context.Background(), []byte("key1"), []byte("value3")) require.NoError(t, err) - err = newTrie.(*trieView).CalculateIDs(context.Background()) + err = newTrie.(*trieView).calculateIDs(context.Background()) require.NoError(t, err) root, err := newTrie.getNode(context.Background(), EmptyPath) require.NoError(t, err) @@ -735,7 +647,7 @@ func Test_Trie_ChainDeletion(t *testing.T) { require.NoError(t, err) err = newTrie.Remove(context.Background(), []byte("key1")) require.NoError(t, err) - err = newTrie.(*trieView).CalculateIDs(context.Background()) + err = newTrie.(*trieView).calculateIDs(context.Background()) require.NoError(t, err) root, err = newTrie.getNode(context.Background(), EmptyPath) require.NoError(t, err) @@ -743,175 +655,62 @@ func Test_Trie_ChainDeletion(t *testing.T) { require.Equal(t, 0, len(root.children)) } -func Test_Trie_NodeCollapse(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) +func Test_Trie_Invalidate_Children_On_Edits(t *testing.T) { + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) - trie, err := dbTrie.NewView(context.Background()) - require.NoError(t, err) - - err = trie.Insert(context.Background(), []byte("k"), []byte("value0")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("ke"), []byte("value1")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key"), []byte("value2")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key1"), []byte("value3")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key2"), []byte("value4")) - require.NoError(t, err) - - err = trie.(*trieView).CalculateIDs(context.Background()) - require.NoError(t, err) - root, err := trie.getNode(context.Background(), EmptyPath) - require.NoError(t, err) - require.Equal(t, 1, len(root.children)) - root, err = trie.getNode(context.Background(), EmptyPath) - require.NoError(t, err) - require.Equal(t, 1, len(root.children)) - - firstNode, err := trie.getNode(context.Background(), root.getSingleChildPath()) + trie, err := dbTrie.NewView(context.Background()) require.NoError(t, err) - require.Equal(t, 1, len(firstNode.children)) - // delete the middle values - err = trie.Remove(context.Background(), []byte("k")) + childTrie1, err := trie.NewView(context.Background()) require.NoError(t, err) - err = trie.Remove(context.Background(), []byte("ke")) + childTrie2, err := trie.NewView(context.Background()) require.NoError(t, err) - err = trie.Remove(context.Background(), []byte("key")) + childTrie3, err := trie.NewView(context.Background()) require.NoError(t, err) - err = trie.(*trieView).CalculateIDs(context.Background()) - require.NoError(t, err) + require.False(t, childTrie1.(*trieView).isInvalid()) + require.False(t, childTrie2.(*trieView).isInvalid()) + require.False(t, childTrie3.(*trieView).isInvalid()) - root, err = trie.getNode(context.Background(), EmptyPath) + err = trie.Insert(context.Background(), []byte{0}, []byte{0}) require.NoError(t, err) - require.Equal(t, 1, len(root.children)) - firstNode, err = trie.getNode(context.Background(), root.getSingleChildPath()) - require.NoError(t, err) - require.Equal(t, 2, len(firstNode.children)) + require.True(t, childTrie1.(*trieView).isInvalid()) + require.True(t, childTrie2.(*trieView).isInvalid()) + require.True(t, childTrie3.(*trieView).isInvalid()) } -func Test_Trie_Duplicate_Commit(t *testing.T) { - // create two views with the same changes - // create a view on top of one of them - // commit the other duplicate view - // should still be able to commit the view on top of the uncommitted duplicate - - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) +func Test_Trie_Invalidate_Siblings_On_Commit(t *testing.T) { + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) - trie, err := dbTrie.NewView(context.Background()) - require.NoError(t, err) - - err = trie.Insert(context.Background(), []byte("k"), []byte("value0")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("ke"), []byte("value1")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key"), []byte("value2")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key1"), []byte("value3")) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key2"), []byte("value4")) - require.NoError(t, err) - committedView, err := trie.NewView(context.Background()) + baseView, err := dbTrie.NewView(context.Background()) require.NoError(t, err) - uncommittedView, err := trie.NewView(context.Background()) + viewToCommit, err := baseView.NewView(context.Background()) require.NoError(t, err) - err = committedView.Insert(context.Background(), []byte("k2"), []byte("value02")) + sibling1, err := baseView.NewView(context.Background()) require.NoError(t, err) - - err = uncommittedView.Insert(context.Background(), []byte("k2"), []byte("value02")) + sibling2, err := baseView.NewView(context.Background()) require.NoError(t, err) - err = committedView.Commit(context.Background()) - require.NoError(t, err) - - committedRoot, err := committedView.GetMerkleRoot(context.Background()) - require.NoError(t, err) - - uncommittedRoot, err := uncommittedView.GetMerkleRoot(context.Background()) - require.NoError(t, err) + require.False(t, sibling1.(*trieView).isInvalid()) + require.False(t, sibling2.(*trieView).isInvalid()) - require.Equal(t, committedRoot, uncommittedRoot) + require.NoError(t, viewToCommit.Insert(context.Background(), []byte{0}, []byte{0})) + require.NoError(t, viewToCommit.CommitToDB(context.Background())) - newView, err := uncommittedView.NewView(context.Background()) - require.NoError(t, err) - - err = newView.Insert(context.Background(), []byte("k3"), []byte("value03")) - require.NoError(t, err) - - // ok because uncommittedView's root has already been committed by committedView - err = newView.Commit(context.Background()) - require.NoError(t, err) + require.True(t, sibling1.(*trieView).isInvalid()) + require.True(t, sibling2.(*trieView).isInvalid()) + require.False(t, viewToCommit.(*trieView).isInvalid()) } -func Test_Trie_ChangedRoot(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) - require.NoError(t, err) - require.NotNil(t, dbTrie) - err = dbTrie.Insert(context.Background(), []byte("key1"), []byte("value1")) - require.NoError(t, err) - trie, err := dbTrie.NewView(context.Background()) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key2"), []byte("value2")) - require.NoError(t, err) - - err = dbTrie.Insert(context.Background(), []byte("key3"), []byte("value3")) - require.NoError(t, err) - - _, err = trie.GetValue(context.Background(), []byte("key3")) - require.ErrorIs(t, err, ErrChangedBaseRoot) -} - -func Test_Trie_CommittedView_Validate(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) +func Test_Trie_NodeCollapse(t *testing.T) { + dbTrie, err := getBasicDB() require.NoError(t, err) require.NotNil(t, dbTrie) trie, err := dbTrie.NewView(context.Background()) @@ -919,95 +718,47 @@ func Test_Trie_CommittedView_Validate(t *testing.T) { err = trie.Insert(context.Background(), []byte("k"), []byte("value0")) require.NoError(t, err) - - committedView, err := trie.NewView(context.Background()) - require.NoError(t, err) - - err = committedView.Commit(context.Background()) - require.NoError(t, err) - - err = committedView.(*trieView).validateDBRoot(context.Background()) - require.NoError(t, err) - - err = dbTrie.Insert(context.Background(), []byte("k2"), []byte("value02")) + err = trie.Insert(context.Background(), []byte("ke"), []byte("value1")) require.NoError(t, err) - - err = committedView.(*trieView).validateDBRoot(context.Background()) - require.ErrorIs(t, err, ErrChangedBaseRoot) -} - -func Test_Trie_OtherViewCommitBeforeValidate(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) + err = trie.Insert(context.Background(), []byte("key"), []byte("value2")) require.NoError(t, err) - require.NotNil(t, dbTrie) - trie, err := dbTrie.NewView(context.Background()) + err = trie.Insert(context.Background(), []byte("key1"), []byte("value3")) require.NoError(t, err) - - err = trie.Insert(context.Background(), []byte("k"), []byte("value0")) + err = trie.Insert(context.Background(), []byte("key2"), []byte("value4")) require.NoError(t, err) - committedView, err := trie.NewView(context.Background()) + err = trie.(*trieView).calculateIDs(context.Background()) require.NoError(t, err) - - uncommittedView, err := committedView.NewView(context.Background()) + root, err := trie.getNode(context.Background(), EmptyPath) require.NoError(t, err) + require.Equal(t, 1, len(root.children)) - err = uncommittedView.Insert(context.Background(), []byte("k2"), []byte("value02")) + root, err = trie.getNode(context.Background(), EmptyPath) require.NoError(t, err) + require.Equal(t, 1, len(root.children)) - err = uncommittedView.Insert(context.Background(), []byte("k3"), []byte("value03")) + firstNode, err := trie.getNode(context.Background(), root.getSingleChildPath()) require.NoError(t, err) + require.Equal(t, 1, len(firstNode.children)) - newView, err := uncommittedView.NewView(context.Background()) + // delete the middle values + err = trie.Remove(context.Background(), []byte("k")) require.NoError(t, err) - - err = committedView.Commit(context.Background()) + err = trie.Remove(context.Background(), []byte("ke")) require.NoError(t, err) - - // should still be valid because db is at root for a view in this view's viewstack - err = newView.(*trieView).validateDBRoot(context.Background()) + err = trie.Remove(context.Background(), []byte("key")) require.NoError(t, err) -} -func Test_Trie_ChangeLock(t *testing.T) { - dbTrie, err := newDatabase( - context.Background(), - memdb.New(), - Config{ - Tracer: newNoopTracer(), - ValueCacheSize: 1000, - HistoryLength: 1000, - NodeCacheSize: 1000, - }, - &mockMetrics{}, - ) - require.NoError(t, err) - require.NotNil(t, dbTrie) - err = dbTrie.Insert(context.Background(), []byte("key1"), []byte("value1")) - require.NoError(t, err) - trie, err := dbTrie.NewView(context.Background()) - require.NoError(t, err) - err = trie.Insert(context.Background(), []byte("key2"), []byte("value2")) + err = trie.(*trieView).calculateIDs(context.Background()) require.NoError(t, err) - higherView, err := trie.NewView(context.Background()) + root, err = trie.getNode(context.Background(), EmptyPath) require.NoError(t, err) + require.Equal(t, 1, len(root.children)) - err = higherView.Insert(context.Background(), []byte("key3"), []byte("value3")) + firstNode, err = trie.getNode(context.Background(), root.getSingleChildPath()) require.NoError(t, err) - - err = trie.Insert(context.Background(), []byte("key4"), []byte("value4")) - require.ErrorIs(t, err, ErrEditLocked) + require.Equal(t, 2, len(firstNode.children)) } func Test_Trie_MultipleStates(t *testing.T) { @@ -1047,7 +798,7 @@ func Test_Trie_MultipleStates(t *testing.T) { require.NoError(t, err) if commitApproach == "before" { - require.NoError(t, root.Commit(context.Background())) + require.NoError(t, root.CommitToDB(context.Background())) } // Populate additional states @@ -1059,7 +810,7 @@ func Test_Trie_MultipleStates(t *testing.T) { } if commitApproach == "after" { - require.NoError(t, root.Commit(context.Background())) + require.NoError(t, root.CommitToDB(context.Background())) } // Process ops @@ -1103,3 +854,283 @@ func Test_Trie_MultipleStates(t *testing.T) { }) } } + +func TestNewViewOnCommittedView(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a view + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + // view1 + // | + // db + + require.Len(db.childViews, 1) + require.Contains(db.childViews, view1) + require.Equal(db, view1.parentTrie) + + err = view1.Insert(context.Background(), []byte{1}, []byte{1}) + require.NoError(err) + + // Commit the view + err = view1.CommitToDB(context.Background()) + require.NoError(err) + + // view1 (committed) + // | + // db + + require.Len(db.childViews, 1) + require.Contains(db.childViews, view1) + require.Equal(db, view1.parentTrie) + + // Create a new view on the committed view + view2Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + // view2 + // | + // view1 (committed) + // | + // db + + require.Equal(db, view2.parentTrie) + require.Contains(db.childViews, view1) + require.Contains(db.childViews, view2) + require.Len(db.childViews, 2) + + // Make sure the new view has the right value + got, err := view2.GetValue(context.Background(), []byte{1}) + require.NoError(err) + require.Equal([]byte{1}, got) + + // Make another view + view3Intf, err := view2.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + // view3 + // | + // view2 + // | + // view1 (committed) + // | + // db + + require.Equal(view2, view3.parentTrie) + require.Contains(view2.childViews, view3) + require.Len(view2.childViews, 1) + require.Contains(db.childViews, view1) + require.Contains(db.childViews, view2) + require.Len(db.childViews, 2) + + // Commit view2 + err = view2.CommitToDB(context.Background()) + require.NoError(err) + + // view3 + // | + // view2 (committed) + // | + // view1 (committed) + // | + // db + + // Note that view2 being committed invalidates view1 + require.True(view1.invalidated) + require.Contains(db.childViews, view2) + require.Contains(db.childViews, view3) + require.Len(db.childViews, 2) + require.Equal(db, view3.parentTrie) + + // Commit view3 + err = view3.CommitToDB(context.Background()) + require.NoError(err) + + // view3 being committed invalidates view2 + require.True(view2.invalidated) + require.Contains(db.childViews, view3) + require.Len(db.childViews, 1) + require.Equal(db, view3.parentTrie) +} + +func TestTrieViewNewView(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a view + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + // Create a view atop view1 + view2Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + // view2 + // | + // view1 + // | + // db + + // Assert view2's parent is view1 + require.Equal(view1, view2.parentTrie) + require.Contains(view1.childViews, view2) + require.Len(view1.childViews, 1) + + // Commit view1 + err = view1.CommitToDB(context.Background()) + require.NoError(err) + + // Make another view atop view1 + view3Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + // view3 + // | + // view2 + // | + // view1 + // | + // db + + // Assert view3's parent is db + require.Equal(db, view3.parentTrie) + require.Contains(db.childViews, view3) + require.NotContains(view1.childViews, view3) + + // Assert that NewPreallocatedView on an invalid view fails + invalidView := &trieView{invalidated: true} + _, err = invalidView.NewView(context.Background()) + require.ErrorIs(err, ErrInvalid) +} + +func TestTrieViewInvalidate(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a view + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + // Create 2 views atop view1 + view2Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + view3Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + // view2 view3 + // | / + // view1 + // | + // db + + // Invalidate view1 + view1.invalidate() + + require.Empty(view1.childViews) + require.True(view1.invalidated) + require.True(view2.invalidated) + require.True(view3.invalidated) +} + +func TestTrieViewMoveChildViewsToView(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a view + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + // Create a view atop view1 + view2Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + // Create a view atop view2 + view3Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + // view3 + // | + // view2 + // | + // view1 + // | + // db + + view1.moveChildViewsToView(view2) + + require.Equal(view1, view3.parentTrie) + require.Contains(view1.childViews, view3) + require.Contains(view1.childViews, view2) + require.Len(view1.childViews, 2) +} + +func TestTrieViewInvalidChildrenExcept(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + // Create a view + view1Intf, err := db.NewView(context.Background()) + require.NoError(err) + view1, ok := view1Intf.(*trieView) + require.True(ok) + + // Create 2 views atop view1 + view2Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view2, ok := view2Intf.(*trieView) + require.True(ok) + + view3Intf, err := view1.NewView(context.Background()) + require.NoError(err) + view3, ok := view3Intf.(*trieView) + require.True(ok) + + view1.invalidateChildrenExcept(view2) + + require.False(view2.invalidated) + require.True(view3.invalidated) + require.Contains(view1.childViews, view2) + require.Len(view1.childViews, 1) + + view1.invalidateChildrenExcept(nil) + require.True(view2.invalidated) + require.True(view3.invalidated) + require.Empty(view1.childViews) +} diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index 3c2e16c15de8..7bd6a8ee1699 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -30,28 +30,63 @@ const ( ) var ( - ErrCommitted = errors.New("view has been committed") - ErrChangedBaseRoot = errors.New("the trie this view was based on has changed its root") - ErrEditLocked = errors.New( - "view has been edit locked. Any view generated from this view would be corrupted by edits", - ) + ErrCommitted = errors.New("view has been committed") + ErrInvalid = errors.New("the trie this view was based on has changed, rending this view invalid") ErrOddLengthWithValue = errors.New( "the underlying db only supports whole number of byte keys, so cannot record changes with odd nibble length", ) - ErrGetClosestNodeFailure = errors.New("GetClosestNode failed to return the closest node") - ErrStartAfterEnd = errors.New("start key > end key") + ErrGetPathToFailure = errors.New("GetPathTo failed to return the closest node") + ErrStartAfterEnd = errors.New("start key > end key") + ErrViewIsNotAChild = errors.New("passed in view is required to be a child of the current view") _ TrieView = &trieView{} numCPU = runtime.NumCPU() ) -// Editable view of a trie, collects changes on top of a base trie. +// Editable view of a trie, collects changes on top of a parent trie. // Delays adding key/value pairs to the trie. type trieView struct { - // Must be held when reading/writing fields. + // Must be held when reading/writing fields except + // [childViews] and [invalidated]. lock sync.Mutex + // Controls the trie's invalidation related fields. + // Must be held while reading/writing [childViews], [invalidated], and [parentTrie]. + // Must not grab the [lock] of this trie or any ancestor while this is held. + invalidationLock sync.RWMutex + + // If true, this view has been invalidated and can't be used. + // + // Invariant: This view is marked as invalid before any of its ancestors change. + // Since we hold locks on ancestors when query/modify them, we're + // guaranteed that no ancestor changes if this view is valid + // after we grab the view stack locks until we release them. + // Namely if we have a method with: + // + // t.lockStack() + // defer t.unlockStack() + // t.invalidationLock.Lock() + // if t.invalidated { + // t.invalidationLock.Unlock() + // return ErrInvalid + // } + // t.invalidationLock.Unlock() + // + // Then we're guaranteed no ancestor changes after the if statement + // and before the method returns. + // + // [invalidationLock] must be held when reading/writing this field. + invalidated bool + + // the uncommitted parent trie of this view + // [invalidationLock] must be held when reading/writing this field. + parentTrie Trie + + // The valid children of this trie. + // [invalidationLock] must be held when reading/writing this field. + childViews []*trieView + // Changes made to this view. // May include nodes that haven't been updated // but will when their ID is recalculated. @@ -63,18 +98,7 @@ type trieView struct { // A Nothing value indicates that the key has been removed. unappliedValueChanges map[path]Maybe[[]byte] - // The trie below this one in the current view stack. - // This is either [baseView] or [db]. - // Used to get information missing from the local view. - baseTrie Trie - - // The root of [db] when this view was created. - basedOnRoot ids.ID - db *Database - - // the view that this view is based upon (if it exists, nil otherwise). - // If non-nil, is [baseTrie]. - baseView *trieView + db *Database // The root of the trie represented by this view. root *node @@ -86,14 +110,11 @@ type trieView struct { // Calls to Insert and Remove will return ErrCommitted. committed bool - // If true, this view has been edit locked because another view - // exists atop it. - // Calls to Insert and Remove will return ErrEditLocked. - changeLocked bool estimatedSize int } // Returns a new view on top of this one. +// Adds the new view to [t.childViews]. // Assumes this view stack is unlocked. func (t *trieView) NewView(ctx context.Context) (TrieView, error) { return t.NewPreallocatedView(ctx, defaultPreallocationSize) @@ -101,47 +122,62 @@ func (t *trieView) NewView(ctx context.Context) (TrieView, error) { // Returns a new view on top of this one with memory allocated to store the // [estimatedChanges] number of key/value changes. +// If this view is already committed, the new view's parent will +// be set to the parent of the current view. +// Otherwise adds the new view to [t.childViews]. // Assumes this view stack is unlocked. func (t *trieView) NewPreallocatedView( ctx context.Context, estimatedChanges int, ) (TrieView, error) { - t.lockStack() - defer t.unlockStack() + if t.isInvalid() { + return nil, ErrInvalid + } + + // lock local trie view while checking for committed + t.lock.Lock() + defer t.lock.Unlock() + + if t.committed { + return t.getParentTrie().NewPreallocatedView(ctx, estimatedChanges) + } + + // lock the rest of the stack while generating + t.getParentTrie().lockStack() + defer t.getParentTrie().unlockStack() + + newView, err := newTrieView(ctx, t.db, t, nil, estimatedChanges) + if err != nil { + return nil, err + } - return newTrieView(ctx, t.db, t, nil, estimatedChanges) + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + + if t.invalidated { + return nil, ErrInvalid + } + t.childViews = append(t.childViews, newView) + + return newView, nil } -// Creates a new view atop the given [baseView]. -// If [baseView] is nil, the view is created atop [db]. -// If [baseView] isn't nil, sets [baseView.changeLocked] to true. +// Creates a new view with the given [parentTrie]. // If [changes] is nil, a new changeSummary is created. -// Assumes [db.lock] is read locked. -// Assumes [baseView] is nil or locked. +// Assumes [parentTrie] and its ancestors are read locked. func newTrieView( ctx context.Context, db *Database, - baseView *trieView, + parentTrie Trie, changes *changeSummary, estimatedSize int, ) (*trieView, error) { if changes == nil { changes = newChangeSummary(estimatedSize) } - - baseTrie := Trie(db) - if baseView != nil { - baseTrie = baseView - baseView.changeLocked = true - } - - baseRoot := db.getMerkleRoot() - result := &trieView{ db: db, - baseView: baseView, - baseTrie: baseTrie, - basedOnRoot: baseRoot, + parentTrie: parentTrie, changes: changes, estimatedSize: estimatedSize, unappliedValueChanges: make(map[path]Maybe[[]byte], estimatedSize), @@ -154,28 +190,20 @@ func newTrieView( // Write locks this view and read locks all views/the database below it. func (t *trieView) lockStack() { t.lock.Lock() - t.baseTrie.lockStack() + t.getParentTrie().lockStack() } func (t *trieView) unlockStack() { - t.baseTrie.unlockStack() + t.getParentTrie().unlockStack() t.lock.Unlock() } -// Calculates the IDs of all nodes in this trie. -func (t *trieView) CalculateIDs(ctx context.Context) error { - ctx, span := t.db.tracer.Start(ctx, "MerkleDB.trieview.CalculateIDs") - defer span.End() - - t.lockStack() - defer t.unlockStack() - - return t.calculateIDs(ctx) -} - // Recalculates the node IDs for all changed nodes in the trie. // Assumes this view stack is locked. func (t *trieView) calculateIDs(ctx context.Context) error { + if t.isInvalid() { + return ErrInvalid + } if !t.needsRecalculation { return nil } @@ -192,8 +220,8 @@ func (t *trieView) calculateIDs(ctx context.Context) error { defer span.End() // ensure that the view under this one is up to date before potentially pulling in nodes from it - if t.baseView != nil { - if err := t.baseView.calculateIDs(ctx); err != nil { + if t.parentTrie != nil { + if err := t.getParentTrie().calculateIDs(ctx); err != nil { return err } } @@ -215,6 +243,7 @@ func (t *trieView) calculateIDs(ctx context.Context) error { return err } t.needsRecalculation = false + t.changes.rootID = t.root.id return nil } @@ -431,51 +460,8 @@ func (t *trieView) getRangeProof( return &result, nil } -// Removes from the view stack views that have been committed or whose -// changes are already in the database. -// Returns true if [t]'s changes are already in the database. -// Assumes this view stack is locked. -func (t *trieView) cleanupCommittedViews(ctx context.Context) (bool, error) { - if t.committed { - return true, nil - } - - root, err := t.getMerkleRoot(ctx) - if err != nil { - return false, err - } - - if root == t.db.getMerkleRoot() { - // this view's root matches the db's root, so the changes in it are already in the db. - t.markViewStackCommitted() - return true, nil - } - - if t.baseView == nil { - // There are no views under this one so we're done cleaning the view stack. - return false, nil - } - - inDatabase, err := t.baseView.cleanupCommittedViews(ctx) - if err != nil { - return false, err - } - if !inDatabase { - // [t.baseView]'s changes aren't in the database yet - // so we can't remove our reference to it. - return false, nil - } - - // [t.baseView]'s changes are in the database, so we can remove our reference to it. - // We don't need to commit it to the database. - t.baseView = nil - // There's no view under this one, so we should read/write changes to the database. - t.baseTrie = t.db - return false, nil -} - // Commits changes from this trie to the underlying DB. -func (t *trieView) Commit(ctx context.Context) error { +func (t *trieView) CommitToDB(ctx context.Context) error { ctx, span := t.db.tracer.Start(ctx, "MerkleDB.trieview.Commit") defer span.End() @@ -490,17 +476,90 @@ func (t *trieView) Commit(ctx context.Context) error { // to modify [t.db]. No other view's call to lockStack() can proceed // until this method returns because we hold [t.db]'s write lock. - if err := t.validateDBRoot(ctx); err != nil { + return t.commitToDB(ctx, nil) +} + +// Adds the changes from [trieToCommit] to this trie. +// Assumes [trieToCommit] is a child of this trie. +// Assumes [t.db.lock] is held. +// Note this means [lockStack] is blocking for all other views. +func (t *trieView) commitChanges(ctx context.Context, trieToCommit *trieView) error { + _, span := t.db.tracer.Start(ctx, "MerkleDB.triview.commitChanges", oteltrace.WithAttributes( + attribute.Int("changeCount", len(t.changes.values)), + )) + defer span.End() + + switch { + case t.isInvalid(): + // don't apply changes to an invalid view + return ErrInvalid + case trieToCommit == nil: + // no changes to apply + return nil + case trieToCommit.parentTrie != t: + // trieToCommit needs to be a child of t, otherwise the changes merge would not work + return ErrViewIsNotAChild + case trieToCommit.isInvalid(): + // don't apply changes from an invalid view + return ErrInvalid + } + + // Invalidate all child views except the view being committed. + // Note that we invalidate children before modifying their ancestor [t] + // to uphold the invariant on [t.invalidated]. + t.invalidateChildrenExcept(trieToCommit) + + // ensure that the changes from the incoming trie are ready to be merged into the current trie. + // Note that we hold [db.lock] so no other thread can be modifying a trie, including [trieToCommit], + // since calls to [lockStack] will block. So it's safe to call [calculateIDs] here. + if err := trieToCommit.calculateIDs(ctx); err != nil { return err } - return t.commit(ctx) + // no changes in the trie, so there isn't anything to do + if len(t.changes.nodes) == 0 { + return nil + } + + for key, nodeChange := range trieToCommit.changes.nodes { + if existing, ok := t.changes.nodes[key]; ok { + existing.after = nodeChange.after + } else { + t.changes.nodes[key] = &change[*node]{ + before: nodeChange.before, + after: nodeChange.after, + } + } + } + + for key, valueChange := range trieToCommit.changes.values { + if existing, ok := t.changes.values[key]; ok { + existing.after = valueChange.after + } else { + t.changes.values[key] = &change[Maybe[[]byte]]{ + before: valueChange.before, + after: valueChange.after, + } + } + } + // update this view's root info to match the newly committed root + t.root = trieToCommit.changes.nodes[RootPath].after + t.changes.rootID = trieToCommit.changes.rootID + + // move the children from the incoming trieview to the current trieview + // do this after the current view has been updated + // this allows child views calls to their parent to remain consistent during the move + t.moveChildViewsToView(trieToCommit) + + return nil } -// Commits the changes from this trie to the underlying DB. -// Assumes [t.lock] and [t.db.lock] are held. -func (t *trieView) commit(ctx context.Context) error { - ctx, span := t.db.tracer.Start(ctx, "MerkleDB.triview.commit", oteltrace.WithAttributes( +// Commits the changes from [trieToCommit] to this view, +// this view to its parent, and so on until committing to the db. +// Assumes [t.db.lock] is held. +// Note this means [lockStack] is blocking for all other views. +func (t *trieView) commitToDB(ctx context.Context, trieToCommit *trieView) error { + ctx, span := t.db.tracer.Start(ctx, "MerkleDB.triview.commitToDB", oteltrace.WithAttributes( attribute.Int("changeCount", len(t.changes.values)), )) defer span.End() @@ -509,32 +568,107 @@ func (t *trieView) commit(ctx context.Context) error { return ErrCommitted } + // ensure all of this view's changes have been calculated if err := t.calculateIDs(ctx); err != nil { return err } - // ensure we don't recommit any committed tries - if alreadyCommitted, err := t.cleanupCommittedViews(ctx); alreadyCommitted || err != nil { + // overwrite this view with changes from the incoming view + if err := t.commitChanges(ctx, trieToCommit); err != nil { return err } - // commit [t.baseView] before committing the current view - if t.baseView != nil { - // We have [db.lock] here so [t.baseView] can't be changing. - if err := t.baseView.commit(ctx); err != nil { - return err - } - t.baseView = nil - t.baseTrie = t.db - } - - if err := t.db.commitChanges(ctx, t.changes); err != nil { + // pass the result onto the parent trie to merge and then commit to db + if err := t.getParentTrie().commitToDB(ctx, t); err != nil { return err } t.committed = true + + // now that this view is committed, all child views have been moved to the db, so none need to be tracked by this view + t.clearChildView() return nil } +// Assumes [t.invalidationLock] isn't held. +func (t *trieView) isInvalid() bool { + t.invalidationLock.RLock() + defer t.invalidationLock.RUnlock() + + return t.invalidated +} + +// Invalidates this view and all descendants. +// Assumes [t.invalidationLock] isn't held. +func (t *trieView) invalidate() { + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + + t.invalidated = true + + for _, childView := range t.childViews { + childView.invalidate() + } + + // after invalidating the children, they no longer need to be tracked + t.childViews = make([]*trieView, 0, defaultPreallocationSize) +} + +// Invalidates all children of this view. +// Assumes [t.invalidationLock] isn't held. +func (t *trieView) invalidateChildren() { + t.invalidateChildrenExcept(nil) +} + +// move any child views from the trieToCommit to the current trie view +// assumes that the [db.lock] is held +func (t *trieView) moveChildViewsToView(trieToCommit *trieView) { + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + + for _, childView := range trieToCommit.childViews { + childView.updateParent(t) + t.childViews = append(t.childViews, childView) + } +} + +func (t *trieView) updateParent(newParent Trie) { + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + + t.parentTrie = newParent +} + +// Removes all tracked child views from [childViews] +// Assumes [t.invalidationLock] isn't held. +func (t *trieView) clearChildView() { + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + + t.childViews = make([]*trieView, 0, defaultPreallocationSize) +} + +// Invalidates all children of this view except [exception]. +// [t.childViews] will only contain the exception after invalidation is complete. +// Assumes [t.invalidationLock] isn't held. +func (t *trieView) invalidateChildrenExcept(exception *trieView) { + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + + for _, childView := range t.childViews { + if childView != exception { + childView.invalidate() + } + } + + // after invalidating the children, they no longer need to be tracked + t.childViews = make([]*trieView, 0, defaultPreallocationSize) + + // add back in the exception view since it is still valid + if exception != nil { + t.childViews = append(t.childViews, exception) + } +} + // Returns the ID of the root of this trie. func (t *trieView) GetMerkleRoot(ctx context.Context) (ids.ID, error) { t.lockStack() @@ -554,7 +688,7 @@ func (t *trieView) getMerkleRoot(ctx context.Context) (ids.ID, error) { // Returns up to [maxLength] key/values from keys in closed range [start, end]. // Acts similarly to the merge step of a merge sort to combine state from the view -// with state from the base trie. +// with state from the parent trie. // Assumes this view stack is locked. func (t *trieView) getKeyValues( ctx context.Context, @@ -570,6 +704,10 @@ func (t *trieView) getKeyValues( return nil, fmt.Errorf("%w but was %d", ErrInvalidMaxLength, maxLength) } + if t.isInvalid() { + return nil, ErrInvalid + } + // collect all values that have changed or been deleted changes := make([]KeyValue, 0, len(t.changes.values)) for key, change := range t.changes.values { @@ -583,12 +721,12 @@ func (t *trieView) getKeyValues( }) } } - // sort [changes] so they can be merged with the base trie's state + // sort [changes] so they can be merged with the parent trie's state slices.SortFunc(changes, func(a, b KeyValue) bool { return bytes.Compare(a.Key, b.Key) == -1 }) - baseKeyValues, err := t.baseTrie.getKeyValues(ctx, start, end, maxLength, keysToIgnore) + baseKeyValues, err := t.getParentTrie().getKeyValues(ctx, start, end, maxLength, keysToIgnore) if err != nil { return nil, err } @@ -675,13 +813,6 @@ func (t *trieView) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []er results := make([][]byte, len(keys)) errors := make([]error, len(keys)) - if err := t.validateDBRoot(ctx); err != nil { - for i := range keys { - errors[i] = err - } - return results, errors - } - for i, key := range keys { results[i], errors[i] = t.getValue(ctx, newPath(key)) } @@ -694,14 +825,15 @@ func (t *trieView) GetValue(ctx context.Context, key []byte) ([]byte, error) { t.lockStack() defer t.unlockStack() - if err := t.validateDBRoot(ctx); err != nil { - return nil, err - } return t.getValue(ctx, newPath(key)) } // Assumes this view stack is locked. func (t *trieView) getValue(ctx context.Context, key path) ([]byte, error) { + if t.isInvalid() { + return nil, ErrInvalid + } + if change, ok := t.changes.values[key]; ok { t.db.metrics.ViewValueCacheHit() if change.after.IsNothing() { @@ -711,8 +843,8 @@ func (t *trieView) getValue(ctx context.Context, key path) ([]byte, error) { } t.db.metrics.ViewValueCacheMiss() - // if we don't have local copy of the key, then grab a copy from the base trie - value, err := t.baseTrie.getValue(ctx, key) + // if we don't have local copy of the key, then grab a copy from the parent trie + value, err := t.getParentTrie().getValue(ctx, key) if err != nil { return nil, err } @@ -728,14 +860,20 @@ func (t *trieView) Insert(ctx context.Context, key []byte, value []byte) error { } // Assumes this view stack is locked. +// Assumes [t.invalidationLock] isn't held. func (t *trieView) insert(ctx context.Context, key []byte, value []byte) error { if t.committed { return ErrCommitted } - if t.changeLocked { - return ErrEditLocked + if t.isInvalid() { + return ErrInvalid } + + // the trie has been changed, so invalidate all children and remove them from tracking + t.invalidateChildren() + valCopy := slices.Clone(value) + return t.recordValueChange(ctx, newPath(key), Some(valCopy)) } @@ -748,54 +886,20 @@ func (t *trieView) Remove(ctx context.Context, key []byte) error { } // Assumes this view stack is locked. +// Assumes [t.invalidationLock] isn't held. func (t *trieView) remove(ctx context.Context, key []byte) error { if t.committed { return ErrCommitted } - if t.changeLocked { - return ErrEditLocked - } - - return t.recordValueChange(ctx, newPath(key), Nothing[[]byte]()) -} - -// Returns nil iff at least one of the following is true: -// - The root of the db hasn't changed since this view was created. -// - This view's root is the same as the db's root. -// - This method returns nil for the view under this one. -// -// Assumes this view stack is locked. -func (t *trieView) validateDBRoot(ctx context.Context) error { - dbRoot := t.db.getMerkleRoot() - - // the root has not changed, so the trieview is still valid - if dbRoot == t.basedOnRoot { - return nil - } - - if t.baseView != nil { - // if the view that this view is based on is valid, - // then this view is valid too. - if err := t.baseView.validateDBRoot(ctx); err == nil { - return nil - } + if t.isInvalid() { + return ErrInvalid } - // this view has no base view or an invalid base view. - // calculate the current view's root and check if it matches the db. - localRoot, err := t.getMerkleRoot(ctx) - if err != nil { - return err - } - - // the roots don't match, which means that that the changes - // in this view aren't already represented in the db - if localRoot != dbRoot { - return ErrChangedBaseRoot - } + // the trie has been changed, so invalidate all children and remove them from tracking + t.invalidateChildren() - return nil + return t.recordValueChange(ctx, newPath(key), Nothing[[]byte]()) } // Assumes this view stack is locked. @@ -937,6 +1041,9 @@ func getLengthOfCommonPrefix(first, second path) int { // Assumes this view stack is locked. func (t *trieView) getNode(ctx context.Context, key path) (*node, error) { + if t.isInvalid() { + return nil, ErrInvalid + } if err := t.calculateIDs(ctx); err != nil { return nil, err } @@ -1032,9 +1139,9 @@ func (t *trieView) insertIntoTrie( // the existing child's key is of length: len(closestNodekey) + 1 for the child index + len(existing child's compressed key) // if that length is less than or equal to the branch node's key that implies that the existing child's key matched the key to be inserted - // since it matched the key to be inserted, it should have been returned by getClosestNode + // since it matched the key to be inserted, it should have been returned by GetPathTo if len(existingChildKey) <= len(branchNode.key) { - return nil, ErrGetClosestNodeFailure + return nil, ErrGetPathToFailure } branchNode.addChildWithoutNode( @@ -1046,16 +1153,6 @@ func (t *trieView) insertIntoTrie( return nodeWithValue, t.recordNodeChange(ctx, branchNode) } -// Mark this view and all views under this view as committed. -// Assumes this view stack is locked. -func (t *trieView) markViewStackCommitted() { - currentView := t - for currentView != nil { - currentView.committed = true - currentView = currentView.baseView - } -} - // Records that a node has been changed. // Assumes this view stack is locked. func (t *trieView) recordNodeChange(ctx context.Context, after *node) error { @@ -1082,7 +1179,7 @@ func (t *trieView) recordKeyChange(ctx context.Context, key path, after *node) e return nil } - before, err := t.baseTrie.getNode(ctx, key) + before, err := t.getParentTrie().getNode(ctx, key) if err != nil { if err != database.ErrNotFound { return err @@ -1116,7 +1213,7 @@ func (t *trieView) recordValueChange(ctx context.Context, key path, value Maybe[ // grab the before value var beforeMaybe Maybe[[]byte] - before, err := t.baseTrie.getValue(ctx, key) + before, err := t.getParentTrie().getValue(ctx, key) switch err { case nil: beforeMaybe = Some(before) @@ -1189,7 +1286,7 @@ func (t *trieView) getNodeFromParent(ctx context.Context, parent *node, key path } // Retrieves a node with the given [key]. -// If the node is fetched from [t.baseTrie] and [id] isn't empty, +// If the node is fetched from [t.parentTrie] and [id] isn't empty, // sets the node's ID to [id]. // Returns database.ErrNotFound if the node doesn't exist. // Assumes this view stack is locked. @@ -1203,19 +1300,25 @@ func (t *trieView) getNodeWithID(ctx context.Context, id ids.ID, key path) (*nod return nodeChange.after, nil } - // get the node from the base trie and store a localy copy - baseTrieNode, err := t.baseTrie.getNode(ctx, key) + // get the node from the parent trie and store a localy copy + parentTrieNode, err := t.getParentTrie().getNode(ctx, key) if err != nil { return nil, err } - // copy the node so any alterations to it don't affect the base trie - node := baseTrieNode.clone() + // copy the node so any alterations to it don't affect the parent trie + node := parentTrieNode.clone() - // only need to initialize the id if it's from the base trie. + // only need to initialize the id if it's from the parent trie. // nodes in the current view change list have already been initialized. if id != ids.Empty { node.id = id } return node, nil } + +func (t *trieView) getParentTrie() Trie { + t.invalidationLock.Lock() + defer t.invalidationLock.Unlock() + return t.parentTrie +}