Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize key derivation in the wallet #17991

Merged
merged 1 commit into from
May 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 47 additions & 32 deletions chia/wallet/wallet_state_manager.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -416,47 +416,63 @@ async def create_more_puzzle_hashes(
self.log.debug(f"Requested to generate puzzle hashes to at least index {unused}") self.log.debug(f"Requested to generate puzzle hashes to at least index {unused}")
start_t = time.time() start_t = time.time()
to_generate = num_additional_phs if num_additional_phs is not None else self.initial_num_public_keys to_generate = num_additional_phs if num_additional_phs is not None else self.initial_num_public_keys
new_paths: bool = False


# iterate all wallets that need derived keys and establish the start
# index for all of them
start_index: int = 0
start_index_by_wallet: Dict[uint32, int] = {}
last_index = unused + to_generate
for wallet_id in targets: for wallet_id in targets:
target_wallet = self.wallets[wallet_id] target_wallet = self.wallets[wallet_id]
if not target_wallet.require_derivation_paths(): if not target_wallet.require_derivation_paths():
self.log.debug("Skipping wallet %s as no derivation paths required", wallet_id) self.log.debug("Skipping wallet %s as no derivation paths required", wallet_id)
continue continue
if from_zero:
start_index_by_wallet[wallet_id] = 0
continue
last: Optional[uint32] = await self.puzzle_store.get_last_derivation_path_for_wallet(wallet_id) last: Optional[uint32] = await self.puzzle_store.get_last_derivation_path_for_wallet(wallet_id)
self.log.debug(
"Fetched last record for wallet %r: %s (from_zero=%r, unused=%r)", wallet_id, last, from_zero, unused
)
start_index = 0
derivation_paths: List[DerivationRecord] = []

if last is not None: if last is not None:
start_index = last + 1 if last + 1 >= last_index:

# If the key was replaced (from_zero=True), we should generate the puzzle hashes for the new key
if from_zero:
start_index = 0
last_index = unused + to_generate
if start_index >= last_index:
self.log.debug(f"Nothing to create for for wallet_id: {wallet_id}, index: {start_index}") self.log.debug(f"Nothing to create for for wallet_id: {wallet_id}, index: {start_index}")
continue
start_index = min(start_index, last + 1)
start_index_by_wallet[wallet_id] = last + 1
else: else:
creating_msg = ( start_index_by_wallet[wallet_id] = 0
f"Creating puzzle hashes from {start_index} to {last_index - 1} for wallet_id: {wallet_id}"
) if len(start_index_by_wallet) == 0:
self.log.info(f"Start: {creating_msg}") return
arvidn marked this conversation as resolved.
Show resolved Hide resolved

# now derive the keysfrom start_index to last_index
# these maps derivation index to public key
hardened_keys: Dict[int, G1Element] = {}
unhardened_keys: Dict[int, G1Element] = {}

if self.private_key is not None: if self.private_key is not None:
# Hardened
intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key) intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key)
for index in range(start_index, last_index):
hardened_keys[index] = _derive_path(intermediate_sk, [index]).get_g1()

# Unhardened
intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.root_pubkey) intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.root_pubkey)
for index in range(start_index, last_index): for index in range(start_index, last_index):
if target_wallet.type() == WalletType.POOLING_WALLET: unhardened_keys[index] = _derive_pk_unhardened(intermediate_pk_un, [index])
arvidn marked this conversation as resolved.
Show resolved Hide resolved
continue


if self.private_key is not None: for wallet_id, start_index in start_index_by_wallet.items():
target_wallet = self.wallets[wallet_id]
assert target_wallet.type() != WalletType.POOLING_WALLET
arvidn marked this conversation as resolved.
Show resolved Hide resolved
assert start_index < last_index

derivation_paths: List[DerivationRecord] = []
creating_msg = f"Creating puzzle hashes from {start_index} to {last_index - 1} for wallet_id: {wallet_id}"
self.log.info(f"Start: {creating_msg}")
for index in range(start_index, last_index):
pubkey: Optional[G1Element] = hardened_keys.get(index)
if pubkey is not None:
# Hardened # Hardened
pubkey: G1Element = _derive_path(intermediate_sk, [index]).get_g1()
puzzlehash: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey) puzzlehash: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey)
self.log.debug(f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash.hex()}") self.log.debug(f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash.hex()}")
new_paths = True
derivation_paths.append( derivation_paths.append(
DerivationRecord( DerivationRecord(
uint32(index), uint32(index),
Expand All @@ -468,34 +484,33 @@ async def create_more_puzzle_hashes(
) )
) )
# Unhardened # Unhardened
pubkey_unhardened: G1Element = _derive_pk_unhardened(intermediate_pk_un, [index]) pubkey = unhardened_keys.get(index)
puzzlehash_unhardened: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey_unhardened) assert pubkey is not None
puzzlehash_unhardened: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey)
self.log.debug( self.log.debug(
f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash_unhardened.hex()}" f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash_unhardened.hex()}"
) )
# We await sleep here to allow an asyncio context switch (since the other parts of this loop do
# not have await and therefore block). This can prevent networking layer from responding to ping.
await asyncio.sleep(0)
Quexington marked this conversation as resolved.
Show resolved Hide resolved
derivation_paths.append( derivation_paths.append(
DerivationRecord( DerivationRecord(
uint32(index), uint32(index),
puzzlehash_unhardened, puzzlehash_unhardened,
pubkey_unhardened, pubkey,
target_wallet.type(), target_wallet.type(),
uint32(target_wallet.id()), uint32(target_wallet.id()),
False, False,
) )
) )
self.log.info(f"Done: {creating_msg} Time: {time.time() - start_t} seconds") self.log.info(f"Done: {creating_msg} Time: {time.time() - start_t} seconds")
await self.puzzle_store.add_derivation_paths(derivation_paths)
if len(derivation_paths) > 0: if len(derivation_paths) > 0:
await self.puzzle_store.add_derivation_paths(derivation_paths)
if wallet_id == self.main_wallet.id(): if wallet_id == self.main_wallet.id():
await self.wallet_node.new_peak_queue.subscribe_to_puzzle_hashes( await self.wallet_node.new_peak_queue.subscribe_to_puzzle_hashes(
[record.puzzle_hash for record in derivation_paths] [record.puzzle_hash for record in derivation_paths]
) )
self.state_changed("new_derivation_index", data_object={"index": derivation_paths[-1].index}) if len(unhardened_keys) > 0:
self.state_changed("new_derivation_index", data_object={"index": last_index - 1})
# By default, we'll mark previously generated unused puzzle hashes as used if we have new paths # By default, we'll mark previously generated unused puzzle hashes as used if we have new paths
if mark_existing_as_used and unused > 0 and new_paths: if mark_existing_as_used and unused > 0 and len(unhardened_keys) > 0:
arvidn marked this conversation as resolved.
Show resolved Hide resolved
self.log.info(f"Updating last used derivation index: {unused - 1}") self.log.info(f"Updating last used derivation index: {unused - 1}")
await self.puzzle_store.set_used_up_to(uint32(unused - 1)) await self.puzzle_store.set_used_up_to(uint32(unused - 1))


Expand Down