Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 51 additions & 102 deletions internal/crypto/aes.go
Original file line number Diff line number Diff line change
@@ -1,143 +1,92 @@
package crypto

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"io"

"golang.org/x/crypto/hkdf"
)

// var bytes = []byte{35, 46, 57, 24, 85, 35, 24, 74, 87, 35, 88, 98, 66, 32, 14, 0o5}
// hkdfInfo is the fixed info string used for AES key derivation.
const hkdfInfo = "authorizer-aes-key"

// const (
// // Static key for encryption
// encryptionKey = "authorizerdev"
// )
// deriveAESKey derives a 32-byte AES key from the provided input keying
// material using HKDF-SHA256 with a fixed info string and no salt.
func deriveAESKey(ikm string) ([]byte, error) {
reader := hkdf.New(sha256.New, []byte(ikm), nil, []byte(hkdfInfo))
key := make([]byte, 32)
if _, err := io.ReadFull(reader, key); err != nil {
return nil, err
}
return key, nil
}

// EncryptAES method is to encrypt or hide any classified text
// EncryptAES encrypts plaintext using AES-256-GCM. The nonce is prepended to
// the ciphertext and the result is encoded as base64 RawURL.
func EncryptAES(key, text string) (string, error) {
keyBytes := []byte(ensureHashKey(key))
keyBytes, err := deriveAESKey(key)
if err != nil {
return "", err
}

block, err := aes.NewCipher(keyBytes)
if err != nil {
return "", err
}

// The IV needs to be unique, but not secure. Therefore, it's common to
// include it at the beginning of the ciphertext.
ciphertext := make([]byte, aes.BlockSize+len(text))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}

stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(text))
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}

// Encode the ciphertext to URL-safe base64 without padding
// Seal appends the encrypted and authenticated ciphertext to nonce.
ciphertext := gcm.Seal(nonce, nonce, []byte(text), nil)
return base64.RawURLEncoding.EncodeToString(ciphertext), nil
}

// DecryptAES method is to extract back the encrypted text
// DecryptAES decrypts a base64 RawURL-encoded AES-256-GCM ciphertext produced
// by EncryptAES. Returns an error if authentication fails or input is malformed.
func DecryptAES(key, encryptedText string) (string, error) {
keyBytes := []byte(ensureHashKey(key))
ciphertext, err := base64.RawURLEncoding.DecodeString(encryptedText)
keyBytes, err := deriveAESKey(key)
if err != nil {
return "", err
}

block, err := aes.NewCipher(keyBytes)
data, err := base64.RawURLEncoding.DecodeString(encryptedText)
if err != nil {
return "", err
}

if len(ciphertext) < aes.BlockSize {
return "", errors.New("ciphertext too short")
block, err := aes.NewCipher(keyBytes)
if err != nil {
return "", err
}

iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]

stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}

return string(ciphertext), nil
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}

// ensureHashKey ensure the key is 32 bytes long
// if short it will append 0's to the key
// if long it will truncate the key
func ensureHashKey(key string) string {
if len(key) < 32 {
return key + string(bytes.Repeat([]byte{0}, 32-len(key)))
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
}
return key[:32]
}

// EncryptAESEnv encrypts data using AES algorithm
// kept for the backward compatibility of env data encryption
// TODO: Check if this is still needed
// func EncryptAESEnv(text []byte) ([]byte, error) {
// var res []byte
// key := []byte(encryptionKey)
// c, err := aes.NewCipher(key)
// if err != nil {
// return res, err
// }

// // gcm or Galois/Counter Mode, is a mode of operation
// // for symmetric key cryptographic block ciphers
// // - https://en.wikipedia.org/wiki/Galois/Counter_Mode
// gcm, err := cipher.NewGCM(c)
// if err != nil {
// return res, err
// }

// // creates a new byte array the size of the nonce
// // which must be passed to Seal
// nonce := make([]byte, gcm.NonceSize())
// // populates our nonce with a cryptographically secure
// // random sequence
// if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
// return res, err
// }

// // here we encrypt our text using the Seal function
// // Seal encrypts and authenticates plaintext, authenticates the
// // additional data and appends the result to dst, returning the updated
// // slice. The nonce must be NonceSize() bytes long and unique for all
// // time, for a given key.
// return gcm.Seal(nonce, nonce, text, nil), nil
// }

// // DecryptAES decrypts data using AES algorithm
// // Kept for the backward compatibility of env data decryption
// // TODO: Check if this is still needed
// func DecryptAESEnv(ciphertext []byte) ([]byte, error) {
// var res []byte
// key := []byte(encryptionKey)
// c, err := aes.NewCipher(key)
// if err != nil {
// return res, err
// }

// gcm, err := cipher.NewGCM(c)
// if err != nil {
// return res, err
// }

// nonceSize := gcm.NonceSize()
// if len(ciphertext) < nonceSize {
// return res, err
// }

// nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
// plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
// if err != nil {
// return res, err
// }

// return plaintext, nil
// }
return string(plaintext), nil
}
95 changes: 95 additions & 0 deletions internal/crypto/aes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package crypto

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestEncryptDecryptAES(t *testing.T) {
key := "test-client-secret"
plaintext := `{"sub":"user123","roles":["admin"],"nonce":"abc"}`

encrypted, err := EncryptAES(key, plaintext)
require.NoError(t, err)
assert.NotEmpty(t, encrypted)
assert.NotEqual(t, plaintext, encrypted)

decrypted, err := DecryptAES(key, encrypted)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}

func TestEncryptAES_DifferentNonces(t *testing.T) {
key := "test-key"
plaintext := "same-text"

enc1, err := EncryptAES(key, plaintext)
require.NoError(t, err)
enc2, err := EncryptAES(key, plaintext)
require.NoError(t, err)

// Same plaintext should produce different ciphertexts due to random nonce
assert.NotEqual(t, enc1, enc2)
}

func TestDecryptAES_TamperedCiphertext(t *testing.T) {
key := "test-key"
plaintext := "sensitive-data"

encrypted, err := EncryptAES(key, plaintext)
require.NoError(t, err)

// Tamper with the ciphertext (flip a character in the middle)
runes := []rune(encrypted)
mid := len(runes) / 2
if runes[mid] == 'A' {
runes[mid] = 'B'
} else {
runes[mid] = 'A'
}
tampered := string(runes)

_, err = DecryptAES(key, tampered)
assert.Error(t, err, "GCM should detect tampered ciphertext")
}

func TestDecryptAES_WrongKey(t *testing.T) {
encrypted, err := EncryptAES("correct-key", "secret")
require.NoError(t, err)

_, err = DecryptAES("wrong-key", encrypted)
assert.Error(t, err, "Should fail with wrong key")
}

func TestDecryptAES_TooShort(t *testing.T) {
_, err := DecryptAES("key", "dG9vc2hvcnQ")
assert.Error(t, err)
}

func TestEncryptDecryptAES_ShortKey(t *testing.T) {
// Even a short key should work properly via HKDF derivation
key := "short"
plaintext := "test data"

encrypted, err := EncryptAES(key, plaintext)
require.NoError(t, err)

decrypted, err := DecryptAES(key, encrypted)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}

func TestEncryptDecryptAES_LongKey(t *testing.T) {
key := strings.Repeat("a", 100)
plaintext := "test data"

encrypted, err := EncryptAES(key, plaintext)
require.NoError(t, err)

decrypted, err := DecryptAES(key, encrypted)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}