* https://docs.databricks.com/aws/en/udf/python-batch-udf
* https://docs.databricks.com/aws/en/connect/unity-catalog/cloud-services/use-service-credentials 
* https://docs.databricks.com/aws/en/sql/language-manual/functions/aes_decrypt
* https://docs.databricks.com/aws/en/sql/language-manual/functions/aes_encrypt
* https://docs.databricks.com/aws/en/udf/udf-task-context?language=PySpark+UDF
* https://docs.databricks.com/aws/en/udf/pandas
* https://docs.google.com/document/d/1fj2Mt9FtzWr5wB7kpa-TxVasO3wzqxaQ5uFLBci7_Q4/edit?tab=t.0

In [0]:
%pip install pycryptodome

In [0]:
# pycryptodome -> Stored in AWS SM

    # KEK per catalog? Stored as a secret like -> 
        # <env>/<catalog_name>/<year>
        # example: prod/aweaver_catalog_1323553108280374/2025

            # kek_password = (String)
            # kek_salt = (String)

    # Used to encrypt/decrypt DEKs:
        # encrypted_dek (String)
        # encrypted_iv (String)
        # encrypted_aad (String)
        
        #For each:
            # nonce (String)
            # tag (String)

In [0]:
import secrets
from base64 import b64encode
from Crypto.Random import get_random_bytes

kek_password = b64encode(secrets.token_bytes(32)).decode("utf-8")
print(f"kek_password: {kek_password}")

kek_salt = b64encode(get_random_bytes(32)).decode("utf-8")
print(f"kek_salt: {kek_salt}")

In [0]:
import random
import string

dek = b64encode(secrets.token_bytes(24)).decode('utf-8')
iv = ''.join(random.choices(string.ascii_uppercase + string.digits, k=12))
aad = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
print(f"dek: {dek}")
print(f"iv: {iv}")
print(f"aad: {aad}")

In [0]:
from Crypto.Protocol.KDF import scrypt

def encrypt_with_kek(kek_password, kek_salt, to_encrypt):

    kek_salt = base64.b64decode(kek_salt)
    kek = scrypt(kek_password, kek_salt, key_len=32, N=2**17, r=8, p=1)
    cipher = AES.new(kek, AES.MODE_GCM) 
    nonce_bytes = cipher.nonce 
    encrypted = cipher.encrypt(to_encrypt.encode('utf-8'))
    tag_bytes = cipher.digest() 
    encrypted_string = b64encode(encrypted).decode('utf-8')
    nonce = b64encode(nonce_bytes).decode('utf-8')
    tag = b64encode(tag_bytes).decode('utf-8')

    return encrypted_string, nonce, tag

In [0]:
encrypted_string, nonce, tag = encrypt_with_kek(kek_password, kek_salt, dek)
encrypted_string

In [0]:
import base64

def decrypt_with_kek(kek_password, kek_salt, to_decrypt, nonce, tag):
    
    kek_salt = base64.b64decode(kek_salt)
    kek = scrypt(kek_password, kek_salt, key_len=32, N=2**17, r=8, p=1)
    cipher = AES.new(kek, AES.MODE_GCM, nonce=base64.b64decode(nonce)) 
    decrypted = cipher.decrypt(base64.b64decode(to_decrypt)) 

    try:
        cipher.verify(base64.b64decode(tag))
    except ValueError as e:
        raise e

    return decrypted.decode('utf-8')

In [0]:
decrypted = decrypt_with_kek(kek_password, kek_salt, encrypted_string, nonce, tag)
decrypted

In [0]:
import boto3
from botocore.exceptions import ClientError
import json
import pandas as pd

boto3_session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider('aweaver-secrets-manager'), region_name='eu-west-1')

sm = boto3_session.client('secretsmanager')
secret_name = "prod/aweaver_catalog_1323553108280374/2026"

try:
    get_secret_value_response = sm.get_secret_value(
        SecretId=secret_name
    )
except ClientError as e:
    # For a list of exceptions thrown, see
    # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
    raise e

secret = json.loads(get_secret_value_response['SecretString'])
# pdf = pd.DataFrame([secret])
# display(pdf)
secret

In [0]:
%sql
CREATE OR REPLACE FUNCTION aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name STRING) 
RETURNS STRUCT<dek: STRING, iv: STRING, aad: STRING>
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
  `aweaver-secrets-manager` DEFAULT
)
ENVIRONMENT (
  dependencies = '["pycryptodome==3.22.0"]',
  environment_version = 'None'
)
AS $$
import boto3
from pyspark.taskcontext import TaskContext
from botocore.exceptions import ClientError
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import scrypt
import base64
from typing import Iterator
import json
import pandas as pd

