Skip to content

Commit

Permalink
[SPARK-43286][SQL] Updates aes_encrypt CBC mode to generate random IVs
Browse files Browse the repository at this point in the history
  • Loading branch information
sweisdb committed May 2, 2023
1 parent 8f24a7f commit cd33dd5
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 108 deletions.
Expand Up @@ -26,27 +26,54 @@
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.security.spec.AlgorithmParameterSpec;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
* An utility class for constructing expressions.
* A utility class for constructing expressions.
*/
public class ExpressionImplUtils {
private static final SecureRandom secureRandom = new SecureRandom();
private static final ThreadLocal<SecureRandom> threadLocalSecureRandom =
ThreadLocal.withInitial(SecureRandom::new);

private static final int GCM_IV_LEN = 12;
private static final int GCM_TAG_LEN = 128;

private static final int CBC_IV_LEN = 16;
private static final int CBC_SALT_LEN = 8;
/** OpenSSL's magic initial bytes. */
private static final String SALTED_STR = "Salted__";
private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII);

enum CipherMode {
ECB("ECB", 0, 0, "AES/ECB/PKCS5Padding", false),
CBC("CBC", CBC_IV_LEN, 0, "AES/CBC/PKCS5Padding", true),
GCM("GCM", GCM_IV_LEN, GCM_TAG_LEN, "AES/GCM/NoPadding", true);

private final String name;
final int ivLength;
final int tagLength;
final String transformation;
final boolean usesSpec;

CipherMode(String name, int ivLen, int tagLen, String transformation, boolean usesSpec) {
this.name = name;
this.ivLength = ivLen;
this.tagLength = tagLen;
this.transformation = transformation;
this.usesSpec = usesSpec;
}

static CipherMode fromString(String modeName, String padding) {
if (modeName.equalsIgnoreCase(ECB.name) &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
return ECB;
} else if (modeName.equalsIgnoreCase(CBC.name) &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
return CBC;
} else if (modeName.equalsIgnoreCase(GCM.name) &&
(padding.equalsIgnoreCase("NONE") || padding.equalsIgnoreCase("DEFAULT"))) {
return GCM;
}
throw QueryExecutionErrors.aesModeUnsupportedError(modeName, padding);
}
}

