From 9b95c3547b001e7881dba0c3c4d0825770503e78 Mon Sep 17 00:00:00 2001 From: arvidn Date: Wed, 8 May 2024 15:44:12 +0200 Subject: [PATCH] optimize key derivation in the wallet. instead of deriving the same keys for every wallet, derive the keys once and re-use them for the derived puzzle hashes per wallet --- chia/wallet/wallet_state_manager.py | 129 ++++++++++++++++------------ 1 file changed, 72 insertions(+), 57 deletions(-) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index a0b28741da70..f8b820513249 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -416,86 +416,101 @@ async def create_more_puzzle_hashes( self.log.debug(f"Requested to generate puzzle hashes to at least index {unused}") start_t = time.time() to_generate = num_additional_phs if num_additional_phs is not None else self.initial_num_public_keys - new_paths: bool = False + # iterate all wallets that need derived keys and establish the start + # index for all of them + start_index: int = 0 + start_index_by_wallet: Dict[uint32, int] = {} + last_index = unused + to_generate for wallet_id in targets: target_wallet = self.wallets[wallet_id] if not target_wallet.require_derivation_paths(): self.log.debug("Skipping wallet %s as no derivation paths required", wallet_id) continue + if from_zero: + start_index_by_wallet[wallet_id] = 0 + continue last: Optional[uint32] = await self.puzzle_store.get_last_derivation_path_for_wallet(wallet_id) - self.log.debug( - "Fetched last record for wallet %r: %s (from_zero=%r, unused=%r)", wallet_id, last, from_zero, unused - ) - start_index = 0 - derivation_paths: List[DerivationRecord] = [] - if last is not None: - start_index = last + 1 - - # If the key was replaced (from_zero=True), we should generate the puzzle hashes for the new key - if from_zero: - start_index = 0 - last_index = unused + to_generate - if start_index >= last_index: - self.log.debug(f"Nothing to create for for wallet_id: {wallet_id}, index: {start_index}") + if last + 1 >= last_index: + self.log.debug(f"Nothing to create for for wallet_id: {wallet_id}, index: {start_index}") + continue + start_index = min(start_index, last + 1) + start_index_by_wallet[wallet_id] = last + 1 else: - creating_msg = ( - f"Creating puzzle hashes from {start_index} to {last_index - 1} for wallet_id: {wallet_id}" - ) - self.log.info(f"Start: {creating_msg}") - if self.private_key is not None: - intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key) - intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.root_pubkey) - for index in range(start_index, last_index): - if target_wallet.type() == WalletType.POOLING_WALLET: - continue + start_index_by_wallet[wallet_id] = 0 - if self.private_key is not None: - # Hardened - pubkey: G1Element = _derive_path(intermediate_sk, [index]).get_g1() - puzzlehash: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey) - self.log.debug(f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash.hex()}") - new_paths = True - derivation_paths.append( - DerivationRecord( - uint32(index), - puzzlehash, - pubkey, - target_wallet.type(), - uint32(target_wallet.id()), - True, - ) - ) - # Unhardened - pubkey_unhardened: G1Element = _derive_pk_unhardened(intermediate_pk_un, [index]) - puzzlehash_unhardened: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey_unhardened) - self.log.debug( - f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash_unhardened.hex()}" - ) - # We await sleep here to allow an asyncio context switch (since the other parts of this loop do - # not have await and therefore block). This can prevent networking layer from responding to ping. - await asyncio.sleep(0) + if len(start_index_by_wallet) == 0: + return + + # now derive the keysfrom start_index to last_index + # these maps derivation index to public key + hardened_keys: Dict[int, G1Element] = {} + unhardened_keys: Dict[int, G1Element] = {} + + if self.private_key is not None: + # Hardened + intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key) + for index in range(start_index, last_index): + hardened_keys[index] = _derive_path(intermediate_sk, [index]).get_g1() + + # Unhardened + intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.root_pubkey) + for index in range(start_index, last_index): + unhardened_keys[index] = _derive_pk_unhardened(intermediate_pk_un, [index]) + + for wallet_id, start_index in start_index_by_wallet.items(): + target_wallet = self.wallets[wallet_id] + assert target_wallet.type() != WalletType.POOLING_WALLET + assert start_index < last_index + + derivation_paths: List[DerivationRecord] = [] + creating_msg = f"Creating puzzle hashes from {start_index} to {last_index - 1} for wallet_id: {wallet_id}" + self.log.info(f"Start: {creating_msg}") + for index in range(start_index, last_index): + pubkey: Optional[G1Element] = hardened_keys.get(index) + if pubkey is not None: + # Hardened + puzzlehash: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey) + self.log.debug(f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash.hex()}") derivation_paths.append( DerivationRecord( uint32(index), - puzzlehash_unhardened, - pubkey_unhardened, + puzzlehash, + pubkey, target_wallet.type(), uint32(target_wallet.id()), - False, + True, ) ) - self.log.info(f"Done: {creating_msg} Time: {time.time() - start_t} seconds") - await self.puzzle_store.add_derivation_paths(derivation_paths) + # Unhardened + pubkey = unhardened_keys.get(index) + assert pubkey is not None + puzzlehash_unhardened: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey) + self.log.debug( + f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash_unhardened.hex()}" + ) + derivation_paths.append( + DerivationRecord( + uint32(index), + puzzlehash_unhardened, + pubkey, + target_wallet.type(), + uint32(target_wallet.id()), + False, + ) + ) + self.log.info(f"Done: {creating_msg} Time: {time.time() - start_t} seconds") if len(derivation_paths) > 0: + await self.puzzle_store.add_derivation_paths(derivation_paths) if wallet_id == self.main_wallet.id(): await self.wallet_node.new_peak_queue.subscribe_to_puzzle_hashes( [record.puzzle_hash for record in derivation_paths] ) - self.state_changed("new_derivation_index", data_object={"index": derivation_paths[-1].index}) + if len(unhardened_keys) > 0: + self.state_changed("new_derivation_index", data_object={"index": last_index - 1}) # By default, we'll mark previously generated unused puzzle hashes as used if we have new paths - if mark_existing_as_used and unused > 0 and new_paths: + if mark_existing_as_used and unused > 0 and len(unhardened_keys) > 0: self.log.info(f"Updating last used derivation index: {unused - 1}") await self.puzzle_store.set_used_up_to(uint32(unused - 1))