Skip to content

Commit

Permalink
[SPARK-43286][SQL] Updates aes_encrypt CBC mode to generate random IVs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The current implementation of AES-CBC mode called via `aes_encrypt` and `aes_decrypt` uses a key derivation function (KDF) based on OpenSSL's [EVP_BytesToKey](https://www.openssl.org/docs/man3.0/man3/EVP_BytesToKey.html). This is intended for generating keys based on passwords and OpenSSL's documents discourage its use: "Newer applications should use a more modern algorithm".

`aes_encrypt` and `aes_decrypt` should use the key directly in CBC mode, as it does for both GCM and ECB mode. The output should then be the initialization vector (IV) prepended to the ciphertext – as is done with GCM mode:
`[16-byte randomly generated IV | AES-CBC encrypted ciphertext]`

### Why are the changes needed?

We want to have the ciphertext output similar across different modes. OpenSSL's EVP_BytesToKey is effectively deprecated and their own documentation says not to use it. Instead, CBC mode will generate a random vector.

### Does this PR introduce _any_ user-facing change?

AES-CBC output generated by the previous format will be incompatible with this change. That change was recently landed and we want to land this before CBC mode is used in practice.

### How was this patch tested?

A new unit test in `DataFrameFunctionsSuite` was added to test both GCM and CBC modes. Also, a new standalone unit test suite was added in `ExpressionImplUtilsSuite` to test all the modes and various key lengths.
```
build/sbt "sql/test:testOnly org.apache.spark.sql.DataFrameFunctionsSuite"
build/sbt "sql/test:testOnly org.apache.spark.sql.catalyst.expressions.ExpressionImplUtilsSuite"
```

CBC values can be verified with `openssl enc` using the following command:
```
echo -n "[INPUT]" | openssl enc -a -e -aes-256-cbc -iv [HEX IV] -K [HEX KEY]
echo -n "Spark" | openssl enc -a -e -aes-256-cbc -iv f8c832cc9c61bac6151960a58e4edf86 -K 6162636465666768696a6b6c6d6e6f7031323334353637384142434445464748
```

Closes apache#40969 from sweisdb/SPARK-43286.

