/
kms.go
135 lines (107 loc) · 2.92 KB
/
kms.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package sender
import (
"context"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"golang.org/x/exp/slices"
)
var secp256k1N = crypto.S256().Params().N
var secp256k1HalfN = new(big.Int).Div(secp256k1N, big.NewInt(2))
type kmsSender struct {
keyID string
pub []byte
address common.Address
client *kms.Client
}
func (s *kmsSender) Address() common.Address {
return s.address
}
func (s *kmsSender) Sign(ctx context.Context, hash common.Hash) ([]byte, error) {
signInput := &kms.SignInput{
KeyId: aws.String(s.keyID),
SigningAlgorithm: types.SigningAlgorithmSpecEcdsaSha256,
MessageType: types.MessageTypeDigest,
Message: hash[:],
}
out, err := s.client.Sign(ctx, signInput)
if err != nil {
return nil, err
}
sig := new(ecdsaSigFields)
_, err = asn1.Unmarshal(out.Signature, sig)
if err != nil {
return nil, err
}
sigR, sigS := sig.R.Bytes, sig.S.Bytes
// Correct S, if necessary, so that it's in the lower half of the group.
sigSNum := new(big.Int).SetBytes(sigS)
if sigSNum.Cmp(secp256k1HalfN) > 0 {
sigS = new(big.Int).Sub(secp256k1N, sigSNum).Bytes()
}
// Determine whether V ought to be 0 or 1.
sigRS := append(fixLen(sigR), fixLen(sigS)...)
sigRSV := append(sigRS, 0)
recPub, err := crypto.Ecrecover(hash[:], sigRSV)
if err != nil {
return nil, err
}
if slices.Equal(recPub, s.pub) {
return sigRSV, nil
}
sigRSV = append(sigRS, 1)
recPub, err = crypto.Ecrecover(hash[:], sigRSV)
if err != nil {
return nil, err
}
if slices.Equal(recPub, s.pub) {
return sigRSV, nil
}
return nil, fmt.Errorf("couldn't choose a working V from the returned R and S")
}
func fixLen(in []byte) []byte {
outStart := 0
inLen := len(in)
inStart := 0
if inLen > 32 {
inStart = inLen - 32
} else if inLen < 32 {
outStart = 32 - inLen
}
out := make([]byte, common.HashLength)
copy(out[outStart:], in[inStart:])
return out
}
type publicKeyFields struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}
type ecdsaSigFields struct {
R asn1.RawValue
S asn1.RawValue
}
func FromKMS(ctx context.Context, client *kms.Client, keyID string) (Sender, error) {
pubResp, err := client.GetPublicKey(ctx, &kms.GetPublicKeyInput{KeyId: aws.String(keyID)})
if err != nil {
return nil, err
}
cert := new(publicKeyFields)
_, err = asn1.Unmarshal(pubResp.PublicKey, cert)
if err != nil {
return nil, err
}
pub, err := crypto.UnmarshalPubkey(cert.SubjectPublicKey.Bytes)
if err != nil {
return nil, err
}
pubBytes := secp256k1.S256().Marshal(pub.X, pub.Y)
addr := crypto.PubkeyToAddress(*pub)
return &kmsSender{keyID: keyID, pub: pubBytes, address: addr, client: client}, nil
}