/**
* Function to check if a given number string is a valid Luhn number
Expand Down Expand Up @@ -85,113 +112,73 @@ public static byte[] aesDecrypt(byte[] input, byte[] key, UTF8String mode, UTF8S
return aesInternal(input, key, mode.toString(), padding.toString(), Cipher.DECRYPT_MODE);
}

private static SecretKeySpec getSecretKeySpec(byte[] key) {
switch (key.length) {
case 16: case 24: case 32:
return new SecretKeySpec(key, 0, key.length, "AES");
default:
throw QueryExecutionErrors.invalidAesKeyLengthError(key.length);
}
}

private static byte[] generateIv(CipherMode mode) {
byte[] iv = new byte[mode.ivLength];
threadLocalSecureRandom.get().nextBytes(iv);
return iv;
}

private static AlgorithmParameterSpec getParamSpec(CipherMode mode, byte[] input) {
switch (mode) {
case CBC:
return new IvParameterSpec(input, 0, mode.ivLength);
case GCM:
return new GCMParameterSpec(mode.tagLength, input, 0, mode.ivLength);
default:
return null;
}
}

private static byte[] aesInternal(
byte[] input,
byte[] key,
String mode,
String padding,
int opmode) {
SecretKeySpec secretKey;

switch (key.length) {
case 16:
case 24:
case 32:
secretKey = new SecretKeySpec(key, 0, key.length, "AES");
break;
default:
throw QueryExecutionErrors.invalidAesKeyLengthError(key.length);
}

try {
if (mode.equalsIgnoreCase("ECB") &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
cipher.init(opmode, secretKey);
return cipher.doFinal(input, 0, input.length);
} else if (mode.equalsIgnoreCase("GCM") &&
(padding.equalsIgnoreCase("NONE") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
if (opmode == Cipher.ENCRYPT_MODE) {
byte[] iv = new byte[GCM_IV_LEN];
secureRandom.nextBytes(iv);
GCMParameterSpec parameterSpec = new GCMParameterSpec(GCM_TAG_LEN, iv);
cipher.init(Cipher.ENCRYPT_MODE, secretKey, parameterSpec);
byte[] encrypted = cipher.doFinal(input, 0, input.length);
SecretKeySpec secretKey = getSecretKeySpec(key);
CipherMode cipherMode = CipherMode.fromString(mode, padding);
Cipher cipher = Cipher.getInstance(cipherMode.transformation);
if (opmode == Cipher.ENCRYPT_MODE) {
// This IV will be 0-length for ECB
byte[] iv = generateIv(cipherMode);
if (cipherMode.usesSpec) {
AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, iv);
cipher.init(opmode, secretKey, algSpec);
} else {
cipher.init(opmode, secretKey);
}
byte[] encrypted = cipher.doFinal(input, 0, input.length);
if (iv.length > 0) {
ByteBuffer byteBuffer = ByteBuffer.allocate(iv.length + encrypted.length);
byteBuffer.put(iv);
byteBuffer.put(encrypted);
return byteBuffer.array();
} else {
assert(opmode == Cipher.DECRYPT_MODE);
GCMParameterSpec parameterSpec = new GCMParameterSpec(GCM_TAG_LEN, input, 0, GCM_IV_LEN);
cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec);
return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN);
return encrypted;
}
} else if (mode.equalsIgnoreCase("CBC") &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
if (opmode == Cipher.ENCRYPT_MODE) {
byte[] salt = new byte[CBC_SALT_LEN];
secureRandom.nextBytes(salt);
final byte[] keyAndIv = getKeyAndIv(key, salt);
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
cipher.init(
Cipher.ENCRYPT_MODE,
new SecretKeySpec(keyValue, "AES"),
new IvParameterSpec(iv));
byte[] encrypted = cipher.doFinal(input, 0, input.length);
ByteBuffer byteBuffer = ByteBuffer.allocate(
SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length);
byteBuffer.put(SALTED_MAGIC);
byteBuffer.put(salt);
byteBuffer.put(encrypted);
return byteBuffer.array();
} else {
assert(opmode == Cipher.DECRYPT_MODE);
if (cipherMode.usesSpec) {
AlgorithmParameterSpec algSpec = getParamSpec(cipherMode, input);
cipher.init(opmode, secretKey, algSpec);
return cipher.doFinal(input, cipherMode.ivLength, input.length - cipherMode.ivLength);
} else {
assert(opmode == Cipher.DECRYPT_MODE);
final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, SALTED_MAGIC.length);
if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) {
throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic);
}
final byte[] salt = Arrays.copyOfRange(
input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN);
final byte[] keyAndIv = getKeyAndIv(key, salt);
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
cipher.init(
Cipher.DECRYPT_MODE,
new SecretKeySpec(keyValue, "AES"),
new IvParameterSpec(iv, 0, CBC_IV_LEN));
return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN);
cipher.init(opmode, secretKey);
return cipher.doFinal(input, 0, input.length);
}
} else {
throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
}
} catch (GeneralSecurityException e) {
throw QueryExecutionErrors.aesCryptoError(e.getMessage());
}
}

// Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey
// since the version 1.1.0c which switched to SHA-256 as the hash.
private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws NoSuchAlgorithmException {
final byte[] keyAndSalt = arrConcat(key, salt);
byte[] hash = new byte[0];
byte[] keyAndIv = new byte[0];
for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) {
final byte[] hashData = arrConcat(hash, keyAndSalt);
final MessageDigest md = MessageDigest.getInstance("SHA-256");
hash = md.digest(hashData);
keyAndIv = arrConcat(keyAndIv, hash);
}
return keyAndIv;
}

private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) {
final byte[] res = new byte[arr1.length + arr2.length];
System.arraycopy(arr1, 0, res, 0, arr1.length);
System.arraycopy(arr2, 0, res, arr1.length, arr2.length);
return res;
}
}
Expand Up @@ -334,7 +334,7 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
> SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
3lmwu+Mw0H3fi5NDvcu9lg==
> SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 'DEFAULT'));
U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM=
2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo=
""",
since = "3.3.0",
group = "misc_funcs")
Expand Down Expand Up @@ -399,7 +399,7 @@ case class AesEncrypt(
Spark SQL
> SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), '1234567890abcdef', 'ECB', 'PKCS');
Spark SQL
> SELECT _FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), '1234567890abcdef', 'CBC');
> SELECT _FUNC_(unbase64('2NYmDCjgXTbbxGA3/SnJEfFC/JQ7olk2VQWReIAAFKo='), '1234567890abcdef', 'CBC');
Apache Spark
""",
since = "3.3.0",
Expand Down
Expand Up @@ -344,6 +344,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}

test("misc aes function") {
val key32 = "abcdefghijklmnop12345678ABCDEFGH"
val encryptedEcb = "9J3iZbIxnmaG+OIA9Amd+A=="
val encryptedGcm = "y5la3muiuxN2suj6VsYXB+0XUFjtrUD0/zv5eDafsA3U"
val encryptedCbc = "+MgyzJxhusYVGWCljk7fhhl6C6oUqWmtdqoaG93KvhY="
val df1 = Seq("Spark").toDF

// Successful decryption of fixed values
Seq(
(key32, encryptedEcb, "ECB"),
(key32, encryptedGcm, "GCM"),
(key32, encryptedCbc, "CBC")).foreach {
case (key, encryptedText, mode) =>
checkAnswer(
df1.selectExpr(
s"cast(aes_decrypt(unbase64('$encryptedText'), '$key', '$mode') as string)"),
Seq(Row("Spark")))
checkAnswer(
df1.selectExpr(
s"cast(aes_decrypt(unbase64('$encryptedText'), binary('$key'), '$mode') as string)"),
Seq(Row("Spark")))
}
}

test("misc aes ECB function") {
val key16 = "abcdefghijklmnop"
val key24 = "abcdefghijklmnop12345678"
val key32 = "abcdefghijklmnop12345678ABCDEFGH"
Expand All @@ -358,15 +382,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {

// Successful encryption
Seq(
(key16, encryptedText16, encryptedEmptyText16),
(key24, encryptedText24, encryptedEmptyText24),
(key32, encryptedText32, encryptedEmptyText32)).foreach {
case (key, encryptedText, encryptedEmptyText) =>
(key16, encryptedText16, encryptedEmptyText16, "ECB"),
(key24, encryptedText24, encryptedEmptyText24, "ECB"),
(key32, encryptedText32, encryptedEmptyText32, "ECB")).foreach {
case (key, encryptedText, encryptedEmptyText, mode) =>
checkAnswer(
df1.selectExpr(s"base64(aes_encrypt(value, '$key', 'ECB'))"),
df1.selectExpr(s"base64(aes_encrypt(value, '$key', '$mode'))"),
Seq(Row(encryptedText), Row(encryptedEmptyText)))
checkAnswer(
df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', 'ECB'))"),
df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', '$mode'))"),
Seq(Row(encryptedText), Row(encryptedEmptyText)))
}

Expand Down

0 comments on commit cd33dd5

Please sign in to comment.