From e2d643777eca05f696c3a267d082ea26e8a7535b Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 27 Feb 2024 15:44:51 -0800 Subject: [PATCH 1/2] sqlite: fix duplicate elements during resubscribe, move balance updates into store. --- cmd/walletd/main.go | 2 +- persist/sqlite/consensus.go | 331 ++++++++++++++++++++++++++----- persist/sqlite/init.sql | 2 +- persist/sqlite/wallet_test.go | 120 +++++++++++- wallet/update.go | 356 +++++++--------------------------- 5 files changed, 478 insertions(+), 333 deletions(-) diff --git a/cmd/walletd/main.go b/cmd/walletd/main.go index bc19658..acedcd8 100644 --- a/cmd/walletd/main.go +++ b/cmd/walletd/main.go @@ -189,7 +189,7 @@ func main() { consoleEncoder := zapcore.NewConsoleEncoder(consoleCfg) // only log info messages to console unless stdout logging is enabled - consoleCore := zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zap.NewAtomicLevelAt(zap.InfoLevel)) + consoleCore := zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zap.NewAtomicLevelAt(zap.DebugLevel)) logger := zap.New(consoleCore, zap.AddCaller()) defer logger.Sync() // redirect stdlib log to zap diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index c110087..9d22d37 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -19,6 +19,11 @@ type updateTx struct { relevantAddresses map[types.Address]bool } +type addressRef struct { + ID int64 + Balance wallet.Balance +} + func scanStateElement(s scanner) (se types.StateElement, err error) { err = s.Scan(decode(&se.ID), &se.LeafIndex, decodeSlice(&se.MerkleProof)) return @@ -29,6 +34,11 @@ func scanSiacoinElement(s scanner) (se types.SiacoinElement, err error) { return } +func scanAddress(s scanner) (ab addressRef, err error) { + err = s.Scan(&ab.ID, decode(&ab.Balance.Siacoins), decode(&ab.Balance.ImmatureSiacoins), &ab.Balance.Siafunds) + return +} + func (ut *updateTx) SiacoinStateElements() ([]types.StateElement, error) { const query = `SELECT id, leaf_index, merkle_proof FROM siacoin_elements` rows, err := ut.tx.Query(query) @@ -125,151 +135,378 @@ func (ut *updateTx) AddressBalance(addr types.Address) (balance wallet.Balance, return } -func (ut *updateTx) UpdateBalances(balances []wallet.AddressBalance) error { - const query = `UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2, siafund_balance=$3 WHERE sia_address=$4` - stmt, err := ut.tx.Prepare(query) +func (ut *updateTx) ApplyMatureSiacoinBalance(index types.ChainIndex) error { + const query = `SELECT se.address_id, se.siacoin_value +FROM siacoin_elements se +WHERE maturity_height=$1` + rows, err := ut.tx.Query(query, index.Height) + if err != nil { + return fmt.Errorf("failed to query siacoin elements: %w", err) + } + defer rows.Close() + + balanceDelta := make(map[int64]types.Currency) + for rows.Next() { + var addressID int64 + var value types.Currency + + if err := rows.Scan(&addressID, decode(&value)); err != nil { + return fmt.Errorf("failed to scan siacoin elements: %w", err) + } + balanceDelta[addressID] = balanceDelta[addressID].Add(value) + } + + getAddressBalanceStmt, err := ut.tx.Prepare(`SELECT siacoin_balance, immature_siacoin_balance FROM sia_addresses WHERE id=$1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer stmt.Close() + defer getAddressBalanceStmt.Close() - for _, ab := range balances { - _, err := stmt.Exec(encode(ab.Balance.Siacoins), encode(ab.Balance.ImmatureSiacoins), ab.Balance.Siafunds, encode(ab.Address)) + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer updateAddressBalanceStmt.Close() + + for addressID, delta := range balanceDelta { + var balance, immatureBalance types.Currency + err := getAddressBalanceStmt.QueryRow(addressID).Scan(decode(&balance), decode(&immatureBalance)) if err != nil { - return fmt.Errorf("failed to execute statement: %w", err) + return fmt.Errorf("failed to get address balance: %w", err) + } + + balance = balance.Add(delta) + immatureBalance = immatureBalance.Sub(delta) + + res, err := updateAddressBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) + if err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row affected, got %v", n) } } return nil } -func (ut *updateTx) MaturedSiacoinElements(index types.ChainIndex) (elements []types.SiacoinElement, err error) { - const query = `SELECT se.id, se.siacoin_value, se.merkle_proof, se.leaf_index, se.maturity_height, a.sia_address -FROM siacoin_elements se -INNER JOIN sia_addresses a ON (se.address_id=a.id) -WHERE maturity_height=$1` +func (ut *updateTx) RevertMatureSiacoinBalance(index types.ChainIndex) error { + const query = `SELECT se.address_id, se.siacoin_value + FROM siacoin_elements se + WHERE maturity_height=$1` rows, err := ut.tx.Query(query, index.Height) if err != nil { - return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + return fmt.Errorf("failed to query siacoin elements: %w", err) } defer rows.Close() + balanceDelta := make(map[int64]types.Currency) for rows.Next() { - element, err := scanSiacoinElement(rows) - if err != nil { - return nil, fmt.Errorf("failed to scan siacoin element: %w", err) + var addressID int64 + var value types.Currency + + if err := rows.Scan(&addressID, decode(&value)); err != nil { + return fmt.Errorf("failed to scan siacoin elements: %w", err) } - elements = append(elements, element) + balanceDelta[addressID] = balanceDelta[addressID].Add(value) } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("failed to scan siacoin elements: %w", err) + + getAddressBalanceStmt, err := ut.tx.Prepare(`SELECT siacoin_balance, immature_siacoin_balance FROM sia_addresses WHERE id=$1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) } - return + defer getAddressBalanceStmt.Close() + + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer updateAddressBalanceStmt.Close() + + for addressID, delta := range balanceDelta { + var balance, immatureBalance types.Currency + err := getAddressBalanceStmt.QueryRow(addressID).Scan(decode(&balance), decode(&immatureBalance)) + if err != nil { + return fmt.Errorf("failed to get address balance: %w", err) + } + + balance = balance.Sub(delta) + immatureBalance = immatureBalance.Add(delta) + + res, err := updateAddressBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) + if err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row affected, got %v", n) + } + } + return nil } -func (ut *updateTx) AddSiacoinElements(elements []types.SiacoinElement) error { +func (ut *updateTx) AddSiacoinElements(elements []types.SiacoinElement, index types.ChainIndex) error { + if len(elements) == 0 { + return nil + } + addrStmt, err := insertAddressStatement(ut.tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } defer addrStmt.Close() - insertStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) + // ignore elements already in the database. + insertStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (id) DO NOTHING RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare insert statement: %w", err) } defer insertStmt.Close() + balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - var addressID int64 - err = addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) + addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0)) if err != nil { return fmt.Errorf("failed to query address: %w", err) + } else if _, ok := balanceChanges[addrRef.ID]; !ok { + balanceChanges[addrRef.ID] = addrRef.Balance } - _, err = insertStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID) - if err != nil { + var dummyID types.Hash256 + err = insertStmt.QueryRow(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addrRef.ID).Scan(decode(&dummyID)) + if errors.Is(err, sql.ErrNoRows) { + continue // skip if the element already exists + } else if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } + + // update the balance if the element does not exist + balance := balanceChanges[addrRef.ID] + if se.MaturityHeight <= index.Height { + balance.Siacoins = balance.Siacoins.Add(se.SiacoinOutput.Value) + } else { + balance.ImmatureSiacoins = balance.ImmatureSiacoins.Add(se.SiacoinOutput.Value) + } + balanceChanges[addrRef.ID] = balance + } + + if len(balanceChanges) == 0 { + return nil + } + + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare update balance statement: %w", err) + } + defer updateAddressBalanceStmt.Close() + + for addrID, balance := range balanceChanges { + res, err := updateAddressBalanceStmt.Exec(encode(balance.Siacoins), encode(balance.ImmatureSiacoins), addrID) + if err != nil { + return fmt.Errorf("failed to update balance: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row affected, got %v", n) + } } return nil } -func (ut *updateTx) RemoveSiacoinElements(elements []types.SiacoinOutputID) error { +func (ut *updateTx) RemoveSiacoinElements(elements []types.SiacoinElement, index types.ChainIndex) error { + if len(elements) == 0 { + return nil + } + + addrStmt, err := insertAddressStatement(ut.tx) + if err != nil { + return fmt.Errorf("failed to prepare address statement: %w", err) + } + defer addrStmt.Close() + stmt, err := ut.tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer stmt.Close() - for _, id := range elements { + balanceChanges := make(map[int64]wallet.Balance) + for _, se := range elements { + addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0)) + if err != nil { + return fmt.Errorf("failed to query address: %w", err) + } else if _, ok := balanceChanges[addrRef.ID]; !ok { + balanceChanges[addrRef.ID] = addrRef.Balance + } + var dummy types.Hash256 - err := stmt.QueryRow(encode(id)).Scan(decode(&dummy)) + err = stmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) + if err != nil { + return fmt.Errorf("failed to delete element %q: %w", se.ID, err) + } + + balance := balanceChanges[addrRef.ID] + if se.MaturityHeight < index.Height { + balance.Siacoins = balance.Siacoins.Sub(se.SiacoinOutput.Value) + } else { + balance.ImmatureSiacoins = balance.ImmatureSiacoins.Sub(se.SiacoinOutput.Value) + } + balanceChanges[addrRef.ID] = balance + } + + if len(balanceChanges) == 0 { + return nil + } + + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare update balance statement: %w", err) + } + defer updateAddressBalanceStmt.Close() + + for addrID, balance := range balanceChanges { + res, err := updateAddressBalanceStmt.Exec(encode(balance.Siacoins), encode(balance.ImmatureSiacoins), addrID) if err != nil { - return fmt.Errorf("failed to delete element %q: %w", id, err) + return fmt.Errorf("failed to update balance: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row affected, got %v", n) } } return nil } -func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement) error { +func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement, index types.ChainIndex) error { + if len(elements) == 0 { + return nil + } + addrStmt, err := insertAddressStatement(ut.tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } defer addrStmt.Close() - insertStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) + insertStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (id) DO NOTHING RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer insertStmt.Close() + balanceChanges := make(map[types.Address]uint64) for _, se := range elements { - var addressID int64 - err = addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) + addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) if err != nil { return fmt.Errorf("failed to query address: %w", err) + } else if _, ok := balanceChanges[se.SiafundOutput.Address]; !ok { + balanceChanges[se.SiafundOutput.Address] = addrRef.Balance.Siafunds } - _, err = insertStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID) - if err != nil { + var dummy types.Hash256 + err = insertStmt.QueryRow(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addrRef.ID).Scan(decode(&dummy)) + if errors.Is(err, sql.ErrNoRows) { + continue // skip if the element already exists + } else if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } + balanceChanges[se.SiafundOutput.Address] += se.SiafundOutput.Value + } + + if len(balanceChanges) == 0 { + return nil + } + + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE sia_address=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update balance statement: %w", err) + } + defer updateAddressBalanceStmt.Close() + + for addr, balance := range balanceChanges { + res, err := updateAddressBalanceStmt.Exec(balance, encode(addr)) + if err != nil { + return fmt.Errorf("failed to update balance: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row affected, got %v", n) + } } return nil } -func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundOutputID) error { +func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundElement, index types.ChainIndex) error { + addrStmt, err := insertAddressStatement(ut.tx) + if err != nil { + return fmt.Errorf("failed to prepare address statement: %w", err) + } + defer addrStmt.Close() + stmt, err := ut.tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer stmt.Close() - for _, id := range elements { + balanceChanges := make(map[types.Address]uint64) + for _, se := range elements { + addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) + if err != nil { + return fmt.Errorf("failed to query address: %w", err) + } else if _, ok := balanceChanges[se.SiafundOutput.Address]; !ok { + balanceChanges[se.SiafundOutput.Address] = addrRef.Balance.Siafunds + } + var dummy types.Hash256 - err := stmt.QueryRow(encode(id)).Scan(decode(&dummy)) + err = stmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) if err != nil { - return fmt.Errorf("failed to delete element %q: %w", id, err) + return fmt.Errorf("failed to delete element %q: %w", se.ID, err) + } + + if balanceChanges[se.SiafundOutput.Address] < se.SiafundOutput.Value { + panic("siafund balance cannot be negative") + } + balanceChanges[se.SiafundOutput.Address] -= se.SiafundOutput.Value + } + + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE sia_address=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update balance statement: %w", err) + } + defer updateAddressBalanceStmt.Close() + + for addr, balance := range balanceChanges { + res, err := updateAddressBalanceStmt.Exec(balance, encode(addr)) + if err != nil { + return fmt.Errorf("failed to update balance: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row affected, got %v", n) } } return nil } func (ut *updateTx) AddEvents(events []wallet.Event) error { + if len(events) == 0 { + return nil + } + indexStmt, err := insertIndexStmt(ut.tx) if err != nil { return fmt.Errorf("failed to prepare index statement: %w", err) } defer indexStmt.Close() - eventStmt, err := ut.tx.Prepare(`INSERT INTO events (event_id, maturity_height, date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id`) + insertEventStmt, err := ut.tx.Prepare(`INSERT INTO events (event_id, maturity_height, date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (event_id) DO NOTHING RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare event statement: %w", err) } - defer eventStmt.Close() + defer insertEventStmt.Close() - addrStmt, err := insertAddressStatement(ut.tx) + addrStmt, err := ut.tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $3, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } @@ -296,8 +533,10 @@ func (ut *updateTx) AddEvents(events []wallet.Event) error { } var eventID int64 - err = eventStmt.QueryRow(encode(event.ID), event.MaturityHeight, encode(event.Timestamp), chainIndexID, event.Data.EventType(), buf.String()).Scan(&eventID) - if err != nil { + err = insertEventStmt.QueryRow(encode(event.ID), event.MaturityHeight, encode(event.Timestamp), chainIndexID, event.Data.EventType(), buf.String()).Scan(&eventID) + if errors.Is(err, sql.ErrNoRows) { + continue // skip if the event already exists + } else if err != nil { return fmt.Errorf("failed to add event: %w", err) } @@ -400,7 +639,7 @@ func setLastCommittedIndex(tx *txn, index types.ChainIndex) error { } func insertAddressStatement(tx *txn) (*stmt, error) { - return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) + return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id, siacoin_balance, immature_siacoin_balance, siafund_balance`) } func insertIndexStmt(tx *txn) (*stmt, error) { diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index 7d12cc0..6a9ee54 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -55,7 +55,7 @@ CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); CREATE TABLE events ( id INTEGER PRIMARY KEY, - event_id BLOB NOT NULL, + event_id BLOB UNIQUE NOT NULL, index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, maturity_height INTEGER NOT NULL, date_created INTEGER NOT NULL, diff --git a/persist/sqlite/wallet_test.go b/persist/sqlite/wallet_test.go index 5c1e960..08465f0 100644 --- a/persist/sqlite/wallet_test.go +++ b/persist/sqlite/wallet_test.go @@ -1,4 +1,4 @@ -package sqlite +package sqlite_test import ( "encoding/json" @@ -6,13 +6,16 @@ import ( "testing" "go.sia.tech/core/types" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/walletd/persist/sqlite" "go.sia.tech/walletd/wallet" "go.uber.org/zap/zaptest" ) func TestWalletAddresses(t *testing.T) { log := zaptest.NewLogger(t) - db, err := OpenDatabase(filepath.Join(t.TempDir(), "walletd.sqlite3"), log.Named("sqlite3")) + db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "walletd.sqlite3"), log.Named("sqlite3")) if err != nil { t.Fatal(err) } @@ -104,3 +107,116 @@ func TestWalletAddresses(t *testing.T) { t.Fatal("expected 0 addresses, got", len(addresses)) } } + +func TestResubscribe(t *testing.T) { + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + defer bdb.Close() + + network, genesisBlock := testV1Network() + + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + cm := chain.NewManager(store, genesisState) + + if err := cm.AddSubscriber(db, types.ChainIndex{}); err != nil { + t.Fatal(err) + } + + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + + w, err := db.AddWallet(wallet.Wallet{Name: "test"}) + if err != nil { + t.Fatal(err) + } else if err := db.AddWalletAddress(w.ID, wallet.Address{Address: addr}); err != nil { + t.Fatal(err) + } + + expectedPayout := cm.TipState().BlockReward() + maturityHeight := cm.TipState().MaturityHeight() + // mine a block sending the payout to the wallet + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + + // check that the payout was received + balance, err := db.WalletBalance(w.ID) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err := db.WalletEvents(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } + + // check that the utxo was created + utxos, err := db.WalletSiacoinOutputs(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } + + cm.RemoveSubscriber(db) + if err := cm.AddSubscriber(db, types.ChainIndex{}); err != nil { + t.Fatal(err) + } + + // check that the balance, events, and utxos did not change + // check that the payout was received + balance, err = db.WalletBalance(w.ID) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err = db.WalletEvents(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } + + // check that the utxo was created + utxos, err = db.WalletSiacoinOutputs(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } +} diff --git a/wallet/update.go b/wallet/update.go index f9a140a..2890d3c 100644 --- a/wallet/update.go +++ b/wallet/update.go @@ -22,23 +22,20 @@ type ( SiafundStateElements() ([]types.StateElement, error) UpdateSiafundStateElements([]types.StateElement) error - AddSiacoinElements([]types.SiacoinElement) error - RemoveSiacoinElements([]types.SiacoinOutputID) error + AddSiacoinElements([]types.SiacoinElement, types.ChainIndex) error + RemoveSiacoinElements([]types.SiacoinElement, types.ChainIndex) error - AddSiafundElements([]types.SiafundElement) error - RemoveSiafundElements([]types.SiafundOutputID) error - - MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) + AddSiafundElements([]types.SiafundElement, types.ChainIndex) error + RemoveSiafundElements([]types.SiafundElement, types.ChainIndex) error AddressRelevant(types.Address) (bool, error) - AddressBalance(types.Address) (Balance, error) - UpdateBalances([]AddressBalance) error } // An ApplyTx atomically applies a set of updates to a store. ApplyTx interface { UpdateTx + ApplyMatureSiacoinBalance(types.ChainIndex) error AddEvents([]Event) error } @@ -46,59 +43,18 @@ type ( RevertTx interface { UpdateTx + RevertMatureSiacoinBalance(types.ChainIndex) error RevertEvents(index types.ChainIndex) error } ) // ApplyChainUpdates atomically applies a set of chain updates to a store func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { - var events []Event - balances := make(map[types.Address]Balance) - newSiacoinElements := make(map[types.SiacoinOutputID]types.SiacoinElement) - newSiafundElements := make(map[types.SiafundOutputID]types.SiafundElement) - spentSiacoinElements := make(map[types.SiacoinOutputID]bool) - spentSiafundElements := make(map[types.SiafundOutputID]bool) - - updateBalance := func(addr types.Address, fn func(b *Balance)) error { - balance, ok := balances[addr] - if !ok { - var err error - balance, err = tx.AddressBalance(addr) - if err != nil { - return fmt.Errorf("failed to get address balance: %w", err) - } - } - - fn(&balance) - balances[addr] = balance - return nil - } - - // fetch all siacoin and siafund state elements - siacoinStateElements, err := tx.SiacoinStateElements() - if err != nil { - return fmt.Errorf("failed to get siacoin state elements: %w", err) - } - siafundStateElements, err := tx.SiafundStateElements() - if err != nil { - return fmt.Errorf("failed to get siafund state elements: %w", err) - } - for _, cau := range updates { // update the immature balance of each relevant address - matured, err := tx.MaturedSiacoinElements(cau.State.Index) - if err != nil { + if err := tx.ApplyMatureSiacoinBalance(cau.State.Index); err != nil { return fmt.Errorf("failed to get matured siacoin elements: %w", err) } - for _, se := range matured { - err := updateBalance(se.SiacoinOutput.Address, func(b *Balance) { - b.ImmatureSiacoins = b.ImmatureSiacoins.Sub(se.SiacoinOutput.Value) - b.Siacoins = b.Siacoins.Add(se.SiacoinOutput.Value) - }) - if err != nil { - return fmt.Errorf("failed to update address balance: %w", err) - } - } // determine which siacoin and siafund elements are ephemeral // @@ -122,87 +78,58 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { } // add new siacoin elements to the store - var siacoinElementErr error + var newSiacoinElements, spentSiacoinElements []types.SiacoinElement cau.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - if siacoinElementErr != nil { - return - } else if ephemeral[se.ID] { + if ephemeral[se.ID] { return } relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) if err != nil { - siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return + panic(err) } else if !relevant { return } if spent { - delete(newSiacoinElements, types.SiacoinOutputID(se.ID)) - spentSiacoinElements[types.SiacoinOutputID(se.ID)] = true + spentSiacoinElements = append(spentSiacoinElements, se) } else { - newSiacoinElements[types.SiacoinOutputID(se.ID)] = se - } - - err = updateBalance(se.SiacoinOutput.Address, func(b *Balance) { - switch { - case se.MaturityHeight > cau.State.Index.Height: - b.ImmatureSiacoins = b.ImmatureSiacoins.Add(se.SiacoinOutput.Value) - case spent: - b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) - default: - b.Siacoins = b.Siacoins.Add(se.SiacoinOutput.Value) - } - }) - if err != nil { - siacoinElementErr = fmt.Errorf("failed to update address balance: %w", err) - return + newSiacoinElements = append(newSiacoinElements, se) } }) - if siacoinElementErr != nil { - return fmt.Errorf("failed to add siacoin elements: %w", siacoinElementErr) + + if err := tx.AddSiacoinElements(newSiacoinElements, cau.State.Index); err != nil { + return fmt.Errorf("failed to add siacoin elements: %w", err) + } else if err := tx.RemoveSiacoinElements(spentSiacoinElements, cau.State.Index); err != nil { + return fmt.Errorf("failed to remove siacoin elements: %w", err) } - var siafundElementErr error + var newSiafundElements, spentSiafundElements []types.SiafundElement cau.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { - if siafundElementErr != nil { - return - } else if ephemeral[se.ID] { + if ephemeral[se.ID] { return } relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) if err != nil { - siafundElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return + panic(err) } else if !relevant { return } if spent { - delete(newSiafundElements, types.SiafundOutputID(se.ID)) - spentSiafundElements[types.SiafundOutputID(se.ID)] = true + spentSiafundElements = append(spentSiafundElements, se) } else { - newSiafundElements[types.SiafundOutputID(se.ID)] = se - } - - err = updateBalance(se.SiafundOutput.Address, func(b *Balance) { - if spent { - if b.Siafunds < se.SiafundOutput.Value { - panic(fmt.Errorf("negative siafund balance")) - } - b.Siafunds -= se.SiafundOutput.Value - } else { - b.Siafunds += se.SiafundOutput.Value - } - }) - if err != nil { - siafundElementErr = fmt.Errorf("failed to update address balance: %w", err) - return + newSiafundElements = append(newSiafundElements, se) } }) + if err := tx.AddSiafundElements(newSiafundElements, cau.State.Index); err != nil { + return fmt.Errorf("failed to add siafund elements: %w", err) + } else if err := tx.RemoveSiafundElements(spentSiafundElements, cau.State.Index); err != nil { + return fmt.Errorf("failed to remove siafund elements: %w", err) + } + // add events relevant := func(addr types.Address) bool { relevant, err := tx.AddressRelevant(addr) @@ -211,135 +138,44 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { } return relevant } + if err := tx.AddEvents(AppliedEvents(cau.State, cau.Block, cau, relevant)); err != nil { + return fmt.Errorf("failed to add events: %w", err) + } + + // fetch all siacoin and siafund state elements + siacoinStateElements, err := tx.SiacoinStateElements() if err != nil { - return fmt.Errorf("failed to get applied events: %w", err) + return fmt.Errorf("failed to get siacoin state elements: %w", err) } - events = append(events, AppliedEvents(cau.State, cau.Block, cau, relevant)...) // update siacoin element proofs - for id := range newSiacoinElements { - ele := newSiacoinElements[id] - cau.UpdateElementProof(&ele.StateElement) - newSiacoinElements[id] = ele - } for i := range siacoinStateElements { cau.UpdateElementProof(&siacoinStateElements[i]) } - // update siafund element proofs - for id := range newSiafundElements { - ele := newSiafundElements[id] - cau.UpdateElementProof(&ele.StateElement) - newSiafundElements[id] = ele + if err := tx.UpdateSiacoinStateElements(siacoinStateElements); err != nil { + return fmt.Errorf("failed to update siacoin state elements: %w", err) } - for i := range siafundStateElements { - cau.UpdateElementProof(&siafundStateElements[i]) - } - } - - // update the address balances - balanceChanges := make([]AddressBalance, 0, len(balances)) - for addr, balance := range balances { - balanceChanges = append(balanceChanges, AddressBalance{ - Address: addr, - Balance: balance, - }) - } - if err = tx.UpdateBalances(balanceChanges); err != nil { - return fmt.Errorf("failed to update address balance: %w", err) - } - // add the new siacoin elements - siacoinElements := make([]types.SiacoinElement, 0, len(newSiacoinElements)) - for _, ele := range newSiacoinElements { - siacoinElements = append(siacoinElements, ele) - } - if err = tx.AddSiacoinElements(siacoinElements); err != nil { - return fmt.Errorf("failed to add siacoin elements: %w", err) - } - - // remove the spent siacoin elements - siacoinOutputIDs := make([]types.SiacoinOutputID, 0, len(spentSiacoinElements)) - for id := range spentSiacoinElements { - siacoinOutputIDs = append(siacoinOutputIDs, id) - } - if err = tx.RemoveSiacoinElements(siacoinOutputIDs); err != nil { - return fmt.Errorf("failed to remove siacoin elements: %w", err) - } - - // add the new siafund elements - siafundElements := make([]types.SiafundElement, 0, len(newSiafundElements)) - for _, ele := range newSiafundElements { - siafundElements = append(siafundElements, ele) - } - if err = tx.AddSiafundElements(siafundElements); err != nil { - return fmt.Errorf("failed to add siafund elements: %w", err) - } - - // remove the spent siafund elements - siafundOutputIDs := make([]types.SiafundOutputID, 0, len(spentSiafundElements)) - for id := range spentSiafundElements { - siafundOutputIDs = append(siafundOutputIDs, id) - } - if err = tx.RemoveSiafundElements(siafundOutputIDs); err != nil { - return fmt.Errorf("failed to remove siafund elements: %w", err) - } - - // add new events - if err = tx.AddEvents(events); err != nil { - return fmt.Errorf("failed to add events: %w", err) - } + siafundStateElements, err := tx.SiafundStateElements() + if err != nil { + return fmt.Errorf("failed to get siafund state elements: %w", err) + } - // update the siacoin state elements - filteredStateElements := siacoinStateElements[:0] - for _, se := range siacoinStateElements { - if _, ok := spentSiacoinElements[types.SiacoinOutputID(se.ID)]; !ok { - filteredStateElements = append(filteredStateElements, se) + // update siafund element proofs + for i := range siafundStateElements { + cau.UpdateElementProof(&siafundStateElements[i]) } - } - err = tx.UpdateSiacoinStateElements(filteredStateElements) - if err != nil { - return fmt.Errorf("failed to update siacoin state elements: %w", err) - } - // update the siafund state elements - filteredStateElements = siafundStateElements[:0] - for _, se := range siafundStateElements { - if _, ok := spentSiafundElements[types.SiafundOutputID(se.ID)]; !ok { - filteredStateElements = append(filteredStateElements, se) + if err := tx.UpdateSiafundStateElements(siafundStateElements); err != nil { + return fmt.Errorf("failed to update siacoin state elements: %w", err) } } - if err = tx.UpdateSiafundStateElements(filteredStateElements); err != nil { - return fmt.Errorf("failed to update siafund state elements: %w", err) - } - return nil } // RevertChainUpdate atomically reverts a chain update from a store func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { - balances := make(map[types.Address]Balance) - - var deletedSiacoinElements []types.SiacoinOutputID - var addedSiacoinElements []types.SiacoinElement - var deletedSiafundElements []types.SiafundOutputID - var addedSiafundElements []types.SiafundElement - - updateBalance := func(addr types.Address, fn func(b *Balance)) error { - balance, ok := balances[addr] - if !ok { - var err error - balance, err = tx.AddressBalance(addr) - if err != nil { - return fmt.Errorf("failed to get address balance: %w", err) - } - } - - fn(&balance) - balances[addr] = balance - return nil - } - // determine which siacoin and siafund elements are ephemeral // // note: I thought we could use LeafIndex == EphemeralLeafIndex, but @@ -367,34 +203,17 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { ID: cru.Block.ID(), } - matured, err := tx.MaturedSiacoinElements(revertedIndex) - if err != nil { - return fmt.Errorf("failed to get matured siacoin elements: %w", err) - } - for _, se := range matured { - err := updateBalance(se.SiacoinOutput.Address, func(b *Balance) { - b.ImmatureSiacoins = b.ImmatureSiacoins.Add(se.SiacoinOutput.Value) - b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) - }) - if err != nil { - return fmt.Errorf("failed to update address balance: %w", err) - } - } - - var siacoinElementErr error + var removedSiacoinElements, addedSiacoinElements []types.SiacoinElement cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - if siacoinElementErr != nil { + if ephemeral[se.ID] { return } relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) if err != nil { - siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return + panic(err) } else if !relevant { return - } else if ephemeral[se.ID] { - return } if spent { @@ -402,38 +221,27 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { addedSiacoinElements = append(addedSiacoinElements, se) } else { // delete any created siacoin elements - deletedSiacoinElements = append(deletedSiacoinElements, types.SiacoinOutputID(se.ID)) + removedSiacoinElements = append(removedSiacoinElements, se) } - - siacoinElementErr = updateBalance(se.SiacoinOutput.Address, func(b *Balance) { - switch { - case se.MaturityHeight > cru.State.Index.Height: - b.ImmatureSiacoins = b.ImmatureSiacoins.Sub(se.SiacoinOutput.Value) - case spent: - b.Siacoins = b.Siacoins.Add(se.SiacoinOutput.Value) - default: - b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) - } - }) }) - if siacoinElementErr != nil { - return fmt.Errorf("failed to update address balance: %w", siacoinElementErr) + + if err := tx.AddSiacoinElements(addedSiacoinElements, revertedIndex); err != nil { + return fmt.Errorf("failed to add siacoin elements: %w", err) + } else if err := tx.RemoveSiacoinElements(removedSiacoinElements, revertedIndex); err != nil { + return fmt.Errorf("failed to remove siacoin elements: %w", err) } - var siafundElementErr error + var removedSiafundElements, addedSiafundElements []types.SiafundElement cru.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { - if siafundElementErr != nil { + if ephemeral[se.ID] { return } relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) if err != nil { - siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return + panic(err) } else if !relevant { return - } else if ephemeral[se.ID] { - return } if spent { @@ -441,40 +249,22 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { addedSiafundElements = append(addedSiafundElements, se) } else { // delete any created siafund elements - deletedSiafundElements = append(deletedSiafundElements, types.SiafundOutputID(se.ID)) + removedSiafundElements = append(removedSiafundElements, se) } - - siafundElementErr = updateBalance(se.SiafundOutput.Address, func(b *Balance) { - if spent { - b.Siafunds += se.SiafundOutput.Value - } else { - b.Siafunds -= se.SiafundOutput.Value - } - }) }) - if siafundElementErr != nil { - return fmt.Errorf("failed to update address balance: %w", siafundElementErr) - } - balanceChanges := make([]AddressBalance, 0, len(balances)) - for addr, balance := range balances { - balanceChanges = append(balanceChanges, AddressBalance{ - Address: addr, - Balance: balance, - }) - } - if err := tx.UpdateBalances(balanceChanges); err != nil { - return fmt.Errorf("failed to update address balance: %w", err) + // revert siafund element changes + if err := tx.AddSiafundElements(addedSiafundElements, revertedIndex); err != nil { + return fmt.Errorf("failed to add siafund elements: %w", err) + } else if err := tx.RemoveSiafundElements(removedSiafundElements, revertedIndex); err != nil { + return fmt.Errorf("failed to remove siafund elements: %w", err) } - // revert siacoin element changes - if err := tx.AddSiacoinElements(addedSiacoinElements); err != nil { - return fmt.Errorf("failed to add siacoin elements: %w", err) - } else if err := tx.RemoveSiacoinElements(deletedSiacoinElements); err != nil { - return fmt.Errorf("failed to remove siacoin elements: %w", err) + // revert mature siacoin balance for each relevant address + if err := tx.RevertMatureSiacoinBalance(revertedIndex); err != nil { + return fmt.Errorf("failed to get matured siacoin elements: %w", err) } - // update siacoin element proofs siacoinElements, err := tx.SiacoinStateElements() if err != nil { return fmt.Errorf("failed to get siacoin state elements: %w", err) @@ -482,12 +272,8 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { for i := range siacoinElements { cru.UpdateElementProof(&siacoinElements[i]) } - - // revert siafund element changes - if err := tx.AddSiafundElements(addedSiafundElements); err != nil { - return fmt.Errorf("failed to add siafund elements: %w", err) - } else if err := tx.RemoveSiafundElements(deletedSiafundElements); err != nil { - return fmt.Errorf("failed to remove siafund elements: %w", err) + if err := tx.UpdateSiacoinStateElements(siacoinElements); err != nil { + return fmt.Errorf("failed to update siacoin state elements: %w", err) } // update siafund element proofs @@ -498,6 +284,10 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { for i := range siafundElements { cru.UpdateElementProof(&siafundElements[i]) } + if err := tx.UpdateSiafundStateElements(siafundElements); err != nil { + return fmt.Errorf("failed to update siafund state elements: %w", err) + } + // revert events return tx.RevertEvents(revertedIndex) } From b9279f998068a2e55c53340533df15c7b29eb136 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 27 Feb 2024 16:00:09 -0800 Subject: [PATCH 2/2] sqlite: set siafund address, fix siafund balance update --- persist/sqlite/consensus.go | 34 ++++++++++++++++++-------------- persist/sqlite/consensus_test.go | 30 +++++++++++++++------------- persist/sqlite/wallet_test.go | 8 ++++---- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 9d22d37..7d96a4c 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -394,13 +394,13 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement, index ty } defer insertStmt.Close() - balanceChanges := make(map[types.Address]uint64) + balanceChanges := make(map[int64]uint64) for _, se := range elements { addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) if err != nil { return fmt.Errorf("failed to query address: %w", err) - } else if _, ok := balanceChanges[se.SiafundOutput.Address]; !ok { - balanceChanges[se.SiafundOutput.Address] = addrRef.Balance.Siafunds + } else if _, ok := balanceChanges[addrRef.ID]; !ok { + balanceChanges[addrRef.ID] = addrRef.Balance.Siafunds } var dummy types.Hash256 @@ -410,21 +410,21 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement, index ty } else if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } - balanceChanges[se.SiafundOutput.Address] += se.SiafundOutput.Value + balanceChanges[addrRef.ID] += se.SiafundOutput.Value } if len(balanceChanges) == 0 { return nil } - updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE sia_address=$2`) + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) if err != nil { return fmt.Errorf("failed to prepare update balance statement: %w", err) } defer updateAddressBalanceStmt.Close() - for addr, balance := range balanceChanges { - res, err := updateAddressBalanceStmt.Exec(balance, encode(addr)) + for addrID, balance := range balanceChanges { + res, err := updateAddressBalanceStmt.Exec(balance, addrID) if err != nil { return fmt.Errorf("failed to update balance: %w", err) } else if n, err := res.RowsAffected(); err != nil { @@ -449,13 +449,13 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundElement, index } defer stmt.Close() - balanceChanges := make(map[types.Address]uint64) + balanceChanges := make(map[int64]uint64) for _, se := range elements { addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) if err != nil { return fmt.Errorf("failed to query address: %w", err) - } else if _, ok := balanceChanges[se.SiafundOutput.Address]; !ok { - balanceChanges[se.SiafundOutput.Address] = addrRef.Balance.Siafunds + } else if _, ok := balanceChanges[addrRef.ID]; !ok { + balanceChanges[addrRef.ID] = addrRef.Balance.Siafunds } var dummy types.Hash256 @@ -464,20 +464,24 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundElement, index return fmt.Errorf("failed to delete element %q: %w", se.ID, err) } - if balanceChanges[se.SiafundOutput.Address] < se.SiafundOutput.Value { + if balanceChanges[addrRef.ID] < se.SiafundOutput.Value { panic("siafund balance cannot be negative") } - balanceChanges[se.SiafundOutput.Address] -= se.SiafundOutput.Value + balanceChanges[addrRef.ID] -= se.SiafundOutput.Value + } + + if len(balanceChanges) == 0 { + return nil } - updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE sia_address=$2`) + updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) if err != nil { return fmt.Errorf("failed to prepare update balance statement: %w", err) } defer updateAddressBalanceStmt.Close() - for addr, balance := range balanceChanges { - res, err := updateAddressBalanceStmt.Exec(balance, encode(addr)) + for addrID, balance := range balanceChanges { + res, err := updateAddressBalanceStmt.Exec(balance, addrID) if err != nil { return fmt.Errorf("failed to update balance: %w", err) } else if n, err := res.RowsAffected(); err != nil { diff --git a/persist/sqlite/consensus_test.go b/persist/sqlite/consensus_test.go index 8462428..1892c57 100644 --- a/persist/sqlite/consensus_test.go +++ b/persist/sqlite/consensus_test.go @@ -13,9 +13,10 @@ import ( "go.uber.org/zap/zaptest" ) -func testV1Network() (*consensus.Network, types.Block) { +func testV1Network(siafundAddr types.Address) (*consensus.Network, types.Block) { // use a modified version of Zen n, genesisBlock := chain.TestnetZen() + genesisBlock.Transactions[0].SiafundOutputs[0].Address = siafundAddr n.InitialTarget = types.BlockID{0xFF} n.HardforkDevAddr.Height = 1 n.HardforkTax.Height = 1 @@ -28,9 +29,10 @@ func testV1Network() (*consensus.Network, types.Block) { return n, genesisBlock } -func testV2Network() (*consensus.Network, types.Block) { +func testV2Network(siafundAddr types.Address) (*consensus.Network, types.Block) { // use a modified version of Zen n, genesisBlock := chain.TestnetZen() + genesisBlock.Transactions[0].SiafundOutputs[0].Address = siafundAddr n.InitialTarget = types.BlockID{0xFF} n.HardforkDevAddr.Height = 1 n.HardforkTax.Height = 1 @@ -75,6 +77,9 @@ func mineV2Block(state consensus.State, txns []types.V2Transaction, minerAddr ty } func TestReorg(t *testing.T) { + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + log := zaptest.NewLogger(t) dir := t.TempDir() db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) @@ -89,7 +94,7 @@ func TestReorg(t *testing.T) { } defer bdb.Close() - network, genesisBlock := testV1Network() + network, genesisBlock := testV1Network(addr) store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) if err != nil { @@ -103,9 +108,6 @@ func TestReorg(t *testing.T) { t.Fatal(err) } - pk := types.GeneratePrivateKey() - addr := types.StandardUnlockHash(pk.PublicKey()) - w, err := db.AddWallet(wallet.Wallet{Name: "test"}) if err != nil { t.Fatal(err) @@ -280,6 +282,9 @@ func TestReorg(t *testing.T) { } func TestEphemeralBalance(t *testing.T) { + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + log := zaptest.NewLogger(t) dir := t.TempDir() db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) @@ -294,7 +299,7 @@ func TestEphemeralBalance(t *testing.T) { } defer bdb.Close() - network, genesisBlock := testV1Network() + network, genesisBlock := testV1Network(addr) store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) if err != nil { @@ -308,9 +313,6 @@ func TestEphemeralBalance(t *testing.T) { t.Fatal(err) } - pk := types.GeneratePrivateKey() - addr := types.StandardUnlockHash(pk.PublicKey()) - w, err := db.AddWallet(wallet.Wallet{Name: "test"}) if err != nil { t.Fatal(err) @@ -475,6 +477,9 @@ func TestEphemeralBalance(t *testing.T) { } func TestV2(t *testing.T) { + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + log := zaptest.NewLogger(t) dir := t.TempDir() db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) @@ -489,7 +494,7 @@ func TestV2(t *testing.T) { } defer bdb.Close() - network, genesisBlock := testV2Network() + network, genesisBlock := testV2Network(addr) store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) if err != nil { @@ -503,9 +508,6 @@ func TestV2(t *testing.T) { t.Fatal(err) } - pk := types.GeneratePrivateKey() - addr := types.StandardUnlockHash(pk.PublicKey()) - w, err := db.AddWallet(wallet.Wallet{Name: "test"}) if err != nil { t.Fatal(err) diff --git a/persist/sqlite/wallet_test.go b/persist/sqlite/wallet_test.go index 08465f0..20db201 100644 --- a/persist/sqlite/wallet_test.go +++ b/persist/sqlite/wallet_test.go @@ -109,6 +109,9 @@ func TestWalletAddresses(t *testing.T) { } func TestResubscribe(t *testing.T) { + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + log := zaptest.NewLogger(t) dir := t.TempDir() db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) @@ -123,7 +126,7 @@ func TestResubscribe(t *testing.T) { } defer bdb.Close() - network, genesisBlock := testV1Network() + network, genesisBlock := testV1Network(types.VoidAddress) store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) if err != nil { @@ -137,9 +140,6 @@ func TestResubscribe(t *testing.T) { t.Fatal(err) } - pk := types.GeneratePrivateKey() - addr := types.StandardUnlockHash(pk.PublicKey()) - w, err := db.AddWallet(wallet.Wallet{Name: "test"}) if err != nil { t.Fatal(err)