diff --git a/lib/trie/database.go b/lib/trie/database.go index 8218bd1b38..1f15b0159a 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -18,6 +18,7 @@ package trie import ( "bytes" + "errors" "fmt" "github.com/ChainSafe/gossamer/lib/common" @@ -25,6 +26,9 @@ import ( "github.com/ChainSafe/chaindb" ) +// ErrEmptyProof indicates the proof slice is empty +var ErrEmptyProof = errors.New("proof slice empty") + // Store stores each trie node in the database, where the key is the hash of the encoded node and the value is the encoded node. // Generally, this will only be used for the genesis trie. func (t *Trie) Store(db chaindb.Database) error { @@ -73,6 +77,64 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { return nil } +// LoadFromProof create a partial trie based on the proof slice, as it only contains nodes that are in the proof afaik. +func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { + if len(proof) == 0 { + return ErrEmptyProof + } + + mappedNodes := make(map[string]node, len(proof)) + + // map all the proofs hash -> decoded node + // and takes the loop to indentify the root node + for _, rawNode := range proof { + decNode, err := decodeBytes(rawNode) + if err != nil { + return err + } + + decNode.setDirty(false) + decNode.setEncodingAndHash(rawNode, nil) + + _, computedRoot, err := decNode.encodeAndHash() + if err != nil { + return err + } + + mappedNodes[common.BytesToHex(computedRoot)] = decNode + + if bytes.Equal(computedRoot, root) { + t.root = decNode + } + } + + t.loadProof(mappedNodes, t.root) + return nil +} + +// loadProof is a recursive function that will create all the trie paths based +// on the mapped proofs slice starting by the root +func (t *Trie) loadProof(proof map[string]node, curr node) { + c, ok := curr.(*branch) + if !ok { + return + } + + for i, child := range c.children { + if child == nil { + continue + } + + proofNode, ok := proof[common.BytesToHex(child.getHash())] + if !ok { + continue + } + + c.children[i] = proofNode + t.loadProof(proof, proofNode) + } +} + // Load reconstructs the trie from the database from the given root hash. Used when restarting the node to load the current state trie. func (t *Trie) Load(db chaindb.Database, root common.Hash) error { if root == EmptyHash { diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index dd8600963e..67f2792692 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -2,86 +2,37 @@ package trie import ( "bytes" - "errors" - - "github.com/ChainSafe/chaindb" -) - -var ( - // ErrProofNodeNotFound when a needed proof node is not in the database - ErrProofNodeNotFound = errors.New("cannot find a trie node in the database") ) -// lookup struct holds the state root and database reference -// used to retrieve trie information from database -type lookup struct { - // root to start the lookup - root []byte - db chaindb.Database +// findAndRecord search for a desired key recording all the nodes in the path including the desired node +func findAndRecord(t *Trie, key []byte, recorder *recorder) error { + return find(t.root, key, recorder) } -// newLookup returns a Lookup to helps the proof generator -func newLookup(rootHash []byte, db chaindb.Database) *lookup { - lk := &lookup{db: db} - lk.root = make([]byte, len(rootHash)) - copy(lk.root, rootHash) - - return lk -} - -// find will return the desired value or nil if key cannot be found and will record visited nodes -func (l *lookup) find(key []byte, recorder *recorder) ([]byte, error) { - partial := key - hash := l.root - - for { - nodeData, err := l.db.Get(hash) - if err != nil { - return nil, ErrProofNodeNotFound - } - - nodeHash := make([]byte, len(hash)) - copy(nodeHash, hash) - - recorder.record(nodeHash, nodeData) - - decoded, err := decodeBytes(nodeData) - if err != nil { - return nil, err - } +func find(parent node, key []byte, recorder *recorder) error { + enc, hash, err := parent.encodeAndHash() + if err != nil { + return err + } - switch currNode := decoded.(type) { - case nil: - return nil, nil + recorder.record(hash, enc) - case *leaf: - if bytes.Equal(currNode.key, partial) { - return currNode.value, nil - } - return nil, nil + b, ok := parent.(*branch) + if !ok { + return nil + } - case *branch: - switch len(partial) { - case 0: - return currNode.value, nil - default: - if !bytes.HasPrefix(partial, currNode.key) { - return nil, nil - } + length := lenCommonPrefix(b.key, key) - if bytes.Equal(partial, currNode.key) { - return currNode.value, nil - } + // found the value at this node + if bytes.Equal(b.key, key) || len(key) == 0 { + return nil + } - length := lenCommonPrefix(currNode.key, partial) - switch child := currNode.children[partial[length]].(type) { - case nil: - return nil, nil - default: - partial = partial[length+1:] - copy(hash, child.getHash()) - } - } - } + // did not find value + if bytes.Equal(b.key[:length], key) && len(key) < len(b.key) { + return nil } + + return find(b.children[key[length]], key[length+1:], recorder) } diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 7668b69df8..a4a83b919f 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -17,7 +17,10 @@ package trie import ( + "bytes" + "encoding/hex" "errors" + "fmt" "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" @@ -26,6 +29,15 @@ import ( var ( // ErrEmptyTrieRoot occurs when trying to craft a prove with an empty trie root ErrEmptyTrieRoot = errors.New("provided trie must have a root") + + // ErrValueNotFound indicates that a returned verify proof value doesnt match with the expected value on items array + ErrValueNotFound = errors.New("expected value not found in the trie") + + // ErrDuplicateKeys not allowed to verify proof with duplicate keys + ErrDuplicateKeys = errors.New("duplicate keys on verify proof") + + // ErrLoadFromProof occurs when there are problems with the proof slice while building the partial proof trie + ErrLoadFromProof = errors.New("failed to build the proof trie") ) // GenerateProof receive the keys to proof, the trie root and a reference to database @@ -33,13 +45,16 @@ var ( func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, error) { trackedProofs := make(map[string][]byte) + proofTrie := NewEmptyTrie() + if err := proofTrie.Load(db, common.BytesToHash(root)); err != nil { + return nil, err + } + for _, k := range keys { nk := keyToNibbles(k) - lookup := newLookup(root, db) recorder := new(recorder) - - _, err := lookup.find(nk, recorder) + err := findAndRecord(proofTrie, nk, recorder) if err != nil { return nil, err } @@ -54,10 +69,43 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e } proofs := make([][]byte, 0) - for _, p := range trackedProofs { proofs = append(proofs, p) } return proofs, nil } + +// Pair holds the key and value to check while verifying the proof +type Pair struct{ Key, Value []byte } + +// VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice +// this function ignores the order of proofs +func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { + set := make(map[string]struct{}, len(items)) + + // check for duplicate keys + for _, item := range items { + hexKey := hex.EncodeToString(item.Key) + if _, ok := set[hexKey]; ok { + return false, ErrDuplicateKeys + } + set[hexKey] = struct{}{} + } + + proofTrie := NewEmptyTrie() + if err := proofTrie.LoadFromProof(proof, root); err != nil { + return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) + } + + for _, item := range items { + recValue := proofTrie.Get(item.Key) + + // here we need to compare value only if the caller pass the value + if item.Value != nil && !bytes.Equal(item.Value, recValue) { + return false, ErrValueNotFound + } + } + + return true, nil +} diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 2340395363..72a42ad81b 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -25,6 +25,7 @@ import ( ) func TestProofGeneration(t *testing.T) { + t.Parallel() tmp, err := ioutil.TempDir("", "*-test-trie") require.NoError(t, err) @@ -34,10 +35,12 @@ func TestProofGeneration(t *testing.T) { }) require.NoError(t, err) + expectedValue := rand32Bytes() + trie := NewEmptyTrie() trie.Put([]byte("cat"), rand32Bytes()) trie.Put([]byte("catapulta"), rand32Bytes()) - trie.Put([]byte("catapora"), rand32Bytes()) + trie.Put([]byte("catapora"), expectedValue) trie.Put([]byte("dog"), rand32Bytes()) trie.Put([]byte("doguinho"), rand32Bytes()) @@ -50,6 +53,123 @@ func TestProofGeneration(t *testing.T) { proof, err := GenerateProof(hash.ToBytes(), [][]byte{[]byte("catapulta"), []byte("catapora")}, memdb) require.NoError(t, err) - // TODO: use the verify_proof function to assert the tests (#1790) require.Equal(t, 5, len(proof)) + + pl := []Pair{ + {Key: []byte("catapora"), Value: expectedValue}, + } + + v, err := VerifyProof(proof, hash.ToBytes(), pl) + require.True(t, v) + require.NoError(t, err) +} + +func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][]byte, []Pair) { + t.Helper() + + tmp, err := ioutil.TempDir("", "*-test-trie") + require.NoError(t, err) + + memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ + InMemory: true, + DataDir: tmp, + }) + require.NoError(t, err) + + trie := NewEmptyTrie() + for _, e := range entries { + trie.Put(e.Key, e.Value) + } + + err = trie.Store(memdb) + require.NoError(t, err) + + root := trie.root.getHash() + proof, err := GenerateProof(root, keys, memdb) + require.NoError(t, err) + + items := make([]Pair, len(keys)) + for idx, key := range keys { + value := trie.Get(key) + require.NotNil(t, value) + + itemFromDB := Pair{ + Key: key, + Value: value, + } + items[idx] = itemFromDB + } + + return root, proof, items +} + +func TestVerifyProof_ShouldReturnTrue(t *testing.T) { + t.Parallel() + + entries := []Pair{ + {Key: []byte("alpha"), Value: make([]byte, 32)}, + {Key: []byte("bravo"), Value: []byte("bravo")}, + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("dog"), Value: []byte("puppy")}, + {Key: []byte("doge"), Value: make([]byte, 32)}, + {Key: []byte("horse"), Value: []byte("stallion")}, + {Key: []byte("house"), Value: []byte("building")}, + } + + keys := [][]byte{ + []byte("do"), + []byte("dog"), + []byte("doge"), + } + + root, proof, pl := testGenerateProof(t, entries, keys) + + v, err := VerifyProof(proof, root, pl) + require.True(t, v) + require.NoError(t, err) +} + +func TestVerifyProof_ShouldReturnDuplicateKeysError(t *testing.T) { + t.Parallel() + + pl := []Pair{ + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("do"), Value: []byte("puppy")}, + } + + v, err := VerifyProof([][]byte{}, []byte{}, pl) + require.False(t, v) + require.Error(t, err, ErrDuplicateKeys) +} + +func TestVerifyProof_ShouldReturnTrueWithouCompareValues(t *testing.T) { + t.Parallel() + + entries := []Pair{ + {Key: []byte("alpha"), Value: make([]byte, 32)}, + {Key: []byte("bravo"), Value: []byte("bravo")}, + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("dog"), Value: []byte("puppy")}, + {Key: []byte("doge"), Value: make([]byte, 32)}, + {Key: []byte("horse"), Value: []byte("stallion")}, + {Key: []byte("house"), Value: []byte("building")}, + } + + keys := [][]byte{ + []byte("do"), + []byte("dog"), + []byte("doge"), + } + + root, proof, _ := testGenerateProof(t, entries, keys) + + pl := []Pair{ + {Key: []byte("do"), Value: nil}, + {Key: []byte("dog"), Value: nil}, + {Key: []byte("doge"), Value: nil}, + } + + v, err := VerifyProof(proof, root, pl) + require.True(t, v) + require.NoError(t, err) } diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go index 7c2b9a40c9..5443e55401 100644 --- a/lib/trie/recorder.go +++ b/lib/trie/recorder.go @@ -6,15 +6,15 @@ type nodeRecord struct { hash []byte } -// Recorder keeps the list of nodes find by Lookup.Find +// recorder keeps the list of nodes find by Lookup.Find type recorder []nodeRecord -// Record insert a node insede the recorded list +// record insert a node inside the recorded list func (r *recorder) record(h, rd []byte) { *r = append(*r, nodeRecord{rawData: rd, hash: h}) } -// Next returns the current item the cursor is on and increment the cursor by 1 +// next returns the current item the cursor is on and increment the cursor by 1 func (r *recorder) next() *nodeRecord { if !r.isEmpty() { n := (*r)[0] @@ -25,7 +25,7 @@ func (r *recorder) next() *nodeRecord { return nil } -// IsEmpty returns bool if there is data inside the slice +// isEmpty returns bool if there is data inside the slice func (r *recorder) isEmpty() bool { return len(*r) <= 0 }