From 767ff66c4ea058475e8fe778c720f185c861ce0b Mon Sep 17 00:00:00 2001 From: Christopher Schinnerl Date: Mon, 8 Jan 2018 12:12:32 -0500 Subject: [PATCH] persist tracked transaction sets --- modules/wallet/broadcast.go | 47 +++++++++++++++- modules/wallet/broadcast_test.go | 94 ++++++++++++++++++++++++++++++++ modules/wallet/database.go | 49 ++++++++++++++++- modules/wallet/update.go | 18 ++++-- modules/wallet/wallet.go | 6 ++ 5 files changed, 207 insertions(+), 7 deletions(-) diff --git a/modules/wallet/broadcast.go b/modules/wallet/broadcast.go index 26a7189fe9..d1b0cbe140 100644 --- a/modules/wallet/broadcast.go +++ b/modules/wallet/broadcast.go @@ -1,6 +1,10 @@ package wallet -import "github.com/NebulousLabs/Sia/types" +import ( + "github.com/NebulousLabs/Sia/crypto" + "github.com/NebulousLabs/Sia/modules" + "github.com/NebulousLabs/Sia/types" +) // broadcastedTSet is a helper struct to keep track of transaction sets and to // help rebroadcast them. @@ -9,11 +13,44 @@ type broadcastedTSet struct { lastTry types.BlockHeight // last time the tSet was broadcasted confirmedTxn map[types.TransactionID]bool // tracks confirmed txns of set transactions []types.Transaction // the tSet + id modules.TransactionSetID // the tSet's ID + w *Wallet +} + +// persistBTS is the on-disk version of the broadcastedTSets structure. This is +// necessary since we can't marshal a map directly. Instead we make sure that +// confirmedTxn[i] corresponds to the confirmation state of transactions[i]. +type persistBTS struct { + FirstTry types.BlockHeight // first time the tSet was broadcasted + LastTry types.BlockHeight // last time the tSet was broadcasted + ConfirmedTxn []bool // tracks confirmed txns of set + Transactions []types.Transaction // the tSet +} + +// confirmed is a helper function that sets a certain transactions to confirmed +// or unconfirmed. It also updates the state on disk. +func (bts *broadcastedTSet) confirmed(txid types.TransactionID, confirmed bool) error { + bts.confirmedTxn[txid] = confirmed + return dbPutBroadcastedTSet(bts.w.dbTx, *bts) +} + +// deleteBroadcastedTSet removes a broadcastedTSet from the wallet and disk +func (w *Wallet) deleteBroadcastedTSet(tSetID modules.TransactionSetID) error { + // Remove it from wallet + delete(w.broadcastedTSets, tSetID) + + // Remove it from disk + if err := dbDeleteBroadcastedTSet(w.dbTx, tSetID); err != nil { + return err + } + return nil } // newBroadcastedTSet creates a broadcastedTSet from a normal tSet func (w *Wallet) newBroadcastedTSet(tSet []types.Transaction) (bts *broadcastedTSet, err error) { - bts = &broadcastedTSet{} + bts = &broadcastedTSet{ + w: w, + } // Set the height of the first and last try bts.firstTry, err = dbGetConsensusHeight(w.dbTx) if err != nil { @@ -27,5 +64,11 @@ func (w *Wallet) newBroadcastedTSet(tSet []types.Transaction) (bts *broadcastedT bts.confirmedTxn[txn.ID()] = false bts.transactions = append(bts.transactions, txn) } + + // Persist the new tSet + bts.id = modules.TransactionSetID(crypto.HashAll(tSet)) + if err := dbPutBroadcastedTSet(w.dbTx, *bts); err != nil { + return nil, err + } return } diff --git a/modules/wallet/broadcast_test.go b/modules/wallet/broadcast_test.go index c8f2d8532c..ccb52aa566 100644 --- a/modules/wallet/broadcast_test.go +++ b/modules/wallet/broadcast_test.go @@ -1,10 +1,12 @@ package wallet import ( + "path/filepath" "testing" "github.com/NebulousLabs/Sia/crypto" "github.com/NebulousLabs/Sia/modules" + "github.com/NebulousLabs/Sia/modules/miner" "github.com/NebulousLabs/Sia/types" ) @@ -87,3 +89,95 @@ func TestRebroadcastTransactions(t *testing.T) { t.Fatalf("Wallet should drop txnSet after %v blocks", rebroadcastTimeout) } } + +// TestRebroadcastTransactionsPersist checks if the wallet keeps tracking +// transactions for rebroadcasting after a restart +func TestRebroadcastTransactionsPersist(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + wt, err := createWalletTester(t.Name(), &ProductionDependencies{}) + if err != nil { + t.Fatal(err) + } + defer wt.closeWt() + + // Get an address to send money to + uc, err := wt.wallet.NextAddress() + if err != nil { + t.Fatal(err) + } + // Send money to the address + tSet, err := wt.wallet.SendSiacoins(types.SiacoinPrecision, uc.UnlockHash()) + if err != nil { + t.Fatal(err) + } + // The wallet should track the tSet + if len(wt.wallet.broadcastedTSets) != 1 { + t.Fatalf("len(broadcastedTSets) should be %v but was %v", + 1, len(wt.wallet.broadcastedTSets)) + } + // Mine a block to get the tSet confirmed + if _, err := wt.miner.AddBlock(); err != nil { + t.Fatal(err) + } + // Corrupt the tSet to make sure the wallet believes it is not confirmed + tSetID := modules.TransactionSetID(crypto.HashAll(tSet)) + bts := wt.wallet.broadcastedTSets[tSetID] + bts.confirmedTxn[tSet[0].ID()] = false + if err := bts.confirmed(tSet[0].ID(), false); err != nil { + t.Fatal(err) + } + + // Close and restart the wallet and miner + if err := wt.wallet.Close(); err != nil { + t.Fatal(err) + } + if err := wt.miner.Close(); err != nil { + t.Fatal(err) + } + wt.wallet, err = New(wt.cs, wt.tpool, filepath.Join(wt.persistDir, modules.WalletDir)) + if err != nil { + t.Fatal(err) + } + if err := wt.wallet.Unlock(wt.walletMasterKey); err != nil { + t.Fatal(err) + } + wt.miner, err = miner.New(wt.cs, wt.tpool, wt.wallet, filepath.Join(wt.persistDir, modules.WalletDir)) + if err != nil { + t.Fatal(err) + } + // The wallet should still track the new tSet + if len(wt.wallet.broadcastedTSets) != 1 { + t.Fatalf("len(broadcastedTSets) should be %v but was %v", + 1, len(wt.wallet.broadcastedTSets)) + } + // The same transactions should be marked as confirmed + btsNew := wt.wallet.broadcastedTSets[tSetID] + for key := range btsNew.confirmedTxn { + if btsNew.confirmedTxn[key] != bts.confirmedTxn[key] { + t.Fatalf("txn confirmation state should be %v but was %v", + bts.confirmedTxn[key], btsNew.confirmedTxn[key]) + } + } + // Mine rebroadcastInterval blocks. The wallet should keep tracking the + // tSet afterwards + for i := 0; i < rebroadcastInterval+1; i++ { + if _, err := wt.miner.AddBlock(); err != nil { + t.Fatal(err) + } + } + if len(wt.wallet.broadcastedTSets) != 1 { + t.Fatalf("The wallet should still track the tSet") + } + // Continue mining to make sure that the wallet stops tracking the tSet + // once the max number of retries is reached + for i := types.BlockHeight(0); i < rebroadcastTimeout; i++ { + if _, err := wt.miner.AddBlock(); err != nil { + t.Fatal(err) + } + } + if _, exists := wt.wallet.broadcastedTSets[tSetID]; exists { + t.Fatalf("Wallet should drop txnSet after %v blocks", rebroadcastTimeout) + } +} diff --git a/modules/wallet/database.go b/modules/wallet/database.go index 36a90c6fb2..42020605a9 100644 --- a/modules/wallet/database.go +++ b/modules/wallet/database.go @@ -6,6 +6,7 @@ import ( "reflect" "time" + "github.com/NebulousLabs/Sia/crypto" "github.com/NebulousLabs/Sia/encoding" "github.com/NebulousLabs/Sia/modules" "github.com/NebulousLabs/Sia/types" @@ -38,9 +39,13 @@ var ( // bucketWallet contains various fields needed by the wallet, such as its // UID, EncryptionVerification, and PrimarySeedFile. bucketWallet = []byte("bucketWallet") + // bucketBroadcastedTSets contains the transaction sets that are tracked + // for rebroadcasting + bucketBroadcastedTSets = []byte("bucketBroadcastedTSets") dbBuckets = [][]byte{ bucketProcessedTransactions, + bucketBroadcastedTSets, bucketAddrTransactions, bucketSiacoinOutputs, bucketSiafundOutputs, @@ -207,7 +212,6 @@ func dbGetSpentOutput(tx *bolt.Tx, id types.OutputID) (height types.BlockHeight, func dbDeleteSpentOutput(tx *bolt.Tx, id types.OutputID) error { return dbDelete(tx.Bucket(bucketSpentOutputs), id) } - func dbPutAddrTransactions(tx *bolt.Tx, addr types.UnlockHash, txns []uint64) error { return dbPut(tx.Bucket(bucketAddrTransactions), addr, txns) } @@ -215,6 +219,49 @@ func dbGetAddrTransactions(tx *bolt.Tx, addr types.UnlockHash) (txns []uint64, e err = dbGet(tx.Bucket(bucketAddrTransactions), addr, &txns) return } +func dbPutBroadcastedTSet(tx *bolt.Tx, bts broadcastedTSet) error { + persist := persistBTS{ + FirstTry: bts.firstTry, + LastTry: bts.lastTry, + Transactions: bts.transactions, + } + // Convert bts.confirmedTxn to a boolean array + persist.ConfirmedTxn = make([]bool, len(bts.confirmedTxn)) + for i, txn := range persist.Transactions { + persist.ConfirmedTxn[i] = bts.confirmedTxn[txn.ID()] + } + return dbPut(tx.Bucket(bucketBroadcastedTSets), bts.id, persist) +} +func dbDeleteBroadcastedTSet(tx *bolt.Tx, tSetID modules.TransactionSetID) error { + return dbDelete(tx.Bucket(bucketBroadcastedTSets), tSetID) +} + +// dbLoadBroadcastedTSets returns all the broadcasted tSets from the database +func dbLoadBroadcastedTSets(tx *bolt.Tx) (tSets map[modules.TransactionSetID]*broadcastedTSet, err error) { + tSets = make(map[modules.TransactionSetID]*broadcastedTSet) + err = tx.Bucket(bucketBroadcastedTSets).ForEach(func(k []byte, v []byte) error { + // Load the persisted structure from disk + var pbts persistBTS + if err := encoding.Unmarshal(v, &pbts); err != nil { + return err + } + // Convert it to the in-memory structure + bts := broadcastedTSet{ + firstTry: pbts.FirstTry, + lastTry: pbts.LastTry, + transactions: pbts.Transactions, + } + bts.confirmedTxn = make(map[types.TransactionID]bool) + for i, txn := range pbts.Transactions { + bts.confirmedTxn[txn.ID()] = pbts.ConfirmedTxn[i] + } + + bts.id = modules.TransactionSetID(crypto.HashAll(bts.transactions)) + tSets[bts.id] = &bts + return nil + }) + return +} // dbAddAddrTransaction appends a single transaction index to the set of // transactions associated with addr. If the index is already in the set, it is diff --git a/modules/wallet/update.go b/modules/wallet/update.go index 93e077c9b0..1c4e2e78ed 100644 --- a/modules/wallet/update.go +++ b/modules/wallet/update.go @@ -482,7 +482,10 @@ func (w *Wallet) rebroadcastOldTransactions(tx *bolt.Tx, cc modules.ConsensusCha for _, bts := range w.broadcastedTSets { for _, txn := range block.Transactions { if _, exists := bts.confirmedTxn[txn.ID()]; exists { - bts.confirmedTxn[txn.ID()] = false + err = bts.confirmed(txn.ID(), false) + } + if err != nil { + return err } } } @@ -493,7 +496,10 @@ func (w *Wallet) rebroadcastOldTransactions(tx *bolt.Tx, cc modules.ConsensusCha for _, bts := range w.broadcastedTSets { for _, txn := range block.Transactions { if _, exists := bts.confirmedTxn[txn.ID()]; exists { - bts.confirmedTxn[txn.ID()] = true + err = bts.confirmed(txn.ID(), true) + } + if err != nil { + return err } } } @@ -511,7 +517,9 @@ func (w *Wallet) rebroadcastOldTransactions(tx *bolt.Tx, cc modules.ConsensusCha // If the transaction set has been confirmed for one broadcast cycle it // should be safe to remove it if confirmed && consensusHeight > bts.lastTry+rebroadcastInterval { - delete(w.broadcastedTSets, tSetID) + if err := w.deleteBroadcastedTSet(tSetID); err != nil { + return err + } continue } // If the transaction set has been confirmed recently we wait a little @@ -531,7 +539,9 @@ func (w *Wallet) rebroadcastOldTransactions(tx *bolt.Tx, cc modules.ConsensusCha // Delete the transaction set once we have tried for RespendTimeout // blocks if consensusHeight >= bts.firstTry+rebroadcastTimeout { - delete(w.broadcastedTSets, tSetID) + if err := w.deleteBroadcastedTSet(tSetID); err != nil { + return err + } } } } diff --git a/modules/wallet/wallet.go b/modules/wallet/wallet.go index c2a4034c66..24fed5fdc3 100644 --- a/modules/wallet/wallet.go +++ b/modules/wallet/wallet.go @@ -142,6 +142,12 @@ func newWallet(cs modules.ConsensusSet, tpool modules.TransactionPool, persistDi w.log.Critical("ERROR: failed to start database update:", err) } + // retrieve the previously tracked broadcasted tSets from the database + w.broadcastedTSets, err = dbLoadBroadcastedTSets(w.dbTx) + if err != nil { + return nil, err + } + // make sure we commit on shutdown w.tg.AfterStop(func() { err := w.dbTx.Commit()