Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reorg balance #50

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 16 additions & 12 deletions persist/sqlite/consensus.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,20 @@ func (ut *updateTx) AddSiacoinElements(elements []types.SiacoinElement) error {
}
defer addrStmt.Close()

inserStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`)
insertStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`)
if err != nil {
return fmt.Errorf("failed to prepare insert statement: %w", err)
}
defer inserStmt.Close()
defer insertStmt.Close()

for _, se := range elements {
var addressID int64
err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID)
err = addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID)
if err != nil {
return fmt.Errorf("failed to query address: %w", err)
}

_, err = inserStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID)
_, err = insertStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID)
if err != nil {
return fmt.Errorf("failed to execute statement: %w", err)
}
Expand Down Expand Up @@ -215,20 +215,20 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement) error {
}
defer addrStmt.Close()

inserStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`)
insertStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`)
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer inserStmt.Close()
defer insertStmt.Close()

for _, se := range elements {
var addressID int64
err := addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID)
err = addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID)
if err != nil {
return fmt.Errorf("failed to query address: %w", err)
}

_, err = inserStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID)
_, err = insertStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID)
if err != nil {
return fmt.Errorf("failed to execute statement: %w", err)
}
Expand All @@ -254,7 +254,7 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundOutputID) erro
}

func (ut *updateTx) AddEvents(events []wallet.Event) error {
indexStmt, err := ut.tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`)
indexStmt, err := insertIndexStmt(ut.tx)
if err != nil {
return fmt.Errorf("failed to prepare index statement: %w", err)
}
Expand Down Expand Up @@ -321,10 +321,10 @@ func (ut *updateTx) AddEvents(events []wallet.Event) error {
return nil
}

// RevertEvents reverts the events that were added in the given block.
func (ut *updateTx) RevertEvents(blockID types.BlockID) error {
// RevertEvents reverts any events that were added by the index
func (ut *updateTx) RevertEvents(index types.ChainIndex) error {
var id int64
err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 RETURNING id`, encode(blockID)).Scan(&id)
err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 AND height=$2 RETURNING id`, encode(index.ID), index.Height).Scan(&id)
ChrisSchinnerl marked this conversation as resolved.
Show resolved Hide resolved
if errors.Is(err, sql.ErrNoRows) {
return nil
}
Expand Down Expand Up @@ -399,3 +399,7 @@ func setLastCommittedIndex(tx *txn, index types.ChainIndex) error {
func insertAddressStatement(tx *txn) (*stmt, error) {
return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`)
}

