/
kms_helper.go
182 lines (165 loc) · 4.42 KB
/
kms_helper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package aws_encryption_sdk
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"golang.org/x/crypto/hkdf"
)
type KmsHelper struct {
client *kms.KMS
}
func NewKmsHelper(region string, assumedRole string) *KmsHelper {
k := &KmsHelper{}
// Set up AWS KMS session
conf := aws.NewConfig().WithRegion(region)
sess := session.Must(session.NewSession(conf))
if assumedRole != "" {
creds := stscreds.NewCredentials(sess, assumedRole)
k.client = kms.New(sess, &aws.Config{Credentials: creds})
} else {
k.client = kms.New(sess)
}
return k
}
// Decrypt encrypted data keys
func (k *KmsHelper) decryptDataKeys(m *Message) ([][]byte, error) {
ret := make([][]byte, 0)
var i uint16
for i = 0; i < m.EncDataKeyCount; i++ {
data, err := k.kmsDecrypt(m.EncDataKeys[i].EncKeyData, m)
if err != nil {
return nil, err
}
ret = append(ret, data)
}
return ret, nil
}
// Generate derived encryption key
func (k *KmsHelper) getDerivedKey(key []byte, m *Message) ([]byte, error) {
if m.Algorithm.HashFunc != nil {
info := bytes.NewBuffer(nil)
if err := binary.Write(info, binary.BigEndian, m.Algorithm.Id); err != nil {
return nil, err
}
if _, err := info.Write(m.MessageId[:]); err != nil {
return nil, err
}
tmp_hkdf := hkdf.New(m.Algorithm.HashFunc, key, nil, info.Bytes())
ret := make([]byte, m.Algorithm.DataKeyLength)
if _, err := tmp_hkdf.Read(ret); err != nil {
return nil, err
}
return ret, nil
} else {
return key, nil
}
}
// Build additional data string for use in decryption
func (k *KmsHelper) buildContentAAD(m *Message, f *Frame) ([]byte, error) {
ret := bytes.NewBuffer(nil)
if _, err := ret.Write(m.MessageId[:]); err != nil {
return nil, err
}
if _, err := ret.Write(f.AADContentString); err != nil {
return nil, err
}
if err := binary.Write(ret, binary.BigEndian, f.SeqNumber); err != nil {
return nil, err
}
if err := binary.Write(ret, binary.BigEndian, uint64(f.EncContentLength)); err != nil {
return nil, err
}
return ret.Bytes(), nil
}
// Decrypt using KMS
func (k *KmsHelper) kmsDecrypt(data []byte, m *Message) ([]byte, error) {
input := &kms.DecryptInput{
CiphertextBlob: data,
}
if m != nil {
context := make(map[string]*string)
for key, value := range m.EncContext {
context[key] = &value
}
input.EncryptionContext = context
}
result, err := k.client.Decrypt(input)
if err != nil {
return nil, err
}
return result.Plaintext, nil
}
// Decryption entrypoint
func (k *KmsHelper) Decrypt(data []byte) ([]byte, error) {
var err error
var plaintext []byte
var data_keys [][]byte
// Try simple KMS decryption first
if plaintext, err = k.kmsDecrypt(data, nil); err == nil {
return plaintext, nil
} else if strings.HasPrefix(err.Error(), kms.ErrCodeInvalidCiphertextException) {
// Do nothing for an InvalidCiphertextException error
} else {
// Unknown error
return nil, err
}
r := bytes.NewReader(data)
message := NewMessage()
if err := message.Decode(r); err != nil {
return nil, err
}
data_keys, err = k.decryptDataKeys(message)
if err != nil {
return nil, err
}
plaintext = make([]byte, 0)
for _, frame := range message.Frames {
// TODO: support multiple data keys
tmp_key, err := k.getDerivedKey(data_keys[0], message)
if err != nil {
return nil, err
}
var c cipher.Block
switch message.Algorithm.Type {
case ALGORITHM_TYPE_AES:
c, err = aes.NewCipher(tmp_key)
if err != nil {
return nil, err
}
default:
return nil, errors.New("Unknown encryption algorithm type")
}
var mode cipher.AEAD
switch message.Algorithm.Mode {
case ALGORITHM_MODE_GCM:
mode, err = cipher.NewGCM(c)
if err != nil {
return nil, err
}
default:
return nil, errors.New("Unknown encryption algorithm mode")
}
ciphertext := frame.EncContent
// The encryption functions expect the auth tag to be appended to the ciphertext
ciphertext = append(ciphertext, frame.AuthTag...)
nonce := frame.IV
contentAAD, err := k.buildContentAAD(message, &frame)
if err != nil {
return nil, err
}
frame_plaintext, err := mode.Open(nil, nonce, ciphertext, contentAAD)
if err != nil {
return nil, err
}
// Append frame plaintext to overall plaintext
plaintext = append(plaintext, frame_plaintext...)
}
return plaintext, nil
}