Skip to content

Commit

Permalink
[SPARK-43038][SQL] Support the CBC mode by aes_encrypt()/`aes_decry…
Browse files Browse the repository at this point in the history
…pt()`

### What changes were proposed in this pull request?
In the PR, I propose new AES mode for the `aes_encrypt()`/`aes_decrypt()` functions - `CBC` ([Cipher Block Chaining](https://www.ibm.com/docs/en/linux-on-systems?topic=operation-cipher-block-chaining-cbc-mode)) with the padding `PKCS7(5)`. The `aes_encrypt()` function returns a binary value which consists of the following fields:
1. The salt magic prefix `Salted__` with the length of 8 bytes.
2. A salt generated per every `aes_encrypt()` call using `java.security.SecureRandom`. Its length is 8 bytes.
3. The encrypted input.

The encrypt function derives the secret key and initialization vector (16 bytes) from the salt and user's key using the same algorithm as OpenSSL's `EVP_BytesToKey()` (versions >= 1.1.0c).

The `aes_decrypt()` functions assumes that its input has the fields as showed above.

For example:
```sql
spark-sql> SELECT base64(aes_encrypt('Apache Spark', '0000111122223333', 'CBC', 'PKCS'));
U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk=
spark-sql> SELECT aes_decrypt(unbase64('U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='), '0000111122223333', 'CBC', 'PKCS');
Apache Spark
```

### Why are the changes needed?
To achieve feature parity with other systems/frameworks, and make the migration process from them to Spark SQL easier. For example, the `CBC` mode is supported by:
- BigQuery: https://cloud.google.com/bigquery/docs/reference/standard-sql/aead-encryption-concepts#block_cipher_modes
- Snowflake: https://docs.snowflake.com/en/sql-reference/functions/encrypt.html

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By running new checks:
```
$ build/sbt "sql/testOnly *QueryExecutionErrorsSuite"
$ build/sbt "sql/test:testOnly org.apache.spark.sql.expressions.ExpressionInfoSuite"
$ build/sbt "test:testOnly org.apache.spark.sql.MiscFunctionsSuite"
$ build/sbt "core/testOnly *SparkThrowableSuite"
```
and checked compatibility with LibreSSL/OpenSSL:
```
$ openssl version
LibreSSL 3.3.6
$ echo -n 'Apache Spark' | openssl enc -e -aes-128-cbc -pass pass:0000111122223333 -a
U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls=
```
```sql
spark-sql (default)> SELECT aes_decrypt(unbase64('U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls='), '0000111122223333', 'CBC');
Apache Spark
```
decrypt Spark's output by OpenSSL:
```sql
spark-sql (default)> SELECT base64(aes_encrypt('Apache Spark', 'abcdefghijklmnop12345678ABCDEFGH', 'CBC', 'PKCS'));
U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA=
```
```
$ echo 'U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA=' | openssl aes-256-cbc -a -d -pass pass:abcdefghijklmnop12345678ABCDEFGH
Apache Spark
```

Closes #40704 from MaxGekk/aes-cbc.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Apr 12, 2023
1 parent 74d840c commit dabd771
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 25 deletions.
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Expand Up @@ -978,6 +978,11 @@
"expects a binary value with 16, 24 or 32 bytes, but got <actualLength> bytes."
]
},
"AES_SALTED_MAGIC" : {
"message" : [
"Initial bytes from input <saltedMagic> do not match 'Salted__' (0x53616C7465645F5F)."
]
},
"PATTERN" : {
"message" : [
"<value>."
Expand Down
Expand Up @@ -22,10 +22,16 @@

import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
* An utility class for constructing expressions.
Expand All @@ -35,6 +41,13 @@ public class ExpressionImplUtils {
private static final int GCM_IV_LEN = 12;
private static final int GCM_TAG_LEN = 128;

private static final int CBC_IV_LEN = 16;
private static final int CBC_SALT_LEN = 8;
/** OpenSSL's magic initial bytes. */
private static final String SALTED_STR = "Salted__";
private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII);


/**
* Function to check if a given number string is a valid Luhn number
* @param numberString
Expand Down Expand Up @@ -115,11 +128,70 @@ private static byte[] aesInternal(
cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec);
return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN);
}
} else if (mode.equalsIgnoreCase("CBC") &&
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
if (opmode == Cipher.ENCRYPT_MODE) {
byte[] salt = new byte[CBC_SALT_LEN];
secureRandom.nextBytes(salt);
final byte[] keyAndIv = getKeyAndIv(key, salt);
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
cipher.init(
Cipher.ENCRYPT_MODE,
new SecretKeySpec(keyValue, "AES"),
new IvParameterSpec(iv));
byte[] encrypted = cipher.doFinal(input, 0, input.length);
ByteBuffer byteBuffer = ByteBuffer.allocate(
SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length);
byteBuffer.put(SALTED_MAGIC);
byteBuffer.put(salt);
byteBuffer.put(encrypted);
return byteBuffer.array();
} else {
assert(opmode == Cipher.DECRYPT_MODE);
final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, SALTED_MAGIC.length);
if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) {
throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic);
}
final byte[] salt = Arrays.copyOfRange(
input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN);
final byte[] keyAndIv = getKeyAndIv(key, salt);
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
cipher.init(
Cipher.DECRYPT_MODE,
new SecretKeySpec(keyValue, "AES"),
new IvParameterSpec(iv, 0, CBC_IV_LEN));
return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN);
}
} else {
throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
}
} catch (GeneralSecurityException e) {
throw QueryExecutionErrors.aesCryptoError(e.getMessage());
}
}

// Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey
// since the version 1.1.0c which switched to SHA-256 as the hash.
private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws NoSuchAlgorithmException {
final byte[] keyAndSalt = arrConcat(key, salt);
byte[] hash = new byte[0];
byte[] keyAndIv = new byte[0];
for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) {
final byte[] hashData = arrConcat(hash, keyAndSalt);
final MessageDigest md = MessageDigest.getInstance("SHA-256");
hash = md.digest(hashData);
keyAndIv = arrConcat(keyAndIv, hash);
}
return keyAndIv;
}

private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) {
final byte[] res = new byte[arr1.length + arr2.length];
System.arraycopy(arr1, 0, res, 0, arr1.length);
System.arraycopy(arr2, 0, res, arr1.length, arr2.length);
return res;
}
}
Expand Up @@ -313,17 +313,17 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
@ExpressionDescription(
usage = """
_FUNC_(expr, key[, mode[, padding]]) - Returns an encrypted value of `expr` using AES in given `mode` with the specified `padding`.
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE').
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
The default mode is GCM.
""",
arguments = """
Arguments:
* expr - The binary value to encrypt.
* key - The passphrase to use to encrypt the data.
* mode - Specifies which block cipher mode should be used to encrypt messages.
Valid modes: ECB, GCM.
Valid modes: ECB, GCM, CBC.
* padding - Specifies how to pad messages whose length is not a multiple of the block size.
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB and NONE for GCM.
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC.
""",
examples = """
Examples:
Expand All @@ -333,6 +333,8 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210
> SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
3lmwu+Mw0H3fi5NDvcu9lg==
> SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 'DEFAULT'));
U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM=
""",
since = "3.3.0",
group = "misc_funcs")
Expand Down Expand Up @@ -377,17 +379,17 @@ case class AesEncrypt(
@ExpressionDescription(
usage = """
_FUNC_(expr, key[, mode[, padding]]) - Returns a decrypted value of `expr` using AES in `mode` with `padding`.
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE').
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
The default mode is GCM.
""",
arguments = """
Arguments:
* expr - The binary value to decrypt.
* key - The passphrase to use to decrypt the data.
* mode - Specifies which block cipher mode should be used to decrypt messages.
Valid modes: ECB, GCM.
Valid modes: ECB, GCM, CBC.
* padding - Specifies how to pad messages whose length is not a multiple of the block size.
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB and NONE for GCM.
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC.
""",
examples = """
Examples:
Expand All @@ -397,6 +399,8 @@ case class AesEncrypt(
Spark SQL
> SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), '1234567890abcdef', 'ECB', 'PKCS');
Spark SQL
> SELECT _FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), '1234567890abcdef', 'CBC');
Apache Spark
""",
since = "3.3.0",
group = "misc_funcs")
Expand Down
Expand Up @@ -2651,6 +2651,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
"detailMessage" -> detailMessage))
}

def aesInvalidSalt(saltedMagic: Array[Byte]): RuntimeException = {
new SparkRuntimeException(
errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
messageParameters = Map(
"parameter" -> toSQLId("expr"),
"functionName" -> toSQLId("aes_decrypt"),
"saltedMagic" -> saltedMagic.map("%02X" format _).mkString("0x", "", "")))
}

def hiveTableWithAnsiIntervalsError(tableName: String): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_2276",
Expand Down
Expand Up @@ -62,21 +62,26 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-37591: AES functions - GCM mode") {
test("SPARK-37591, SPARK-43038: AES functions - GCM/CBC mode") {
Seq(
("abcdefghijklmnop", ""),
("abcdefghijklmnop", "abcdefghijklmnop"),
("abcdefghijklmnop12345678", "Spark"),
("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
).foreach { case (key, input) =>
val df = Seq((key, input)).toDF("key", "input")
val encrypted = df.selectExpr("aes_encrypt(input, key, 'GCM', 'NONE') AS enc", "input", "key")
assert(encrypted.schema("enc").dataType === BinaryType)
assert(encrypted.filter($"enc" === $"input").isEmpty)
val result = encrypted.selectExpr(
"CAST(aes_decrypt(enc, key, 'GCM', 'NONE') AS STRING) AS res", "input")
assert(!result.filter($"res" === $"input").isEmpty &&
result.filter($"res" =!= $"input").isEmpty)
"GCM" -> "NONE",
"CBC" -> "PKCS").foreach { case (mode, padding) =>
Seq(
("abcdefghijklmnop", ""),
("abcdefghijklmnop", "abcdefghijklmnop"),
("abcdefghijklmnop12345678", "Spark"),
("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
).foreach { case (key, input) =>
val df = Seq((key, input)).toDF("key", "input")
val encrypted = df.selectExpr(
s"aes_encrypt(input, key, '$mode', '$padding') AS enc", "input", "key")
assert(encrypted.schema("enc").dataType === BinaryType)
assert(encrypted.filter($"enc" === $"input").isEmpty)
val result = encrypted.selectExpr(
s"CAST(aes_decrypt(enc, key, '$mode', '$padding') AS STRING) AS res", "input")
assert(!result.filter($"res" === $"input").isEmpty &&
result.filter($"res" =!= $"input").isEmpty)
}
}
}
}
Expand Down
Expand Up @@ -140,6 +140,25 @@ class QueryExecutionErrorsSuite
}
}

test("INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC: AES decrypt failure - invalid salt") {
checkError(
exception = intercept[SparkRuntimeException] {
sql(
"""
|SELECT aes_decrypt(
| unbase64('INVALID_SALT_ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='),
| '0000111122223333',
| 'CBC', 'PKCS')
|""".stripMargin).collect()
},
errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
parameters = Map(
"parameter" -> "`expr`",
"functionName" -> "`aes_decrypt`",
"saltedMagic" -> "0x20D5402C80D200B4"),
sqlState = "22023")
}

test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and padding") {
val key16 = "abcdefghijklmnop"
val key32 = "abcdefghijklmnop12345678ABCDEFGH"
Expand All @@ -157,18 +176,20 @@ class QueryExecutionErrorsSuite
}

// Unsupported AES mode and padding in encrypt
checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC')"),
"CBC", "DEFAULT")
checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC', 'None')"),
"CBC", "None")
checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 'NoPadding')"),
"ECB", "NoPadding")

// Unsupported AES mode and padding in decrypt
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GSM')"),
"GSM", "DEFAULT")
"GSM", "DEFAULT")
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GCM', 'PKCS')"),
"GCM", "PKCS")
"GCM", "PKCS")
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'ECB', 'None')"),
"ECB", "None")
"ECB", "None")
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'CBC', 'NoPadding')"),
"CBC", "NoPadding")
}

test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") {
Expand Down

0 comments on commit dabd771

Please sign in to comment.