def setup_session():

  session = boto3.Session()
  region = TaskContext.get().getLocalProperty("spark.databricks.clusterUsageTags.region")
  client = session.client("secretsmanager", region_name=region)
  return client

def decrypt_with_kek(kek_password, kek_salt, dek, nonce, tag):

  salt = base64.b64decode(kek_salt)
  kek = scrypt(kek_password, salt, key_len=32, N=2**17, r=8, p=1)
  cipher = AES.new(kek, AES.MODE_GCM, nonce=base64.b64decode(nonce))
  decrypted = cipher.decrypt(base64.b64decode(dek))
  try:
    cipher.verify(base64.b64decode(tag))
  except ValueError as e:
    raise e
  return decrypted.decode('utf-8')

client = setup_session()

def batchhandler(batch_iter: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:

  results = []
  for s in batch_iter:
    try:
      response = client.get_secret_value(SecretId=s[0])
      secret = json.loads(response['SecretString'])
    except ClientError as e:
      raise e
    
    secret["dek"] = decrypt_with_kek(
      kek_password=secret["kek_password"], 
      kek_salt=secret["kek_salt"], 
      dek=secret["dek"], 
      nonce=secret["dek_nonce"], 
      tag=secret["dek_tag"])
    results.append(secret)
  
  yield pd.DataFrame(results, columns=['dek', 'iv', 'aad'])
$$;

In [0]:
%sql
SELECT * FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name).dek AS dek
FROM (
  SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name
) t)

In [0]:
%sql
SELECT * FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name)
FROM (
  SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name
) t)

In [0]:
%sql
CREATE OR REPLACE TABLE aweaver_catalog_1323553108280374.encrypted.titanic AS 
SELECT 
  base64(aes_encrypt(
    Name, 
    (SELECT dek FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name).dek AS dek FROM (SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name) t)), 
    'GCM', 
    'DEFAULT'
  )) AS encrypted_name,
  base64(aes_encrypt(
    Sex, 
    (SELECT dek FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name).dek AS dek FROM (SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name) t)), 
    'GCM', 
    'DEFAULT'
  )) AS encrypted_sex,
  base64(aes_encrypt(
    CAST(Age AS STRING), 
    (SELECT dek FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name).dek AS dek FROM (SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name) t)), 
    'GCM', 
    'DEFAULT'
  )) AS encrypted_age,
  * EXCEPT (Name, Sex, Age)
FROM aweaver_catalog_1323553108280374.raw.titanic;

In [0]:
%sql
SELECT 
cast(try_aes_decrypt(unbase64(encrypted_name), 
(SELECT * FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name).dek AS dek FROM (SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name) t)), 
'GCM', 
'DEFAULT') AS STRING) 
AS Name,
* EXCEPT (encrypted_name)
FROM titanic_encrypted

In [0]:
%sql
SELECT 
cast(try_aes_decrypt(unbase64(encrypted_name), 
(SELECT * FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name).dek AS dek FROM (SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name) t)), 
'GCM', 
'DEFAULT') AS STRING) 
AS decrypted_name,
* 
FROM titanic_encrypted

In [0]:
%sql
SHOW TABLES IN aweaver_catalog_1323553108280374.raw

In [0]:
sql(f"""CREATE OR REPLACE FUNCTION sys.crypto.encrypt(col STRING) 
RETURNS STRING
RETURN 
    base64(aes_encrypt(col, 
    sys.crypto.unwrap_key(secret('{secret_scope}', 'dek'), '{kek_name}'),
    'GCM',  
    'DEFAULT',
    sys.crypto.unwrap_key(secret('{secret_scope}', 'iv'), '{kek_name}'),
    sys.crypto.unwrap_key(secret('{secret_scope}', 'aad'), '{kek_name}')
    ))""")

In [0]:
%sql
SELECT aweaver_catalog_1323553108280374.crypto.get_keys(secret_name).kek AS kek
FROM (
  SELECT 'aweaver_catalog_1323553108280374' AS secret_name
) t

In [0]:
%sql
SELECT aweaver_catalog_1323553108280374.crypto.get_secret(secret_name).secret_value FROM VALUES
('aweaver/envelope_encryption') AS t(secret_name)

In [0]:
%sql
SELECT CAST(aes_decrypt(unbase64('4N0pCAsQ5F+4dG1L6apcnjyGr1dyGaPZkAeWl00THCndcP5uTE07NqK6071FaenIiaBq6gQqlZYz\r\nGEB/'), (SELECT aweaver_catalog_1323553108280374.crypto.get_secret(secret_name).secret_value FROM VALUES
('aweaver/envelope_encryption') AS t(secret_name)), 'GCM', 'DEFAULT') AS STRING) AS decrypted_dek

