In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.widgets.dropdown("create_aws_secret", defaultValue="True", choices=["True", "False"])
dbutils.widgets.text("dek_name", defaultValue="")
dbutils.widgets.text("region", defaultValue="eu-west-1")
dbutils.widgets.text("secret_name", defaultValue="")
dbutils.widgets.text("uc_service_credential", defaultValue="")

In [0]:
crypto_functions = dbutils.import_notebook("notebooks.envelope_encryption_v2.common.crypto_functions")

kek = crypto_functions.generate_kek()
dek = crypto_functions.generate_dek()
encrypted_dek = crypto_functions.encrypt_with_kek(
    kek_password=kek.get("kek_password"), 
    kek_salt=kek.get("kek_salt"), 
    to_encrypt=dek.get("private_key"))
encrypted_dek["dek"] = encrypted_dek.get("encrypted_string")
del dek["private_key"]
del encrypted_dek["encrypted_string"]
secret = encrypted_dek | dek

dek_name = dbutils.widgets.get("dek_name")

if dek_name != "":
    updated_secret = {f"{dek_name}_{k}": v for k, v in secret.items()}
    secret = updated_secret

In [0]:
import json 
import boto3

if eval(dbutils.widgets.get("create_aws_secret")):

    secret_string = json.dumps(kek | secret)

    boto3_session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider(dbutils.widgets.get("uc_service_credential")), region_name=dbutils.widgets.get("region"))
    secret = crypto_functions.create_aws_secret(
        session=boto3_session, 
        secret_name=dbutils.widgets.get("secret_name"), 
        secret_description="KEK and encrypted DEKs", 
        secret_string=secret_string, 
        tags=[], 
        kms_key="alias/aws/secretsmanager")
    print(f"Successfully created secret in AWS!\nName: {secret.get('Name')}\nARN: {secret.get('ARN')}\nVersion: {secret.get('VersionId')}")

In [0]:
%sql
SELECT current_catalog();
CREATE SCHEMA IF NOT EXISTS crypto;
USE SCHEMA crypto;

In [0]:
%sql
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE CREDENTIALS() SECTION TO REFERENCE YOUR uc_service_credential

CREATE OR REPLACE FUNCTION unwrap_key(secret_name STRING, key_name STRING) 
RETURNS STRING
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
  `aweaver-secrets-manager` DEFAULT -- IMPORTANT! REPLACE THIS WITH YOUR UC SERVICE CREDENTIAL!!!
)
ENVIRONMENT (
  dependencies = '["pycryptodome==3.22.0"]',
  environment_version = 'None'
)
AS $$
import boto3
from pyspark.taskcontext import TaskContext
from botocore.exceptions import ClientError
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import scrypt
import base64
from typing import Iterator, Tuple
import json
import pandas as pd

def setup_session():

  session = boto3.Session()
  region = TaskContext.get().getLocalProperty("spark.databricks.clusterUsageTags.region")
  client = session.client("secretsmanager", region_name=region)
  return client

def decrypt_with_kek(kek_password, kek_salt, dek, nonce, tag):

  salt = base64.b64decode(kek_salt)
  kek = scrypt(kek_password, salt, key_len=32, N=2**17, r=8, p=1)
  cipher = AES.new(kek, AES.MODE_GCM, nonce=base64.b64decode(nonce))
  decrypted = cipher.decrypt(base64.b64decode(dek))
  try:
    cipher.verify(base64.b64decode(tag))
  except ValueError as e:
    raise e
  return decrypted.decode('utf-8')

client = setup_session()

def batchhandler(batch_iter: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.DataFrame]:

  for s, d in batch_iter:

    key_name = d[0]
    try:
      response = client.get_secret_value(SecretId=s[0])
      secret = json.loads(response.get("SecretString"))
    except ClientError as e:
      raise e
    
    secret["dek"] = decrypt_with_kek(
      kek_password=secret.get("kek_password"), 
      kek_salt=secret.get("kek_salt"), 
      dek=secret.get(f"{key_name}_dek"), 
      nonce=secret.get(f"{key_name}_nonce"),
      tag=secret.get(f"{key_name}_tag"))
    yield pd.Series(secret.get("dek"))
$$;

In [0]:
%sql
SELECT * FROM (SELECT unwrap_key('production/aweaver_catalog_1323553108280374/2026', 'titanic') AS dek)

In [0]:
x = crypto_functions.decrypt_with_kek(
    kek_password="MfDrgFci3MjBYuCH4nUCx6KCwroQlZQGnlgCp1cSm/8=", 
    kek_salt="Le/Ynd223+GtjjQkXRpGv9YcuyVvuDOfkcDPhAXihvU=",
    to_decrypt="WcKC7xSAD3ttKGS6LFoEci96IgjGdHd1IQJ5fgMByE8=",
    nonce="84gOtgTcVERh+xr/C7Bgtw==",
    tag="lyciABdt3Wcc2nohvi/9ag=="
    )
x

In [0]:
%sql
SELECT unwrap_key(secret_name, key_name) AS deks
FROM VALUES
('production/aweaver_catalog_1323553108280374/2026', 'titanic')
AS t(secret_name, key_name)

In [0]:
%sql
CREATE OR REPLACE FUNCTION encrypt(col STRING) 
RETURNS STRING
RETURN 
    base64(aes_encrypt(
        col, 
        (SELECT * FROM (SELECT unwrap_key('production/aweaver_catalog_1323553108280374/2026', 'titanic'))),
        'GCM',  
        'DEFAULT'
    ))

In [0]:
%sql
CREATE OR REPLACE TABLE aweaver_catalog_1323553108280374.encrypted.titanic AS 
SELECT 
PassengerId,
encrypt(Name) AS encrypted_name,
encrypt(Sex) AS encrypted_sex,
encrypt(Age) AS encrypted_age,
* EXCEPT (PassengerId, Name, Sex, Age)
FROM aweaver_catalog_1323553108280374.raw.titanic;

In [0]:
%sql
SELECT * FROM aweaver_catalog_1323553108280374.encrypted.titanic

In [0]:
%sql
CREATE OR REPLACE FUNCTION decrypt(col STRING) 
RETURNS STRING
RETURN 
    CASE WHEN is_account_group_member('keyvault_user') THEN 
    nvl(CAST(try_aes_decrypt(unbase64(col), 
    (SELECT * FROM (SELECT unwrap_key('production/aweaver_catalog_1323553108280374/2026', 'titanic'))),
    'GCM',  
    'DEFAULT') AS STRING), 
    col)
    ELSE col END

In [0]:
%sql
SELECT is_account_group_member('keyvault_user') AS keyvault_user

In [0]:
%sql
SELECT 
decrypt(encrypted_name) AS decrypted_name,
decrypt(encrypted_sex) AS decrypted_sex,
decrypt(encrypted_age) AS decrypted_age,
* 
FROM aweaver_catalog_1323553108280374.encrypted.titanic