In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.widgets.dropdown("create_aws_secret", defaultValue="True", choices=["True", "False"])
dbutils.widgets.text("dek_name", defaultValue="")
dbutils.widgets.text("region", defaultValue="eu-west-1")
dbutils.widgets.text("secret_name", defaultValue="")
dbutils.widgets.text("uc_service_credential", defaultValue="")

In [0]:
crypto_functions = dbutils.import_notebook("notebooks.envelope_encryption_v2.common.crypto_functions")

kek = crypto_functions.generate_kek()
dek = crypto_functions.generate_dek()
encrypted_dek = crypto_functions.encrypt_with_kek(
    kek_password=kek.get("kek_password"), 
    kek_salt=kek.get("kek_salt"), 
    to_encrypt=dek.get("private_key"))
encrypted_dek["dek"] = encrypted_dek.get("encrypted_string")
del dek["private_key"]
del encrypted_dek["encrypted_string"]
secret = encrypted_dek | dek

dek_name = dbutils.widgets.get("dek_name")

if dek_name != "":
    updated_secret = {f"{dek_name}_{k}": v for k, v in secret.items()}
    secret = updated_secret

In [0]:
import json 
import boto3

if eval(dbutils.widgets.get("create_aws_secret")):

    secret_string = json.dumps(kek | secret)

    boto3_session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider(dbutils.widgets.get("uc_service_credential")), region_name=dbutils.widgets.get("region"))
    secret = crypto_functions.create_aws_secret(
        session=boto3_session, 
        secret_name=dbutils.widgets.get("secret_name"), 
        secret_description="KEK and encrypted DEKs", 
        secret_string=secret_string, 
        tags=[], 
        kms_key="alias/aws/secretsmanager")
    print(f"Successfully created secret in AWS!\nName: {secret.get('Name')}\nARN: {secret.get('ARN')}\nVersion: {secret.get('VersionId')}")

In [0]:
%sql
SELECT current_catalog()

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS crypto

In [0]:
%sql
USE SCHEMA crypto

In [0]:
import boto3
from botocore.exceptions import ClientError
import json
import pandas as pd

boto3_session = boto3.Session(
    botocore_session=dbutils.credentials.getServiceCredentialsProvider('aweaver-secrets-manager'), 
    region_name='eu-west-1'
)

sm = boto3_session.client('secretsmanager')
secret_name = "prod/aweaver_catalog_1323553108280374/2026"

try:
    get_secret_value_response = sm.get_secret_value(
        SecretId=secret_name
    )
except ClientError as e:
    # For a list of exceptions thrown, see
    # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
    raise e

secret = json.loads(get_secret_value_response['SecretString'])
secret

In [0]:
pdf = pd.DataFrame([secret], columns = ["kek_password", "kek_salt", "dek", "dek_nonce", "dek_tag", "iv", "aad"])
pdf

In [0]:
%sql
CREATE OR REPLACE FUNCTION aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name STRING, dek_name STRING) 
RETURNS STRUCT<dek: STRING, iv: STRING, aad: STRING>
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
  `aweaver-secrets-manager` DEFAULT
)
ENVIRONMENT (
  dependencies = '["pycryptodome==3.22.0"]',
  environment_version = 'None'
)
AS $$
import boto3
from pyspark.taskcontext import TaskContext
from botocore.exceptions import ClientError
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import scrypt
import base64
from typing import Iterator
import json
import pandas as pd

def setup_session():

  session = boto3.Session()
  region = TaskContext.get().getLocalProperty("spark.databricks.clusterUsageTags.region")
  client = session.client("secretsmanager", region_name=region)
  return client

def decrypt_with_kek(kek_password, kek_salt, dek, nonce, tag):

  salt = base64.b64decode(kek_salt)
  kek = scrypt(kek_password, salt, key_len=32, N=2**17, r=8, p=1)
  cipher = AES.new(kek, AES.MODE_GCM, nonce=base64.b64decode(nonce))
  decrypted = cipher.decrypt(base64.b64decode(dek))
  try:
    cipher.verify(base64.b64decode(tag))
  except ValueError as e:
    raise e
  return decrypted.decode('utf-8')

client = setup_session()

def batchhandler(batch_iter: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:

  results = []
  for s in batch_iter:
    try:
      response = client.get_secret_value(SecretId=s[0])
      secret = json.loads(response['SecretString'])
    except ClientError as e:
      raise e
    
    secret["dek"] = decrypt_with_kek(
      kek_password=secret["kek_password"], 
      kek_salt=secret["kek_salt"], 
      dek=secret["dek"], 
      nonce=secret["dek_nonce"], 
      tag=secret["dek_tag"])
    results.append(secret)
  
  yield pd.DataFrame(results, columns=['dek', 'iv', 'aad'])
$$;

In [0]:
%sql
SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name, dek_name)
FROM VALUES
('prod/aweaver_catalog_1323553108280374/2026', 'titanic'),
('prod/aweaver_catalog_1323553108280374/2026', 'titanic')
AS t(secret_name, dek_name)

In [0]:
%sql
SELECT * FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name)
FROM (
  SELECT 'prod/aweaver_catalog_1323553108280374/2026' AS secret_name
) t)

In [0]:
#pd.Series(response['SecretString'])

In [0]:
%sql
SELECT * FROM (SELECT aweaver_catalog_1323553108280374.crypto.unwrap_dek(secret_name, dek_name)
FROM (
  SELECT 'testing/aweaver_catalog_1323553108280374/2026' AS secret_name, 'titanic' AS dek_name
) t)

In [0]:
# kek_password
# kek_salt
# titanic_nonce
# titanic_tag
# titanic_dek
# titanic_iv
# titanic_aad