func insertIndexStmt(tx *txn) (*stmt, error) {
return tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`)
ChrisSchinnerl marked this conversation as resolved.
Show resolved Hide resolved
}
137 changes: 136 additions & 1 deletion persist/sqlite/consensus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func TestReorg(t *testing.T) {
}

expectedPayout := cm.TipState().BlockReward()
maturityHeight := cm.TipState().MaturityHeight()
// mine a block sending the payout to the wallet
if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil {
t.Fatal(err)
Expand All @@ -136,6 +137,18 @@ func TestReorg(t *testing.T) {
t.Fatalf("expected payout event, got %v", events[0].Data.EventType())
}

// check that the utxo was created
utxos, err := db.UnspentSiacoinOutputs("test")
if err != nil {
t.Fatal(err)
} else if len(utxos) != 1 {
t.Fatalf("expected 1 output, got %v", len(utxos))
} else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 {
t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value)
} else if utxos[0].MaturityHeight != maturityHeight {
t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight)
}

// mine to trigger a reorg
var blocks []types.Block
state := genesisState
Expand Down Expand Up @@ -163,6 +176,106 @@ func TestReorg(t *testing.T) {
} else if len(events) != 0 {
t.Fatalf("expected 0 events, got %v", len(events))
}

// check that the utxo was removed
utxos, err = db.UnspentSiacoinOutputs("test")
if err != nil {
t.Fatal(err)
} else if len(utxos) != 0 {
t.Fatalf("expected 0 outputs, got %v", len(utxos))
}

// mine a new payout
expectedPayout = cm.TipState().BlockReward()
maturityHeight = cm.TipState().MaturityHeight()
if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil {
t.Fatal(err)
}

// check that the payout was received
balance, err = db.AddressBalance(addr)
if err != nil {
t.Fatal(err)
} else if !balance.ImmatureSiacoins.Equals(expectedPayout) {
t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins)
}

// check that a payout event was recorded
events, err = db.WalletEvents("test", 0, 100)
if err != nil {
t.Fatal(err)
} else if len(events) != 1 {
t.Fatalf("expected 1 event, got %v", len(events))
} else if events[0].Data.EventType() != wallet.EventTypeMinerPayout {
t.Fatalf("expected payout event, got %v", events[0].Data.EventType())
}

// check that the utxo was created
utxos, err = db.UnspentSiacoinOutputs("test")
if err != nil {
t.Fatal(err)
} else if len(utxos) != 1 {
t.Fatalf("expected 1 output, got %v", len(utxos))
} else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 {
t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value)
} else if utxos[0].MaturityHeight != maturityHeight {
t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight)
}

// mine until the payout matures
var prevState consensus.State
for i := cm.TipState().Index.Height; i < maturityHeight+1; i++ {
if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil {
t.Fatal(err)
}
if i == maturityHeight-5 {
prevState = cm.TipState()
}
}

// check that the balance was updated
balance, err = db.AddressBalance(addr)
if err != nil {
t.Fatal(err)
} else if !balance.ImmatureSiacoins.IsZero() {
t.Fatalf("expected %v, got %v", types.ZeroCurrency, balance.ImmatureSiacoins)
} else if !balance.Siacoins.Equals(expectedPayout) {
t.Fatalf("expected %v, got %v", expectedPayout, balance.Siacoins)
}

// reorg the last few blocks to re-mature the payout
blocks = nil
state = prevState
for i := 0; i < 10; i++ {
blocks = append(blocks, mineBlock(state, nil, types.VoidAddress))
state.Index.ID = blocks[len(blocks)-1].ID()
state.Index.Height = state.Index.Height + 1
}
if err := cm.AddBlocks(blocks); err != nil {
t.Fatal(err)
}

// check that the balance is correct
balance, err = db.AddressBalance(addr)
if err != nil {
t.Fatal(err)
} else if !balance.ImmatureSiacoins.IsZero() {
t.Fatalf("expected %v, got %v", types.ZeroCurrency, balance.ImmatureSiacoins)
} else if !balance.Siacoins.Equals(expectedPayout) {
t.Fatalf("expected %v, got %v", expectedPayout, balance.Siacoins)
}

// check that only the single utxo still exists
utxos, err = db.UnspentSiacoinOutputs("test")
if err != nil {
t.Fatal(err)
} else if len(utxos) != 1 {
t.Fatalf("expected 1 output, got %v", len(utxos))
} else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 {
t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value)
} else if utxos[0].MaturityHeight != maturityHeight {
t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight)
}
}

func TestEphemeralBalance(t *testing.T) {
Expand Down Expand Up @@ -205,8 +318,10 @@ func TestEphemeralBalance(t *testing.T) {

expectedPayout := cm.TipState().BlockReward()
maturityHeight := cm.TipState().MaturityHeight() + 1
block := mineBlock(cm.TipState(), nil, addr)
minerPayoutID := block.ID().MinerOutputID(0)
// mine a block sending the payout to the wallet
if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil {
if err := cm.AddBlocks([]types.Block{block}); err != nil {
t.Fatal(err)
}

Expand All @@ -226,6 +341,8 @@ func TestEphemeralBalance(t *testing.T) {
t.Fatalf("expected 1 event, got %v", len(events))
} else if events[0].Data.EventType() != wallet.EventTypeMinerPayout {
t.Fatalf("expected payout event, got %v", events[0].Data.EventType())
} else if events[0].ID != types.Hash256(minerPayoutID) {
t.Fatalf("expected %v, got %v", minerPayoutID, events[0].ID)
}

// mine until the payout matures
Expand Down Expand Up @@ -306,6 +423,24 @@ func TestEphemeralBalance(t *testing.T) {
t.Fatalf("expected 0, got %v", balance.Siacoins)
}

// check that both transactions were added
events, err = db.WalletEvents("test", 0, 100)
if err != nil {
t.Fatal(err)
} else if len(events) != 3 { // 1 payout, 2 transactions
t.Fatalf("expected 3 events, got %v", len(events))
} else if events[2].Data.EventType() != wallet.EventTypeMinerPayout {
t.Fatalf("expected miner payout event, got %v", events[2].Data.EventType())
} else if events[1].Data.EventType() != wallet.EventTypeTransaction {
t.Fatalf("expected transaction event, got %v", events[1].Data.EventType())
} else if events[0].Data.EventType() != wallet.EventTypeTransaction {
t.Fatalf("expected transaction event, got %v", events[0].Data.EventType())
} else if events[1].ID != types.Hash256(parentTxn.ID()) { // parent txn first
t.Fatalf("expected %v, got %v", parentTxn.ID(), events[1].ID)
} else if events[0].ID != types.Hash256(txn.ID()) { // child txn second
t.Fatalf("expected %v, got %v", txn.ID(), events[0].ID)
}

// trigger a reorg
var blocks []types.Block
state := revertState
Expand Down
3 changes: 2 additions & 1 deletion persist/sqlite/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ CREATE TABLE siacoin_elements (
address_id INTEGER NOT NULL REFERENCES sia_addresses (id)
);
CREATE INDEX siacoin_elements_address_id ON siacoin_elements (address_id);
CREATE INDEX siacoin_elements_maturity_height ON siacoin_elements (maturity_height);

CREATE TABLE siafund_elements (
id BLOB PRIMARY KEY,
Expand Down Expand Up @@ -48,9 +49,9 @@ CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id);
CREATE TABLE events (
id INTEGER PRIMARY KEY,
event_id BLOB NOT NULL,
index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE,
maturity_height INTEGER NOT NULL,
date_created INTEGER NOT NULL,
index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE,
event_type TEXT NOT NULL,
event_data TEXT NOT NULL
);
Expand Down
2 changes: 1 addition & 1 deletion persist/sqlite/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func getWalletEvents(tx *txn, walletID string, offset, limit int) (events []wall
FROM events ev
INNER JOIN chain_indices ci ON (ev.index_id = ci.id)
WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1))
ORDER BY ev.maturity_height DESC
ORDER BY ev.maturity_height DESC, ev.id DESC
LIMIT $2 OFFSET $3`

rows, err := tx.Query(query, walletID, limit, offset)
Expand Down