From 99b25e063ed782aa737ee02674fe4d88eedae7d0 Mon Sep 17 00:00:00 2001 From: Jon Griffiths Date: Fri, 17 Dec 2021 09:50:25 +1300 Subject: [PATCH] bip32: Allow public path derivation with hardened child elements If we start with an extended private key, derive privately until the last hardened child element, then continue publicly. If hardened derivation is required and the starting key is neutered, then fail. This change allows generating all pubkeys from master extended private keys without exposing the derived private key outside of the library (Previously you would have to derive privately and ignore the resulting private key). --- include/wally_bip32.h | 6 ++++ src/bip32.c | 30 +++++++++++++++---- src/test/test_bip32.py | 67 +++++++++++++++++++++++++++++------------- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/include/wally_bip32.h b/include/wally_bip32.h index ebbfaaa2c..6fc719280 100644 --- a/include/wally_bip32.h +++ b/include/wally_bip32.h @@ -240,6 +240,9 @@ WALLY_CORE_API int bip32_key_from_parent_alloc( * :param child_path_len: The number of child numbers in ``child_path``. * :param flags: ``BIP32_FLAG_`` Flags indicating the type of derivation wanted. * :param output: Destination for the resulting child extended key. + * + * .. note:: If ``child_path`` contains hardened child numbers, then ``hdkey`` + * must be an extended private key or this function will fail. */ WALLY_CORE_API int bip32_key_from_parent_path( const struct ext_key *hdkey, @@ -269,6 +272,9 @@ WALLY_CORE_API int bip32_key_from_parent_path_alloc( * :param child_num: The child number to use if ``path_str`` contains a ``*`` wildcard. * :param flags: ``BIP32_FLAG_`` Flags indicating the type of derivation wanted. * :param output: Destination for the resulting child extended key. + * + * .. note:: If ``child_path`` contains hardened child numbers, then ``hdkey`` + * must be an extended private key or this function will fail. */ int bip32_key_from_parent_path_str( const struct ext_key *hdkey, diff --git a/src/bip32.c b/src/bip32.c index 17811745f..360508bf7 100644 --- a/src/bip32.c +++ b/src/bip32.c @@ -617,10 +617,8 @@ int bip32_key_from_parent_path(const struct ext_key *hdkey, const uint32_t *child_path, size_t child_path_len, uint32_t flags, struct ext_key *key_out) { - /* Optimization: We can skip hash calculations for internal nodes */ - uint32_t derivation_flags = flags | BIP32_FLAG_SKIP_HASH; struct ext_key tmp[2]; - size_t i, tmp_idx = 0; + size_t i, tmp_idx = 0, private_until = 0; int ret; if (flags & ~BIP32_ALL_DEFINED_FLAGS) @@ -629,14 +627,34 @@ int bip32_key_from_parent_path(const struct ext_key *hdkey, if (!hdkey || !child_path || !child_path_len || child_path_len > BIP32_PATH_MAX_LEN || !key_out) return WALLY_EINVAL; + if (flags & BIP32_FLAG_KEY_PUBLIC) { + /* Public derivation: Check for intermediate hardened keys */ + for (i = 0; i < child_path_len; ++i) { + if (child_is_hardened(child_path[i])) + private_until = i + 1; /* Derive privately until this index */ + } + if (private_until && !key_is_private(hdkey)) + return WALLY_EINVAL; /* Unsupported derivation */ + } + for (i = 0; i < child_path_len; ++i) { struct ext_key *derived = &tmp[tmp_idx]; + uint32_t derivation_flags = flags; + + if (private_until && i < private_until - 1) { + /* Derive privately until we reach the last hardened child */ + derivation_flags &= ~BIP32_FLAG_KEY_PUBLIC; + derivation_flags |= BIP32_FLAG_KEY_PRIVATE; + } + if (i + 2 < child_path_len) + derivation_flags |= BIP32_FLAG_SKIP_HASH; /* Skip hash for internal keys */ + #ifdef BUILD_ELEMENTS if (flags & BIP32_FLAG_KEY_TWEAK_SUM) - memcpy(derived->pub_key_tweak_sum, hdkey->pub_key_tweak_sum, sizeof(hdkey->pub_key_tweak_sum)); + memcpy(derived->pub_key_tweak_sum, + hdkey->pub_key_tweak_sum, sizeof(hdkey->pub_key_tweak_sum)); #endif /* BUILD_ELEMENTS */ - if (i + 2 >= child_path_len) - derivation_flags = flags; /* Use callers flags for the final derivations */ + ret = bip32_key_from_parent(hdkey, child_path[i], derivation_flags, derived); if (ret != WALLY_OK) break; diff --git a/src/test/test_bip32.py b/src/test/test_bip32.py index 4b3e7a4f5..05c4ccd76 100755 --- a/src/test/test_bip32.py +++ b/src/test/test_bip32.py @@ -12,6 +12,7 @@ ALL_DEFINED_FLAGS = FLAG_KEY_PRIVATE | FLAG_KEY_PUBLIC | FLAG_SKIP_HASH BIP32_SERIALIZED_LEN = 78 BIP32_FLAG_SKIP_HASH = 0x2 +EMPTY_PRIV_KEY = utf8('01' + ('00') * 32) # These vectors are expressed in binary rather than base 58. The spec base 58 # representation just obfuscates the data we are validating. For example, the @@ -167,7 +168,8 @@ def path_to_c(self, path): return c_path def str_to_path(self, path_str, wildcard): - path = path_str.replace('*h', str(2147483648 + wildcard)) + path = path_str.replace('1h', '2147483649') + path = path.replace('*h', str(2147483648 + wildcard)) path = path.replace('*', str(wildcard)).replace('m/', '').split('/') return [int(v) for v in path] @@ -354,6 +356,10 @@ def create_master_pub_priv(self): # Derive the same child public and private keys from master priv = self.derive_key(master, 1, FLAG_KEY_PRIVATE) pub = self.derive_key(master, 1, FLAG_KEY_PUBLIC) + # Verify both derviation types resulted in the same pubkey + self.assertEqual(h(priv.pub_key), h(pub.pub_key)) + # Verify that the public derivation does not contain a private key + self.assertEqual(h(pub.priv_key), EMPTY_PRIV_KEY) return master, pub, priv def test_public_derivation_identities(self): @@ -457,33 +463,54 @@ def get_paths(path): ret = bip32_key_from_parent_path_str_n(m, path, len(path), wildcard, flags, key_out) self.assertEqual(ret, WALLY_EINVAL) - # After stripping the parents' private key, hardened path derivation fails + # Hardened derivation is possible from a full key + fn = lambda f: bip32_key_from_parent_path_str_n(m, 'm/1h/1h', 7, 0, f, key_out) + self.assertEqual(fn(FLAG_KEY_PRIVATE), WALLY_OK) + self.assertEqual(fn(FLAG_KEY_PUBLIC), WALLY_OK) + # After stripping the parents' private key, hardened derivation fails self.assertEqual(bip32_key_strip_private_key(m), WALLY_OK) - ret = bip32_key_from_parent_path_str_n(m, 'm/1h', len('m/1h'), 0, FLAG_KEY_PUBLIC, key_out) - self.assertEqual(ret, WALLY_EINVAL) + self.assertEqual(fn(FLAG_KEY_PRIVATE), WALLY_EINVAL) + self.assertEqual(fn(FLAG_KEY_PUBLIC), WALLY_EINVAL) def test_wildcard(self): master, pub, priv = self.create_master_pub_priv() m = byref(master) - flags = FLAG_STR_WILDCARD | FLAG_KEY_PRIVATE key_out, int_key_out = ext_key(), ext_key() - cases = [('m/1/*', 55), - ('m/*', 55), - ('m/1/*/1', 55), - ('m/1/*h', 55), - ('m/*h', 55), - ('m/1/*h/1', 55)] + cases = [('m/1/*', 55), + ('m/*', 55), + ('m/1/*/1', 55), + ('m/1h/*', 55), + ('m/1h/*/1', 55), + ('m/1h/*/1h', 55), + ('m/1/*h', 55), + ('m/*h', 55), + ('m/1/*h/1', 55), + ('m/1/*h/1h', 55)] for path, wildcard in cases: - ret = bip32_key_from_parent_path_str(m, path, wildcard, flags, byref(key_out)) - self.assertEqual(ret, WALLY_OK) - - # Verify the result matches a key derived using the non-string version - path = self.str_to_path(path, wildcard) - c_path = self.path_to_c(path) - ret = bip32_key_from_parent_path(m, c_path, len(path), flags, byref(int_key_out)) - self.assertEqual(ret, WALLY_OK) - self.compare_keys(key_out, int_key_out, flags) + pub_key_hex = '' + for flag in [FLAG_KEY_PRIVATE, FLAG_KEY_PUBLIC]: + flags = flag | FLAG_STR_WILDCARD | FLAG_SKIP_HASH + ret = bip32_key_from_parent_path_str(m, path, wildcard, flags, byref(key_out)) + self.assertEqual(ret, WALLY_OK) + if flag == FLAG_KEY_PRIVATE: + pub_key_hex = h(key_out.pub_key) + else: + # Check that public derivation computed the same pubkey + # that private derivation did. + self.assertEqual(pub_key_hex, h(key_out.pub_key)) + # Check that public derivation did not return a private key + self.assertEqual(h(key_out.priv_key), EMPTY_PRIV_KEY) + # Verify the result matches a key derived using the non-string version + int_path = self.str_to_path(path, wildcard) + path_len = len(int_path) + c_path = self.path_to_c(int_path) + ret = bip32_key_from_parent_path(m, c_path, path_len, flags, byref(int_key_out)) + self.assertEqual(ret, WALLY_OK) + self.compare_keys(key_out, int_key_out, flags) + self.assertEqual(pub_key_hex, h(key_out.pub_key)) + if flag != FLAG_KEY_PRIVATE: + self.assertEqual(h(key_out.priv_key), EMPTY_PRIV_KEY) def test_free_invalid(self): self.assertEqual(WALLY_EINVAL, bip32_key_free(None))