Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wallet: Implement WalletPuzzleStore.delete_wallet #15125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions chia/wallet/wallet_puzzle_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,24 @@ async def get_unused_derivation_path(self) -> Optional[uint32]:
return uint32(row[0])

return None

async def delete_wallet(self, wallet_id: uint32) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
# First fetch all puzzle hashes since we need them to drop them from the cache
rows = await conn.execute_fetchall(
"SELECT puzzle_hash FROM derivation_paths WHERE wallet_id=?", (wallet_id,)
)
cursor = await conn.execute("DELETE FROM derivation_paths WHERE wallet_id=?;", (wallet_id,))
await cursor.close()
# Clear caches
puzzle_hashes = set(bytes32.fromhex(row[0]) for row in rows)
for puzzle_hash in puzzle_hashes:
try:
self.wallet_identifier_cache.remove(puzzle_hash)
except KeyError:
pass
try:
self.last_wallet_derivation_index.pop(wallet_id)
except KeyError:
pass
self.last_derivation_index = None
62 changes: 62 additions & 0 deletions tests/wallet/test_puzzle_store.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
from __future__ import annotations

from dataclasses import dataclass, field
from secrets import token_bytes
from typing import Dict, List

import pytest
from blspy import AugSchemeMPL

from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint32
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.util.wallet_types import WalletIdentifier, WalletType
from chia.wallet.wallet_puzzle_store import WalletPuzzleStore
from tests.util.db_connection import DBConnection


def get_dummy_record(index: int, wallet_id: int) -> DerivationRecord:
return DerivationRecord(
uint32(index),
bytes32(token_bytes(32)),
AugSchemeMPL.key_gen(token_bytes(32)).get_g1(),
WalletType.STANDARD_WALLET,
uint32(wallet_id),
False,
)


@dataclass
class DummyDerivationRecords:
index_per_wallet: Dict[int, int] = field(default_factory=dict)
records_per_wallet: Dict[int, List[DerivationRecord]] = field(default_factory=dict)

def generate(self, wallet_id: int, count: int) -> None:
records = self.records_per_wallet.setdefault(wallet_id, [])
self.index_per_wallet.setdefault(wallet_id, 0)
for _ in range(count):
records.append(get_dummy_record(self.index_per_wallet[wallet_id], wallet_id))
self.index_per_wallet[wallet_id] += 1


class TestPuzzleStore:
@pytest.mark.asyncio
async def test_puzzle_store(self):
Expand Down Expand Up @@ -67,3 +94,38 @@ async def test_puzzle_store(self):
await db.set_used_up_to(249)

assert await db.get_unused_derivation_path() == 250


@pytest.mark.asyncio
async def test_delete_wallet() -> None:
dummy_records = DummyDerivationRecords()
for i in range(5):
dummy_records.generate(i, i * 5)
async with DBConnection(1) as wrapper:
db = await WalletPuzzleStore.create(wrapper)
# Add the records per wallet and verify them
for wallet_id, records in dummy_records.records_per_wallet.items():
await db.add_derivation_paths(records)
for record in records:
assert await db.get_derivation_record(record.index, record.wallet_id, record.hardened) == record
assert await db.get_wallet_identifier_for_puzzle_hash(record.puzzle_hash) == WalletIdentifier(
record.wallet_id, record.wallet_type
)
# Remove one wallet after the other and verify before and after each
for wallet_id, records in dummy_records.records_per_wallet.items():
# Assert the existence again here to make sure the previous removals did not affect other wallet_ids
for record in records:
assert await db.get_derivation_record(record.index, record.wallet_id, record.hardened) == record
assert await db.get_wallet_identifier_for_puzzle_hash(record.puzzle_hash) == WalletIdentifier(
record.wallet_id, record.wallet_type
)
assert await db.get_last_derivation_path_for_wallet(wallet_id) is not None
# Remove the wallet_id and make sure its removed fully
await db.delete_wallet(wallet_id)
for record in records:
assert await db.get_derivation_record(record.index, record.wallet_id, record.hardened) is None
assert await db.get_wallet_identifier_for_puzzle_hash(record.puzzle_hash) is None
assert await db.get_last_derivation_path_for_wallet(wallet_id) is None
assert await db.get_last_derivation_path() is None
assert db.last_derivation_index is None
assert len(db.last_wallet_derivation_index) == 0