diff --git a/core/bloombits/matcher.go b/core/bloombits/matcher.go index e33de018a..32a660337 100644 --- a/core/bloombits/matcher.go +++ b/core/bloombits/matcher.go @@ -18,6 +18,7 @@ package bloombits import ( "bytes" + "context" "errors" "math" "sort" @@ -60,6 +61,8 @@ type Retrieval struct { Bit uint Sections []uint64 Bitsets [][]byte + Error error + Context context.Context } // Matcher is a pipelined system of schedulers and logic matchers which perform @@ -137,7 +140,7 @@ func (m *Matcher) addScheduler(idx uint) { // Start starts the matching process and returns a stream of bloom matches in // a given range of blocks. If there are no more matches in the range, the result // channel is closed. -func (m *Matcher) Start(begin, end uint64, results chan uint64) (*MatcherSession, error) { +func (m *Matcher) Start(ctx context.Context, begin, end uint64, results chan uint64) (*MatcherSession, error) { // Make sure we're not creating concurrent sessions if atomic.SwapUint32(&m.running, 1) == 1 { return nil, errors.New("matcher already running") @@ -149,6 +152,7 @@ func (m *Matcher) Start(begin, end uint64, results chan uint64) (*MatcherSession matcher: m, quit: make(chan struct{}), kill: make(chan struct{}), + ctx: ctx, } for _, scheduler := range m.schedulers { scheduler.reset() @@ -502,15 +506,28 @@ func (m *Matcher) distributor(dist chan *request, session *MatcherSession) { type MatcherSession struct { matcher *Matcher - quit chan struct{} // Quit channel to request pipeline termination - kill chan struct{} // Term channel to signal non-graceful forced shutdown - pend sync.WaitGroup + quit chan struct{} // Quit channel to request pipeline termination + kill chan struct{} // Term channel to signal non-graceful forced shutdown + ctx context.Context + err error + stopping bool + lock sync.Mutex + pend sync.WaitGroup } // Close stops the matching process and waits for all subprocesses to terminate // before returning. The timeout may be used for graceful shutdown, allowing the // currently running retrievals to complete before this time. -func (s *MatcherSession) Close(timeout time.Duration) { +func (s *MatcherSession) Close() { + s.lock.Lock() + stopping := s.stopping + s.stopping = true + s.lock.Unlock() + // ensure that we only close the session once + if stopping { + return + } + // Bail out if the matcher is not running select { case <-s.quit: @@ -519,10 +536,26 @@ func (s *MatcherSession) Close(timeout time.Duration) { } // Signal termination and wait for all goroutines to tear down close(s.quit) - time.AfterFunc(timeout, func() { close(s.kill) }) + time.AfterFunc(time.Second, func() { close(s.kill) }) s.pend.Wait() } +// setError sets an error and stops the session +func (s *MatcherSession) setError(err error) { + s.lock.Lock() + s.err = err + s.lock.Unlock() + s.Close() +} + +// Error returns an error if one has happened during the session +func (s *MatcherSession) Error() error { + s.lock.Lock() + defer s.lock.Unlock() + + return s.err +} + // AllocateRetrieval assigns a bloom bit index to a client process that can either // immediately reuest and fetch the section contents assigned to this bit or wait // a little while for more sections to be requested. @@ -618,9 +651,13 @@ func (s *MatcherSession) Multiplex(batch int, wait time.Duration, mux chan chan case mux <- request: // Retrieval accepted, something must arrive before we're aborting - request <- &Retrieval{Bit: bit, Sections: sections} + request <- &Retrieval{Bit: bit, Sections: sections, Context: s.ctx} result := <-request + if result.Error != nil { + s.setError(result.Error) + } + s.DeliverSections(result.Bit, result.Sections, result.Bitsets) } } diff --git a/core/bloombits/matcher_test.go b/core/bloombits/matcher_test.go index 2e15e7aac..0d8544136 100644 --- a/core/bloombits/matcher_test.go +++ b/core/bloombits/matcher_test.go @@ -17,6 +17,7 @@ package bloombits import ( + "context" "math/rand" "sync/atomic" "testing" @@ -144,7 +145,7 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, blocks uint64, intermitt quit := make(chan struct{}) matches := make(chan uint64, 16) - session, err := matcher.Start(0, blocks-1, matches) + session, err := matcher.Start(context.Background(), 0, blocks-1, matches) if err != nil { t.Fatalf("failed to stat matcher session: %v", err) } @@ -163,13 +164,13 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, blocks uint64, intermitt } // If we're testing intermittent mode, abort and restart the pipeline if intermittent { - session.Close(time.Second) + session.Close() close(quit) quit = make(chan struct{}) matches = make(chan uint64, 16) - session, err = matcher.Start(i+1, blocks-1, matches) + session, err = matcher.Start(context.Background(), i+1, blocks-1, matches) if err != nil { t.Fatalf("failed to stat matcher session: %v", err) } @@ -183,7 +184,7 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, blocks uint64, intermitt t.Errorf("filter = %v blocks = %v intermittent = %v: expected closed channel, got #%v", filter, blocks, intermittent, match) } // Clean up the session and ensure we match the expected retrieval count - session.Close(time.Second) + session.Close() close(quit) if retrievals != 0 && requested != retrievals { diff --git a/core/chain_indexer.go b/core/chain_indexer.go index f4c207dcc..837c908ab 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -36,13 +36,14 @@ import ( type ChainIndexerBackend interface { // Reset initiates the processing of a new chain segment, potentially terminating // any partially completed operations (in case of a reorg). - Reset(section uint64) + Reset(section uint64, lastSectionHead common.Hash) error // Process crunches through the next header in the chain segment. The caller // will ensure a sequential order of headers. Process(header *types.Header) - // Commit finalizes the section metadata and stores it into the database. + // Commit finalizes the section metadata and stores it into the database. This + // interface will usually be a batch writer. Commit() error } @@ -100,11 +101,34 @@ func NewChainIndexer(chainDb, indexDb ethdb.Database, backend ChainIndexerBacken return c } +// AddKnownSectionHead marks a new section head as known/processed if it is newer +// than the already known best section head +func (c *ChainIndexer) AddKnownSectionHead(section uint64, shead common.Hash) { + c.lock.Lock() + defer c.lock.Unlock() + + if section < c.storedSections { + return + } + c.setSectionHead(section, shead) + c.setValidSections(section + 1) +} + +// IndexerChain interface is used for connecting the indexer to a blockchain +type IndexerChain interface { + CurrentHeader() *types.Header + SubscribeChainEvent(ch chan<- ChainEvent) event.Subscription +} + // Start creates a goroutine to feed chain head events into the indexer for // cascading background processing. Children do not need to be started, they // are notified about new events by their parents. -func (c *ChainIndexer) Start(currentHeader *types.Header, chainEventer func(ch chan<- ChainEvent) event.Subscription) { - go c.eventLoop(currentHeader, chainEventer) +func (c *ChainIndexer) Start(chain IndexerChain) { + ch := make(chan ChainEvent, 10) + sub := chain.SubscribeChainEvent(ch) + currentHeader := chain.CurrentHeader() + + go c.eventLoop(currentHeader, ch, sub) } // Close tears down all goroutines belonging to the indexer and returns any error @@ -125,12 +149,14 @@ func (c *ChainIndexer) Close() error { errs = append(errs, err) } } + // Close all children for _, child := range c.children { if err := child.Close(); err != nil { errs = append(errs, err) } } + // Return any failures switch { case len(errs) == 0: @@ -147,12 +173,10 @@ func (c *ChainIndexer) Close() error { // eventLoop is a secondary - optional - event loop of the indexer which is only // started for the outermost indexer to push chain head events into a processing // queue. -func (c *ChainIndexer) eventLoop(currentHeader *types.Header, chainEventer func(ch chan<- ChainEvent) event.Subscription) { +func (c *ChainIndexer) eventLoop(currentHeader *types.Header, ch chan ChainEvent, sub event.Subscription) { // Mark the chain indexer as active, requiring an additional teardown atomic.StoreUint32(&c.active, 1) - events := make(chan ChainEvent, 10) - sub := chainEventer(events) defer sub.Unsubscribe() // Fire the initial new head event to start any outstanding processing @@ -169,7 +193,7 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, chainEventer func( errc <- nil return - case ev, ok := <-events: + case ev, ok := <-ch: // Received a new event, ensure it's not nil (closing) and update if !ok { errc := <-c.quit @@ -178,7 +202,9 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, chainEventer func( } header := ev.Block.Header() if header.ParentHash != prevHash { - c.newHead(FindCommonAncestor(c.chainDb, prevHeader, header).Number.Uint64(), true) + if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { + c.newHead(h.Number.Uint64(), true) + } } c.newHead(header.Number.Uint64(), false) @@ -233,9 +259,10 @@ func (c *ChainIndexer) newHead(head uint64, reorg bool) { // down into the processing backend. func (c *ChainIndexer) updateLoop() { var ( - updating bool - updated time.Time + updated time.Time + updateMsg bool ) + for { select { case errc := <-c.quit: @@ -250,7 +277,7 @@ func (c *ChainIndexer) updateLoop() { // Periodically print an upgrade log message to the user if time.Since(updated) > 8*time.Second { if c.knownSections > c.storedSections+1 { - updating = true + updateMsg = true c.log.Info("Upgrading chain index", "percentage", c.storedSections*100/c.knownSections) } updated = time.Now() @@ -259,7 +286,7 @@ func (c *ChainIndexer) updateLoop() { section := c.storedSections var oldHead common.Hash if section > 0 { - oldHead = c.sectionHead(section - 1) + oldHead = c.SectionHead(section - 1) } // Process the newly defined section in the background c.lock.Unlock() @@ -270,11 +297,11 @@ func (c *ChainIndexer) updateLoop() { c.lock.Lock() // If processing succeeded and no reorgs occcurred, mark the section completed - if err == nil && oldHead == c.sectionHead(section-1) { + if err == nil && oldHead == c.SectionHead(section-1) { c.setSectionHead(section, newHead) c.setValidSections(section + 1) - if c.storedSections == c.knownSections && updating { - updating = false + if c.storedSections == c.knownSections && updateMsg { + updateMsg = false c.log.Info("Finished upgrading chain index") } @@ -311,7 +338,11 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com c.log.Trace("Processing new chain section", "section", section) // Reset and partial processing - c.backend.Reset(section) + + if err := c.backend.Reset(section, lastHead); err != nil { + c.setValidSections(0) + return common.Hash{}, err + } for number := section * c.sectionSize; number < (section+1)*c.sectionSize; number++ { hash := GetCanonicalHash(c.chainDb, number) @@ -341,7 +372,7 @@ func (c *ChainIndexer) Sections() (uint64, uint64, common.Hash) { c.lock.Lock() defer c.lock.Unlock() - return c.storedSections, c.storedSections*c.sectionSize - 1, c.sectionHead(c.storedSections - 1) + return c.storedSections, c.storedSections*c.sectionSize - 1, c.SectionHead(c.storedSections - 1) } // AddChildIndexer adds a child ChainIndexer that can use the output of this one @@ -383,7 +414,7 @@ func (c *ChainIndexer) setValidSections(sections uint64) { // sectionHead retrieves the last block hash of a processed section from the // index database. -func (c *ChainIndexer) sectionHead(section uint64) common.Hash { +func (c *ChainIndexer) SectionHead(section uint64) common.Hash { var data [8]byte binary.BigEndian.PutUint64(data[:], section) diff --git a/core/chain_indexer_test.go b/core/chain_indexer_test.go index b761e8a5b..d685d3f8d 100644 --- a/core/chain_indexer_test.go +++ b/core/chain_indexer_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" ) @@ -208,9 +209,10 @@ func (b *testChainIndexBackend) reorg(headNum uint64) uint64 { return b.stored * b.indexer.sectionSize } -func (b *testChainIndexBackend) Reset(section uint64) { +func (b *testChainIndexBackend) Reset(section uint64, lastSectionHead common.Hash) error { b.section = section b.headerCnt = 0 + return nil } func (b *testChainIndexBackend) Process(header *types.Header) { diff --git a/core/database_util.go b/core/database_util.go index 1730a048e..c6b125dae 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -74,9 +74,9 @@ var ( preimageHitCounter = metrics.NewCounter("db/preimage/hits") ) -// txLookupEntry is a positional metadata to help looking up the data content of +// TxLookupEntry is a positional metadata to help looking up the data content of // a transaction or receipt given only its hash. -type txLookupEntry struct { +type TxLookupEntry struct { BlockHash common.Hash BlockIndex uint64 Index uint64 @@ -260,7 +260,7 @@ func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, return common.Hash{}, 0, 0 } // Parse and return the contents of the lookup entry - var entry txLookupEntry + var entry TxLookupEntry if err := rlp.DecodeBytes(data, &entry); err != nil { log.Error("Invalid lookup entry RLP", "hash", hash, "err", err) return common.Hash{}, 0, 0 @@ -296,7 +296,7 @@ func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, co if len(data) == 0 { return nil, common.Hash{}, 0, 0 } - var entry txLookupEntry + var entry TxLookupEntry if err := rlp.DecodeBytes(data, &entry); err != nil { return nil, common.Hash{}, 0, 0 } @@ -332,14 +332,13 @@ func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Has // GetBloomBits retrieves the compressed bloom bit vector belonging to the given // section and bit index from the. -func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) []byte { +func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...) binary.BigEndian.PutUint16(key[1:], uint16(bit)) binary.BigEndian.PutUint64(key[3:], section) - bits, _ := db.Get(key) - return bits + return db.Get(key) } // WriteCanonicalHash stores the canonical hash for the given block number. @@ -465,7 +464,7 @@ func WriteBlockReceipts(db ethdb.Putter, hash common.Hash, number uint64, receip func WriteTxLookupEntries(db ethdb.Putter, block *types.Block) error { // Iterate over each transaction and encode its metadata for i, tx := range block.Transactions() { - entry := txLookupEntry{ + entry := TxLookupEntry{ BlockHash: block.Hash(), BlockIndex: block.NumberU64(), Index: uint64(i), diff --git a/core/tx_list.go b/core/tx_list.go index 2935929d7..94721aa5f 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -384,13 +384,13 @@ func (h *priceHeap) Pop() interface{} { // txPricedList is a price-sorted heap to allow operating on transactions pool // contents in a price-incrementing way. type txPricedList struct { - all *map[common.Hash]*types.Transaction // Pointer to the map of all transactions - items *priceHeap // Heap of prices of all the stored transactions - stales int // Number of stale price points to (re-heap trigger) + all *map[common.Hash]txLookupRec // Pointer to the map of all transactions + items *priceHeap // Heap of prices of all the stored transactions + stales int // Number of stale price points to (re-heap trigger) } // newTxPricedList creates a new price-sorted transaction heap. -func newTxPricedList(all *map[common.Hash]*types.Transaction) *txPricedList { +func newTxPricedList(all *map[common.Hash]txLookupRec) *txPricedList { return &txPricedList{ all: all, items: new(priceHeap), @@ -416,7 +416,7 @@ func (l *txPricedList) Removed() { l.stales, l.items = 0, &reheap for _, tx := range *l.all { - *l.items = append(*l.items, tx) + *l.items = append(*l.items, tx.tx) } heap.Init(l.items) } diff --git a/core/tx_pool.go b/core/tx_pool.go index a705e36d6..5fdc91e65 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -192,17 +192,22 @@ type TxPool struct { locals *accountSet // Set of local transaction to exepmt from evicion rules journal *txJournal // Journal of local transaction to back up to disk - pending map[common.Address]*txList // All currently processable transactions - queue map[common.Address]*txList // Queued but non-processable transactions - beats map[common.Address]time.Time // Last heartbeat from each known account - all map[common.Hash]*types.Transaction // All transactions to allow lookups - priced *txPricedList // All transactions sorted by price + pending map[common.Address]*txList // All currently processable transactions + queue map[common.Address]*txList // Queued but non-processable transactions + beats map[common.Address]time.Time // Last heartbeat from each known account + all map[common.Hash]txLookupRec // All transactions to allow lookups + priced *txPricedList // All transactions sorted by price wg sync.WaitGroup // for shutdown sync homestead bool } +type txLookupRec struct { + tx *types.Transaction + pending bool +} + // NewTxPool creates a new transaction pool to gather, sort and filter inbound // trnsactions from the network. func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain blockChain) *TxPool { @@ -218,7 +223,7 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block pending: make(map[common.Address]*txList), queue: make(map[common.Address]*txList), beats: make(map[common.Address]time.Time), - all: make(map[common.Hash]*types.Transaction), + all: make(map[common.Hash]txLookupRec), chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize), gasPrice: new(big.Int).SetUint64(config.PriceLimit), } @@ -594,7 +599,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { // If the transaction is already known, discard it hash := tx.Hash() - if pool.all[hash] != nil { + if _, ok := pool.all[hash]; ok { log.Trace("Discarding already known transaction", "hash", hash) return false, fmt.Errorf("known transaction: %x", hash) } @@ -635,7 +640,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { pool.priced.Removed() pendingReplaceCounter.Inc(1) } - pool.all[tx.Hash()] = tx + pool.all[tx.Hash()] = txLookupRec{tx, false} pool.priced.Put(tx) pool.journalTx(from, tx) @@ -682,7 +687,7 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er pool.priced.Removed() queuedReplaceCounter.Inc(1) } - pool.all[hash] = tx + pool.all[hash] = txLookupRec{tx, false} pool.priced.Put(tx) return old != nil, nil } @@ -725,10 +730,13 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T pendingReplaceCounter.Inc(1) } - // Failsafe to work around direct pending inserts (tests) - if pool.all[hash] == nil { - pool.all[hash] = tx + if pool.all[hash].tx == nil { + // Failsafe to work around direct pending inserts (tests) + pool.all[hash] = txLookupRec{tx, true} pool.priced.Put(tx) + } else { + // set pending flag to true + pool.all[hash] = txLookupRec{tx, true} } // Set the potentially new pending nonce and notify any subsystems of the new tx pool.beats[addr] = time.Now() @@ -755,14 +763,16 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error { // marking the senders as a local ones in the mean time, ensuring they go around // the local pricing constraints. func (pool *TxPool) AddLocals(txs []*types.Transaction) error { - return pool.addTxs(txs, !pool.config.NoLocals) + pool.addTxs(txs, !pool.config.NoLocals) + return nil } // AddRemotes enqueues a batch of transactions into the pool if they are valid. // If the senders are not among the locally tracked ones, full pricing constraints // will apply. func (pool *TxPool) AddRemotes(txs []*types.Transaction) error { - return pool.addTxs(txs, false) + pool.addTxs(txs, false) + return nil } // addTx enqueues a single transaction into the pool if it is valid. @@ -784,7 +794,7 @@ func (pool *TxPool) addTx(tx *types.Transaction, local bool) error { } // addTxs attempts to queue a batch of transactions if they are valid. -func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) error { +func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) []error { pool.mu.Lock() defer pool.mu.Unlock() @@ -793,11 +803,13 @@ func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) error { // addTxsLocked attempts to queue a batch of transactions if they are valid, // whilst assuming the transaction pool lock is already held. -func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) error { +func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error { // Add the batch of transaction, tracking the accepted ones dirty := make(map[common.Address]struct{}) - for _, tx := range txs { - if replace, err := pool.add(tx, local); err == nil { + txErr := make([]error, len(txs)) + for i, tx := range txs { + var replace bool + if replace, txErr[i] = pool.add(tx, local); txErr[i] == nil { if !replace { from, _ := types.Sender(pool.signer, tx) // already validated dirty[from] = struct{}{} @@ -812,7 +824,58 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) error { } pool.promoteExecutables(addrs) } - return nil + return txErr +} + +// TxStatusData is returned by AddOrGetTxStatus for each transaction +type TxStatusData struct { + Status uint + Data []byte +} + +const ( + TxStatusUnknown = iota + TxStatusQueued + TxStatusPending + TxStatusIncluded // Data contains a TxChainPos struct + TxStatusError // Data contains the error string +) + +// AddOrGetTxStatus returns the status (unknown/pending/queued) of a batch of transactions +// identified by their hashes in txHashes. Optionally the transactions themselves can be +// passed too in txs, in which case the function will try adding the previously unknown ones +// to the pool. If a new transaction cannot be added, TxStatusError is returned. Adding already +// known transactions will return their previous status. +// If txs is specified, txHashes is still required and has to match the transactions in txs. + +// Note: TxStatusIncluded is never returned by this function since the pool does not track +// mined transactions. Included status can be checked by the caller (as it happens in the +// LES protocol manager) +func (pool *TxPool) AddOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []TxStatusData { + status := make([]TxStatusData, len(txHashes)) + if txs != nil { + if len(txs) != len(txHashes) { + panic(nil) + } + txErr := pool.addTxs(txs, false) + for i, err := range txErr { + if err != nil { + status[i] = TxStatusData{TxStatusError, ([]byte)(err.Error())} + } + } + } + + for i, hash := range txHashes { + r, ok := pool.all[hash] + if ok { + if r.pending { + status[i] = TxStatusData{TxStatusPending, nil} + } else { + status[i] = TxStatusData{TxStatusQueued, nil} + } + } + } + return status } // Get returns a transaction if it is contained in the pool @@ -821,17 +884,18 @@ func (pool *TxPool) Get(hash common.Hash) *types.Transaction { pool.mu.RLock() defer pool.mu.RUnlock() - return pool.all[hash] + return pool.all[hash].tx } // removeTx removes a single transaction from the queue, moving all subsequent // transactions back to the future queue. func (pool *TxPool) removeTx(hash common.Hash) { // Fetch the transaction we wish to delete - tx, ok := pool.all[hash] + txl, ok := pool.all[hash] if !ok { return } + tx := txl.tx addr, _ := types.Sender(pool.signer, tx) // already validated during insertion // Remove it from the list of known transactions diff --git a/eth/backend.go b/eth/backend.go index 6a06bd829..1cd9e8fff 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -54,6 +54,7 @@ type LesServer interface { Start(srvr *p2p.Server) Stop() Protocols() []p2p.Protocol + SetBloomBitsIndexer(bbIndexer *core.ChainIndexer) } // Ethereum implements the Ethereum full node service. @@ -95,6 +96,7 @@ type Ethereum struct { func (s *Ethereum) AddLesServer(ls LesServer) { s.lesServer = ls + ls.SetBloomBitsIndexer(s.bloomIndexer) } // New creates a new Ethereum object (including the @@ -154,7 +156,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { eth.blockchain.SetHead(compat.RewindTo) core.WriteChainConfig(chainDb, genesisHash, chainConfig) } - eth.bloomIndexer.Start(eth.blockchain.CurrentHeader(), eth.blockchain.SubscribeChainEvent) + eth.bloomIndexer.Start(eth.blockchain) if config.TxPool.Journal != "" { config.TxPool.Journal = ctx.ResolvePath(config.TxPool.Journal) diff --git a/eth/bloombits.go b/eth/bloombits.go index 32f6c7b31..c5597391c 100644 --- a/eth/bloombits.go +++ b/eth/bloombits.go @@ -58,15 +58,18 @@ func (eth *Ethereum) startBloomHandlers() { case request := <-eth.bloomRequests: task := <-request - task.Bitsets = make([][]byte, len(task.Sections)) for i, section := range task.Sections { head := core.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1) - blob, err := bitutil.DecompressBytes(core.GetBloomBits(eth.chainDb, task.Bit, section, head), int(params.BloomBitsBlocks)/8) - if err != nil { - panic(err) + if compVector, err := core.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil { + if blob, err := bitutil.DecompressBytes(compVector, int(params.BloomBitsBlocks)/8); err == nil { + task.Bitsets[i] = blob + } else { + task.Error = err + } + } else { + task.Error = err } - task.Bitsets[i] = blob } request <- task } @@ -111,12 +114,10 @@ func NewBloomIndexer(db ethdb.Database, size uint64) *core.ChainIndexer { // Reset implements core.ChainIndexerBackend, starting a new bloombits index // section. -func (b *BloomIndexer) Reset(section uint64) { +func (b *BloomIndexer) Reset(section uint64, lastSectionHead common.Hash) error { gen, err := bloombits.NewGenerator(uint(b.size)) - if err != nil { - panic(err) - } b.gen, b.section, b.head = gen, section, common.Hash{} + return err } // Process implements core.ChainIndexerBackend, adding a new header's bloom into diff --git a/eth/filters/filter.go b/eth/filters/filter.go index 026cbf95c..d16af84ee 100644 --- a/eth/filters/filter.go +++ b/eth/filters/filter.go @@ -19,7 +19,6 @@ package filters import ( "context" "math/big" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -136,11 +135,11 @@ func (f *Filter) indexedLogs(ctx context.Context, end uint64) ([]*types.Log, err // Create a matcher session and request servicing from the backend matches := make(chan uint64, 64) - session, err := f.matcher.Start(uint64(f.begin), end, matches) + session, err := f.matcher.Start(ctx, uint64(f.begin), end, matches) if err != nil { return nil, err } - defer session.Close(time.Second) + defer session.Close() f.backend.ServiceFilter(ctx, session) @@ -152,9 +151,13 @@ func (f *Filter) indexedLogs(ctx context.Context, end uint64) ([]*types.Log, err case number, ok := <-matches: // Abort if all matches have been fulfilled if !ok { - f.begin = int64(end) + 1 - return logs, nil + err := session.Error() + if err == nil { + f.begin = int64(end) + 1 + } + return logs, err } + f.begin = int64(number) + 1 // Retrieve the suggested block and pull any truly matching logs header, err := f.backend.HeaderByNumber(ctx, rpc.BlockNumber(number)) if header == nil || err != nil { diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index bc3511f23..7da114fda 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -109,7 +109,7 @@ func (b *testBackend) ServiceFilter(ctx context.Context, session *bloombits.Matc for i, section := range task.Sections { if rand.Int()%4 != 0 { // Handle occasional missing deliveries head := core.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1) - task.Bitsets[i] = core.GetBloomBits(b.db, task.Bit, section, head) + task.Bitsets[i], _ = core.GetBloomBits(b.db, task.Bit, section, head) } } request <- task diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index bb03dc72b..77784ff4a 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -379,7 +379,7 @@ func (s *Service) login(conn *websocket.Conn) error { protocol = fmt.Sprintf("eth/%d", eth.ProtocolVersions[0]) } else { network = fmt.Sprintf("%d", infos.Protocols["les"].(*eth.EthNodeInfo).Network) - protocol = fmt.Sprintf("les/%d", les.ProtocolVersions[0]) + protocol = fmt.Sprintf("les/%d", les.ClientProtocolVersions[0]) } auth := &authMsg{ Id: s.node, diff --git a/les/api_backend.go b/les/api_backend.go index 0d2d31b67..56f617a7d 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -174,8 +174,15 @@ func (b *LesApiBackend) AccountManager() *accounts.Manager { } func (b *LesApiBackend) BloomStatus() (uint64, uint64) { - return params.BloomBitsBlocks, 0 + if b.eth.bloomIndexer == nil { + return 0, 0 + } + sections, _, _ := b.eth.bloomIndexer.Sections() + return light.BloomTrieFrequency, sections } func (b *LesApiBackend) ServiceFilter(ctx context.Context, session *bloombits.MatcherSession) { + for i := 0; i < bloomFilterThreads; i++ { + go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests) + } } diff --git a/les/backend.go b/les/backend.go index 4c33417c0..3a68d13eb 100644 --- a/les/backend.go +++ b/les/backend.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth/downloader" @@ -61,6 +62,9 @@ type LightEthereum struct { // DB interfaces chainDb ethdb.Database // Block chain database + bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests + bloomIndexer, chtIndexer, bloomTrieIndexer *core.ChainIndexer + ApiBackend *LesApiBackend eventMux *event.TypeMux @@ -87,47 +91,61 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { peers := newPeerSet() quitSync := make(chan struct{}) - eth := &LightEthereum{ - chainConfig: chainConfig, - chainDb: chainDb, - eventMux: ctx.EventMux, - peers: peers, - reqDist: newRequestDistributor(peers, quitSync), - accountManager: ctx.AccountManager, - engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), - shutdownChan: make(chan bool), - networkId: config.NetworkId, + leth := &LightEthereum{ + chainConfig: chainConfig, + chainDb: chainDb, + eventMux: ctx.EventMux, + peers: peers, + reqDist: newRequestDistributor(peers, quitSync), + accountManager: ctx.AccountManager, + engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), + shutdownChan: make(chan bool), + networkId: config.NetworkId, + bloomRequests: make(chan chan *bloombits.Retrieval), + bloomIndexer: eth.NewBloomIndexer(chainDb, light.BloomTrieFrequency), + chtIndexer: light.NewChtIndexer(chainDb, true), + bloomTrieIndexer: light.NewBloomTrieIndexer(chainDb, true), } - eth.relay = NewLesTxRelay(peers, eth.reqDist) - eth.serverPool = newServerPool(chainDb, quitSync, ð.wg) - eth.retriever = newRetrieveManager(peers, eth.reqDist, eth.serverPool) - eth.odr = NewLesOdr(chainDb, eth.retriever) - if eth.blockchain, err = light.NewLightChain(eth.odr, eth.chainConfig, eth.engine); err != nil { + leth.relay = NewLesTxRelay(peers, leth.reqDist) + leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg) + leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) + leth.odr = NewLesOdr(chainDb, leth.chtIndexer, leth.bloomTrieIndexer, leth.bloomIndexer, leth.retriever) + if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine); err != nil { return nil, err } + leth.bloomIndexer.Start(leth.blockchain) // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) - eth.blockchain.SetHead(compat.RewindTo) + leth.blockchain.SetHead(compat.RewindTo) core.WriteChainConfig(chainDb, genesisHash, chainConfig) } - eth.txPool = light.NewTxPool(eth.chainConfig, eth.blockchain, eth.relay) - if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, true, config.NetworkId, eth.eventMux, eth.engine, eth.peers, eth.blockchain, nil, chainDb, eth.odr, eth.relay, quitSync, ð.wg); err != nil { + leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) + if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, true, ClientProtocolVersions, config.NetworkId, leth.eventMux, leth.engine, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.relay, quitSync, &leth.wg); err != nil { return nil, err } - eth.ApiBackend = &LesApiBackend{eth, nil} + leth.ApiBackend = &LesApiBackend{leth, nil} gpoParams := config.GPO if gpoParams.Default == nil { gpoParams.Default = config.GasPrice } - eth.ApiBackend.gpo = gasprice.NewOracle(eth.ApiBackend, gpoParams) - return eth, nil + leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) + return leth, nil } -func lesTopic(genesisHash common.Hash) discv5.Topic { - return discv5.Topic("LES@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) +func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic { + var name string + switch protocolVersion { + case lpv1: + name = "LES" + case lpv2: + name = "LES2" + default: + panic(nil) + } + return discv5.Topic(name + common.Bytes2Hex(genesisHash.Bytes()[0:8])) } type LightDummyAPI struct{} @@ -200,9 +218,13 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { // Start implements node.Service, starting all internal goroutines needed by the // Ethereum protocol implementation. func (s *LightEthereum) Start(srvr *p2p.Server) error { + s.startBloomHandlers() log.Warn("Light client mode is an experimental feature") s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId) - s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash())) + // search the topic belonging to the oldest supported protocol because + // servers always advertise all supported protocols + protocolVersion := ClientProtocolVersions[len(ClientProtocolVersions)-1] + s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) s.protocolManager.Start() return nil } @@ -211,6 +233,15 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error { // Ethereum protocol. func (s *LightEthereum) Stop() error { s.odr.Stop() + if s.bloomIndexer != nil { + s.bloomIndexer.Close() + } + if s.chtIndexer != nil { + s.chtIndexer.Close() + } + if s.bloomTrieIndexer != nil { + s.bloomTrieIndexer.Close() + } s.blockchain.Stop() s.protocolManager.Stop() s.txPool.Stop() diff --git a/les/bloombits.go b/les/bloombits.go new file mode 100644 index 000000000..dff83d349 --- /dev/null +++ b/les/bloombits.go @@ -0,0 +1,84 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "time" + + "github.com/ethereum/go-ethereum/common/bitutil" + "github.com/ethereum/go-ethereum/light" +) + +const ( + // bloomServiceThreads is the number of goroutines used globally by an Ethereum + // instance to service bloombits lookups for all running filters. + bloomServiceThreads = 16 + + // bloomFilterThreads is the number of goroutines used locally per filter to + // multiplex requests onto the global servicing goroutines. + bloomFilterThreads = 3 + + // bloomRetrievalBatch is the maximum number of bloom bit retrievals to service + // in a single batch. + bloomRetrievalBatch = 16 + + // bloomRetrievalWait is the maximum time to wait for enough bloom bit requests + // to accumulate request an entire batch (avoiding hysteresis). + bloomRetrievalWait = time.Microsecond * 100 +) + +// startBloomHandlers starts a batch of goroutines to accept bloom bit database +// retrievals from possibly a range of filters and serving the data to satisfy. +func (eth *LightEthereum) startBloomHandlers() { + for i := 0; i < bloomServiceThreads; i++ { + go func() { + for { + select { + case <-eth.shutdownChan: + return + + case request := <-eth.bloomRequests: + task := <-request + task.Bitsets = make([][]byte, len(task.Sections)) + compVectors, err := light.GetBloomBits(task.Context, eth.odr, task.Bit, task.Sections) + if err == nil { + for i, _ := range task.Sections { + if blob, err := bitutil.DecompressBytes(compVectors[i], int(light.BloomTrieFrequency/8)); err == nil { + task.Bitsets[i] = blob + } else { + task.Error = err + } + } + } else { + task.Error = err + } + request <- task + } + } + }() + } +} + +const ( + // bloomConfirms is the number of confirmation blocks before a bloom section is + // considered probably final and its rotated bits are calculated. + bloomConfirms = 256 + + // bloomThrottling is the time to wait between processing two consecutive index + // sections. It's useful during chain upgrades to prevent disk overload. + bloomThrottling = 100 * time.Millisecond +) diff --git a/les/handler.go b/les/handler.go index df7eb6af5..de07b7244 100644 --- a/les/handler.go +++ b/les/handler.go @@ -18,6 +18,7 @@ package les import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -35,6 +36,7 @@ import ( "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" @@ -50,13 +52,14 @@ const ( ethVersion = 63 // equivalent eth version for the downloader - MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request - MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request - MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request - MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request - MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxHeaderProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxTxSend = 64 // Amount of transactions to be send per request + MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request + MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request + MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request + MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxHelperTrieProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxTxSend = 64 // Amount of transactions to be send per request + MaxTxStatus = 256 // Amount of transactions to queried per request disableClientRemovePeer = false ) @@ -86,8 +89,7 @@ type BlockChain interface { } type txPool interface { - // AddRemotes should add the given transactions to the pool. - AddRemotes([]*types.Transaction) error + AddOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []core.TxStatusData } type ProtocolManager struct { @@ -125,7 +127,7 @@ type ProtocolManager struct { // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. -func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { +func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protocolVersions []uint, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { // Create the protocol manager with the base fields manager := &ProtocolManager{ lightSync: lightSync, @@ -147,15 +149,16 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network manager.retriever = odr.retriever manager.reqDist = odr.retriever.dist } + // Initiate a sub-protocol for every implemented version we can handle - manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions)) - for i, version := range ProtocolVersions { + manager.SubProtocols = make([]p2p.Protocol, 0, len(protocolVersions)) + for _, version := range protocolVersions { // Compatible, initialize the sub-protocol version := version // Closure for the run manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{ Name: "les", Version: version, - Length: ProtocolLengths[i], + Length: ProtocolLengths[version], Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { var entry *poolEntry peer := manager.newPeer(int(version), networkId, p, rw) @@ -315,7 +318,7 @@ func (pm *ProtocolManager) handle(p *peer) error { } } -var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsMsg, SendTxMsg, GetHeaderProofsMsg} +var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsV1Msg, SendTxMsg, SendTxV2Msg, GetTxStatusMsg, GetHeaderProofsMsg, GetProofsV2Msg, GetHelperTrieProofsMsg} // handleMsg is invoked whenever an inbound message is received from a remote // peer. The remote connection is torn down upon returning any error. @@ -362,11 +365,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Block header query, collect the requested headers and reply case AnnounceMsg: p.Log().Trace("Received announce message") + if p.requestAnnounceType == announceTypeNone { + return errResp(ErrUnexpectedResponse, "") + } var req announceData if err := msg.Decode(&req); err != nil { return errResp(ErrDecode, "%v: %v", msg, err) } + + if p.requestAnnounceType == announceTypeSigned { + if err := req.checkSignature(p.pubKey); err != nil { + p.Log().Trace("Invalid announcement signature", "err", err) + return err + } + p.Log().Trace("Valid announcement signature") + } + p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth) if pm.fetcher != nil { pm.fetcher.announce(p, &req) @@ -655,7 +670,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Receipts, } - case GetProofsMsg: + case GetProofsV1Msg: p.Log().Trace("Received proofs request") // Decode the retrieval message var req struct { @@ -690,9 +705,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } } if tr != nil { - proof := tr.Prove(req.Key) + var proof light.NodeList + tr.Prove(req.Key, 0, &proof) proofs = append(proofs, proof) - bytes += len(proof) + bytes += proof.DataSize() } } } @@ -701,7 +717,67 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) return p.SendProofs(req.ReqID, bv, proofs) - case ProofsMsg: + case GetProofsV2Msg: + p.Log().Trace("Received les/2 proofs request") + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []ProofReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + lastBHash common.Hash + lastAccKey []byte + tr, str *trie.Trie + ) + reqCnt := len(req.Reqs) + if reject(uint64(reqCnt), MaxProofsFetch) { + return errResp(ErrRequestRejected, "") + } + + nodes := light.NewNodeSet() + + for _, req := range req.Reqs { + if nodes.DataSize() >= softResponseLimit { + break + } + if tr == nil || req.BHash != lastBHash { + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + tr, _ = trie.New(header.Root, pm.chainDb) + } else { + tr = nil + } + lastBHash = req.BHash + str = nil + } + if tr != nil { + if len(req.AccKey) > 0 { + if str == nil || !bytes.Equal(req.AccKey, lastAccKey) { + sdata := tr.Get(req.AccKey) + str = nil + var acc state.Account + if err := rlp.DecodeBytes(sdata, &acc); err == nil { + str, _ = trie.New(acc.Root, pm.chainDb) + } + lastAccKey = common.CopyBytes(req.AccKey) + } + if str != nil { + str.Prove(req.Key, req.FromLevel, nodes) + } + } else { + tr.Prove(req.Key, req.FromLevel, nodes) + } + } + } + proofs := nodes.NodeList() + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendProofsV2(req.ReqID, bv, proofs) + + case ProofsV1Msg: if pm.odr == nil { return errResp(ErrUnexpectedResponse, "") } @@ -710,14 +786,35 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // A batch of merkle proofs arrived to one of our previous requests var resp struct { ReqID, BV uint64 - Data [][]rlp.RawValue + Data []light.NodeList } if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } p.fcServer.GotReply(resp.ReqID, resp.BV) deliverMsg = &Msg{ - MsgType: MsgProofs, + MsgType: MsgProofsV1, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case ProofsV2Msg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received les/2 proofs response") + // A batch of merkle proofs arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgProofsV2, ReqID: resp.ReqID, Obj: resp.Data, } @@ -738,22 +835,25 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { proofs []ChtResp ) reqCnt := len(req.Reqs) - if reject(uint64(reqCnt), MaxHeaderProofsFetch) { + if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { return errResp(ErrRequestRejected, "") } + trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix) for _, req := range req.Reqs { if bytes >= softResponseLimit { break } if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { - if root := getChtRoot(pm.chainDb, req.ChtNum); root != (common.Hash{}) { - if tr, _ := trie.New(root, pm.chainDb); tr != nil { + sectionHead := core.GetCanonicalHash(pm.chainDb, (req.ChtNum+1)*light.ChtV1Frequency-1) + if root := light.GetChtRoot(pm.chainDb, req.ChtNum, sectionHead); root != (common.Hash{}) { + if tr, _ := trie.New(root, trieDb); tr != nil { var encNumber [8]byte binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) - proof := tr.Prove(encNumber[:]) + var proof light.NodeList + tr.Prove(encNumber[:], 0, &proof) proofs = append(proofs, ChtResp{Header: header, Proof: proof}) - bytes += len(proof) + estHeaderRlpSize + bytes += proof.DataSize() + estHeaderRlpSize } } } @@ -762,6 +862,73 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) return p.SendHeaderProofs(req.ReqID, bv, proofs) + case GetHelperTrieProofsMsg: + p.Log().Trace("Received helper trie proof request") + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []HelperTrieReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + auxBytes int + auxData [][]byte + ) + reqCnt := len(req.Reqs) + if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { + return errResp(ErrRequestRejected, "") + } + + var ( + lastIdx uint64 + lastType uint + root common.Hash + tr *trie.Trie + ) + + nodes := light.NewNodeSet() + + for _, req := range req.Reqs { + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } + if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { + var prefix string + root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx) + if root != (common.Hash{}) { + if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil { + tr = t + } + } + lastType = req.HelperTrieType + lastIdx = req.TrieIdx + } + if req.AuxReq == auxRoot { + var data []byte + if root != (common.Hash{}) { + data = root[:] + } + auxData = append(auxData, data) + auxBytes += len(data) + } else { + if tr != nil { + tr.Prove(req.Key, req.FromLevel, nodes) + } + if req.AuxReq != 0 { + data := pm.getHelperTrieAuxData(req) + auxData = append(auxData, data) + auxBytes += len(data) + } + } + } + proofs := nodes.NodeList() + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendHelperTrieProofs(req.ReqID, bv, HelperTrieResps{Proofs: proofs, AuxData: auxData}) + case HeaderProofsMsg: if pm.odr == nil { return errResp(ErrUnexpectedResponse, "") @@ -782,9 +949,30 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Data, } + case HelperTrieProofsMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received helper trie proof response") + var resp struct { + ReqID, BV uint64 + Data HelperTrieResps + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgHelperTrieProofs, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case SendTxMsg: if pm.txpool == nil { - return errResp(ErrUnexpectedResponse, "") + return errResp(ErrRequestRejected, "") } // Transactions arrived, parse all of them and deliver to the pool var txs []*types.Transaction @@ -796,13 +984,82 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrRequestRejected, "") } - if err := pm.txpool.AddRemotes(txs); err != nil { - return errResp(ErrUnexpectedResponse, "msg: %v", err) + txHashes := make([]common.Hash, len(txs)) + for i, tx := range txs { + txHashes[i] = tx.Hash() } + pm.addOrGetTxStatus(txs, txHashes) _, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + case SendTxV2Msg: + if pm.txpool == nil { + return errResp(ErrRequestRejected, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var req struct { + ReqID uint64 + Txs []*types.Transaction + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.Txs) + if reject(uint64(reqCnt), MaxTxSend) { + return errResp(ErrRequestRejected, "") + } + + txHashes := make([]common.Hash, len(req.Txs)) + for i, tx := range req.Txs { + txHashes[i] = tx.Hash() + } + + res := pm.addOrGetTxStatus(req.Txs, txHashes) + + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendTxStatus(req.ReqID, bv, res) + + case GetTxStatusMsg: + if pm.txpool == nil { + return errResp(ErrUnexpectedResponse, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var req struct { + ReqID uint64 + TxHashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.TxHashes) + if reject(uint64(reqCnt), MaxTxStatus) { + return errResp(ErrRequestRejected, "") + } + + res := pm.addOrGetTxStatus(nil, req.TxHashes) + + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendTxStatus(req.ReqID, bv, res) + + case TxStatusMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received tx status response") + var resp struct { + ReqID, BV uint64 + Status []core.TxStatusData + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + p.fcServer.GotReply(resp.ReqID, resp.BV) + default: p.Log().Trace("Received unknown message", "code", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code) @@ -820,6 +1077,47 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return nil } +// getHelperTrie returns the post-processed trie root for the given trie ID and section index +func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { + switch id { + case htCanonical: + sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.ChtFrequency-1) + return light.GetChtV2Root(pm.chainDb, idx, sectionHead), light.ChtTablePrefix + case htBloomBits: + sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) + return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix + } + return common.Hash{}, "" +} + +// getHelperTrieAuxData returns requested auxiliary data for the given HelperTrie request +func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte { + if req.HelperTrieType == htCanonical && req.AuxReq == auxHeader { + if len(req.Key) != 8 { + return nil + } + blockNum := binary.BigEndian.Uint64(req.Key) + hash := core.GetCanonicalHash(pm.chainDb, blockNum) + return core.GetHeaderRLP(pm.chainDb, hash, blockNum) + } + return nil +} + +func (pm *ProtocolManager) addOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []core.TxStatusData { + status := pm.txpool.AddOrGetTxStatus(txs, txHashes) + for i, _ := range status { + blockHash, blockNum, txIndex := core.GetTxLookupEntry(pm.chainDb, txHashes[i]) + if blockHash != (common.Hash{}) { + enc, err := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: blockHash, BlockIndex: blockNum, Index: txIndex}) + if err != nil { + panic(err) + } + status[i] = core.TxStatusData{Status: core.TxStatusIncluded, Data: enc} + } + } + return status +} + // NodeInfo retrieves some protocol metadata about the running host node. func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo { return ð.EthNodeInfo{ diff --git a/les/handler_test.go b/les/handler_test.go index b1f1aa095..a094cdc84 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -17,7 +17,10 @@ package les import ( + "bytes" + "math/big" "math/rand" + "runtime" "testing" "github.com/ethereum/go-ethereum/common" @@ -26,7 +29,9 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -39,9 +44,29 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{} return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data}) } +func testCheckProof(t *testing.T, exp *light.NodeSet, got light.NodeList) { + if exp.KeyCount() > len(got) { + t.Errorf("proof has fewer nodes than expected") + return + } + if exp.KeyCount() < len(got) { + t.Errorf("proof has more nodes than expected") + return + } + for _, node := range got { + n, _ := exp.Get(crypto.Keccak256(node)) + if !bytes.Equal(n, node) { + t.Errorf("proof contents mismatch") + return + } + } +} + // Tests that block headers can be retrieved from a remote chain based on user queries. func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) } +func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } + func testGetBlockHeaders(t *testing.T, protocol int) { db, _ := ethdb.NewMemDatabase() pm := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil, nil, nil, db) @@ -171,6 +196,8 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Tests that block contents can be retrieved from a remote chain based on their hashes. func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) } +func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } + func testGetBlockBodies(t *testing.T, protocol int) { db, _ := ethdb.NewMemDatabase() pm := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil, nil, nil, db) @@ -247,6 +274,8 @@ func testGetBlockBodies(t *testing.T, protocol int) { // Tests that the contract codes can be retrieved based on account addresses. func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) } +func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } + func testGetCode(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -280,6 +309,8 @@ func testGetCode(t *testing.T, protocol int) { // Tests that the transaction receipts can be retrieved based on hashes. func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) } +func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } + func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -307,6 +338,8 @@ func testGetReceipt(t *testing.T, protocol int) { // Tests that trie merkle proofs can be retrieved func TestGetProofsLes1(t *testing.T) { testGetProofs(t, 1) } +func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } + func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -315,8 +348,11 @@ func testGetProofs(t *testing.T, protocol int) { peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() - var proofreqs []ProofReq - var proofs [][]rlp.RawValue + var ( + proofreqs []ProofReq + proofsV1 [][]rlp.RawValue + ) + proofsV2 := light.NewNodeSet() accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr, {}} for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { @@ -331,14 +367,124 @@ func testGetProofs(t *testing.T, protocol int) { } proofreqs = append(proofreqs, req) - proof := trie.Prove(crypto.Keccak256(acc[:])) - proofs = append(proofs, proof) + switch protocol { + case 1: + var proof light.NodeList + trie.Prove(crypto.Keccak256(acc[:]), 0, &proof) + proofsV1 = append(proofsV1, proof) + case 2: + trie.Prove(crypto.Keccak256(acc[:]), 0, proofsV2) + } } } // Send the proof request and verify the response - cost := peer.GetRequestCost(GetProofsMsg, len(proofreqs)) - sendRequest(peer.app, GetProofsMsg, 42, cost, proofreqs) - if err := expectResponse(peer.app, ProofsMsg, 42, testBufLimit, proofs); err != nil { - t.Errorf("proofs mismatch: %v", err) + switch protocol { + case 1: + cost := peer.GetRequestCost(GetProofsV1Msg, len(proofreqs)) + sendRequest(peer.app, GetProofsV1Msg, 42, cost, proofreqs) + if err := expectResponse(peer.app, ProofsV1Msg, 42, testBufLimit, proofsV1); err != nil { + t.Errorf("proofs mismatch: %v", err) + } + case 2: + cost := peer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) + sendRequest(peer.app, GetProofsV2Msg, 42, cost, proofreqs) + msg, err := peer.app.ReadMsg() + if err != nil { + t.Errorf("Message read error: %v", err) + } + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + t.Errorf("reply decode error: %v", err) + } + if msg.Code != ProofsV2Msg { + t.Errorf("Message code mismatch") + } + if resp.ReqID != 42 { + t.Errorf("ReqID mismatch") + } + if resp.BV != testBufLimit { + t.Errorf("BV mismatch") + } + testCheckProof(t, proofsV2, resp.Data) + } +} + +func TestTransactionStatusLes2(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, 0, nil, nil, nil, db) + chain := pm.blockchain.(*core.BlockChain) + txpool := core.NewTxPool(core.DefaultTxPoolConfig, params.TestChainConfig, chain) + pm.txpool = txpool + peer, _ := newTestPeer(t, "peer", 2, pm, true) + defer peer.close() + + var reqID uint64 + + test := func(tx *types.Transaction, send bool, expStatus core.TxStatusData) { + reqID++ + if send { + cost := peer.GetRequestCost(SendTxV2Msg, 1) + sendRequest(peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) + } else { + cost := peer.GetRequestCost(GetTxStatusMsg, 1) + sendRequest(peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) + } + if err := expectResponse(peer.app, TxStatusMsg, reqID, testBufLimit, []core.TxStatusData{expStatus}); err != nil { + t.Errorf("transaction status mismatch") + } + } + + signer := types.HomesteadSigner{} + + // test error status by sending an underpriced transaction + tx0, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, nil, nil), signer, testBankKey) + test(tx0, true, core.TxStatusData{Status: core.TxStatusError, Data: []byte("transaction underpriced")}) + + tx1, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + test(tx1, false, core.TxStatusData{Status: core.TxStatusUnknown}) // query before sending, should be unknown + test(tx1, true, core.TxStatusData{Status: core.TxStatusPending}) // send valid processable tx, should return pending + test(tx1, true, core.TxStatusData{Status: core.TxStatusPending}) // adding it again should not return an error + + tx2, _ := types.SignTx(types.NewTransaction(1, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + tx3, _ := types.SignTx(types.NewTransaction(2, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + // send transactions in the wrong order, tx3 should be queued + test(tx3, true, core.TxStatusData{Status: core.TxStatusQueued}) + test(tx2, true, core.TxStatusData{Status: core.TxStatusPending}) + // query again, now tx3 should be pending too + test(tx3, false, core.TxStatusData{Status: core.TxStatusPending}) + + // generate and add a block with tx1 and tx2 included + gchain, _ := core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 1, func(i int, block *core.BlockGen) { + block.AddTx(tx1) + block.AddTx(tx2) + }) + if _, err := chain.InsertChain(gchain); err != nil { + panic(err) + } + + // check if their status is included now + block1hash := core.GetCanonicalHash(db, 1) + tx1pos, _ := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}) + tx2pos, _ := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}) + test(tx1, false, core.TxStatusData{Status: core.TxStatusIncluded, Data: tx1pos}) + test(tx2, false, core.TxStatusData{Status: core.TxStatusIncluded, Data: tx2pos}) + + // create a reorg that rolls them back + gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 2, func(i int, block *core.BlockGen) {}) + if _, err := chain.InsertChain(gchain); err != nil { + panic(err) + } + // wait until TxPool processes the reorg + for { + if pending, _ := txpool.Stats(); pending == 3 { + break + } + runtime.Gosched() } + // check if their status is pending again + test(tx1, false, core.TxStatusData{Status: core.TxStatusPending}) + test(tx2, false, core.TxStatusData{Status: core.TxStatusPending}) } diff --git a/les/helper_test.go b/les/helper_test.go index b33454e1d..a06f84cca 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -43,7 +43,7 @@ import ( var ( testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) - testBankFunds = big.NewInt(1000000) + testBankFunds = big.NewInt(1000000000000000000) acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") @@ -156,7 +156,13 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor chain = blockchain } - pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) + var protocolVersions []uint + if lightSync { + protocolVersions = ClientProtocolVersions + } else { + protocolVersions = ServerProtocolVersions + } + pm, err := NewProtocolManager(gspec.Config, lightSync, protocolVersions, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) if err != nil { return nil, err } diff --git a/les/odr.go b/les/odr.go index 3f7584b48..986630dbf 100644 --- a/les/odr.go +++ b/les/odr.go @@ -19,6 +19,7 @@ package les import ( "context" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" @@ -26,33 +27,56 @@ import ( // LesOdr implements light.OdrBackend type LesOdr struct { - db ethdb.Database - stop chan struct{} - retriever *retrieveManager + db ethdb.Database + chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer + retriever *retrieveManager + stop chan struct{} } -func NewLesOdr(db ethdb.Database, retriever *retrieveManager) *LesOdr { +func NewLesOdr(db ethdb.Database, chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer, retriever *retrieveManager) *LesOdr { return &LesOdr{ - db: db, - retriever: retriever, - stop: make(chan struct{}), + db: db, + chtIndexer: chtIndexer, + bloomTrieIndexer: bloomTrieIndexer, + bloomIndexer: bloomIndexer, + retriever: retriever, + stop: make(chan struct{}), } } +// Stop cancels all pending retrievals func (odr *LesOdr) Stop() { close(odr.stop) } +// Database returns the backing database func (odr *LesOdr) Database() ethdb.Database { return odr.db } +// ChtIndexer returns the CHT chain indexer +func (odr *LesOdr) ChtIndexer() *core.ChainIndexer { + return odr.chtIndexer +} + +// BloomTrieIndexer returns the bloom trie chain indexer +func (odr *LesOdr) BloomTrieIndexer() *core.ChainIndexer { + return odr.bloomTrieIndexer +} + +// BloomIndexer returns the bloombits chain indexer +func (odr *LesOdr) BloomIndexer() *core.ChainIndexer { + return odr.bloomIndexer +} + const ( MsgBlockBodies = iota MsgCode MsgReceipts - MsgProofs + MsgProofsV1 + MsgProofsV2 MsgHeaderProofs + MsgHelperTrieProofs ) // Msg encodes a LES message that delivers reply data for a request @@ -64,7 +88,7 @@ type Msg struct { // Retrieve tries to fetch an object from the LES network. // If the network retrieval was successful, it stores the object in local db. -func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { +func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { lreq := LesRequest(req) reqID := genReqID() @@ -84,9 +108,9 @@ func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err err }, } - if err = self.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(self.db, msg) }); err == nil { + if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { // retrieved from network, store in db - req.StoreResult(self.db) + req.StoreResult(odr.db) } else { log.Debug("Failed to retrieve data from network", "err", err) } diff --git a/les/odr_requests.go b/les/odr_requests.go index 1f853b341..937a4f1d9 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -36,13 +36,15 @@ import ( var ( errInvalidMessageType = errors.New("invalid message type") - errMultipleEntries = errors.New("multiple response entries") + errInvalidEntryCount = errors.New("invalid number of response entries") errHeaderUnavailable = errors.New("header unavailable") errTxHashMismatch = errors.New("transaction hash mismatch") errUncleHashMismatch = errors.New("uncle hash mismatch") errReceiptHashMismatch = errors.New("receipt hash mismatch") errDataHashMismatch = errors.New("data hash mismatch") errCHTHashMismatch = errors.New("cht hash mismatch") + errCHTNumberMismatch = errors.New("cht number mismatch") + errUselessNodes = errors.New("useless nodes in merkle proof nodeset") ) type LesOdrRequest interface { @@ -64,6 +66,8 @@ func LesRequest(req light.OdrRequest) LesOdrRequest { return (*CodeRequest)(r) case *light.ChtRequest: return (*ChtRequest)(r) + case *light.BloomRequest: + return (*BloomRequest)(r) default: return nil } @@ -101,7 +105,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { } bodies := msg.Obj.([]*types.Body) if len(bodies) != 1 { - return errMultipleEntries + return errInvalidEntryCount } body := bodies[0] @@ -157,7 +161,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { } receipts := msg.Obj.([]types.Receipts) if len(receipts) != 1 { - return errMultipleEntries + return errInvalidEntryCount } receipt := receipts[0] @@ -186,7 +190,14 @@ type TrieRequest light.TrieRequest // GetCost returns the cost of the given ODR request according to the serving // peer's cost table (implementation of LesOdrRequest) func (r *TrieRequest) GetCost(peer *peer) uint64 { - return peer.GetRequestCost(GetProofsMsg, 1) + switch peer.version { + case lpv1: + return peer.GetRequestCost(GetProofsV1Msg, 1) + case lpv2: + return peer.GetRequestCost(GetProofsV2Msg, 1) + default: + panic(nil) + } } // CanSend tells if a certain peer is suitable for serving the given request @@ -197,12 +208,12 @@ func (r *TrieRequest) CanSend(peer *peer) bool { // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *TrieRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key) - req := &ProofReq{ + req := ProofReq{ BHash: r.Id.BlockHash, AccKey: r.Id.AccKey, Key: r.Key, } - return peer.RequestProofs(reqID, r.GetCost(peer), []*ProofReq{req}) + return peer.RequestProofs(reqID, r.GetCost(peer), []ProofReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -211,20 +222,38 @@ func (r *TrieRequest) Request(reqID uint64, peer *peer) error { func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error { log.Debug("Validating trie proof", "root", r.Id.Root, "key", r.Key) - // Ensure we have a correct message with a single proof - if msg.MsgType != MsgProofs { + switch msg.MsgType { + case MsgProofsV1: + proofs := msg.Obj.([]light.NodeList) + if len(proofs) != 1 { + return errInvalidEntryCount + } + nodeSet := proofs[0].NodeSet() + // Verify the proof and store if checks out + if _, err, _ := trie.VerifyProof(r.Id.Root, r.Key, nodeSet); err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + r.Proof = nodeSet + return nil + + case MsgProofsV2: + proofs := msg.Obj.(light.NodeList) + // Verify the proof and store if checks out + nodeSet := proofs.NodeSet() + reads := &readTraceDB{db: nodeSet} + if _, err, _ := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + // check if all nodes have been read by VerifyProof + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + r.Proof = nodeSet + return nil + + default: return errInvalidMessageType } - proofs := msg.Obj.([][]rlp.RawValue) - if len(proofs) != 1 { - return errMultipleEntries - } - // Verify the proof and store if checks out - if _, err := trie.VerifyProof(r.Id.Root, r.Key, proofs[0]); err != nil { - return fmt.Errorf("merkle proof verification failed: %v", err) - } - r.Proof = proofs[0] - return nil } type CodeReq struct { @@ -249,11 +278,11 @@ func (r *CodeRequest) CanSend(peer *peer) bool { // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *CodeRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting code data", "hash", r.Hash) - req := &CodeReq{ + req := CodeReq{ BHash: r.Id.BlockHash, AccKey: r.Id.AccKey, } - return peer.RequestCode(reqID, r.GetCost(peer), []*CodeReq{req}) + return peer.RequestCode(reqID, r.GetCost(peer), []CodeReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -268,7 +297,7 @@ func (r *CodeRequest) Validate(db ethdb.Database, msg *Msg) error { } reply := msg.Obj.([][]byte) if len(reply) != 1 { - return errMultipleEntries + return errInvalidEntryCount } data := reply[0] @@ -280,10 +309,36 @@ func (r *CodeRequest) Validate(db ethdb.Database, msg *Msg) error { return nil } +const ( + // helper trie type constants + htCanonical = iota // Canonical hash trie + htBloomBits // BloomBits trie + + // applicable for all helper trie requests + auxRoot = 1 + // applicable for htCanonical + auxHeader = 2 +) + +type HelperTrieReq struct { + HelperTrieType uint + TrieIdx uint64 + Key []byte + FromLevel, AuxReq uint +} + +type HelperTrieResps struct { // describes all responses, not just a single one + Proofs light.NodeList + AuxData [][]byte +} + +// legacy LES/1 type ChtReq struct { - ChtNum, BlockNum, FromLevel uint64 + ChtNum, BlockNum uint64 + FromLevel uint } +// legacy LES/1 type ChtResp struct { Header *types.Header Proof []rlp.RawValue @@ -295,7 +350,14 @@ type ChtRequest light.ChtRequest // GetCost returns the cost of the given ODR request according to the serving // peer's cost table (implementation of LesOdrRequest) func (r *ChtRequest) GetCost(peer *peer) uint64 { - return peer.GetRequestCost(GetHeaderProofsMsg, 1) + switch peer.version { + case lpv1: + return peer.GetRequestCost(GetHeaderProofsMsg, 1) + case lpv2: + return peer.GetRequestCost(GetHelperTrieProofsMsg, 1) + default: + panic(nil) + } } // CanSend tells if a certain peer is suitable for serving the given request @@ -303,17 +365,21 @@ func (r *ChtRequest) CanSend(peer *peer) bool { peer.lock.RLock() defer peer.lock.RUnlock() - return r.ChtNum <= (peer.headInfo.Number-light.ChtConfirmations)/light.ChtFrequency + return peer.headInfo.Number >= light.HelperTrieConfirmations && r.ChtNum <= (peer.headInfo.Number-light.HelperTrieConfirmations)/light.ChtFrequency } // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *ChtRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting CHT", "cht", r.ChtNum, "block", r.BlockNum) - req := &ChtReq{ - ChtNum: r.ChtNum, - BlockNum: r.BlockNum, + var encNum [8]byte + binary.BigEndian.PutUint64(encNum[:], r.BlockNum) + req := HelperTrieReq{ + HelperTrieType: htCanonical, + TrieIdx: r.ChtNum, + Key: encNum[:], + AuxReq: auxHeader, } - return peer.RequestHeaderProofs(reqID, r.GetCost(peer), []*ChtReq{req}) + return peer.RequestHelperTrieProofs(reqID, r.GetCost(peer), []HelperTrieReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -322,35 +388,179 @@ func (r *ChtRequest) Request(reqID uint64, peer *peer) error { func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error { log.Debug("Validating CHT", "cht", r.ChtNum, "block", r.BlockNum) - // Ensure we have a correct message with a single proof element - if msg.MsgType != MsgHeaderProofs { + switch msg.MsgType { + case MsgHeaderProofs: // LES/1 backwards compatibility + proofs := msg.Obj.([]ChtResp) + if len(proofs) != 1 { + return errInvalidEntryCount + } + proof := proofs[0] + + // Verify the CHT + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + + value, err, _ := trie.VerifyProof(r.ChtRoot, encNumber[:], light.NodeList(proof.Proof).NodeSet()) + if err != nil { + return err + } + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != proof.Header.Hash() { + return errCHTHashMismatch + } + // Verifications passed, store and return + r.Header = proof.Header + r.Proof = light.NodeList(proof.Proof).NodeSet() + r.Td = node.Td + case MsgHelperTrieProofs: + resp := msg.Obj.(HelperTrieResps) + if len(resp.AuxData) != 1 { + return errInvalidEntryCount + } + nodeSet := resp.Proofs.NodeSet() + headerEnc := resp.AuxData[0] + if len(headerEnc) == 0 { + return errHeaderUnavailable + } + header := new(types.Header) + if err := rlp.DecodeBytes(headerEnc, header); err != nil { + return errHeaderUnavailable + } + + // Verify the CHT + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + + reads := &readTraceDB{db: nodeSet} + value, err, _ := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) + if err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != header.Hash() { + return errCHTHashMismatch + } + if r.BlockNum != header.Number.Uint64() { + return errCHTNumberMismatch + } + // Verifications passed, store and return + r.Header = header + r.Proof = nodeSet + r.Td = node.Td + default: return errInvalidMessageType } - proofs := msg.Obj.([]ChtResp) - if len(proofs) != 1 { - return errMultipleEntries - } - proof := proofs[0] + return nil +} - // Verify the CHT - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) +type BloomReq struct { + BloomTrieNum, BitIdx, SectionIdx, FromLevel uint64 +} - value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], proof.Proof) - if err != nil { - return err +// ODR request type for requesting headers by Canonical Hash Trie, see LesOdrRequest interface +type BloomRequest light.BloomRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (r *BloomRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetHelperTrieProofsMsg, len(r.SectionIdxList)) +} + +// CanSend tells if a certain peer is suitable for serving the given request +func (r *BloomRequest) CanSend(peer *peer) bool { + peer.lock.RLock() + defer peer.lock.RUnlock() + + if peer.version < lpv2 { + return false } - var node light.ChtNode - if err := rlp.DecodeBytes(value, &node); err != nil { - return err + return peer.headInfo.Number >= light.HelperTrieConfirmations && r.BloomTrieNum <= (peer.headInfo.Number-light.HelperTrieConfirmations)/light.BloomTrieFrequency +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (r *BloomRequest) Request(reqID uint64, peer *peer) error { + peer.Log().Debug("Requesting BloomBits", "bloomTrie", r.BloomTrieNum, "bitIdx", r.BitIdx, "sections", r.SectionIdxList) + reqs := make([]HelperTrieReq, len(r.SectionIdxList)) + + var encNumber [10]byte + binary.BigEndian.PutUint16(encNumber[0:2], uint16(r.BitIdx)) + + for i, sectionIdx := range r.SectionIdxList { + binary.BigEndian.PutUint64(encNumber[2:10], sectionIdx) + reqs[i] = HelperTrieReq{ + HelperTrieType: htBloomBits, + TrieIdx: r.BloomTrieNum, + Key: common.CopyBytes(encNumber[:]), + } } - if node.Hash != proof.Header.Hash() { - return errCHTHashMismatch + return peer.RequestHelperTrieProofs(reqID, r.GetCost(peer), reqs) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (r *BloomRequest) Validate(db ethdb.Database, msg *Msg) error { + log.Debug("Validating BloomBits", "bloomTrie", r.BloomTrieNum, "bitIdx", r.BitIdx, "sections", r.SectionIdxList) + + // Ensure we have a correct message with a single proof element + if msg.MsgType != MsgHelperTrieProofs { + return errInvalidMessageType + } + resps := msg.Obj.(HelperTrieResps) + proofs := resps.Proofs + nodeSet := proofs.NodeSet() + reads := &readTraceDB{db: nodeSet} + + r.BloomBits = make([][]byte, len(r.SectionIdxList)) + + // Verify the proofs + var encNumber [10]byte + binary.BigEndian.PutUint16(encNumber[0:2], uint16(r.BitIdx)) + + for i, idx := range r.SectionIdxList { + binary.BigEndian.PutUint64(encNumber[2:10], idx) + value, err, _ := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads) + if err != nil { + return err + } + r.BloomBits[i] = value } - // Verifications passed, store and return - r.Header = proof.Header - r.Proof = proof.Proof - r.Td = node.Td + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + r.Proofs = nodeSet return nil } + +// readTraceDB stores the keys of database reads. We use this to check that received node +// sets contain only the trie nodes necessary to make proofs pass. +type readTraceDB struct { + db trie.DatabaseReader + reads map[string]struct{} +} + +// Get returns a stored node +func (db *readTraceDB) Get(k []byte) ([]byte, error) { + if db.reads == nil { + db.reads = make(map[string]struct{}) + } + db.reads[string(k)] = struct{}{} + return db.db.Get(k) +} + +// Has returns true if the node set contains the given key +func (db *readTraceDB) Has(key []byte) (bool, error) { + _, err := db.Get(key) + return err == nil, nil +} diff --git a/les/odr_test.go b/les/odr_test.go index f56c4036d..865f5d83e 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/params" @@ -39,6 +40,8 @@ type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.Chain func TestOdrGetBlockLes1(t *testing.T) { testOdr(t, 1, 1, odrGetBlock) } +func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, odrGetBlock) } + func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var block *types.Block if bc != nil { @@ -55,6 +58,8 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon func TestOdrGetReceiptsLes1(t *testing.T) { testOdr(t, 1, 1, odrGetReceipts) } +func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, odrGetReceipts) } + func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts if bc != nil { @@ -71,6 +76,8 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain func TestOdrAccountsLes1(t *testing.T) { testOdr(t, 1, 1, odrAccounts) } +func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, odrAccounts) } + func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} @@ -100,6 +107,8 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon func TestOdrContractCallLes1(t *testing.T) { testOdr(t, 1, 2, odrContractCall) } +func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, odrContractCall) } + type callmsg struct { types.Message } @@ -154,7 +163,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { rm := newRetrieveManager(peers, dist, nil) db, _ := ethdb.NewMemDatabase() ldb, _ := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, rm) + odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) diff --git a/les/peer.go b/les/peer.go index 3ba2df3fe..104afb6dc 100644 --- a/les/peer.go +++ b/les/peer.go @@ -18,6 +18,8 @@ package les import ( + "crypto/ecdsa" + "encoding/binary" "errors" "fmt" "math/big" @@ -25,9 +27,11 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rlp" ) @@ -40,14 +44,23 @@ var ( const maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) +const ( + announceTypeNone = iota + announceTypeSimple + announceTypeSigned +) + type peer struct { *p2p.Peer + pubKey *ecdsa.PublicKey rw p2p.MsgReadWriter version int // Protocol version negotiated network uint64 // Network ID being on + announceType, requestAnnounceType uint64 + id string headInfo *announceData @@ -68,9 +81,11 @@ type peer struct { func newPeer(version int, network uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { id := p.ID() + pubKey, _ := id.Pubkey() return &peer{ Peer: p, + pubKey: pubKey, rw: rw, version: version, network: network, @@ -197,16 +212,31 @@ func (p *peer) SendReceiptsRLP(reqID, bv uint64, receipts []rlp.RawValue) error return sendResponse(p.rw, ReceiptsMsg, reqID, bv, receipts) } -// SendProofs sends a batch of merkle proofs, corresponding to the ones requested. +// SendProofs sends a batch of legacy LES/1 merkle proofs, corresponding to the ones requested. func (p *peer) SendProofs(reqID, bv uint64, proofs proofsData) error { - return sendResponse(p.rw, ProofsMsg, reqID, bv, proofs) + return sendResponse(p.rw, ProofsV1Msg, reqID, bv, proofs) } -// SendHeaderProofs sends a batch of header proofs, corresponding to the ones requested. +// SendProofsV2 sends a batch of merkle proofs, corresponding to the ones requested. +func (p *peer) SendProofsV2(reqID, bv uint64, proofs light.NodeList) error { + return sendResponse(p.rw, ProofsV2Msg, reqID, bv, proofs) +} + +// SendHeaderProofs sends a batch of legacy LES/1 header proofs, corresponding to the ones requested. func (p *peer) SendHeaderProofs(reqID, bv uint64, proofs []ChtResp) error { return sendResponse(p.rw, HeaderProofsMsg, reqID, bv, proofs) } +// SendHelperTrieProofs sends a batch of HelperTrie proofs, corresponding to the ones requested. +func (p *peer) SendHelperTrieProofs(reqID, bv uint64, resp HelperTrieResps) error { + return sendResponse(p.rw, HelperTrieProofsMsg, reqID, bv, resp) +} + +// SendTxStatus sends a batch of transaction status records, corresponding to the ones requested. +func (p *peer) SendTxStatus(reqID, bv uint64, status []core.TxStatusData) error { + return sendResponse(p.rw, TxStatusMsg, reqID, bv, status) +} + // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the // specified header query, based on the hash of an origin block. func (p *peer) RequestHeadersByHash(reqID, cost uint64, origin common.Hash, amount int, skip int, reverse bool) error { @@ -230,7 +260,7 @@ func (p *peer) RequestBodies(reqID, cost uint64, hashes []common.Hash) error { // RequestCode fetches a batch of arbitrary data from a node's known state // data, corresponding to the specified hashes. -func (p *peer) RequestCode(reqID, cost uint64, reqs []*CodeReq) error { +func (p *peer) RequestCode(reqID, cost uint64, reqs []CodeReq) error { p.Log().Debug("Fetching batch of codes", "count", len(reqs)) return sendRequest(p.rw, GetCodeMsg, reqID, cost, reqs) } @@ -242,20 +272,58 @@ func (p *peer) RequestReceipts(reqID, cost uint64, hashes []common.Hash) error { } // RequestProofs fetches a batch of merkle proofs from a remote node. -func (p *peer) RequestProofs(reqID, cost uint64, reqs []*ProofReq) error { +func (p *peer) RequestProofs(reqID, cost uint64, reqs []ProofReq) error { p.Log().Debug("Fetching batch of proofs", "count", len(reqs)) - return sendRequest(p.rw, GetProofsMsg, reqID, cost, reqs) + switch p.version { + case lpv1: + return sendRequest(p.rw, GetProofsV1Msg, reqID, cost, reqs) + case lpv2: + return sendRequest(p.rw, GetProofsV2Msg, reqID, cost, reqs) + default: + panic(nil) + } + +} + +// RequestHelperTrieProofs fetches a batch of HelperTrie merkle proofs from a remote node. +func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, reqs []HelperTrieReq) error { + p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs)) + switch p.version { + case lpv1: + reqsV1 := make([]ChtReq, len(reqs)) + for i, req := range reqs { + if req.HelperTrieType != htCanonical || req.AuxReq != auxHeader || len(req.Key) != 8 { + return fmt.Errorf("Request invalid in LES/1 mode") + } + blockNum := binary.BigEndian.Uint64(req.Key) + // convert HelperTrie request to old CHT request + reqsV1[i] = ChtReq{ChtNum: (req.TrieIdx+1)*(light.ChtFrequency/light.ChtV1Frequency) - 1, BlockNum: blockNum, FromLevel: req.FromLevel} + } + return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqsV1) + case lpv2: + return sendRequest(p.rw, GetHelperTrieProofsMsg, reqID, cost, reqs) + default: + panic(nil) + } } -// RequestHeaderProofs fetches a batch of header merkle proofs from a remote node. -func (p *peer) RequestHeaderProofs(reqID, cost uint64, reqs []*ChtReq) error { - p.Log().Debug("Fetching batch of header proofs", "count", len(reqs)) - return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs) +// RequestTxStatus fetches a batch of transaction status records from a remote node. +func (p *peer) RequestTxStatus(reqID, cost uint64, txHashes []common.Hash) error { + p.Log().Debug("Requesting transaction status", "count", len(txHashes)) + return sendRequest(p.rw, GetTxStatusMsg, reqID, cost, txHashes) } +// SendTxStatus sends a batch of transactions to be added to the remote transaction pool. func (p *peer) SendTxs(reqID, cost uint64, txs types.Transactions) error { p.Log().Debug("Fetching batch of transactions", "count", len(txs)) - return p2p.Send(p.rw, SendTxMsg, txs) + switch p.version { + case lpv1: + return p2p.Send(p.rw, SendTxMsg, txs) // old message format does not include reqID + case lpv2: + return sendRequest(p.rw, SendTxV2Msg, reqID, cost, txs) + default: + panic(nil) + } } type keyValueEntry struct { @@ -289,7 +357,7 @@ func (l keyValueList) decode() keyValueMap { func (m keyValueMap) get(key string, val interface{}) error { enc, ok := m[key] if !ok { - return errResp(ErrHandshakeMissingKey, "%s", key) + return errResp(ErrMissingKey, "%s", key) } if val == nil { return nil @@ -348,6 +416,9 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis list := server.fcCostStats.getCurrentList() send = send.add("flowControl/MRC", list) p.fcCosts = list.decode() + } else { + p.requestAnnounceType = announceTypeSimple // set to default until "very light" client mode is implemented + send = send.add("announceType", p.requestAnnounceType) } recvList, err := p.sendReceiveHandshake(send) if err != nil { @@ -392,6 +463,9 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis /*if recv.get("serveStateSince", nil) == nil { return errResp(ErrUselessPeer, "wanted client, got server") }*/ + if recv.get("announceType", &p.announceType) != nil { + p.announceType = announceTypeSimple + } p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) } else { if recv.get("serveChainSince", nil) != nil { @@ -456,11 +530,15 @@ func newPeerSet() *peerSet { // notify adds a service to be notified about added or removed peers func (ps *peerSet) notify(n peerSetNotify) { ps.lock.Lock() - defer ps.lock.Unlock() - ps.notifyList = append(ps.notifyList, n) + peers := make([]*peer, 0, len(ps.peers)) for _, p := range ps.peers { - go n.registerPeer(p) + peers = append(peers, p) + } + ps.lock.Unlock() + + for _, p := range peers { + n.registerPeer(p) } } @@ -468,8 +546,6 @@ func (ps *peerSet) notify(n peerSetNotify) { // peer is already known. func (ps *peerSet) Register(p *peer) error { ps.lock.Lock() - defer ps.lock.Unlock() - if ps.closed { return errClosed } @@ -478,8 +554,12 @@ func (ps *peerSet) Register(p *peer) error { } ps.peers[p.id] = p p.sendQueue = newExecQueue(100) - for _, n := range ps.notifyList { - go n.registerPeer(p) + peers := make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + ps.lock.Unlock() + + for _, n := range peers { + n.registerPeer(p) } return nil } @@ -488,19 +568,22 @@ func (ps *peerSet) Register(p *peer) error { // actions to/from that particular entity. It also initiates disconnection at the networking layer. func (ps *peerSet) Unregister(id string) error { ps.lock.Lock() - defer ps.lock.Unlock() - if p, ok := ps.peers[id]; !ok { + ps.lock.Unlock() return errNotRegistered } else { - for _, n := range ps.notifyList { - go n.unregisterPeer(p) + delete(ps.peers, id) + peers := make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + ps.lock.Unlock() + + for _, n := range peers { + n.unregisterPeer(p) } p.sendQueue.quit() p.Peer.Disconnect(p2p.DiscUselessPeer) + return nil } - delete(ps.peers, id) - return nil } // AllPeerIDs returns a list of all registered peer IDs diff --git a/les/protocol.go b/les/protocol.go index 33d930ee0..146b02030 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -18,24 +18,34 @@ package les import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "errors" "fmt" "io" "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/rlp" ) // Constants to match up protocol versions and messages const ( lpv1 = 1 + lpv2 = 2 ) -// Supported versions of the les protocol (first is primary). -var ProtocolVersions = []uint{lpv1} +// Supported versions of the les protocol (first is primary) +var ( + ClientProtocolVersions = []uint{lpv2, lpv1} + ServerProtocolVersions = []uint{lpv2, lpv1} +) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = []uint64{15} +var ProtocolLengths = map[uint]uint64{lpv1: 15, lpv2: 22} const ( NetworkId = 1 @@ -53,13 +63,21 @@ const ( BlockBodiesMsg = 0x05 GetReceiptsMsg = 0x06 ReceiptsMsg = 0x07 - GetProofsMsg = 0x08 - ProofsMsg = 0x09 + GetProofsV1Msg = 0x08 + ProofsV1Msg = 0x09 GetCodeMsg = 0x0a CodeMsg = 0x0b SendTxMsg = 0x0c GetHeaderProofsMsg = 0x0d HeaderProofsMsg = 0x0e + // Protocol messages belonging to LPV2 + GetProofsV2Msg = 0x0f + ProofsV2Msg = 0x10 + GetHelperTrieProofsMsg = 0x11 + HelperTrieProofsMsg = 0x12 + SendTxV2Msg = 0x13 + GetTxStatusMsg = 0x14 + TxStatusMsg = 0x15 ) type errCode int @@ -79,7 +97,7 @@ const ( ErrUnexpectedResponse ErrInvalidResponse ErrTooManyTimeouts - ErrHandshakeMissingKey + ErrMissingKey ) func (e errCode) String() string { @@ -101,7 +119,13 @@ var errorToString = map[int]string{ ErrUnexpectedResponse: "Unexpected response", ErrInvalidResponse: "Invalid response", ErrTooManyTimeouts: "Too many request timeouts", - ErrHandshakeMissingKey: "Key missing from handshake message", + ErrMissingKey: "Key missing from list", +} + +type announceBlock struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced + Td *big.Int // Total difficulty of one particular block being announced } // announceData is the network packet for the block announcements. @@ -113,6 +137,32 @@ type announceData struct { Update keyValueList } +// sign adds a signature to the block announcement by the given privKey +func (a *announceData) sign(privKey *ecdsa.PrivateKey) { + rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + sig, _ := crypto.Sign(crypto.Keccak256(rlp), privKey) + a.Update = a.Update.add("sign", sig) +} + +// checkSignature verifies if the block announcement has a valid signature by the given pubKey +func (a *announceData) checkSignature(pubKey *ecdsa.PublicKey) error { + var sig []byte + if err := a.Update.decode().get("sign", &sig); err != nil { + return err + } + rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + recPubkey, err := secp256k1.RecoverPubkey(crypto.Keccak256(rlp), sig) + if err != nil { + return err + } + pbytes := elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y) + if bytes.Equal(pbytes, recPubkey) { + return nil + } else { + return errors.New("Wrong signature") + } +} + type blockInfo struct { Hash common.Hash // Hash of one particular block being announced Number uint64 // Number of one particular block being announced diff --git a/les/request_test.go b/les/request_test.go index 6b594462d..c13625de8 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" ) @@ -38,24 +39,32 @@ type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) ligh func TestBlockAccessLes1(t *testing.T) { testAccess(t, 1, tfBlockAccess) } +func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } + func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.BlockRequest{Hash: bhash, Number: number} } func TestReceiptsAccessLes1(t *testing.T) { testAccess(t, 1, tfReceiptsAccess) } +func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } + func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.ReceiptsRequest{Hash: bhash, Number: number} } func TestTrieEntryAccessLes1(t *testing.T) { testAccess(t, 1, tfTrieEntryAccess) } +func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } + func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} } func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) } +func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } + func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash)) if header.Number.Uint64() < testContractDeployed { @@ -73,7 +82,7 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { rm := newRetrieveManager(peers, dist, nil) db, _ := ethdb.NewMemDatabase() ldb, _ := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, rm) + odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) diff --git a/les/retrieve.go b/les/retrieve.go index b060e0b0d..dd15b56ac 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -22,6 +22,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "fmt" "sync" "time" @@ -111,12 +112,14 @@ func newRetrieveManager(peers *peerSet, dist *requestDistributor, serverPool pee // that is delivered through the deliver function and successfully validated by the // validator callback. It returns when a valid answer is delivered or the context is // cancelled. -func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc) error { +func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc, shutdown chan struct{}) error { sentReq := rm.sendReq(reqID, req, val) select { case <-sentReq.stopCh: case <-ctx.Done(): sentReq.stop(ctx.Err()) + case <-shutdown: + sentReq.stop(fmt.Errorf("Client is shutting down")) } return sentReq.getError() } diff --git a/les/server.go b/les/server.go index 8b2730714..d8f93cd87 100644 --- a/les/server.go +++ b/les/server.go @@ -18,10 +18,11 @@ package les import ( + "crypto/ecdsa" "encoding/binary" + "fmt" "math" "sync" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -34,7 +35,6 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) type LesServer struct { @@ -42,23 +42,55 @@ type LesServer struct { fcManager *flowcontrol.ClientManager // nil if our node is client only fcCostStats *requestCostStats defParams *flowcontrol.ServerParams - lesTopic discv5.Topic + lesTopics []discv5.Topic + privateKey *ecdsa.PrivateKey quitSync chan struct{} + + chtIndexer, bloomTrieIndexer *core.ChainIndexer } func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { quitSync := make(chan struct{}) - pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) + pm, err := NewProtocolManager(eth.BlockChain().Config(), false, ServerProtocolVersions, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) if err != nil { return nil, err } - pm.blockLoop() + + lesTopics := make([]discv5.Topic, len(ServerProtocolVersions)) + for i, pv := range ServerProtocolVersions { + lesTopics[i] = lesTopic(eth.BlockChain().Genesis().Hash(), pv) + } srv := &LesServer{ - protocolManager: pm, - quitSync: quitSync, - lesTopic: lesTopic(eth.BlockChain().Genesis().Hash()), + protocolManager: pm, + quitSync: quitSync, + lesTopics: lesTopics, + chtIndexer: light.NewChtIndexer(eth.ChainDb(), false), + bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), false), + } + logger := log.New() + + chtV1SectionCount, _, _ := srv.chtIndexer.Sections() // indexer still uses LES/1 4k section size for backwards server compatibility + chtV2SectionCount := chtV1SectionCount / (light.ChtFrequency / light.ChtV1Frequency) + if chtV2SectionCount != 0 { + // convert to LES/2 section + chtLastSection := chtV2SectionCount - 1 + // convert last LES/2 section index back to LES/1 index for chtIndexer.SectionHead + chtLastSectionV1 := (chtLastSection+1)*(light.ChtFrequency/light.ChtV1Frequency) - 1 + chtSectionHead := srv.chtIndexer.SectionHead(chtLastSectionV1) + chtRoot := light.GetChtV2Root(pm.chainDb, chtLastSection, chtSectionHead) + logger.Info("CHT", "section", chtLastSection, "sectionHead", fmt.Sprintf("%064x", chtSectionHead), "root", fmt.Sprintf("%064x", chtRoot)) + } + + bloomTrieSectionCount, _, _ := srv.bloomTrieIndexer.Sections() + if bloomTrieSectionCount != 0 { + bloomTrieLastSection := bloomTrieSectionCount - 1 + bloomTrieSectionHead := srv.bloomTrieIndexer.SectionHead(bloomTrieLastSection) + bloomTrieRoot := light.GetBloomTrieRoot(pm.chainDb, bloomTrieLastSection, bloomTrieSectionHead) + logger.Info("BloomTrie", "section", bloomTrieLastSection, "sectionHead", fmt.Sprintf("%064x", bloomTrieSectionHead), "root", fmt.Sprintf("%064x", bloomTrieRoot)) } + + srv.chtIndexer.Start(eth.BlockChain()) pm.server = srv srv.defParams = &flowcontrol.ServerParams{ @@ -77,17 +109,28 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { s.protocolManager.Start() - go func() { - logger := log.New("topic", s.lesTopic) - logger.Info("Starting topic registration") - defer logger.Info("Terminated topic registration") + for _, topic := range s.lesTopics { + topic := topic + go func() { + logger := log.New("topic", topic) + logger.Info("Starting topic registration") + defer logger.Info("Terminated topic registration") + + srvr.DiscV5.RegisterTopic(topic, s.quitSync) + }() + } + s.privateKey = srvr.PrivateKey + s.protocolManager.blockLoop() +} - srvr.DiscV5.RegisterTopic(s.lesTopic, s.quitSync) - }() +func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { + bloomIndexer.AddChildIndexer(s.bloomTrieIndexer) } // Stop stops the LES service func (s *LesServer) Stop() { + s.chtIndexer.Close() + // bloom trie indexer is closed by parent bloombits indexer s.fcCostStats.store() s.fcManager.Stop() go func() { @@ -273,10 +316,7 @@ func (pm *ProtocolManager) blockLoop() { pm.wg.Add(1) headCh := make(chan core.ChainHeadEvent, 10) headSub := pm.blockchain.SubscribeChainHeadEvent(headCh) - newCht := make(chan struct{}, 10) - newCht <- struct{}{} go func() { - var mu sync.Mutex var lastHead *types.Header lastBroadcastTd := common.Big0 for { @@ -299,26 +339,37 @@ func (pm *ProtocolManager) blockLoop() { log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} + var ( + signed bool + signedAnnounce announceData + ) + for _, p := range peers { - select { - case p.announceChn <- announce: - default: - pm.removePeer(p.id) + switch p.announceType { + + case announceTypeSimple: + select { + case p.announceChn <- announce: + default: + pm.removePeer(p.id) + } + + case announceTypeSigned: + if !signed { + signedAnnounce = announce + signedAnnounce.sign(pm.server.privateKey) + signed = true + } + + select { + case p.announceChn <- signedAnnounce: + default: + pm.removePeer(p.id) + } } } } } - newCht <- struct{}{} - case <-newCht: - go func() { - mu.Lock() - more := makeCht(pm.chainDb) - mu.Unlock() - if more { - time.Sleep(time.Millisecond * 10) - newCht <- struct{}{} - } - }() case <-pm.quitSync: headSub.Unsubscribe() pm.wg.Done() @@ -327,86 +378,3 @@ func (pm *ProtocolManager) blockLoop() { } }() } - -var ( - lastChtKey = []byte("LastChtNumber") // chtNum (uint64 big endian) - chtPrefix = []byte("cht") // chtPrefix + chtNum (uint64 big endian) -> trie root hash -) - -func getChtRoot(db ethdb.Database, num uint64) common.Hash { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - data, _ := db.Get(append(chtPrefix, encNumber[:]...)) - return common.BytesToHash(data) -} - -func storeChtRoot(db ethdb.Database, num uint64, root common.Hash) { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - db.Put(append(chtPrefix, encNumber[:]...), root[:]) -} - -func makeCht(db ethdb.Database) bool { - headHash := core.GetHeadBlockHash(db) - headNum := core.GetBlockNumber(db, headHash) - - var newChtNum uint64 - if headNum > light.ChtConfirmations { - newChtNum = (headNum - light.ChtConfirmations) / light.ChtFrequency - } - - var lastChtNum uint64 - data, _ := db.Get(lastChtKey) - if len(data) == 8 { - lastChtNum = binary.BigEndian.Uint64(data[:]) - } - if newChtNum <= lastChtNum { - return false - } - - var t *trie.Trie - if lastChtNum > 0 { - var err error - t, err = trie.New(getChtRoot(db, lastChtNum), db) - if err != nil { - lastChtNum = 0 - } - } - if lastChtNum == 0 { - t, _ = trie.New(common.Hash{}, db) - } - - for num := lastChtNum * light.ChtFrequency; num < (lastChtNum+1)*light.ChtFrequency; num++ { - hash := core.GetCanonicalHash(db, num) - if hash == (common.Hash{}) { - panic("Canonical hash not found") - } - td := core.GetTd(db, hash, num) - if td == nil { - panic("TD not found") - } - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - var node light.ChtNode - node.Hash = hash - node.Td = td - data, _ := rlp.EncodeToBytes(node) - t.Update(encNumber[:], data) - } - - root, err := t.Commit() - if err != nil { - lastChtNum = 0 - } else { - lastChtNum++ - - log.Trace("Generated CHT", "number", lastChtNum, "root", root.Hex()) - - storeChtRoot(db, lastChtNum, root) - var data [8]byte - binary.BigEndian.PutUint64(data[:], lastChtNum) - db.Put(lastChtKey, data[:]) - } - - return newChtNum > lastChtNum -} diff --git a/light/lightchain.go b/light/lightchain.go index 4c877a771..30baeaccb 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -95,15 +95,8 @@ func NewLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus. if bc.genesisBlock == nil { return nil, core.ErrNoGenesis } - if bc.genesisBlock.Hash() == params.MainnetGenesisHash { - // add trusted CHT - WriteTrustedCht(bc.chainDb, TrustedCht{Number: 1040, Root: common.HexToHash("bb4fb4076cbe6923c8a8ce8f158452bbe19564959313466989fda095a60884ca")}) - log.Info("Added trusted CHT for mainnet") - } - if bc.genesisBlock.Hash() == params.TestnetGenesisHash { - // add trusted CHT - WriteTrustedCht(bc.chainDb, TrustedCht{Number: 400, Root: common.HexToHash("2a4befa19e4675d939c3dc22dca8c6ae9fcd642be1f04b06bd6e4203cc304660")}) - log.Info("Added trusted CHT for ropsten testnet") + if cp, ok := trustedCheckpoints[bc.genesisBlock.Hash()]; ok { + bc.addTrustedCheckpoint(cp) } if err := bc.loadLastState(); err != nil { @@ -120,6 +113,22 @@ func NewLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus. return bc, nil } +// addTrustedCheckpoint adds a trusted checkpoint to the blockchain +func (self *LightChain) addTrustedCheckpoint(cp trustedCheckpoint) { + if self.odr.ChtIndexer() != nil { + StoreChtRoot(self.chainDb, cp.sectionIdx, cp.sectionHead, cp.chtRoot) + self.odr.ChtIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + } + if self.odr.BloomTrieIndexer() != nil { + StoreBloomTrieRoot(self.chainDb, cp.sectionIdx, cp.sectionHead, cp.bloomTrieRoot) + self.odr.BloomTrieIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + } + if self.odr.BloomIndexer() != nil { + self.odr.BloomIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + } + log.Info("Added trusted checkpoint", "chain name", cp.name) +} + func (self *LightChain) getProcInterrupt() bool { return atomic.LoadInt32(&self.procInterrupt) == 1 } @@ -449,10 +458,13 @@ func (self *LightChain) GetHeaderByNumberOdr(ctx context.Context, number uint64) } func (self *LightChain) SyncCht(ctx context.Context) bool { + if self.odr.ChtIndexer() == nil { + return false + } headNum := self.CurrentHeader().Number.Uint64() - cht := GetTrustedCht(self.chainDb) - if headNum+1 < cht.Number*ChtFrequency { - num := cht.Number*ChtFrequency - 1 + chtCount, _, _ := self.odr.ChtIndexer().Sections() + if headNum+1 < chtCount*ChtFrequency { + num := chtCount*ChtFrequency - 1 header, err := GetHeaderByNumber(ctx, self.odr, num) if header != nil && err == nil { self.mu.Lock() diff --git a/light/nodeset.go b/light/nodeset.go new file mode 100644 index 000000000..c530a4fbe --- /dev/null +++ b/light/nodeset.go @@ -0,0 +1,141 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package light + +import ( + "errors" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +// NodeSet stores a set of trie nodes. It implements trie.Database and can also +// act as a cache for another trie.Database. +type NodeSet struct { + db map[string][]byte + dataSize int + lock sync.RWMutex +} + +// NewNodeSet creates an empty node set +func NewNodeSet() *NodeSet { + return &NodeSet{ + db: make(map[string][]byte), + } +} + +// Put stores a new node in the set +func (db *NodeSet) Put(key []byte, value []byte) error { + db.lock.Lock() + defer db.lock.Unlock() + + if _, ok := db.db[string(key)]; !ok { + db.db[string(key)] = common.CopyBytes(value) + db.dataSize += len(value) + } + return nil +} + +// Get returns a stored node +func (db *NodeSet) Get(key []byte) ([]byte, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + if entry, ok := db.db[string(key)]; ok { + return entry, nil + } + return nil, errors.New("not found") +} + +// Has returns true if the node set contains the given key +func (db *NodeSet) Has(key []byte) (bool, error) { + _, err := db.Get(key) + return err == nil, nil +} + +// KeyCount returns the number of nodes in the set +func (db *NodeSet) KeyCount() int { + db.lock.RLock() + defer db.lock.RUnlock() + + return len(db.db) +} + +// DataSize returns the aggregated data size of nodes in the set +func (db *NodeSet) DataSize() int { + db.lock.RLock() + defer db.lock.RUnlock() + + return db.dataSize +} + +// NodeList converts the node set to a NodeList +func (db *NodeSet) NodeList() NodeList { + db.lock.RLock() + defer db.lock.RUnlock() + + var values NodeList + for _, value := range db.db { + values = append(values, value) + } + return values +} + +// Store writes the contents of the set to the given database +func (db *NodeSet) Store(target trie.Database) { + db.lock.RLock() + defer db.lock.RUnlock() + + for key, value := range db.db { + target.Put([]byte(key), value) + } +} + +// NodeList stores an ordered list of trie nodes. It implements trie.DatabaseWriter. +type NodeList []rlp.RawValue + +// Store writes the contents of the list to the given database +func (n NodeList) Store(db trie.Database) { + for _, node := range n { + db.Put(crypto.Keccak256(node), node) + } +} + +// NodeSet converts the node list to a NodeSet +func (n NodeList) NodeSet() *NodeSet { + db := NewNodeSet() + n.Store(db) + return db +} + +// Put stores a new node at the end of the list +func (n *NodeList) Put(key []byte, value []byte) error { + *n = append(*n, value) + return nil +} + +// DataSize returns the aggregated data size of nodes in the list +func (n NodeList) DataSize() int { + var size int + for _, node := range n { + size += len(node) + } + return size +} diff --git a/light/odr.go b/light/odr.go index d19a488f6..e2c3d9c5a 100644 --- a/light/odr.go +++ b/light/odr.go @@ -25,9 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/rlp" ) // NoOdr is the default context passed to an ODR capable function when the ODR @@ -37,6 +35,9 @@ var NoOdr = context.Background() // OdrBackend is an interface to a backend service that handles ODR retrievals type type OdrBackend interface { Database() ethdb.Database + ChtIndexer() *core.ChainIndexer + BloomTrieIndexer() *core.ChainIndexer + BloomIndexer() *core.ChainIndexer Retrieve(ctx context.Context, req OdrRequest) error } @@ -80,23 +81,12 @@ type TrieRequest struct { OdrRequest Id *TrieID Key []byte - Proof []rlp.RawValue + Proof *NodeSet } // StoreResult stores the retrieved data in local database func (req *TrieRequest) StoreResult(db ethdb.Database) { - storeProof(db, req.Proof) -} - -// storeProof stores the new trie nodes obtained from a merkle proof in the database -func storeProof(db ethdb.Database, proof []rlp.RawValue) { - for _, buf := range proof { - hash := crypto.Keccak256(buf) - val, _ := db.Get(hash) - if val == nil { - db.Put(hash, buf) - } - } + req.Proof.Store(db) } // CodeRequest is the ODR request type for retrieving contract code @@ -138,14 +128,14 @@ func (req *ReceiptsRequest) StoreResult(db ethdb.Database) { core.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts) } -// TrieRequest is the ODR request type for state/storage trie entries +// ChtRequest is the ODR request type for state/storage trie entries type ChtRequest struct { OdrRequest ChtNum, BlockNum uint64 ChtRoot common.Hash Header *types.Header Td *big.Int - Proof []rlp.RawValue + Proof *NodeSet } // StoreResult stores the retrieved data in local database @@ -155,5 +145,27 @@ func (req *ChtRequest) StoreResult(db ethdb.Database) { hash, num := req.Header.Hash(), req.Header.Number.Uint64() core.WriteTd(db, hash, num, req.Td) core.WriteCanonicalHash(db, hash, num) - //storeProof(db, req.Proof) +} + +// BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure +type BloomRequest struct { + OdrRequest + BloomTrieNum uint64 + BitIdx uint + SectionIdxList []uint64 + BloomTrieRoot common.Hash + BloomBits [][]byte + Proofs *NodeSet +} + +// StoreResult stores the retrieved data in local database +func (req *BloomRequest) StoreResult(db ethdb.Database) { + for i, sectionIdx := range req.SectionIdxList { + sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + // if we don't have the canonical hash stored for this section head number, we'll still store it under + // a key with a zero sectionHead. GetBloomBits will look there too if we still don't have the canonical + // hash. In the unlikely case we've retrieved the section head hash since then, we'll just retrieve the + // bit vector again from the network. + core.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i]) + } } diff --git a/light/odr_test.go b/light/odr_test.go index c0c5438fd..e6afb1a48 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -77,7 +77,9 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) case *TrieRequest: t, _ := trie.New(req.Id.Root, odr.sdb) - req.Proof = t.Prove(req.Key) + nodes := NewNodeSet() + t.Prove(req.Key, 0, nodes) + req.Proof = nodes case *CodeRequest: req.Data, _ = odr.sdb.Get(req.Hash[:]) } diff --git a/light/odr_util.go b/light/odr_util.go index fcdfdb82c..a0eb6303d 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -19,56 +19,16 @@ package light import ( "bytes" "context" - "errors" - "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" ) var sha3_nil = crypto.Keccak256Hash(nil) -var ( - ErrNoTrustedCht = errors.New("No trusted canonical hash trie") - ErrNoHeader = errors.New("Header not found") - - ChtFrequency = uint64(4096) - ChtConfirmations = uint64(2048) - trustedChtKey = []byte("TrustedCHT") -) - -type ChtNode struct { - Hash common.Hash - Td *big.Int -} - -type TrustedCht struct { - Number uint64 - Root common.Hash -} - -func GetTrustedCht(db ethdb.Database) TrustedCht { - data, _ := db.Get(trustedChtKey) - var res TrustedCht - if err := rlp.DecodeBytes(data, &res); err != nil { - return TrustedCht{0, common.Hash{}} - } - return res -} - -func WriteTrustedCht(db ethdb.Database, cht TrustedCht) { - data, _ := rlp.EncodeToBytes(cht) - db.Put(trustedChtKey, data) -} - -func DeleteTrustedCht(db ethdb.Database) { - db.Delete(trustedChtKey) -} - func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) { db := odr.Database() hash := core.GetCanonicalHash(db, number) @@ -81,12 +41,29 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ return header, nil } - cht := GetTrustedCht(db) - if number >= cht.Number*ChtFrequency { + var ( + chtCount, sectionHeadNum uint64 + sectionHead common.Hash + ) + if odr.ChtIndexer() != nil { + chtCount, sectionHeadNum, sectionHead = odr.ChtIndexer().Sections() + canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + // if the CHT was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too + for chtCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { + chtCount-- + if chtCount > 0 { + sectionHeadNum = chtCount*ChtFrequency - 1 + sectionHead = odr.ChtIndexer().SectionHead(chtCount - 1) + canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + } + } + } + + if number >= chtCount*ChtFrequency { return nil, ErrNoTrustedCht } - r := &ChtRequest{ChtRoot: cht.Root, ChtNum: cht.Number, BlockNum: number} + r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number} if err := odr.Retrieve(ctx, r); err != nil { return nil, err } else { @@ -162,3 +139,61 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num } return r.Receipts, nil } + +// GetBloomBits retrieves a batch of compressed bloomBits vectors belonging to the given bit index and section indexes +func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxList []uint64) ([][]byte, error) { + db := odr.Database() + result := make([][]byte, len(sectionIdxList)) + var ( + reqList []uint64 + reqIdx []int + ) + + var ( + bloomTrieCount, sectionHeadNum uint64 + sectionHead common.Hash + ) + if odr.BloomTrieIndexer() != nil { + bloomTrieCount, sectionHeadNum, sectionHead = odr.BloomTrieIndexer().Sections() + canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + // if the BloomTrie was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too + for bloomTrieCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { + bloomTrieCount-- + if bloomTrieCount > 0 { + sectionHeadNum = bloomTrieCount*BloomTrieFrequency - 1 + sectionHead = odr.BloomTrieIndexer().SectionHead(bloomTrieCount - 1) + canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + } + } + } + + for i, sectionIdx := range sectionIdxList { + sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + // if we don't have the canonical hash stored for this section head number, we'll still look for + // an entry with a zero sectionHead (we store it with zero section head too if we don't know it + // at the time of the retrieval) + bloomBits, err := core.GetBloomBits(db, bitIdx, sectionIdx, sectionHead) + if err == nil { + result[i] = bloomBits + } else { + if sectionIdx >= bloomTrieCount { + return nil, ErrNoTrustedBloomTrie + } + reqList = append(reqList, sectionIdx) + reqIdx = append(reqIdx, i) + } + } + if reqList == nil { + return result, nil + } + + r := &BloomRequest{BloomTrieRoot: GetBloomTrieRoot(db, bloomTrieCount-1, sectionHead), BloomTrieNum: bloomTrieCount - 1, BitIdx: bitIdx, SectionIdxList: reqList} + if err := odr.Retrieve(ctx, r); err != nil { + return nil, err + } else { + for i, idx := range reqIdx { + result[idx] = r.BloomBits[i] + } + return result, nil + } +} diff --git a/light/postprocess.go b/light/postprocess.go new file mode 100644 index 000000000..e7e513880 --- /dev/null +++ b/light/postprocess.go @@ -0,0 +1,295 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package light + +import ( + "encoding/binary" + "errors" + "fmt" + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/bitutil" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + ChtFrequency = 32768 + ChtV1Frequency = 4096 // as long as we want to retain LES/1 compatibility, servers generate CHTs with the old, higher frequency + HelperTrieConfirmations = 2048 // number of confirmations before a server is expected to have the given HelperTrie available + HelperTrieProcessConfirmations = 256 // number of confirmations before a HelperTrie is generated +) + +// trustedCheckpoint represents a set of post-processed trie roots (CHT and BloomTrie) associated with +// the appropriate section index and head hash. It is used to start light syncing from this checkpoint +// and avoid downloading the entire header chain while still being able to securely access old headers/logs. +type trustedCheckpoint struct { + name string + sectionIdx uint64 + sectionHead, chtRoot, bloomTrieRoot common.Hash +} + +var ( + mainnetCheckpoint = trustedCheckpoint{ + name: "ETH mainnet", + sectionIdx: 129, + sectionHead: common.HexToHash("64100587c8ec9a76870056d07cb0f58622552d16de6253a59cac4b580c899501"), + chtRoot: common.HexToHash("bb4fb4076cbe6923c8a8ce8f158452bbe19564959313466989fda095a60884ca"), + bloomTrieRoot: common.HexToHash("0db524b2c4a2a9520a42fd842b02d2e8fb58ff37c75cf57bd0eb82daeace6716"), + } + + ropstenCheckpoint = trustedCheckpoint{ + name: "Ropsten testnet", + sectionIdx: 50, + sectionHead: common.HexToHash("00bd65923a1aa67f85e6b4ae67835784dd54be165c37f056691723c55bf016bd"), + chtRoot: common.HexToHash("6f56dc61936752cc1f8c84b4addabdbe6a1c19693de3f21cb818362df2117f03"), + bloomTrieRoot: common.HexToHash("aca7d7c504d22737242effc3fdc604a762a0af9ced898036b5986c3a15220208"), + } +) + +// trustedCheckpoints associates each known checkpoint with the genesis hash of the chain it belongs to +var trustedCheckpoints = map[common.Hash]trustedCheckpoint{ + params.MainnetGenesisHash: mainnetCheckpoint, + params.TestnetGenesisHash: ropstenCheckpoint, +} + +var ( + ErrNoTrustedCht = errors.New("No trusted canonical hash trie") + ErrNoTrustedBloomTrie = errors.New("No trusted bloom trie") + ErrNoHeader = errors.New("Header not found") + chtPrefix = []byte("chtRoot-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash + ChtTablePrefix = "cht-" +) + +// ChtNode structures are stored in the Canonical Hash Trie in an RLP encoded format +type ChtNode struct { + Hash common.Hash + Td *big.Int +} + +// GetChtRoot reads the CHT root assoctiated to the given section from the database +// Note that sectionIdx is specified according to LES/1 CHT section size +func GetChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead common.Hash) common.Hash { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + data, _ := db.Get(append(append(chtPrefix, encNumber[:]...), sectionHead.Bytes()...)) + return common.BytesToHash(data) +} + +// GetChtV2Root reads the CHT root assoctiated to the given section from the database +// Note that sectionIdx is specified according to LES/2 CHT section size +func GetChtV2Root(db ethdb.Database, sectionIdx uint64, sectionHead common.Hash) common.Hash { + return GetChtRoot(db, (sectionIdx+1)*(ChtFrequency/ChtV1Frequency)-1, sectionHead) +} + +// StoreChtRoot writes the CHT root assoctiated to the given section into the database +// Note that sectionIdx is specified according to LES/1 CHT section size +func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common.Hash) { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + db.Put(append(append(chtPrefix, encNumber[:]...), sectionHead.Bytes()...), root.Bytes()) +} + +// ChtIndexerBackend implements core.ChainIndexerBackend +type ChtIndexerBackend struct { + db, cdb ethdb.Database + section, sectionSize uint64 + lastHash common.Hash + trie *trie.Trie +} + +// NewBloomTrieIndexer creates a BloomTrie chain indexer +func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { + cdb := ethdb.NewTable(db, ChtTablePrefix) + idb := ethdb.NewTable(db, "chtIndex-") + var sectionSize, confirmReq uint64 + if clientMode { + sectionSize = ChtFrequency + confirmReq = HelperTrieConfirmations + } else { + sectionSize = ChtV1Frequency + confirmReq = HelperTrieProcessConfirmations + } + return core.NewChainIndexer(db, idb, &ChtIndexerBackend{db: db, cdb: cdb, sectionSize: sectionSize}, sectionSize, confirmReq, time.Millisecond*100, "cht") +} + +// Reset implements core.ChainIndexerBackend +func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { + var root common.Hash + if section > 0 { + root = GetChtRoot(c.db, section-1, lastSectionHead) + } + var err error + c.trie, err = trie.New(root, c.cdb) + c.section = section + return err +} + +// Process implements core.ChainIndexerBackend +func (c *ChtIndexerBackend) Process(header *types.Header) { + hash, num := header.Hash(), header.Number.Uint64() + c.lastHash = hash + + td := core.GetTd(c.db, hash, num) + if td == nil { + panic(nil) + } + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], num) + data, _ := rlp.EncodeToBytes(ChtNode{hash, td}) + c.trie.Update(encNumber[:], data) +} + +// Commit implements core.ChainIndexerBackend +func (c *ChtIndexerBackend) Commit() error { + batch := c.cdb.NewBatch() + root, err := c.trie.CommitTo(batch) + if err != nil { + return err + } else { + batch.Write() + if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 { + log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root)) + } + StoreChtRoot(c.db, c.section, c.lastHash, root) + } + return nil +} + +const ( + BloomTrieFrequency = 32768 + ethBloomBitsSection = 4096 + ethBloomBitsConfirmations = 256 +) + +var ( + bloomTriePrefix = []byte("bltRoot-") // bloomTriePrefix + bloomTrieNum (uint64 big endian) -> trie root hash + BloomTrieTablePrefix = "blt-" +) + +// GetBloomTrieRoot reads the BloomTrie root assoctiated to the given section from the database +func GetBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead common.Hash) common.Hash { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + data, _ := db.Get(append(append(bloomTriePrefix, encNumber[:]...), sectionHead.Bytes()...)) + return common.BytesToHash(data) +} + +// StoreBloomTrieRoot writes the BloomTrie root assoctiated to the given section into the database +func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common.Hash) { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + db.Put(append(append(bloomTriePrefix, encNumber[:]...), sectionHead.Bytes()...), root.Bytes()) +} + +// BloomTrieIndexerBackend implements core.ChainIndexerBackend +type BloomTrieIndexerBackend struct { + db, cdb ethdb.Database + section, parentSectionSize, bloomTrieRatio uint64 + trie *trie.Trie + sectionHeads []common.Hash +} + +// NewBloomTrieIndexer creates a BloomTrie chain indexer +func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { + cdb := ethdb.NewTable(db, BloomTrieTablePrefix) + idb := ethdb.NewTable(db, "bltIndex-") + backend := &BloomTrieIndexerBackend{db: db, cdb: cdb} + var confirmReq uint64 + if clientMode { + backend.parentSectionSize = BloomTrieFrequency + confirmReq = HelperTrieConfirmations + } else { + backend.parentSectionSize = ethBloomBitsSection + confirmReq = HelperTrieProcessConfirmations + } + backend.bloomTrieRatio = BloomTrieFrequency / backend.parentSectionSize + backend.sectionHeads = make([]common.Hash, backend.bloomTrieRatio) + return core.NewChainIndexer(db, idb, backend, BloomTrieFrequency, confirmReq-ethBloomBitsConfirmations, time.Millisecond*100, "bloomtrie") +} + +// Reset implements core.ChainIndexerBackend +func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { + var root common.Hash + if section > 0 { + root = GetBloomTrieRoot(b.db, section-1, lastSectionHead) + } + var err error + b.trie, err = trie.New(root, b.cdb) + b.section = section + return err +} + +// Process implements core.ChainIndexerBackend +func (b *BloomTrieIndexerBackend) Process(header *types.Header) { + num := header.Number.Uint64() - b.section*BloomTrieFrequency + if (num+1)%b.parentSectionSize == 0 { + b.sectionHeads[num/b.parentSectionSize] = header.Hash() + } +} + +// Commit implements core.ChainIndexerBackend +func (b *BloomTrieIndexerBackend) Commit() error { + var compSize, decompSize uint64 + + for i := uint(0); i < types.BloomBitLength; i++ { + var encKey [10]byte + binary.BigEndian.PutUint16(encKey[0:2], uint16(i)) + binary.BigEndian.PutUint64(encKey[2:10], b.section) + var decomp []byte + for j := uint64(0); j < b.bloomTrieRatio; j++ { + data, err := core.GetBloomBits(b.db, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) + if err != nil { + return err + } + decompData, err2 := bitutil.DecompressBytes(data, int(b.parentSectionSize/8)) + if err2 != nil { + return err2 + } + decomp = append(decomp, decompData...) + } + comp := bitutil.CompressBytes(decomp) + + decompSize += uint64(len(decomp)) + compSize += uint64(len(comp)) + if len(comp) > 0 { + b.trie.Update(encKey[:], comp) + } else { + b.trie.Delete(encKey[:]) + } + } + + batch := b.cdb.NewBatch() + root, err := b.trie.CommitTo(batch) + if err != nil { + return err + } else { + batch.Write() + sectionHead := b.sectionHeads[b.bloomTrieRatio-1] + log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize)) + StoreBloomTrieRoot(b.db, b.section, sectionHead, root) + } + + return nil +} diff --git a/trie/proof.go b/trie/proof.go index 298f648c4..5e886a259 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -18,11 +18,10 @@ package trie import ( "bytes" - "errors" "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) @@ -36,7 +35,7 @@ import ( // contains all nodes of the longest existing prefix of the key // (at least the root node), ending with the node that proves the // absence of the key. -func (t *Trie) Prove(key []byte) []rlp.RawValue { +func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { // Collect all nodes on the path to key. key = keybytesToHex(key) nodes := []node{} @@ -61,67 +60,63 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { tn, err = t.resolveHash(n, nil) if err != nil { log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - return nil + return err } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } hasher := newHasher(0, 0) - proof := make([]rlp.RawValue, 0, len(nodes)) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. n, _, _ = hasher.hashChildren(n, nil) hn, _ := hasher.store(n, nil, false) - if _, ok := hn.(hashNode); ok || i == 0 { + if hash, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. - enc, _ := rlp.EncodeToBytes(n) - proof = append(proof, enc) + if fromLevel > 0 { + fromLevel-- + } else { + enc, _ := rlp.EncodeToBytes(n) + if !ok { + hash = crypto.Keccak256(enc) + } + proofDb.Put(hash, enc) + } } } - return proof + return nil } // VerifyProof checks merkle proofs. The given proof must contain the // value for key in a trie with the given root hash. VerifyProof // returns an error if the proof contains invalid trie nodes or the // wrong value. -func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { +func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { key = keybytesToHex(key) - sha := sha3.NewKeccak256() - wantHash := rootHash.Bytes() - for i, buf := range proof { - sha.Reset() - sha.Write(buf) - if !bytes.Equal(sha.Sum(nil), wantHash) { - return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) + wantHash := rootHash[:] + for i := 0; ; i++ { + buf, _ := proofDb.Get(wantHash) + if buf == nil { + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i } n, err := decodeNode(wantHash, buf, 0) if err != nil { - return nil, fmt.Errorf("bad proof node %d: %v", i, err) + return nil, fmt.Errorf("bad proof node %d: %v", i, err), i } keyrest, cld := get(n, key) switch cld := cld.(type) { case nil: - if i != len(proof)-1 { - return nil, fmt.Errorf("key mismatch at proof node %d", i) - } else { - // The trie doesn't contain the key. - return nil, nil - } + // The trie doesn't contain the key. + return nil, nil, i case hashNode: key = keyrest wantHash = cld case valueNode: - if i != len(proof)-1 { - return nil, errors.New("additional nodes at end of proof") - } - return cld, nil + return cld, nil, i + 1 } } - return nil, errors.New("unexpected end of proof") } func get(tn node, key []byte) ([]byte, node) { diff --git a/trie/proof_test.go b/trie/proof_test.go index 91ebcd4a5..fff313d7f 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -24,7 +24,8 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" ) func init() { @@ -35,13 +36,13 @@ func TestProof(t *testing.T) { trie, vals := randomTrie(500) root := trie.Hash() for _, kv := range vals { - proof := trie.Prove(kv.k) - if proof == nil { + proofs, _ := ethdb.NewMemDatabase() + if trie.Prove(kv.k, 0, proofs) != nil { t.Fatalf("missing key %x while constructing proof", kv.k) } - val, err := VerifyProof(root, kv.k, proof) + val, err, _ := VerifyProof(root, kv.k, proofs) if err != nil { - t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof) + t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs) } if !bytes.Equal(val, kv.v) { t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) @@ -52,16 +53,14 @@ func TestProof(t *testing.T) { func TestOneElementProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") - proof := trie.Prove([]byte("k")) - if proof == nil { - t.Fatal("nil proof") - } - if len(proof) != 1 { + proofs, _ := ethdb.NewMemDatabase() + trie.Prove([]byte("k"), 0, proofs) + if len(proofs.Keys()) != 1 { t.Error("proof should have one element") } - val, err := VerifyProof(trie.Hash(), []byte("k"), proof) + val, err, _ := VerifyProof(trie.Hash(), []byte("k"), proofs) if err != nil { - t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof) + t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys()) } if !bytes.Equal(val, []byte("v")) { t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) @@ -72,12 +71,18 @@ func TestVerifyBadProof(t *testing.T) { trie, vals := randomTrie(800) root := trie.Hash() for _, kv := range vals { - proof := trie.Prove(kv.k) - if proof == nil { - t.Fatal("nil proof") + proofs, _ := ethdb.NewMemDatabase() + trie.Prove(kv.k, 0, proofs) + if len(proofs.Keys()) == 0 { + t.Fatal("zero length proof") } - mutateByte(proof[mrand.Intn(len(proof))]) - if _, err := VerifyProof(root, kv.k, proof); err == nil { + keys := proofs.Keys() + key := keys[mrand.Intn(len(keys))] + node, _ := proofs.Get(key) + proofs.Delete(key) + mutateByte(node) + proofs.Put(crypto.Keccak256(node), node) + if _, err, _ := VerifyProof(root, kv.k, proofs); err == nil { t.Fatalf("expected proof to fail for key %x", kv.k) } } @@ -104,8 +109,9 @@ func BenchmarkProve(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { kv := vals[keys[i%len(keys)]] - if trie.Prove(kv.k) == nil { - b.Fatalf("nil proof for %x", kv.k) + proofs, _ := ethdb.NewMemDatabase() + if trie.Prove(kv.k, 0, proofs); len(proofs.Keys()) == 0 { + b.Fatalf("zero length proof for %x", kv.k) } } } @@ -114,16 +120,18 @@ func BenchmarkVerifyProof(b *testing.B) { trie, vals := randomTrie(100) root := trie.Hash() var keys []string - var proofs [][]rlp.RawValue + var proofs []*ethdb.MemDatabase for k := range vals { keys = append(keys, k) - proofs = append(proofs, trie.Prove([]byte(k))) + proof, _ := ethdb.NewMemDatabase() + trie.Prove([]byte(k), 0, proof) + proofs = append(proofs, proof) } b.ResetTimer() for i := 0; i < b.N; i++ { im := i % len(keys) - if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + if _, err, _ := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { b.Fatalf("key %x: %v", keys[im], err) } }