forked from spiffe/spire
/
keymanagerbase.go
356 lines (305 loc) · 10.6 KB
/
keymanagerbase.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
package keymanagerbase
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"sort"
"sync"
keymanagerv1 "github.com/accuknox/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// KeyEntry is an entry maintained by the key manager
type KeyEntry struct {
PrivateKey crypto.Signer
*keymanagerv1.PublicKey
}
// Config is a collection of optional callbacks. Default implementations will be
// used when not provided.
type Config struct {
// Generator is an optional key generator.
Generator Generator
// WriteEntries is an optional callback used to persist key entries
WriteEntries func(ctx context.Context, entries []*KeyEntry) error
}
// Generator is a key generator
type Generator interface {
GenerateRSA2048Key() (*rsa.PrivateKey, error)
GenerateRSA4096Key() (*rsa.PrivateKey, error)
GenerateEC256Key() (*ecdsa.PrivateKey, error)
GenerateEC384Key() (*ecdsa.PrivateKey, error)
}
// Base is the base KeyManager implementation
type Base struct {
keymanagerv1.UnsafeKeyManagerServer
config Config
mu sync.RWMutex
entries map[string]*KeyEntry
}
// New creates a new base key manager using the provided config.
func New(config Config) *Base {
if config.Generator == nil {
config.Generator = defaultGenerator{}
}
return &Base{
config: config,
entries: make(map[string]*KeyEntry),
}
}
// SetEntries is used to replace the set of managed entries. This is generally
// called by implementations when they are first loaded to set the initial set
// of entries.
func (m *Base) SetEntries(entries []*KeyEntry) {
m.mu.Lock()
defer m.mu.Unlock()
m.entries = entriesMapFromSlice(entries)
// populate the fingerprints
for _, entry := range m.entries {
entry.PublicKey.Fingerprint = makeFingerprint(entry.PublicKey.PkixData)
}
}
// GenerateKey implements the KeyManager RPC of the same name.
func (m *Base) GenerateKey(ctx context.Context, req *keymanagerv1.GenerateKeyRequest) (*keymanagerv1.GenerateKeyResponse, error) {
resp, err := m.generateKey(ctx, req)
return resp, prefixStatus(err, "failed to generate key")
}
// GetPublicKey implements the KeyManager RPC of the same name.
func (m *Base) GetPublicKey(ctx context.Context, req *keymanagerv1.GetPublicKeyRequest) (*keymanagerv1.GetPublicKeyResponse, error) {
if req.KeyId == "" {
return nil, status.Error(codes.InvalidArgument, "key id is required")
}
m.mu.RLock()
defer m.mu.RUnlock()
resp := new(keymanagerv1.GetPublicKeyResponse)
entry := m.entries[req.KeyId]
if entry != nil {
resp.PublicKey = clonePublicKey(entry.PublicKey)
}
return resp, nil
}
// GetPublicKeys implements the KeyManager RPC of the same name.
func (m *Base) GetPublicKeys(ctx context.Context, req *keymanagerv1.GetPublicKeysRequest) (*keymanagerv1.GetPublicKeysResponse, error) {
m.mu.RLock()
defer m.mu.RUnlock()
resp := new(keymanagerv1.GetPublicKeysResponse)
for _, entry := range entriesSliceFromMap(m.entries) {
resp.PublicKeys = append(resp.PublicKeys, clonePublicKey(entry.PublicKey))
}
return resp, nil
}
// SignData implements the KeyManager RPC of the same name.
func (m *Base) SignData(ctx context.Context, req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDataResponse, error) {
resp, err := m.signData(req)
return resp, prefixStatus(err, "failed to sign data")
}
func (m *Base) generateKey(ctx context.Context, req *keymanagerv1.GenerateKeyRequest) (*keymanagerv1.GenerateKeyResponse, error) {
if req.KeyId == "" {
return nil, status.Error(codes.InvalidArgument, "key id is required")
}
if req.KeyType == keymanagerv1.KeyType_UNSPECIFIED_KEY_TYPE {
return nil, status.Error(codes.InvalidArgument, "key type is required")
}
newEntry, err := m.generateKeyEntry(req.KeyId, req.KeyType)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
oldEntry, hasEntry := m.entries[req.KeyId]
m.entries[req.KeyId] = newEntry
if m.config.WriteEntries != nil {
if err := m.config.WriteEntries(ctx, entriesSliceFromMap(m.entries)); err != nil {
if hasEntry {
m.entries[req.KeyId] = oldEntry
} else {
delete(m.entries, req.KeyId)
}
return nil, err
}
}
return &keymanagerv1.GenerateKeyResponse{
PublicKey: clonePublicKey(newEntry.PublicKey),
}, nil
}
func (m *Base) signData(req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDataResponse, error) {
if req.KeyId == "" {
return nil, status.Error(codes.InvalidArgument, "key id is required")
}
if req.SignerOpts == nil {
return nil, status.Error(codes.InvalidArgument, "signer opts is required")
}
var signerOpts crypto.SignerOpts
switch opts := req.SignerOpts.(type) {
case *keymanagerv1.SignDataRequest_HashAlgorithm:
if opts.HashAlgorithm == keymanagerv1.HashAlgorithm_UNSPECIFIED_HASH_ALGORITHM {
return nil, status.Error(codes.InvalidArgument, "hash algorithm is required")
}
signerOpts = crypto.Hash(opts.HashAlgorithm)
case *keymanagerv1.SignDataRequest_PssOptions:
if opts.PssOptions == nil {
return nil, status.Error(codes.InvalidArgument, "PSS options are nil")
}
if opts.PssOptions.HashAlgorithm == keymanagerv1.HashAlgorithm_UNSPECIFIED_HASH_ALGORITHM {
return nil, status.Error(codes.InvalidArgument, "hash algorithm in PSS options is required")
}
signerOpts = &rsa.PSSOptions{
SaltLength: int(opts.PssOptions.SaltLength),
Hash: crypto.Hash(opts.PssOptions.HashAlgorithm),
}
default:
return nil, status.Errorf(codes.InvalidArgument, "unsupported signer opts type %T", opts)
}
privateKey, fingerprint, ok := m.getPrivateKeyAndFingerprint(req.KeyId)
if !ok {
return nil, status.Errorf(codes.NotFound, "no such key %q", req.KeyId)
}
signature, err := privateKey.Sign(rand.Reader, req.Data, signerOpts)
if err != nil {
return nil, status.Errorf(codes.Internal, "keypair %q signing operation failed: %v", req.KeyId, err)
}
return &keymanagerv1.SignDataResponse{
Signature: signature,
KeyFingerprint: fingerprint,
}, nil
}
func (m *Base) getPrivateKeyAndFingerprint(id string) (crypto.Signer, string, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
if entry := m.entries[id]; entry != nil {
return entry.PrivateKey, entry.PublicKey.Fingerprint, true
}
return nil, "", false
}
func (m *Base) generateKeyEntry(keyID string, keyType keymanagerv1.KeyType) (e *KeyEntry, err error) {
var privateKey crypto.Signer
switch keyType {
case keymanagerv1.KeyType_EC_P256:
privateKey, err = m.config.Generator.GenerateEC256Key()
case keymanagerv1.KeyType_EC_P384:
privateKey, err = m.config.Generator.GenerateEC384Key()
case keymanagerv1.KeyType_RSA_2048:
privateKey, err = m.config.Generator.GenerateRSA2048Key()
case keymanagerv1.KeyType_RSA_4096:
privateKey, err = m.config.Generator.GenerateRSA4096Key()
default:
return nil, status.Errorf(codes.InvalidArgument, "unable to generate key %q for unknown key type %q", keyID, keyType)
}
if err != nil {
return nil, err
}
entry, err := makeKeyEntry(keyID, keyType, privateKey)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to make key entry for new key %q: %v", keyID, err)
}
return entry, nil
}
func makeKeyEntry(keyID string, keyType keymanagerv1.KeyType, privateKey crypto.Signer) (*KeyEntry, error) {
pkixData, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
return nil, fmt.Errorf("failed to marshal public key for entry %q: %w", keyID, err)
}
return &KeyEntry{
PrivateKey: privateKey,
PublicKey: &keymanagerv1.PublicKey{
Id: keyID,
Type: keyType,
PkixData: pkixData,
Fingerprint: makeFingerprint(pkixData),
},
}, nil
}
func MakeKeyEntryFromKey(id string, privateKey crypto.PrivateKey) (*KeyEntry, error) {
switch privateKey := privateKey.(type) {
case *ecdsa.PrivateKey:
keyType, err := ecdsaKeyType(privateKey)
if err != nil {
return nil, fmt.Errorf("unable to make key entry for key %q: %w", id, err)
}
return makeKeyEntry(id, keyType, privateKey)
case *rsa.PrivateKey:
keyType, err := rsaKeyType(privateKey)
if err != nil {
return nil, fmt.Errorf("unable to make key entry for key %q: %w", id, err)
}
return makeKeyEntry(id, keyType, privateKey)
default:
return nil, fmt.Errorf("unexpected private key type %T for key %q", privateKey, id)
}
}
func rsaKeyType(privateKey *rsa.PrivateKey) (keymanagerv1.KeyType, error) {
bits := privateKey.N.BitLen()
switch bits {
case 2048:
return keymanagerv1.KeyType_RSA_2048, nil
case 4096:
return keymanagerv1.KeyType_RSA_4096, nil
default:
return keymanagerv1.KeyType_UNSPECIFIED_KEY_TYPE, fmt.Errorf("no RSA key type for key bit length: %d", bits)
}
}
func ecdsaKeyType(privateKey *ecdsa.PrivateKey) (keymanagerv1.KeyType, error) {
switch {
case privateKey.Curve == elliptic.P256():
return keymanagerv1.KeyType_EC_P256, nil
case privateKey.Curve == elliptic.P384():
return keymanagerv1.KeyType_EC_P384, nil
default:
return keymanagerv1.KeyType_UNSPECIFIED_KEY_TYPE, fmt.Errorf("no EC key type for EC curve: %s",
privateKey.Curve.Params().Name)
}
}
type defaultGenerator struct{}
func (defaultGenerator) GenerateRSA2048Key() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(rand.Reader, 2048)
}
func (defaultGenerator) GenerateRSA4096Key() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(rand.Reader, 4096)
}
func (defaultGenerator) GenerateEC256Key() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
}
func (defaultGenerator) GenerateEC384Key() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
}
func entriesSliceFromMap(entriesMap map[string]*KeyEntry) (entriesSlice []*KeyEntry) {
for _, entry := range entriesMap {
entriesSlice = append(entriesSlice, entry)
}
SortKeyEntries(entriesSlice)
return entriesSlice
}
func entriesMapFromSlice(entriesSlice []*KeyEntry) map[string]*KeyEntry {
// return keys in sorted order for consistency
entriesMap := make(map[string]*KeyEntry, len(entriesSlice))
for _, entry := range entriesSlice {
entriesMap[entry.Id] = entry
}
return entriesMap
}
func clonePublicKey(publicKey *keymanagerv1.PublicKey) *keymanagerv1.PublicKey {
return proto.Clone(publicKey).(*keymanagerv1.PublicKey)
}
func makeFingerprint(pkixData []byte) string {
s := sha256.Sum256(pkixData)
return hex.EncodeToString(s[:])
}
func SortKeyEntries(entries []*KeyEntry) {
sort.Slice(entries, func(i, j int) bool {
return entries[i].Id < entries[j].Id
})
}
func prefixStatus(err error, prefix string) error {
st := status.Convert(err)
if st.Code() != codes.OK {
return status.Error(st.Code(), prefix+": "+st.Message())
}
return err
}