Authored-by: Steve Weis <steve.weis@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
sweisdb authored and LuciferYang committed May 10, 2023
1 parent 49cad5a commit aad8616
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 141 deletions.
5 changes: 0 additions & 5 deletions core/src/main/resources/error/error-classes.json
Expand Up @@ -1016,11 +1016,6 @@
"expects a binary value with 16, 24 or 32 bytes, but got <actualLength> bytes."
]
},
"AES_SALTED_MAGIC" : {
"message" : [
"Initial bytes from input <saltedMagic> do not match 'Salted__' (0x53616C7465645F5F)."
]
},
"PATTERN" : {
"message" : [
"<value>."
Expand Down
Expand Up @@ -26,27 +26,54 @@
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.security.spec.AlgorithmParameterSpec;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
* An utility class for constructing expressions.
* A utility class for constructing expressions.
*/
public class ExpressionImplUtils {
private static final SecureRandom secureRandom = new SecureRandom();
private static final ThreadLocal<SecureRandom> threadLocalSecureRandom =
ThreadLocal.withInitial(SecureRandom::new);

private static final int GCM_IV_LEN = 12;
private static final int GCM_TAG_LEN = 128;

private static final int CBC_IV_LEN = 16;
private static final int CBC_SALT_LEN = 8;
/** OpenSSL's magic initial bytes. */
private static final String SALTED_STR = "Salted__";
private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII);

enum CipherMode {
ECB("ECB", 0, 0, "AES/ECB/PKCS5Padding", false),
CBC("CBC", CBC_IV_LEN, 0, "AES/CBC/PKCS5Padding", true),
GCM("GCM", GCM_IV_LEN, GCM_TAG_LEN, "AES/GCM/NoPadding", true);

private final String name;
final int ivLength;
final int tagLength;
final String transformation;
final boolean usesSpec;

CipherMode(String name, int ivLen, int tagLen, String transformation, boolean usesSpec) {
this.name = name;
this.ivLength = ivLen;
this.tagLength = tagLen;
this.transformation = transformation;
this.usesSpec = usesSpec;
}

static CipherMode fromString(String modeName, String padding) {
if (modeName.equalsIgnoreCase(ECB.name) &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
return ECB;
} else if (modeName.equalsIgnoreCase(CBC.name) &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
return CBC;
} else if (modeName.equalsIgnoreCase(GCM.name) &&
(padding.equalsIgnoreCase("NONE") || padding.equalsIgnoreCase("DEFAULT"))) {
return GCM;
}
throw QueryExecutionErrors.aesModeUnsupportedError(modeName, padding);
}
}

/**
* Function to check if a given number string is a valid Luhn number
Expand Down Expand Up @@ -85,113 +112,73 @@ public static byte[] aesDecrypt(byte[] input, byte[] key, UTF8String mode, UTF8S
return aesInternal(input, key, mode.toString(), padding.toString(), Cipher.DECRYPT_MODE);
}

private static SecretKeySpec getSecretKeySpec(byte[] key) {
switch (key.length) {
case 16: case 24: case 32:
return new SecretKeySpec(key, 0, key.length, "AES");
default:
throw QueryExecutionErrors.invalidAesKeyLengthError(key.length);
}
}

private static byte[] generateIv(CipherMode mode) {
byte[] iv = new byte[mode.ivLength];
threadLocalSecureRandom.get().nextBytes(iv);
return iv;
}

private static AlgorithmParameterSpec getParamSpec(CipherMode mode, byte[] input) {
switch (mode) {
case CBC:
return new IvParameterSpec(input, 0, mode.ivLength);
case GCM:
return new GCMParameterSpec(mode.tagLength, input, 0, mode.ivLength);
default:
return null;
}
}

private static byte[] aesInternal(
byte[] input,
byte[] key,
String mode,
String padding,
int opmode) {
SecretKeySpec secretKey;

switch (key.length) {
case 16:
case 24:
case 32:
secretKey = new SecretKeySpec(key, 0, key.length, "AES");
break;
default:
throw QueryExecutionErrors.invalidAesKeyLengthError(key.length);
}

try {
if (mode.equalsIgnoreCase("ECB") &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
cipher.init(opmode, secretKey);
return cipher.doFinal(input, 0, input.length);
} else if (mode.equalsIgnoreCase("GCM") &&
(padding.equalsIgnoreCase("NONE") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
if (opmode == Cipher.ENCRYPT_MODE) {
byte[] iv = new byte[GCM_IV_LEN];
secureRandom.nextBytes(iv);
GCMParameterSpec parameterSpec = new GCMParameterSpec(GCM_TAG_LEN, iv);
cipher.init(Cipher.ENCRYPT_MODE, secretKey, parameterSpec);
byte[] encrypted = cipher.doFinal(input, 0, input.length);
SecretKeySpec secretKey = getSecretKeySpec(key);
CipherMode cipherMode = CipherMode.fromString(mode, padding);
Cipher cipher = Cipher.getInstance(cipherMode.transformation);
if (opmode == Cipher.ENCRYPT_MODE) {
// This IV will be 0-length for ECB
byte[] iv = generateIv(cipherMode);
if (cipherMode.usesSpec) {
AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, iv);
cipher.init(opmode, secretKey, algSpec);
} else {
cipher.init(opmode, secretKey);
}
byte[] encrypted = cipher.doFinal(input, 0, input.length);
if (iv.length > 0) {
ByteBuffer byteBuffer = ByteBuffer.allocate(iv.length + encrypted.length);
byteBuffer.put(iv);
byteBuffer.put(encrypted);
return byteBuffer.array();
} else {
assert(opmode == Cipher.DECRYPT_MODE);
GCMParameterSpec parameterSpec = new GCMParameterSpec(GCM_TAG_LEN, input, 0, GCM_IV_LEN);
cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec);
return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN);
return encrypted;
}
} else if (mode.equalsIgnoreCase("CBC") &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
if (opmode == Cipher.ENCRYPT_MODE) {
byte[] salt = new byte[CBC_SALT_LEN];
secureRandom.nextBytes(salt);
final byte[] keyAndIv = getKeyAndIv(key, salt);
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
cipher.init(
Cipher.ENCRYPT_MODE,
new SecretKeySpec(keyValue, "AES"),
new IvParameterSpec(iv));
byte[] encrypted = cipher.doFinal(input, 0, input.length);
ByteBuffer byteBuffer = ByteBuffer.allocate(
SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length);
byteBuffer.put(SALTED_MAGIC);
byteBuffer.put(salt);
byteBuffer.put(encrypted);
return byteBuffer.array();
} else {
assert(opmode == Cipher.DECRYPT_MODE);
if (cipherMode.usesSpec) {
AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, input);
cipher.init(opmode, secretKey, algSpec);
return cipher.doFinal(input, cipherMode.ivLength, input.length - cipherMode.ivLength);
} else {
assert(opmode == Cipher.DECRYPT_MODE);
final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, SALTED_MAGIC.length);
if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) {
throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic);
}
final byte[] salt = Arrays.copyOfRange(
input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN);
final byte[] keyAndIv = getKeyAndIv(key, salt);
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
cipher.init(
Cipher.DECRYPT_MODE,
new SecretKeySpec(keyValue, "AES"),
new IvParameterSpec(iv, 0, CBC_IV_LEN));
return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN);
cipher.init(opmode, secretKey);
return cipher.doFinal(input, 0, input.length);
}
} else {
throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
}
} catch (GeneralSecurityException e) {
throw QueryExecutionErrors.aesCryptoError(e.getMessage());
}
}

// Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey
// since the version 1.1.0c which switched to SHA-256 as the hash.
private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws NoSuchAlgorithmException {
final byte[] keyAndSalt = arrConcat(key, salt);
byte[] hash = new byte[0];
byte[] keyAndIv = new byte[0];
for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) {
final byte[] hashData = arrConcat(hash, keyAndSalt);
final MessageDigest md = MessageDigest.getInstance("SHA-256");
hash = md.digest(hashData);
keyAndIv = arrConcat(keyAndIv, hash);
}
return keyAndIv;
}

private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) {
final byte[] res = new byte[arr1.length + arr2.length];
System.arraycopy(arr1, 0, res, 0, arr1.length);
System.arraycopy(arr2, 0, res, arr1.length, arr2.length);
return res;
}
}
Expand Up @@ -334,7 +334,7 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
> SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
3lmwu+Mw0H3fi5NDvcu9lg==
> SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 'DEFAULT'));
U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM=
2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo=
""",
since = "3.3.0",
group = "misc_funcs")
Expand Down Expand Up @@ -399,7 +399,7 @@ case class AesEncrypt(
Spark SQL
> SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), '1234567890abcdef', 'ECB', 'PKCS');
Spark SQL
> SELECT _FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), '1234567890abcdef', 'CBC');
> SELECT _FUNC_(unbase64('2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo='), '1234567890abcdef', 'CBC');
Apache Spark
""",
since = "3.3.0",
Expand Down
Expand Up @@ -2656,15 +2656,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
"detailMessage" -> detailMessage))
}

def aesInvalidSalt(saltedMagic: Array[Byte]): RuntimeException = {
new SparkRuntimeException(
errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
messageParameters = Map(
"parameter" -> toSQLId("expr"),
"functionName" -> toSQLId("aes_decrypt"),
"saltedMagic" -> saltedMagic.map("%02X" format _).mkString("0x", "", "")))
}

def hiveTableWithAnsiIntervalsError(tableName: String): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_2276",
Expand Down
Expand Up @@ -344,6 +344,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}

test("misc aes function") {
val key32 = "abcdefghijklmnop12345678ABCDEFGH"
val encryptedEcb = "9J3iZbIxnmaG+OIA9Amd+A=="
val encryptedGcm = "y5la3muiuxN2suj6VsYXB+0XUFjtrUD0/zv5eDafsA3U"
val encryptedCbc = "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93KvhY="
val df1 = Seq("Spark").toDF

// Successful decryption of fixed values
Seq(
(key32, encryptedEcb, "ECB"),
(key32, encryptedGcm, "GCM"),
(key32, encryptedCbc, "CBC")).foreach {
case (key, encryptedText, mode) =>
checkAnswer(
df1.selectExpr(
s"cast(aes_decrypt(unbase64('$encryptedText'), '$key', '$mode') as string)"),
Seq(Row("Spark")))
checkAnswer(
df1.selectExpr(
s"cast(aes_decrypt(unbase64('$encryptedText'), binary('$key'), '$mode') as string)"),
Seq(Row("Spark")))
}
}

test("misc aes ECB function") {
val key16 = "abcdefghijklmnop"
val key24 = "abcdefghijklmnop12345678"
val key32 = "abcdefghijklmnop12345678ABCDEFGH"
Expand All @@ -358,15 +382,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {

// Successful encryption
Seq(
(key16, encryptedText16, encryptedEmptyText16),
(key24, encryptedText24, encryptedEmptyText24),
(key32, encryptedText32, encryptedEmptyText32)).foreach {
case (key, encryptedText, encryptedEmptyText) =>
(key16, encryptedText16, encryptedEmptyText16, "ECB"),
(key24, encryptedText24, encryptedEmptyText24, "ECB"),
(key32, encryptedText32, encryptedEmptyText32, "ECB")).foreach {
case (key, encryptedText, encryptedEmptyText, mode) =>
checkAnswer(
df1.selectExpr(s"base64(aes_encrypt(value, '$key', 'ECB'))"),
df1.selectExpr(s"base64(aes_encrypt(value, '$key', '$mode'))"),
Seq(Row(encryptedText), Row(encryptedEmptyText)))
checkAnswer(
df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', 'ECB'))"),
df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', '$mode'))"),
Seq(Row(encryptedText), Row(encryptedEmptyText)))
}

Expand Down

0 comments on commit aad8616

Please sign in to comment.