From 6e4d0d25a9939349439ebd50f286a1cf72571256 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Fri, 3 Apr 2026 21:02:15 +0530 Subject: [PATCH] fix(crypto): switch AES-CFB to AES-GCM and use HKDF key derivation AES-CFB provides no integrity/authentication, allowing attackers to bit-flip encrypted session tokens to alter user ID or roles without detection. AES-GCM provides authenticated encryption. Also replaces null-byte key padding with HKDF-SHA256 key derivation, ensuring full entropy utilization regardless of input key length. Fixes: C1 (Critical), H3 (High) --- internal/crypto/aes.go | 153 ++++++++++++------------------------ internal/crypto/aes_test.go | 95 ++++++++++++++++++++++ 2 files changed, 146 insertions(+), 102 deletions(-) create mode 100644 internal/crypto/aes_test.go diff --git a/internal/crypto/aes.go b/internal/crypto/aes.go index 5c8fec23..7edef295 100644 --- a/internal/crypto/aes.go +++ b/internal/crypto/aes.go @@ -1,143 +1,92 @@ package crypto import ( - "bytes" "crypto/aes" "crypto/cipher" "crypto/rand" + "crypto/sha256" "encoding/base64" "errors" "io" + + "golang.org/x/crypto/hkdf" ) -// var bytes = []byte{35, 46, 57, 24, 85, 35, 24, 74, 87, 35, 88, 98, 66, 32, 14, 0o5} +// hkdfInfo is the fixed info string used for AES key derivation. +const hkdfInfo = "authorizer-aes-key" -// const ( -// // Static key for encryption -// encryptionKey = "authorizerdev" -// ) +// deriveAESKey derives a 32-byte AES key from the provided input keying +// material using HKDF-SHA256 with a fixed info string and no salt. +func deriveAESKey(ikm string) ([]byte, error) { + reader := hkdf.New(sha256.New, []byte(ikm), nil, []byte(hkdfInfo)) + key := make([]byte, 32) + if _, err := io.ReadFull(reader, key); err != nil { + return nil, err + } + return key, nil +} -// EncryptAES method is to encrypt or hide any classified text +// EncryptAES encrypts plaintext using AES-256-GCM. The nonce is prepended to +// the ciphertext and the result is encoded as base64 RawURL. func EncryptAES(key, text string) (string, error) { - keyBytes := []byte(ensureHashKey(key)) + keyBytes, err := deriveAESKey(key) + if err != nil { + return "", err + } + block, err := aes.NewCipher(keyBytes) if err != nil { return "", err } - // The IV needs to be unique, but not secure. Therefore, it's common to - // include it at the beginning of the ciphertext. - ciphertext := make([]byte, aes.BlockSize+len(text)) - iv := ciphertext[:aes.BlockSize] - if _, err := io.ReadFull(rand.Reader, iv); err != nil { + gcm, err := cipher.NewGCM(block) + if err != nil { return "", err } - stream := cipher.NewCFBEncrypter(block, iv) - stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(text)) + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } - // Encode the ciphertext to URL-safe base64 without padding + // Seal appends the encrypted and authenticated ciphertext to nonce. + ciphertext := gcm.Seal(nonce, nonce, []byte(text), nil) return base64.RawURLEncoding.EncodeToString(ciphertext), nil } -// DecryptAES method is to extract back the encrypted text +// DecryptAES decrypts a base64 RawURL-encoded AES-256-GCM ciphertext produced +// by EncryptAES. Returns an error if authentication fails or input is malformed. func DecryptAES(key, encryptedText string) (string, error) { - keyBytes := []byte(ensureHashKey(key)) - ciphertext, err := base64.RawURLEncoding.DecodeString(encryptedText) + keyBytes, err := deriveAESKey(key) if err != nil { return "", err } - block, err := aes.NewCipher(keyBytes) + data, err := base64.RawURLEncoding.DecodeString(encryptedText) if err != nil { return "", err } - if len(ciphertext) < aes.BlockSize { - return "", errors.New("ciphertext too short") + block, err := aes.NewCipher(keyBytes) + if err != nil { + return "", err } - iv := ciphertext[:aes.BlockSize] - ciphertext = ciphertext[aes.BlockSize:] - - stream := cipher.NewCFBDecrypter(block, iv) - stream.XORKeyStream(ciphertext, ciphertext) + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } - return string(ciphertext), nil -} + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return "", errors.New("ciphertext too short") + } -// ensureHashKey ensure the key is 32 bytes long -// if short it will append 0's to the key -// if long it will truncate the key -func ensureHashKey(key string) string { - if len(key) < 32 { - return key + string(bytes.Repeat([]byte{0}, 32-len(key))) + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err } - return key[:32] -} -// EncryptAESEnv encrypts data using AES algorithm -// kept for the backward compatibility of env data encryption -// TODO: Check if this is still needed -// func EncryptAESEnv(text []byte) ([]byte, error) { -// var res []byte -// key := []byte(encryptionKey) -// c, err := aes.NewCipher(key) -// if err != nil { -// return res, err -// } - -// // gcm or Galois/Counter Mode, is a mode of operation -// // for symmetric key cryptographic block ciphers -// // - https://en.wikipedia.org/wiki/Galois/Counter_Mode -// gcm, err := cipher.NewGCM(c) -// if err != nil { -// return res, err -// } - -// // creates a new byte array the size of the nonce -// // which must be passed to Seal -// nonce := make([]byte, gcm.NonceSize()) -// // populates our nonce with a cryptographically secure -// // random sequence -// if _, err = io.ReadFull(rand.Reader, nonce); err != nil { -// return res, err -// } - -// // here we encrypt our text using the Seal function -// // Seal encrypts and authenticates plaintext, authenticates the -// // additional data and appends the result to dst, returning the updated -// // slice. The nonce must be NonceSize() bytes long and unique for all -// // time, for a given key. -// return gcm.Seal(nonce, nonce, text, nil), nil -// } - -// // DecryptAES decrypts data using AES algorithm -// // Kept for the backward compatibility of env data decryption -// // TODO: Check if this is still needed -// func DecryptAESEnv(ciphertext []byte) ([]byte, error) { -// var res []byte -// key := []byte(encryptionKey) -// c, err := aes.NewCipher(key) -// if err != nil { -// return res, err -// } - -// gcm, err := cipher.NewGCM(c) -// if err != nil { -// return res, err -// } - -// nonceSize := gcm.NonceSize() -// if len(ciphertext) < nonceSize { -// return res, err -// } - -// nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] -// plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) -// if err != nil { -// return res, err -// } - -// return plaintext, nil -// } + return string(plaintext), nil +} diff --git a/internal/crypto/aes_test.go b/internal/crypto/aes_test.go new file mode 100644 index 00000000..e8675741 --- /dev/null +++ b/internal/crypto/aes_test.go @@ -0,0 +1,95 @@ +package crypto + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptDecryptAES(t *testing.T) { + key := "test-client-secret" + plaintext := `{"sub":"user123","roles":["admin"],"nonce":"abc"}` + + encrypted, err := EncryptAES(key, plaintext) + require.NoError(t, err) + assert.NotEmpty(t, encrypted) + assert.NotEqual(t, plaintext, encrypted) + + decrypted, err := DecryptAES(key, encrypted) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestEncryptAES_DifferentNonces(t *testing.T) { + key := "test-key" + plaintext := "same-text" + + enc1, err := EncryptAES(key, plaintext) + require.NoError(t, err) + enc2, err := EncryptAES(key, plaintext) + require.NoError(t, err) + + // Same plaintext should produce different ciphertexts due to random nonce + assert.NotEqual(t, enc1, enc2) +} + +func TestDecryptAES_TamperedCiphertext(t *testing.T) { + key := "test-key" + plaintext := "sensitive-data" + + encrypted, err := EncryptAES(key, plaintext) + require.NoError(t, err) + + // Tamper with the ciphertext (flip a character in the middle) + runes := []rune(encrypted) + mid := len(runes) / 2 + if runes[mid] == 'A' { + runes[mid] = 'B' + } else { + runes[mid] = 'A' + } + tampered := string(runes) + + _, err = DecryptAES(key, tampered) + assert.Error(t, err, "GCM should detect tampered ciphertext") +} + +func TestDecryptAES_WrongKey(t *testing.T) { + encrypted, err := EncryptAES("correct-key", "secret") + require.NoError(t, err) + + _, err = DecryptAES("wrong-key", encrypted) + assert.Error(t, err, "Should fail with wrong key") +} + +func TestDecryptAES_TooShort(t *testing.T) { + _, err := DecryptAES("key", "dG9vc2hvcnQ") + assert.Error(t, err) +} + +func TestEncryptDecryptAES_ShortKey(t *testing.T) { + // Even a short key should work properly via HKDF derivation + key := "short" + plaintext := "test data" + + encrypted, err := EncryptAES(key, plaintext) + require.NoError(t, err) + + decrypted, err := DecryptAES(key, encrypted) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestEncryptDecryptAES_LongKey(t *testing.T) { + key := strings.Repeat("a", 100) + plaintext := "test data" + + encrypted, err := EncryptAES(key, plaintext) + require.NoError(t, err) + + decrypted, err := DecryptAES(key, encrypted) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +}