205 changes: 96 additions & 109 deletions service/s3/s3crypto/kms_key_handler.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package s3crypto

import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/kms"
Expand All @@ -12,16 +10,15 @@ import (
const (
// KMSWrap is a constant used during decryption to build a KMS key handler.
KMSWrap = "kms"

// KMSContextWrap is a constant used during decryption to build a kms+context key handler
KMSContextWrap = "kms+context"
)

// kmsKeyHandler will make calls to KMS to get the masterkey
type kmsKeyHandler struct {
kms kmsiface.KMSAPI
cmkID *string
withContext bool
kms kmsiface.KMSAPI
cmkID *string

// useProvidedCMK is toggled when using `kms` key wrapper with V2 client
useProvidedCMK bool

CipherData
}
Expand All @@ -30,116 +27,118 @@ type kmsKeyHandler struct {
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// sess := session.Must(session.NewSession())
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
//
// deprecated: See NewKMSContextKeyGenerator
// deprecated: This feature is in maintenance mode, no new updates will be released. Please see https://docs.aws.amazon.com/general/latest/gr/aws_sdk_cryptography.html for more information.
func NewKMSKeyGenerator(kmsClient kmsiface.KMSAPI, cmkID string) CipherDataGenerator {
return NewKMSKeyGeneratorWithMatDesc(kmsClient, cmkID, MaterialDescription{})
}

// NewKMSContextKeyGenerator builds a new kms+context key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSContextKeyGenerator(kms.New(sess), cmkID)
func NewKMSContextKeyGenerator(client kmsiface.KMSAPI, cmkID string) CipherDataGeneratorWithCEKAlg {
return NewKMSContextKeyGeneratorWithMatDesc(client, cmkID, MaterialDescription{})
}

func newKMSKeyHandler(client kmsiface.KMSAPI, cmkID string, withContext bool, matdesc MaterialDescription) *kmsKeyHandler {
func newKMSKeyHandler(client kmsiface.KMSAPI, cmkID string, matdesc MaterialDescription) *kmsKeyHandler {
// These values are read only making them thread safe
kp := &kmsKeyHandler{
kms: client,
cmkID: &cmkID,
withContext: withContext,
kms: client,
cmkID: &cmkID,
}

if matdesc == nil {
matdesc = MaterialDescription{}
}

// These values are read only making them thread safe
if kp.withContext {
kp.CipherData.WrapAlgorithm = KMSContextWrap
} else {
matdesc["kms_cmk_id"] = &cmkID
kp.CipherData.WrapAlgorithm = KMSWrap
}
matdesc["kms_cmk_id"] = &cmkID

kp.CipherData.WrapAlgorithm = KMSWrap
kp.CipherData.MaterialDescription = matdesc

return kp
}

// NewKMSKeyGeneratorWithMatDesc builds a new KMS key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// sess := session.Must(session.NewSession())
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSKeyGeneratorWithMatDesc(kms.New(sess), cmkID, matdesc)
//
// deprecated: See NewKMSContextKeyGeneratorWithMatDesc
// deprecated: This feature is in maintenance mode, no new updates will be released. Please see https://docs.aws.amazon.com/general/latest/gr/aws_sdk_cryptography.html for more information.
func NewKMSKeyGeneratorWithMatDesc(kmsClient kmsiface.KMSAPI, cmkID string, matdesc MaterialDescription) CipherDataGenerator {
return newKMSKeyHandler(kmsClient, cmkID, false, matdesc)
}

// NewKMSContextKeyGeneratorWithMatDesc builds a new kms+context key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSKeyGeneratorWithMatDesc(kms.New(sess), cmkID, matdesc)
func NewKMSContextKeyGeneratorWithMatDesc(kmsClient kmsiface.KMSAPI, cmkID string, matdesc MaterialDescription) CipherDataGeneratorWithCEKAlg {
return newKMSKeyHandler(kmsClient, cmkID, true, matdesc)
return newKMSKeyHandler(kmsClient, cmkID, matdesc)
}

// NewKMSWrapEntry builds returns a new KMS key provider and its decrypt handler.
//
// Example:
// sess := session.New(&aws.Config{})
// sess := session.Must(session.NewSession())
// customKMSClient := kms.New(sess)
// decryptHandler := s3crypto.NewKMSWrapEntry(customKMSClient)
//
// svc := s3crypto.NewDecryptionClient(sess, func(svc *s3crypto.DecryptionClient) {
// svc.WrapRegistry[s3crypto.KMSWrap] = decryptHandler
// }))
//
// deprecated: See NewKMSContextWrapEntry
// deprecated: This feature is in maintenance mode, no new updates will be released. Please see https://docs.aws.amazon.com/general/latest/gr/aws_sdk_cryptography.html for more information.
func NewKMSWrapEntry(kmsClient kmsiface.KMSAPI) WrapEntry {
// These values are read only making them thread safe
kp := &kmsKeyHandler{
kms: kmsClient,
}

kp := newKMSWrapEntry(kmsClient)
return kp.decryptHandler
}

// NewKMSContextWrapEntry builds returns a new KMS key provider and its decrypt handler.
// RegisterKMSWrapWithCMK registers the `kms` wrapping algorithm to the given WrapRegistry. The wrapper will be
// configured to call KMS Decrypt with the provided CMK.
//
// Example:
// sess := session.New(&aws.Config{})
// customKMSClient := kms.New(sess)
// decryptHandler := s3crypto.NewKMSContextWrapEntry(customKMSClient)
// sess := session.Must(session.NewSession())
// cr := s3crypto.NewCryptoRegistry()
// if err := s3crypto.RegisterKMSWrapWithCMK(cr, kms.New(sess), "cmkId"); err != nil {
// panic(err) // handle error
// }
//
// svc := s3crypto.NewDecryptionClient(sess, func(svc *s3crypto.DecryptionClient) {
// svc.WrapRegistry[s3crypto.KMSContextWrap] = decryptHandler
// }))
func NewKMSContextWrapEntry(kmsClient kmsiface.KMSAPI) WrapEntry {
// deprecated: This feature is in maintenance mode, no new updates will be released. Please see https://docs.aws.amazon.com/general/latest/gr/aws_sdk_cryptography.html for more information.
func RegisterKMSWrapWithCMK(registry *CryptoRegistry, client kmsiface.KMSAPI, cmkID string) error {
if registry == nil {
return errNilCryptoRegistry
}
return registry.AddWrap(KMSWrap, newKMSWrapEntryWithCMK(client, cmkID))
}

// RegisterKMSWrapWithAnyCMK registers the `kms` wrapping algorithm to the given WrapRegistry. The wrapper will be
// configured to call KMS Decrypt without providing a CMK.
//
// Example:
// sess := session.Must(session.NewSession())
// cr := s3crypto.NewCryptoRegistry()
// if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kms.New(sess)); err != nil {
// panic(err) // handle error
// }
//
// deprecated: This feature is in maintenance mode, no new updates will be released. Please see https://docs.aws.amazon.com/general/latest/gr/aws_sdk_cryptography.html for more information.
func RegisterKMSWrapWithAnyCMK(registry *CryptoRegistry, client kmsiface.KMSAPI) error {
if registry == nil {
return errNilCryptoRegistry
}
return registry.AddWrap(KMSWrap, NewKMSWrapEntry(client))
}

// newKMSWrapEntryWithCMK builds returns a new KMS key provider and its decrypt handler. The wrap entry will be configured
// to only attempt to decrypt the data key using the provided CMK.
func newKMSWrapEntryWithCMK(kmsClient kmsiface.KMSAPI, cmkID string) WrapEntry {
kp := newKMSWrapEntry(kmsClient)
kp.useProvidedCMK = true
kp.cmkID = &cmkID
return kp.decryptHandler
}

func newKMSWrapEntry(kmsClient kmsiface.KMSAPI) *kmsKeyHandler {
// These values are read only making them thread safe
kp := &kmsKeyHandler{
kms: kmsClient,
withContext: true,
kms: kmsClient,
}

return kp.decryptHandler
return kp
}

// decryptHandler initializes a KMS keyprovider with a material description. This
Expand All @@ -151,17 +150,14 @@ func (kp kmsKeyHandler) decryptHandler(env Envelope) (CipherDataDecrypter, error
return nil, err
}

cmkID, ok := m["kms_cmk_id"]
if !kp.withContext && !ok {
_, ok := m["kms_cmk_id"]
if !ok {
return nil, awserr.New("MissingCMKIDError", "Material description is missing CMK ID", nil)
}

kp.CipherData.MaterialDescription = m
kp.cmkID = cmkID
kp.WrapAlgorithm = KMSWrap
if kp.withContext {
kp.WrapAlgorithm = KMSContextWrap
}

return &kp, nil
}

Expand All @@ -172,12 +168,18 @@ func (kp *kmsKeyHandler) DecryptKey(key []byte) ([]byte, error) {

// DecryptKeyWithContext makes a call to KMS to decrypt the key with request context.
func (kp *kmsKeyHandler) DecryptKeyWithContext(ctx aws.Context, key []byte) ([]byte, error) {
out, err := kp.kms.DecryptWithContext(ctx,
&kms.DecryptInput{
EncryptionContext: kp.CipherData.MaterialDescription,
CiphertextBlob: key,
GrantTokens: []*string{},
})
in := &kms.DecryptInput{
EncryptionContext: kp.MaterialDescription,
CiphertextBlob: key,
GrantTokens: []*string{},
}

// useProvidedCMK will be true if a constructor was used with the new V2 client
if kp.useProvidedCMK {
in.KeyId = kp.cmkID
}

out, err := kp.kms.DecryptWithContext(ctx, in)
if err != nil {
return nil, err
}
Expand All @@ -190,31 +192,14 @@ func (kp *kmsKeyHandler) GenerateCipherData(keySize, ivSize int) (CipherData, er
return kp.GenerateCipherDataWithContext(aws.BackgroundContext(), keySize, ivSize)
}

func (kp kmsKeyHandler) GenerateCipherDataWithCEKAlg(keySize, ivSize int, cekAlgorithm string) (CipherData, error) {
return kp.GenerateCipherDataWithCEKAlgWithContext(aws.BackgroundContext(), keySize, ivSize, cekAlgorithm)
}

// GenerateCipherDataWithContext makes a call to KMS to generate a data key,
// Upon making the call, it also sets the encrypted key.
func (kp *kmsKeyHandler) GenerateCipherDataWithContext(ctx aws.Context, keySize, ivSize int) (CipherData, error) {
return kp.GenerateCipherDataWithCEKAlgWithContext(ctx, keySize, ivSize, "")
}

func (kp kmsKeyHandler) GenerateCipherDataWithCEKAlgWithContext(ctx aws.Context, keySize int, ivSize int, cekAlgorithm string) (CipherData, error) {
md := kp.CipherData.MaterialDescription

wrapAlgorithm := KMSWrap
if kp.withContext {
wrapAlgorithm = KMSContextWrap
if len(cekAlgorithm) == 0 {
return CipherData{}, fmt.Errorf("CEK algorithm identifier must not be empty")
}
md["aws:"+cekAlgorithmHeader] = &cekAlgorithm
}
cd := kp.CipherData.Clone()

out, err := kp.kms.GenerateDataKeyWithContext(ctx,
&kms.GenerateDataKeyInput{
EncryptionContext: md,
EncryptionContext: cd.MaterialDescription,
KeyId: kp.cmkID,
KeySpec: aws.String("AES_256"),
})
Expand All @@ -227,19 +212,21 @@ func (kp kmsKeyHandler) GenerateCipherDataWithCEKAlgWithContext(ctx aws.Context,
return CipherData{}, err
}

cd := CipherData{
Key: out.Plaintext,
IV: iv,
WrapAlgorithm: wrapAlgorithm,
MaterialDescription: md,
EncryptedKey: out.CiphertextBlob,
}
cd.Key = out.Plaintext
cd.IV = iv
cd.EncryptedKey = out.CiphertextBlob

return cd, nil
}

func (kp *kmsKeyHandler) isUsingDeprecatedFeatures() error {
if !kp.withContext {
return errDeprecatedCipherDataGenerator
}
return nil
func (kp kmsKeyHandler) isAWSFixture() bool {
return true
}

var (
_ CipherDataGenerator = (*kmsKeyHandler)(nil)
_ CipherDataGeneratorWithContext = (*kmsKeyHandler)(nil)
_ CipherDataDecrypter = (*kmsKeyHandler)(nil)
_ CipherDataDecrypterWithContext = (*kmsKeyHandler)(nil)
_ awsFixture = (*kmsKeyHandler)(nil)
)
142 changes: 55 additions & 87 deletions service/s3/s3crypto/kms_key_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
Expand All @@ -17,15 +16,15 @@ import (
"github.com/aws/aws-sdk-go/service/kms"
)

func TestBuildKMSEncryptHandler(t *testing.T) {
func TestNewKMSKeyGenerator(t *testing.T) {
svc := kms.New(unit.Session)
handler := NewKMSKeyGenerator(svc, "testid")
if handler == nil {
t.Error("expected non-nil handler")
}
}

func TestBuildKMSEncryptHandlerWithMatDesc(t *testing.T) {
func TestNewKMSKeyGeneratorWithMatDesc(t *testing.T) {
svc := kms.New(unit.Session)
handler := NewKMSKeyGeneratorWithMatDesc(svc, "testid", MaterialDescription{
"Testing": aws.String("123"),
Expand All @@ -45,7 +44,7 @@ func TestBuildKMSEncryptHandlerWithMatDesc(t *testing.T) {
}
}

func TestKMSGenerateCipherData(t *testing.T) {
func TestKmsKeyHandler_GenerateCipherData(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"CiphertextBlob":"AQEDAHhqBCCY1MSimw8gOGcUma79cn4ANvTtQyv9iuBdbcEF1QAAAH4wfAYJKoZIhvcNAQcGoG8wbQIBADBoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDJ6IcN5E4wVbk38MNAIBEIA7oF1E3lS7FY9DkoxPc/UmJsEwHzL82zMqoLwXIvi8LQHr8If4Lv6zKqY8u0+JRgSVoqCvZDx3p8Cn6nM=","KeyId":"arn:aws:kms:us-west-2:042062605278:key/c80a5cdb-8d09-4f9f-89ee-df01b2e3870a","Plaintext":"6tmyz9JLBE2yIuU7iXpArqpDVle172WSmxjcO6GNT7E="}`)
}))
Expand Down Expand Up @@ -77,10 +76,19 @@ func TestKMSGenerateCipherData(t *testing.T) {
}
}

func TestKMSDecrypt(t *testing.T) {
func TestKmsKeyHandler_DecryptKey(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("expected no error, got %v", err)
w.WriteHeader(500)
return
}
if bytes.Contains(body, []byte(`"KeyId":"test"`)) {
t.Errorf("expected CMK to not be sent")
}
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
Expand All @@ -92,7 +100,7 @@ func TestKMSDecrypt(t *testing.T) {
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
handler, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"}`})
handler, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{WrapAlg: KMSWrap, MatDesc: `{"kms_cmk_id":"test"}`})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
Expand All @@ -107,37 +115,22 @@ func TestKMSDecrypt(t *testing.T) {
}
}

func TestKMSContextGenerateCipherData(t *testing.T) {
func TestKmsKeyHandler_DecryptKey_WithCMK(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyBytes, err := ioutil.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(500)
return
}
var body map[string]interface{}
err = json.Unmarshal(bodyBytes, &body)
if err != nil {
w.WriteHeader(500)
return
}

md, ok := body["EncryptionContext"].(map[string]interface{})
if !ok {
t.Errorf("expected no error, got %v", err)
w.WriteHeader(500)
return
}

exEncContext := map[string]interface{}{
"aws:" + cekAlgorithmHeader: "cekAlgValue",
}

if e, a := exEncContext, md; !reflect.DeepEqual(e, a) {
w.WriteHeader(500)
t.Errorf("expected %v, got %v", e, a)
return
if !bytes.Contains(body, []byte(`"KeyId":"thisKey"`)) {
t.Errorf("expected CMK to be sent")
}

fmt.Fprintln(w, `{"CiphertextBlob":"AQEDAHhqBCCY1MSimw8gOGcUma79cn4ANvTtQyv9iuBdbcEF1QAAAH4wfAYJKoZIhvcNAQcGoG8wbQIBADBoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDJ6IcN5E4wVbk38MNAIBEIA7oF1E3lS7FY9DkoxPc/UmJsEwHzL82zMqoLwXIvi8LQHr8If4Lv6zKqY8u0+JRgSVoqCvZDx3p8Cn6nM=","KeyId":"arn:aws:kms:us-west-2:042062605278:key/c80a5cdb-8d09-4f9f-89ee-df01b2e3870a","Plaintext":"6tmyz9JLBE2yIuU7iXpArqpDVle172WSmxjcO6GNT7E="}`)
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()

Expand All @@ -148,79 +141,54 @@ func TestKMSContextGenerateCipherData(t *testing.T) {
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})

svc := kms.New(sess)
handler := NewKMSContextKeyGenerator(svc, "testid")

keySize := 32
ivSize := 16

cd, err := handler.GenerateCipherDataWithCEKAlg(keySize, ivSize, "cekAlgValue")
handler, err := newKMSWrapEntryWithCMK(kms.New(sess), "thisKey")(Envelope{WrapAlg: KMSWrap, MatDesc: `{"kms_cmk_id":"test"}`})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if keySize != len(cd.Key) {
t.Errorf("expected %d, but received %d", keySize, len(cd.Key))

plaintextKey, err := handler.DecryptKey([]byte{1, 2, 3, 4})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if ivSize != len(cd.IV) {
t.Errorf("expected %d, but received %d", ivSize, len(cd.IV))
if !bytes.Equal(key, plaintextKey) {
t.Errorf("expected %v, but received %v", key, plaintextKey)
}
}

func TestKMSContextDecrypt(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyBytes, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(500)
return
}
var body map[string]interface{}
err = json.Unmarshal(bodyBytes, &body)
if err != nil {
w.WriteHeader(500)
return
}
func TestRegisterKMSWrapWithAnyCMK(t *testing.T) {
kmsClient := kms.New(unit.Session.Copy())

md, ok := body["EncryptionContext"].(map[string]interface{})
if !ok {
w.WriteHeader(500)
return
}
cr := NewCryptoRegistry()
if err := RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil {
t.Errorf("expected no error, got %v", err)
}

exEncContext := map[string]interface{}{
"aws:" + cekAlgorithmHeader: "cekAlgValue",
}
if wrap, ok := cr.GetWrap(KMSWrap); !ok {
t.Errorf("expected wrapped to be present")
} else if wrap == nil {
t.Errorf("expected wrap to not be nil")
}

if e, a := exEncContext, md; !reflect.DeepEqual(e, a) {
w.WriteHeader(500)
t.Errorf("expected %v, got %v", e, a)
return
}
if err := RegisterKMSWrapWithCMK(cr, kmsClient, "test-key-id"); err == nil {
t.Error("expected error, got none")
}
}

fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
func TestRegisterKMSWrapWithCMK(t *testing.T) {
kmsClient := kms.New(unit.Session.Copy())

sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
handler, err := NewKMSContextWrapEntry(kms.New(sess))(Envelope{MatDesc: `{"aws:x-amz-cek-alg": "cekAlgValue"}`})
if err != nil {
t.Errorf("expected no error, but received %v", err)
cr := NewCryptoRegistry()
if err := RegisterKMSWrapWithCMK(cr, kmsClient, "cmkId"); err != nil {
t.Errorf("expected no error, got %v", err)
}

plaintextKey, err := handler.DecryptKey([]byte{1, 2, 3, 4})
if err != nil {
t.Errorf("expected no error, but received %v", err)
if wrap, ok := cr.GetWrap(KMSWrap); !ok {
t.Errorf("expected wrapped to be present")
} else if wrap == nil {
t.Errorf("expected wrap to not be nil")
}

if !bytes.Equal(key, plaintextKey) {
t.Errorf("expected %v, but received %v", key, plaintextKey)
if err := RegisterKMSWrapWithAnyCMK(cr, kmsClient); err == nil {
t.Error("expected error, got none")
}
}
20 changes: 20 additions & 0 deletions service/s3/s3crypto/mat_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ import (
// key has been used.
type MaterialDescription map[string]*string

// Clone returns a copy of the MaterialDescription
func (md MaterialDescription) Clone() (clone MaterialDescription) {
if md == nil {
return nil
}
clone = make(MaterialDescription, len(md))
for k, v := range md {
clone[k] = copyPtrString(v)
}
return clone
}

func (md *MaterialDescription) encodeDescription() ([]byte, error) {
v, err := json.Marshal(&md)
return v, err
Expand All @@ -16,3 +28,11 @@ func (md *MaterialDescription) encodeDescription() ([]byte, error) {
func (md *MaterialDescription) decodeDescription(b []byte) error {
return json.Unmarshal(b, &md)
}

func copyPtrString(v *string) *string {
if v == nil {
return nil
}
ns := *v
return &ns
}
31 changes: 31 additions & 0 deletions service/s3/s3crypto/mat_desc_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// +build go1.7

package s3crypto

import (
Expand Down Expand Up @@ -33,3 +35,32 @@ func TestDecodeMaterialDescription(t *testing.T) {
t.Error("expected material description to be equivalent, but received otherwise")
}
}

func TestMaterialDescription_Clone(t *testing.T) {
tests := map[string]struct {
md MaterialDescription
wantClone MaterialDescription
}{
"it handles nil": {
md: nil,
wantClone: nil,
},
"it copies all values": {
md: MaterialDescription{
"key1": aws.String("value1"),
"key2": aws.String("value2"),
},
wantClone: MaterialDescription{
"key1": aws.String("value1"),
"key2": aws.String("value2"),
},
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if gotClone := tt.md.Clone(); !reflect.DeepEqual(gotClone, tt.wantClone) {
t.Errorf("Clone() = %v, want %v", gotClone, tt.wantClone)
}
})
}
}
69 changes: 63 additions & 6 deletions service/s3/s3crypto/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ func ExampleNewEncryptionClientV2_migration00() {
// Usage of NewKMSKeyGenerator (kms) key wrapping algorithm must be migrated to NewKMSContextKeyGenerator (kms+context) key wrapping algorithm
//
// cipherDataGenerator := s3crypto.NewKMSKeyGenerator(kmsClient, cmkID)
cipherDataGenerator := s3crypto.NewKMSContextKeyGenerator(kmsClient, cmkID)
cipherDataGenerator := s3crypto.NewKMSContextKeyGenerator(kmsClient, cmkID, s3crypto.MaterialDescription{})

// Usage of AESCBCContentCipherBuilder (AES/CBC/PKCS5Padding) must be migrated to AESGCMContentCipherBuilder (AES/GCM/NoPadding)
//
// contentCipherBuilder := s3crypto.AESCBCContentCipherBuilder(cipherDataGenerator, s3crypto.AESCBCPadder)
contentCipherBuilder := s3crypto.AESGCMContentCipherBuilder(cipherDataGenerator)
contentCipherBuilder := s3crypto.AESGCMContentCipherBuilderV2(cipherDataGenerator)

// Construction of an encryption client should be done using NewEncryptionClientV2
//
Expand Down Expand Up @@ -59,9 +59,9 @@ func ExampleNewEncryptionClientV2_migration01() {
kmsClient := kms.New(sess)
cmkID := "1234abcd-12ab-34cd-56ef-1234567890ab"

cipherDataGenerator := s3crypto.NewKMSContextKeyGenerator(kmsClient, cmkID)
cipherDataGenerator := s3crypto.NewKMSContextKeyGenerator(kmsClient, cmkID, s3crypto.MaterialDescription{})

contentCipherBuilder := s3crypto.AESGCMContentCipherBuilder(cipherDataGenerator)
contentCipherBuilder := s3crypto.AESGCMContentCipherBuilderV2(cipherDataGenerator)

// Overriding of the encryption client options is possible by passing in functional arguments that override the
// provided EncryptionClientOptions.
Expand Down Expand Up @@ -98,7 +98,46 @@ func ExampleNewDecryptionClientV2_migration00() {
// The V2 decryption client is able to decrypt object encrypted by the V1 client.
//
// decryptionClient := s3crypto.NewDecryptionClient(sess)
decryptionClient := s3crypto.NewDecryptionClientV2(sess)

// The V2 decryption client requires you to explicitly register the key wrap algorithms and content encryption algorithms
// that you want to explicitly support decryption for.
registry := s3crypto.NewCryptoRegistry()

kmsClient := kms.New(sess)

// If you need support for unwrapping data keys wrapped using the `kms` wrap algorithm you can use RegisterKMSWrapWithAnyCMK.
// Alternatively you may use RegisterKMSWrapWithCMK if you wish to limit KMS decrypt calls to a specific CMK.
if err := s3crypto.RegisterKMSWrapWithAnyCMK(registry, kmsClient); err != nil {
fmt.Printf("error: %v", err)
return
}

// For unwrapping data keys wrapped using the new `kms+context` key wrap algorithm you can use RegisterKMSContextWrapWithAnyCMK.
// Alternatively you may use RegisterKMSWrapWithCMK if you wish to limit KMS decrypt calls to a specific CMK.
if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(registry, kmsClient); err != nil {
fmt.Printf("error: %v", err)
return
}

// If you need to decrypt objects encrypted using the V1 AES/CBC/PCKS5Padding cipher you can do so with RegisterAESCBCContentCipher
if err := s3crypto.RegisterAESCBCContentCipher(registry, s3crypto.AESCBCPadder); err != nil {
fmt.Printf("error: %v", err)
return
}

// For decrypting objects encrypted in V1 or V2 using AES/GCM/NoPadding cipher you can do so with RegisterAESGCMContentCipher.
if err := s3crypto.RegisterAESGCMContentCipher(registry); err != nil {
fmt.Printf("error: %v", err)
return
}

// Instantiate a new decryption client, and provided the Wrap, cek, and Padder that have been registered
// with your desired algorithms.
decryptionClient, err := s3crypto.NewDecryptionClientV2(sess, registry)
if err != nil {
fmt.Printf("error: %v", err)
return
}

getObject, err := decryptionClient.GetObject(&s3.GetObjectInput{
Bucket: aws.String("your_bucket"),
Expand Down Expand Up @@ -127,9 +166,27 @@ func ExampleNewDecryptionClientV2_migration01() {
// decryptionClient := s3crypto.NewDecryptionClient(sess, func(o *s3crypto.DecryptionClient) {
// o.S3Client = s3.New(sess, &aws.Config{Region: aws.String("us-west-2")})
//})
decryptionClient := s3crypto.NewDecryptionClientV2(sess, func(o *s3crypto.DecryptionClientOptions) {
registry := s3crypto.NewCryptoRegistry()

kmsClient := kms.New(sess)
if err := s3crypto.RegisterKMSWrapWithAnyCMK(registry, kmsClient); err != nil {
fmt.Printf("error: %v", err)
return
}

// If you need to decrypt objects encrypted using AES/GCM/NoPadding cipher you can do so with RegisterAESGCMContentCipher
if err := s3crypto.RegisterAESGCMContentCipher(registry); err != nil {
fmt.Printf("error: %v", err)
return
}

decryptionClient, err := s3crypto.NewDecryptionClientV2(sess, registry, func(o *s3crypto.DecryptionClientOptions) {
o.S3Client = s3.New(sess, &aws.Config{Region: aws.String("us-west-2")})
})
if err != nil {
fmt.Printf("error: %v", err)
return
}

getObject, err := decryptionClient.GetObject(&s3.GetObjectInput{
Bucket: aws.String("your_bucket"),
Expand Down
88 changes: 75 additions & 13 deletions service/s3/s3crypto/mock_test.go
Original file line number Diff line number Diff line change
@@ -1,50 +1,93 @@
package s3crypto_test
package s3crypto

import (
"bytes"
"fmt"
"io"
"io/ioutil"

"github.com/aws/aws-sdk-go/service/s3/s3crypto"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
)

type mockGenerator struct{}

func (m mockGenerator) GenerateCipherData(keySize, ivSize int) (s3crypto.CipherData, error) {
cd := s3crypto.CipherData{
func (m mockGenerator) GenerateCipherData(keySize, ivSize int) (CipherData, error) {
cd := CipherData{
Key: make([]byte, keySize),
IV: make([]byte, ivSize),
}
return cd, nil
}

func (m mockGenerator) EncryptKey(key []byte) ([]byte, error) {
size := len(key)
b := bytes.Repeat([]byte{1}, size)
return b, nil
func (m mockGenerator) DecryptKey(key []byte) ([]byte, error) {
return make([]byte, 16), nil
}

func (m mockGenerator) DecryptKey(key []byte) ([]byte, error) {
type mockGeneratorV2 struct{}

func (m mockGeneratorV2) GenerateCipherDataWithCEKAlg(ctx aws.Context, keySize int, ivSize int, cekAlg string) (CipherData, error) {
cd := CipherData{
Key: make([]byte, keySize),
IV: make([]byte, ivSize),
}
return cd, nil
}

func (m mockGeneratorV2) DecryptKey(key []byte) ([]byte, error) {
return make([]byte, 16), nil
}

func (m mockGeneratorV2) isEncryptionVersionCompatible(version clientVersion) error {
if version != v2ClientVersion {
return fmt.Errorf("mock error about version")
}
return nil
}

type mockCipherBuilder struct {
generator s3crypto.CipherDataGenerator
generator CipherDataGenerator
}

func (builder mockCipherBuilder) ContentCipher() (s3crypto.ContentCipher, error) {
func (builder mockCipherBuilder) isEncryptionVersionCompatible(version clientVersion) error {
if version != v1ClientVersion {
return fmt.Errorf("mock error about version")
}
return nil
}

func (builder mockCipherBuilder) ContentCipher() (ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(32, 16)
if err != nil {
return nil, err
}
return &mockContentCipher{cd}, nil
}

type mockCipherBuilderV2 struct {
generator CipherDataGeneratorWithCEKAlg
}

func (builder mockCipherBuilderV2) isEncryptionVersionCompatible(version clientVersion) error {
if version != v2ClientVersion {
return fmt.Errorf("mock error about version")
}
return nil
}

func (builder mockCipherBuilderV2) ContentCipher() (ContentCipher, error) {
cd, err := builder.generator.GenerateCipherDataWithCEKAlg(aws.BackgroundContext(), 32, 16, "mock-cek-alg")
if err != nil {
return nil, err
}
return &mockContentCipher{cd}, nil
}

type mockContentCipher struct {
cd s3crypto.CipherData
cd CipherData
}

func (cipher *mockContentCipher) GetCipherData() s3crypto.CipherData {
func (cipher *mockContentCipher) GetCipherData() CipherData {
return cipher.cd
}

Expand All @@ -66,3 +109,22 @@ func (cipher *mockContentCipher) DecryptContents(src io.ReadCloser) (io.ReadClos
size := len(b)
return ioutil.NopCloser(bytes.NewReader(make([]byte, size))), nil
}

type mockKMS struct {
kmsiface.KMSAPI
}

type mockPadder struct {
}

func (m mockPadder) Pad(i []byte, i2 int) ([]byte, error) {
return i, nil
}

func (m mockPadder) Unpad(i []byte) ([]byte, error) {
return i, nil
}

func (m mockPadder) Name() string {
return "mockPadder"
}
25 changes: 16 additions & 9 deletions service/s3/s3crypto/shared_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
)

// clientConstructionErrorCode is used for operations that can't be completed due to invalid client construction
const clientConstructionErrorCode = "ClientConstructionError"

// mismatchWrapError is an error returned if a wrapping handler receives an unexpected envelope
var mismatchWrapError = awserr.New(clientConstructionErrorCode, "wrap algorithm provided did not match handler", nil)

func putObjectRequest(c EncryptionClientOptions, input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
req, out := c.S3Client.PutObjectRequest(input)

Expand Down Expand Up @@ -45,9 +51,9 @@ func putObjectRequest(c EncryptionClientOptions, input *s3.PutObjectInput) (*req
return
}

md5 := newMD5Reader(input.Body)
lengthReader := newContentLengthReader(input.Body)
sha := newSHA256Writer(dst)
reader, err := encryptor.EncryptContents(md5)
reader, err := encryptor.EncryptContents(lengthReader)
if err != nil {
r.Error = err
return
Expand All @@ -60,7 +66,7 @@ func putObjectRequest(c EncryptionClientOptions, input *s3.PutObjectInput) (*req
}

data := encryptor.GetCipherData()
env, err := encodeMeta(md5, data)
env, err := encodeMeta(lengthReader, data)
if err != nil {
r.Error = err
return
Expand Down Expand Up @@ -101,7 +107,7 @@ func getObjectRequest(options DecryptionClientOptions, input *s3.GetObjectInput)
return
}

// If KMS should return the correct CEK algorithm with the proper
// If KMS should return the correct cek algorithm with the proper
// KMS key provider
cipher, err := contentCipherFromEnvelope(options, r.Context(), env)
if err != nil {
Expand Down Expand Up @@ -143,8 +149,9 @@ func contentCipherFromEnvelope(options DecryptionClientOptions, ctx aws.Context,
}

func wrapFromEnvelope(options DecryptionClientOptions, env Envelope) (CipherDataDecrypter, error) {
f, ok := options.WrapRegistry[env.WrapAlg]
f, ok := options.CryptoRegistry.GetWrap(env.WrapAlg)
if !ok || f == nil {

return nil, awserr.New(
"InvalidWrapAlgorithmError",
"wrap algorithm isn't supported, "+env.WrapAlg,
Expand All @@ -155,7 +162,7 @@ func wrapFromEnvelope(options DecryptionClientOptions, env Envelope) (CipherData
}

func cekFromEnvelope(options DecryptionClientOptions, ctx aws.Context, env Envelope, decrypter CipherDataDecrypter) (ContentCipher, error) {
f, ok := options.CEKRegistry[env.CEKAlg]
f, ok := options.CryptoRegistry.GetCEK(env.CEKAlg)
if !ok || f == nil {
return nil, awserr.New(
"InvalidCEKAlgorithmError",
Expand Down Expand Up @@ -197,11 +204,11 @@ func cekFromEnvelope(options DecryptionClientOptions, ctx aws.Context, env Envel
// If there wasn't a cek algorithm specific padder, we check the padder itself.
// We return a no unpadder, if no unpadder was found. This means any customization
// either contained padding within the cipher implementation, and to maintain
// backwards compatility we will simply not unpad anything.
// backwards compatibility we will simply not unpad anything.
func getPadder(options DecryptionClientOptions, cekAlg string) Padder {
padder, ok := options.PadderRegistry[cekAlg]
padder, ok := options.CryptoRegistry.GetPadder(cekAlg)
if !ok {
padder, ok = options.PadderRegistry[cekAlg[strings.LastIndex(cekAlg, "/")+1:]]
padder, ok = options.CryptoRegistry.GetPadder(cekAlg[strings.LastIndex(cekAlg, "/")+1:])
if !ok {
return NoPadder
}
Expand Down
1 change: 0 additions & 1 deletion service/s3/s3crypto/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ func (load HeaderV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
env.WrapAlg = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, wrapAlgorithmHeader}, "-"))
env.CEKAlg = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, cekAlgorithmHeader}, "-"))
env.TagLen = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, tagLengthHeader}, "-"))
env.UnencryptedMD5 = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, unencryptedMD5Header}, "-"))
env.UnencryptedContentLen = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, unencryptedContentLengthHeader}, "-"))
return env, nil
}
Expand Down