In [0]:
kek = 'y3v66E8GsiDOvnpbaL4TbsBW4bkaRT3d'
encrypted_dek = '4N0pCAsQ5F+4dG1L6apcnjyGr1dyGaPZkAeWl00THCndcP5uTE07NqK6071FaenIiaBq6gQqlZYz\r\nGEB/'
decrypted_dek = sql(f"SELECT CAST(aes_decrypt(unbase64('{encrypted_dek}'), '{kek}', 'GCM', 'DEFAULT') AS STRING)").first()[0]
decrypted_dek

In [0]:
import boto3
from botocore.exceptions import ClientError
import json
import pandas as pd

boto3_session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider('aweaver-secrets-manager'), region_name='eu-west-1')

sm = boto3_session.client('secretsmanager')
secret_name = "aweaver_catalog_1323553108280374"

try:
    get_secret_value_response = sm.get_secret_value(
        SecretId=secret_name
    )
except ClientError as e:
    # For a list of exceptions thrown, see
    # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
    raise e

secret = json.loads(get_secret_value_response['SecretString'])
pdf = pd.DataFrame([secret])
display(pdf)

In [0]:
results = []
kek = json.loads(get_secret_value_response['SecretString'])['kek']
dek = json.loads(get_secret_value_response['SecretString'])['dek']
results.append((kek, dek))
pdf2 = pd.DataFrame(results, columns=['kek', 'dek'])
display(pdf2)

In [0]:
import string
import random

dek = b64encode(urandom(24)).decode('utf-8')
iv = ''.join(random.choices(string.ascii_uppercase + string.digits, k=12))
aad = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))

encrypted_dek = sql(f"SELECT base64(aes_encrypt('{dek}', '{kek}', 'GCM', 'DEFAULT'))").first()[0]
encrypted_iv = sql(f"SELECT base64(aes_encrypt('{iv}', '{kek}', 'GCM', 'DEFAULT'))").first()[0]
encrypted_aad = sql(f"SELECT base64(aes_encrypt('{aad}', '{kek}', 'GCM', 'DEFAULT'))").first()[0]

In [0]:
print(f"dek: {dek}\n iv: {iv}\n aad: {aad}")

In [0]:
aes_decrypt(unbase64(key_to_unwrap), (SELECT FIRST(key) FROM sys.crypto.key_vault WHERE key_enabled AND key_name = key_to_use), 'GCM', 'DEFAULT')

In [0]:
decrypted_dek = sql(f"SELECT CAST(aes_decrypt(unbase64('{encrypted_dek}'), '{kek}', 'GCM', 'DEFAULT') AS STRING)").first()[0]
decrypted_dek

In [0]:
%sql
CREATE OR REPLACE FUNCTION aweaver_catalog_1323553108280374.crypto.unwrap_key(key_to_unwrap STRING, key_to_use STRING) 
RETURNS STRING
RETURN aes_decrypt(unbase64(key_to_unwrap), (SELECT FIRST(key) FROM sys.crypto.key_vault WHERE key_enabled AND key_name = key_to_use), 'GCM', 'DEFAULT')

In [0]:
kek_name = dbutils.widgets.get("kek_name")

sql(f"""CREATE OR REPLACE FUNCTION sys.crypto.encrypt(col STRING) 
RETURNS STRING
RETURN 
    base64(aes_encrypt(col, 
    sys.crypto.unwrap_key(secret('{secret_scope}', 'dek'), '{kek_name}'),
    'GCM',  
    'DEFAULT',
    sys.crypto.unwrap_key(secret('{secret_scope}', 'iv'), '{kek_name}'),
    sys.crypto.unwrap_key(secret('{secret_scope}', 'aad'), '{kek_name}')
    ))""")

In [0]:
sql(f"""CREATE OR REPLACE FUNCTION sys.crypto.decrypt(col STRING) 
RETURNS STRING
RETURN 
    CASE WHEN is_account_group_member('keyvault_user') THEN 
    nvl(CAST(try_aes_decrypt(unbase64(col), 
    sys.crypto.unwrap_key(secret('{secret_scope}', 'dek'), '{kek_name}'),
    'GCM',  
    'DEFAULT',
    sys.crypto.unwrap_key(secret('{secret_scope}', 'aad'), '{kek_name}')) AS STRING), 
    col)
    ELSE col END;""")