Skip to content

Commit

Permalink
refactor GetNodeKeys() to separate out GetNode() + getDerivationIndex…
Browse files Browse the repository at this point in the history
…esFromPath() + add tests for deriving indexes from derivation path
  • Loading branch information
ashfame committed Mar 2, 2021
1 parent 11464a0 commit f057f38
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 23 deletions.
105 changes: 82 additions & 23 deletions wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,47 +233,106 @@ func (w *Wallet) ExportSeed() (seed string, err error) {
return fmt.Sprintf("%x", w.seed), nil
}

// GetNodeKeys return private and public key for the derivation path specified
func (w *Wallet) GetNodeKeys(path string) (xprv string, xpub string, err error) {
// ensure wallet is unlocked before trying to work with it
if !w.isInitialized {
err = errWalletNotInitialized
return
}

node, err := hdkeychain.NewMaster(w.seed, w.getChaincfgParams())
if err != nil {
return
}

// make sure path is correctly specified
// also remove trailing slash, if present
// Function accepts the derivation path in string and returns an array of indexes to derive nodes
func (w *Wallet) getDerivationIndexesFromPath(path string) (d []uint32, err error) {
// remove trailing slash, if present
// let's make sure path is correctly specified
pathArr := strings.Split(strings.TrimSuffix(path, "/"), "/")
for index, pathPart := range pathArr {
if index == 0 {
if pathPart != "m" {
// path must start with m or M
if pathPart != "m" && pathPart != "M" {
err = errIncorrectDerivationPath
return
}

continue
}

trimmed := strings.TrimSuffix(pathPart, "H")
t, _ := strconv.ParseUint(trimmed, 10, 32)
deriveIndex := uint32(t)
// suffix was actually present, so hardened derivation
if trimmed != pathPart {
// do we need to do a hardened derivation?
// represented by 1H, 1h, 1'
hardenedDerivation := false
var trimmed string
for _, hardenedMarker := range []string{"H", "h", "'"} {
trimmed = strings.TrimSuffix(pathPart, hardenedMarker)
if hardenedDerivation {
// if already a single hardened marker has been found
if trimmed != pathPart {
// still another hardened marker found
return nil, errIncorrectDerivationPath
}
} else {
// hardened marker was found for the first time in this loop
if trimmed != pathPart {
// hardened suffix was actually present
hardenedDerivation = true
}
// overwrite trimmed over pathPart so that in next iteraton,
// we can catch if further hardened markers are removed in cases like Hh,
// which will be treated as invalid
pathPart = trimmed
}
}

var deriveIndex uint32
if _, err := strconv.Atoi(trimmed); err == nil {
// looks like a number
t, _ := strconv.ParseUint(trimmed, 10, 32)
deriveIndex = uint32(t)
} else {
// invalid character, an alphabet was encountered
return nil, errIncorrectDerivationPath
}

if hardenedDerivation {
if deriveIndex > hdkeychain.HardenedKeyStart {
return nil, errIncorrectDerivationPath
}
deriveIndex += uint32(hdkeychain.HardenedKeyStart)
}

node, err = node.Derive(deriveIndex)
d = append(d, deriveIndex)
}

return d, nil
}

// GetNode returns the node for a particular path
func (w *Wallet) GetNode(path string) (node *hdkeychain.ExtendedKey, err error) {
// ensure wallet is unlocked before trying to work with it
if !w.isInitialized {
err = errWalletNotInitialized
return
}

node, err = hdkeychain.NewMaster(w.seed, w.getChaincfgParams())
if err != nil {
return
}

indexes, err := w.getDerivationIndexesFromPath(path)
if err != nil {
return
}

for _, d := range indexes {
node, err = node.Derive(d)
if err != nil {
log.Println(err)
return "", "", errIncorrectDerivationPath
return nil, errIncorrectDerivationPath
}
}

return node, nil
}

// GetNodeKeys return private and public key for the derivation path specified
func (w *Wallet) GetNodeKeys(path string) (xprv string, xpub string, err error) {
node, err := w.GetNode(path)
if err != nil {
return
}

// get private key
xprv = node.String()
// get public key
Expand Down
141 changes: 141 additions & 0 deletions wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/hex"
"flag"
"testing"

"github.com/btcsuite/btcutil/hdkeychain"
)

var debug bool
Expand All @@ -18,6 +20,145 @@ func init() {
flag.BoolVar(&debug, "debug", false, "specify if you want debug logs in standard output")
}

// TestDerivationIndexesFromPath tests the conversion of derivation path string into an array of indexes based on which child nodes will be repeatedly derived
func TestDerivationIndexesFromPath(t *testing.T) {
tables := []struct {
path string
dI []uint32
err error
}{
{
path: "",
dI: []uint32{},
err: errIncorrectDerivationPath,
},
{
path: "m",
dI: []uint32{},
err: nil,
},
{
path: "m/",
dI: []uint32{},
err: nil,
},
{
path: "a",
dI: []uint32{},
err: errIncorrectDerivationPath,
},
{
path: "m/1",
dI: []uint32{
1,
},
err: nil,
},
{
path: "M/1",
dI: []uint32{
1,
},
err: nil,
},
{
path: "m/1/2",
dI: []uint32{
1,
2,
},
err: nil,
},
{
path: "m/a",
err: errIncorrectDerivationPath,
},
{
path: "m/1H",
dI: []uint32{
1 + hdkeychain.HardenedKeyStart,
},
err: nil,
},
{
path: "m/H1",
err: errIncorrectDerivationPath,
},
{
path: "m/1Hh",
err: errIncorrectDerivationPath,
},
{
path: "m/1H'",
err: errIncorrectDerivationPath,
},
{
path: "m/1Ha",
err: errIncorrectDerivationPath,
},
{
path: "m/2h",
dI: []uint32{
2 + hdkeychain.HardenedKeyStart,
},
err: nil,
},
{
path: "m/3'",
dI: []uint32{
3 + hdkeychain.HardenedKeyStart,
},
err: nil,
},
{
path: "m/4294967295", // max valid index, (2^32 - 1)
dI: []uint32{
4294967295,
},
err: nil,
},
{
path: "m/4294967295H",
err: errIncorrectDerivationPath,
},
{
path: "m/1/2'/3/4'/5/6h/7H",
dI: []uint32{
1,
2 + hdkeychain.HardenedKeyStart,
3,
4 + hdkeychain.HardenedKeyStart,
5,
6 + hdkeychain.HardenedKeyStart,
7 + hdkeychain.HardenedKeyStart,
},
err: nil,
},
}

vault := NewWallet()
for _, table := range tables {
dI, err := vault.getDerivationIndexesFromPath(table.path)
if err != table.err {
t.Errorf("unexpected error encountered while generating derivation indexes from path: %s", table.path)
}

indexesMismatch := false
if len(dI) != len(table.dI) {
indexesMismatch = true
}
for index := range dI {
if dI[index] != table.dI[index] {
indexesMismatch = true
}
}

if indexesMismatch {
t.Errorf("derivation indexes do not match for path: %s, expected: %v, got: %v", table.path, table.dI, dI)
}
}
}

// TestBIP32SpecTestVector tests the private & pubic key derived at certain derivation paths as specified in BIP32 standard specification
func TestBIP32SpecTestVector(t *testing.T) {
tables := []struct {
Expand Down

0 comments on commit f057f38

Please sign in to comment.