diff --git a/core/state/access_list.go b/core/state/access_list.go index 37abe2ca0b..8e1ab70aeb 100644 --- a/core/state/access_list.go +++ b/core/state/access_list.go @@ -17,8 +17,11 @@ package state import ( + "fmt" "github.com/CortexFoundation/CortexTheseus/common" "maps" + "slices" + "strings" ) type accessList struct { @@ -129,3 +132,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) { func (al *accessList) DeleteAddress(address common.Address) { delete(al.addresses, address) } + +// Equal returns true if the two access lists are identical +func (al *accessList) Equal(other *accessList) bool { + if !maps.Equal(al.addresses, other.addresses) { + return false + } + return slices.EqualFunc(al.slots, other.slots, + func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool { + return maps.Equal(m, m2) + }) +} + +// PrettyPrint prints the contents of the access list in a human-readable form +func (al *accessList) PrettyPrint() string { + out := new(strings.Builder) + var sortedAddrs []common.Address + for addr := range al.addresses { + sortedAddrs = append(sortedAddrs, addr) + } + slices.SortFunc(sortedAddrs, common.Address.Cmp) + for _, addr := range sortedAddrs { + idx := al.addresses[addr] + fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx) + if idx >= 0 { + slotmap := al.slots[idx] + for h := range slotmap { + fmt.Fprintf(out, " %#x\n", h) + } + } + } + return out.String() +} diff --git a/core/state/state_object.go b/core/state/state_object.go index 746a79ee0a..2ebcb68bc1 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -445,22 +445,22 @@ func (s *stateObject) ReturnGas(gas *big.Int) {} func (s *stateObject) deepCopy(db *StateDB) *stateObject { obj := &stateObject{ - db: db, - address: s.address, - addrHash: s.addrHash, - origin: s.origin, - data: s.data, + db: db, + address: s.address, + addrHash: s.addrHash, + origin: s.origin, + data: s.data, + code: s.code, + originStorage: s.originStorage.Copy(), + pendingStorage: s.pendingStorage.Copy(), + dirtyStorage: s.dirtyStorage.Copy(), + dirtyCode: s.dirtyCode, + selfDestructed: s.selfDestructed, + newContract: s.newContract, } if s.trie != nil { obj.trie = db.db.CopyTrie(s.trie) } - obj.code = s.code - obj.originStorage = s.originStorage.Copy() - obj.pendingStorage = s.pendingStorage.Copy() - obj.dirtyStorage = s.dirtyStorage.Copy() - obj.dirtyCode = s.dirtyCode - obj.selfDestructed = s.selfDestructed - obj.newContract = s.newContract return obj } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 64630dce8b..1c06aad652 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -20,10 +20,12 @@ import ( "bytes" "encoding/binary" "fmt" + "maps" "math" "math/big" "math/rand" "reflect" + "slices" "strings" "testing" "testing/quick" @@ -406,6 +408,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check newContract-flag + if obj := state.getStateObject(addr); obj != nil { + checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract) + } // Check storage. if obj := state.getStateObject(addr); obj != nil { state.ForEachStorage(addr, func(key, val common.Hash) bool { @@ -414,12 +420,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool { return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) }) + other := checkstate.getStateObject(addr) + // Check dirty storage which is not in trie + if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) { + print := func(dirty map[common.Hash]common.Hash) string { + var keys []common.Hash + out := new(strings.Builder) + for key := range dirty { + keys = append(keys, key) + } + slices.SortFunc(keys, common.Hash.Cmp) + for i, key := range keys { + fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key]) + } + return out.String() + } + return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v", + print(obj.dirtyStorage), + print(other.dirtyStorage)) + } + } + // Check transient storage. + { + have := state.transientStorage + want := checkstate.transientStorage + eq := maps.EqualFunc(have, want, + func(a Storage, b Storage) bool { + return maps.Equal(a, b) + }) + if !eq { + return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v", + have.PrettyPrint(), + want.PrettyPrint()) + } } if err != nil { return err } } - + if !checkstate.accessList.Equal(state.accessList) { // Check access lists + return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v", + checkstate.accessList.PrettyPrint(), + state.accessList.PrettyPrint()) + } if state.GetRefund() != checkstate.GetRefund() { return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", state.GetRefund(), checkstate.GetRefund()) @@ -428,6 +471,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) } + if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) { + getKeys := func(dirty map[common.Address]int) string { + var keys []common.Address + out := new(strings.Builder) + for key := range dirty { + keys = append(keys, key) + } + slices.SortFunc(keys, common.Address.Cmp) + for i, key := range keys { + fmt.Fprintf(out, " %d. %v\n", i, key) + } + return out.String() + } + have := getKeys(state.journal.dirties) + want := getKeys(checkstate.journal.dirties) + return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want) + } return nil } diff --git a/core/state/transient_storage.go b/core/state/transient_storage.go index f32402aee1..8a5a5775b2 100644 --- a/core/state/transient_storage.go +++ b/core/state/transient_storage.go @@ -17,7 +17,10 @@ package state import ( + "fmt" "github.com/CortexFoundation/CortexTheseus/common" + "slices" + "strings" ) // transientStorage is a representation of EIP-1153 "Transient Storage". @@ -30,10 +33,19 @@ func newTransientStorage() transientStorage { // Set sets the transient-storage `value` for `key` at the given `addr`. func (t transientStorage) Set(addr common.Address, key, value common.Hash) { - if _, ok := t[addr]; !ok { - t[addr] = make(Storage) + if value == (common.Hash{}) { // this is a 'delete' + if _, ok := t[addr]; ok { + delete(t[addr], key) + if len(t[addr]) == 0 { + delete(t, addr) + } + } + } else { + if _, ok := t[addr]; !ok { + t[addr] = make(Storage) + } + t[addr][key] = value } - t[addr][key] = value } // Get gets the transient storage for `key` at the given `addr`. @@ -53,3 +65,27 @@ func (t transientStorage) Copy() transientStorage { } return storage } + +// PrettyPrint prints the contents of the access list in a human-readable form +func (t transientStorage) PrettyPrint() string { + out := new(strings.Builder) + var sortedAddrs []common.Address + for addr := range t { + sortedAddrs = append(sortedAddrs, addr) + slices.SortFunc(sortedAddrs, common.Address.Cmp) + } + + for _, addr := range sortedAddrs { + fmt.Fprintf(out, "%#x:", addr) + var sortedKeys []common.Hash + storage := t[addr] + for key := range storage { + sortedKeys = append(sortedKeys, key) + } + slices.SortFunc(sortedKeys, common.Hash.Cmp) + for _, key := range sortedKeys { + fmt.Fprintf(out, " %X : %X\n", key, storage[key]) + } + } + return out.String() +}