Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion aescbc/aescbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,18 @@ func Encrypt(random io.Reader, key []byte, plaintext []byte) ([]byte, error) {

// PKCS7 unpadding
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
return nil, fmt.Errorf("invalid padding")
}
padding := int(data[len(data)-1])
if padding < 1 || padding > aes.BlockSize {
if padding < 1 || padding > aes.BlockSize || padding > len(data) {
return nil, fmt.Errorf("invalid padding")
}
for _, b := range data[len(data)-padding:] {
if b != byte(padding) {
return nil, fmt.Errorf("invalid padding")
}
}
return data[:len(data)-padding], nil
}

Expand Down
6 changes: 6 additions & 0 deletions aesgcm/aesgcm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ import (

// Encrypt encrypts plaintext using key and random entropy. Key must be a valid AES-256 key with a length of 32 bytes.
// The result is a concatenation of nonce (using standard 12-byte nonce size) and the actual ciphertext.
//
// random must be a cryptographically secure random source (e.g. crypto/rand.Reader or an NSM session).
// AES-GCM is catastrophically broken under nonce reuse.
func Encrypt(random io.Reader, key []byte, plaintext []byte, additionalData []byte) ([]byte, error) {
if random == nil {
return nil, fmt.Errorf("random source must not be nil")
}
if len(key) != 32 {
return nil, fmt.Errorf("key must be 32 bytes for AES-256 but was %d", len(key))
}
Expand Down
7 changes: 3 additions & 4 deletions attestation/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ func Middleware(enc *enclave.Enclave, errorFn func(http.ResponseWriter, error),
return context.WithValue(r.Context(), contextKey, att), cancelFunc, nil
}

runPostMiddleware := func(w http.ResponseWriter, r *http.Request, body []byte, nonce []byte) (err error) {
runPostMiddleware := func(w http.ResponseWriter, r *http.Request, reqBody []byte, resBody []byte, nonce []byte) (err error) {
log := loggerFromContextFn(r.Context())
ctx, span := tracing.Trace(r.Context(), "attestation.Middleware")
defer func() {
span.RecordError(err)
span.End()
}()

userData, err := generateUserData(r, body)
userData, err := generateUserData(r, reqBody, resBody)
if err != nil {
return err
}
Expand Down Expand Up @@ -109,8 +109,7 @@ func Middleware(enc *enclave.Enclave, errorFn func(http.ResponseWriter, error),

next.ServeHTTP(ww, r.WithContext(ctx))

r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
if err := runPostMiddleware(ww, r, body.Bytes(), nonce); err != nil {
if err := runPostMiddleware(ww, r, reqBody, body.Bytes(), nonce); err != nil {
errorFn(w, err)
return
}
Expand Down
16 changes: 2 additions & 14 deletions attestation/userdata.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package attestation

import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"net/http"
)

Expand All @@ -24,20 +22,10 @@ func (u *userData) String() string {
return fmt.Sprintf("%s/%d:%s", u.Prefix, u.Version, base64.StdEncoding.EncodeToString(u.Hash))
}

func generateUserData(r *http.Request, resBody []byte) ([]byte, error) {
func generateUserData(r *http.Request, reqBody []byte, resBody []byte) ([]byte, error) {
hasher := sha256.New()
hasher.Write([]byte(r.Method + " " + r.URL.Path + "\n"))

var reqBody []byte
var err error
if r.Body != nil {
reqBody, err = io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
hasher.Write(reqBody)
}
hasher.Write(reqBody)
hasher.Write([]byte("\n"))
hasher.Write(resBody)
hash := hasher.Sum(nil)
Expand Down
13 changes: 5 additions & 8 deletions cms/ber.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) {
for ber[offset] >= 0x80 {
tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
if offset >= berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
// jvehent 20170227: this doesn't appear to be used anywhere...
//tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
if offset >= berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
Expand All @@ -173,15 +173,15 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) {
var length int
l := ber[offset]
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
indefinite := false
if l > 0x80 {
numberOfBytes := (int)(l & 0x7F)
if numberOfBytes > 4 { // int is only guaranteed to be 32bit
return nil, 0, errors.New("ber2der: BER tag length too long")
}
if offset+numberOfBytes > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F {
return nil, 0, errors.New("ber2der: BER tag length is negative")
}
Expand All @@ -193,9 +193,6 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) {
for i := 0; i < numberOfBytes; i++ {
length = length*256 + (int)(ber[offset])
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
} else if l == 0x80 {
indefinite = true
Expand Down
4 changes: 4 additions & 0 deletions cms/cms.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func (ek *EncryptedKey) Decrypt(key *rsa.PrivateKey) ([]byte, error) {
return nil, errors.New("pkcs7: encryption algorithm parameters are malformed")
}

if len(ek.cipherText) == 0 || len(ek.cipherText)%block.BlockSize() != 0 {
return nil, fmt.Errorf("cms: ciphertext length %d is not a multiple of block size %d", len(ek.cipherText), block.BlockSize())
}

mode := cipher.NewCBCDecrypter(block, ek.iv)
plaintext := make([]byte, len(ek.cipherText))
mode.CryptBlocks(plaintext, ek.cipherText)
Expand Down
35 changes: 35 additions & 0 deletions cms/cms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ cJEGAbCDYhyjvtjBLNy7YDQ1hdmCnqMxg/5AIwUMkvTTRg+qepfboA==
}
)

func TestParse_malformedBER(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"multi-byte tag truncated", []byte{0x1F, 0x80}},
{"multi-byte tag no length", []byte{0x1F, 0x01}},
{"tag only", []byte{0x30}},
{"long-form length truncated", []byte{0x30, 0x82}},
{"long-form length partial", []byte{0x30, 0x82, 0x01}},
{"length exceeds data", []byte{0x30, 0x10, 0x00}},
{"empty input", []byte{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := cms.Parse(tt.input)
require.Error(t, err)
})
}
}

func TestDecodeCiphertextForRecipient(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
Expand All @@ -65,3 +86,17 @@ func TestDecodeCiphertextForRecipient(t *testing.T) {

require.Equal(t, plaintextKey, dataKey)
}

func TestDecryptEnvelopedKey_truncatedCiphertext(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

ciphertext, err := base64.StdEncoding.DecodeString(testCiphertextString)
require.NoError(t, err)

// Truncate by one byte to misalign the inner ciphertext off a block boundary
truncated := ciphertext[:len(ciphertext)-1]
_, err = cms.DecryptEnvelopedKey(key, truncated)
require.Error(t, err)
}
13 changes: 11 additions & 2 deletions enclave/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ func (a *Attestation) Document() []byte {

// Decrypt requests a decryption operation from KMS on ciphertext. If the key used to encrypt the
// original data is not one of allowedKeyIDs, Decrypt returns an error.
//
// allowedKeyIDs must not be empty — callers must explicitly specify which KMS keys are acceptable.
func (a *Attestation) Decrypt(ctx context.Context, ciphertext []byte, allowedKeyIDs []string) ([]byte, error) {
if len(allowedKeyIDs) == 0 {
return nil, fmt.Errorf("allowedKeyIDs must not be empty")
}

params := &kms.DecryptInput{
CiphertextBlob: ciphertext,
EncryptionAlgorithm: types.EncryptionAlgorithmSpecSymmetricDefault,
Expand All @@ -65,6 +71,9 @@ func (a *Attestation) Decrypt(ctx context.Context, ciphertext []byte, allowedKey
}

// Verify that the key used to decrypt was one of the allowed keys
if out.KeyId == nil {
return nil, fmt.Errorf("KMS response missing KeyId, cannot verify against allowed keys")
}
if keyID, ok := keyIsAllowed(out.KeyId, allowedKeyIDs); !ok {
return nil, fmt.Errorf("KMS key not allowed for this operation: %q", keyID)
}
Expand Down Expand Up @@ -117,8 +126,8 @@ func (a *Attestation) GenerateDataKey(ctx context.Context, keyID string) (*DataK
}

func keyIsAllowed(key *string, allowedKeys []string) (string, bool) {
if key == nil || len(allowedKeys) == 0 {
return "", true
if key == nil {
return "", false
Comment thread
patrislav marked this conversation as resolved.
}

for _, v := range allowedKeys {
Expand Down
72 changes: 68 additions & 4 deletions enclave/attestation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ func TestNitroAttestation_Decrypt(t *testing.T) {
doc, err := nitro.Parse(params.Recipient.AttestationDocument)
require.NoError(t, err)
assert.Equal(t, []byte("nonce"), doc.Nonce)
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint("14e8bc5fabb52876f35f122289eaabfa08885837cc7f161149c6d242596258aa")))
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint(doc.RootCertFingerprint())))
assert.NoError(t, doc.Verify())
assert.Equal(t, types.KeyEncryptionMechanismRsaesOaepSha256, params.Recipient.KeyEncryptionAlgorithm)
return &kms.DecryptOutput{CiphertextForRecipient: ciphertextForRecipient}, nil
keyID := "arn:aws:kms:us-east-1:000000000000:key/test-key-id"
return &kms.DecryptOutput{KeyId: &keyID, CiphertextForRecipient: ciphertextForRecipient}, nil
},
}

Expand All @@ -92,14 +93,77 @@ func TestNitroAttestation_Decrypt(t *testing.T) {
att, err := e.GetAttestation(context.Background(), []byte("nonce"), []byte("user-data"))
require.NoError(t, err)

plaintext, err := att.Decrypt(context.Background(), []byte("ciphertext"), nil)
plaintext, err := att.Decrypt(context.Background(), []byte("ciphertext"), []string{"arn:aws:kms:us-east-1:000000000000:key/test-key-id"})
require.NoError(t, err)
assert.Equal(t, expectedPlaintext, plaintext)
}()
}
wg.Wait()
}

func TestNitroAttestation_Decrypt_nilKeyId(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

kmsMock := &mockKMS{
decrypt: func(params *kms.DecryptInput) (*kms.DecryptOutput, error) {
return &kms.DecryptOutput{KeyId: nil, Plaintext: []byte("plaintext")}, nil
},
}

e, err := enclave.New(context.Background(), enclave.DummyProvider(nil), kmsMock, privKey)
require.NoError(t, err)

att, err := e.GetAttestation(context.Background(), []byte("nonce"), nil)
require.NoError(t, err)

_, err = att.Decrypt(context.Background(), []byte("ciphertext"), []string{"some-key"})
require.Error(t, err)
assert.ErrorContains(t, err, "KMS response missing KeyId")
}

func TestNitroAttestation_Decrypt_emptyAllowedKeys(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

kmsMock := &mockKMS{}

e, err := enclave.New(context.Background(), enclave.DummyProvider(nil), kmsMock, privKey)
require.NoError(t, err)

att, err := e.GetAttestation(context.Background(), []byte("nonce"), nil)
require.NoError(t, err)

_, err = att.Decrypt(context.Background(), []byte("ciphertext"), nil)
require.Error(t, err)
assert.ErrorContains(t, err, "allowedKeyIDs must not be empty")
}

func TestNitroAttestation_Decrypt_wrongKey(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

kmsMock := &mockKMS{
decrypt: func(params *kms.DecryptInput) (*kms.DecryptOutput, error) {
keyID := "arn:aws:kms:us-east-1:000000000000:key/other-key"
return &kms.DecryptOutput{KeyId: &keyID, Plaintext: []byte("plaintext")}, nil
},
}

e, err := enclave.New(context.Background(), enclave.DummyProvider(nil), kmsMock, privKey)
require.NoError(t, err)

att, err := e.GetAttestation(context.Background(), []byte("nonce"), nil)
require.NoError(t, err)

_, err = att.Decrypt(context.Background(), []byte("ciphertext"), []string{"arn:aws:kms:us-east-1:000000000000:key/expected-key"})
require.Error(t, err)
assert.ErrorContains(t, err, "KMS key not allowed")
}

func TestAttestation_GenerateDataKey(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
Expand All @@ -116,7 +180,7 @@ func TestAttestation_GenerateDataKey(t *testing.T) {
doc, err := nitro.Parse(params.Recipient.AttestationDocument)
require.NoError(t, err)
assert.Equal(t, []byte("nonce"), doc.Nonce)
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint("14e8bc5fabb52876f35f122289eaabfa08885837cc7f161149c6d242596258aa")))
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint(doc.RootCertFingerprint())))
assert.NoError(t, doc.Verify())
assert.Equal(t, types.KeyEncryptionMechanismRsaesOaepSha256, params.Recipient.KeyEncryptionAlgorithm)
return &kms.GenerateDataKeyOutput{
Expand Down
Loading
Loading