In [0]:
%pip install -r ../../requirements.txt -q

In [0]:
dbutils.library.restartPython()

In [0]:
dbutils.widgets.dropdown("azure_key_creation", defaultValue="True", choices=["True", "False"])
dbutils.widgets.text("uc_service_credential", defaultValue="production-akv")
dbutils.widgets.text("key_vault_url", defaultValue="")
dbutils.widgets.text("metastore_id", defaultValue=sql("SELECT element_at(split(current_metastore(), ':'), -1) AS metastore").first().metastore)
dbutils.widgets.text("catalog", defaultValue="production")
dbutils.widgets.text("schema", defaultValue="finance")

service_credential = dbutils.widgets.get("uc_service_credential")
metastore_id = dbutils.widgets.get("metastore_id")
key_vault_url = dbutils.widgets.get("key_vault_url")
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")

## Step 1
* Download the titanic dataset and store it in a UC volume for raw files. 
* We'll use this to simulate a dataset that contains PII (Name, Age, Sex)

In [0]:
%sql
USE CATALOG IDENTIFIER(:catalog);
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema));
CREATE VOLUME IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema, '.raw_files'));

In [0]:
import subprocess

file_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
volume_path = f"/Volumes/{catalog}/{schema}/raw_files/titanic.csv"

subprocess.run(["wget", file_url, "-O", volume_path], check=True)
display(dbutils.fs.ls(f"/Volumes/{catalog}/{schema}/raw_files/"))

## Step 2

In [0]:
from azure.keyvault.keys import KeyClient, KeyType

credential = dbutils.credentials.getServiceCredentialsProvider(service_credential)
client = KeyClient(vault_url=key_vault_url, credential=credential)
kek_name = f"unity-catalog-{metastore_id}-{catalog.replace('_', '-')}-kek"

if eval(dbutils.widgets.get("azure_key_creation")):

    key = client.create_key(name=kek_name, key_type=KeyType.rsa) 
    print(f"Created kek '{key.name}': {key.id}")

In [0]:
import secrets

dek = secrets.token_bytes(32)
dek

In [0]:
from azure.keyvault.keys.crypto import EncryptionAlgorithm

crypto_client = client.get_cryptography_client(key_name=kek_name)

encrypted_dek = crypto_client.wrap_key(EncryptionAlgorithm.rsa_oaep_256, dek).encrypted_key
encrypted_dek

In [0]:
unwrapped_dek = crypto_client.unwrap_key(EncryptionAlgorithm.rsa_oaep_256, encrypted_dek).key
unwrapped_dek

## Step 3
* Create a `crypto.keyvault` table to store our encrypted DEKs
* Generate an encrypted DEK for our schema
* Store the encrypted DEK in the `crypto.keyvault` table

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.crypto'));
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.crypto.keyvault') (
  id BIGINT GENERATED BY DEFAULT AS IDENTITY,
  created_date DATE, 
  created_time TIMESTAMP,
  last_modified_time TIMESTAMP,
  created_by STRING,
  managed_by STRING,
  key_name STRING,
  key_enabled BOOLEAN,
  key_version INT,
  key_type STRING,
  key BINARY,
  CONSTRAINT pk_key_name_and_version
        PRIMARY KEY (key_name, key_version))
  TBLPROPERTIES (
  'delta.appendOnly' = 'true', -- we only INSERT new versions
  'delta.autoOptimize.optimizeWrite' = 'true',
  'delta.autoOptimize.autoCompact'  = 'true'
);

In [0]:
from databricks.sdk import WorkspaceClient
from pyspark.sql.types import StructType, StructField, IntegerType, DateType, TimestampType, StringType, BooleanType, BinaryType
from datetime import datetime
from datetime import date

ws = WorkspaceClient()

key_version = sql(f"SELECT MAX(key_version) AS max_version FROM crypto.keyvault WHERE key_name = concat('{catalog}', '.', '{schema}')").first().max_version

if not key_version:
  key_version = 1
else:
  key_version += 1

keyvault_schema = StructType([
    StructField("created_date", DateType(), False),
    StructField("created_time", TimestampType(), False),
    StructField("last_modified_time", TimestampType(), False),
    StructField("created_by", StringType(), False),
    StructField("managed_by", StringType(), False),
    StructField("key_name", StringType(), False),
    StructField("key_enabled", BooleanType(), False),
    StructField("key_version", IntegerType(), True),
    StructField("key_type", StringType(), False),
    StructField("key", BinaryType(), False)
])

keyvault_data = [{"created_date": date.today(), "created_time": datetime.now(), "last_modified_time": datetime.now(), "created_by": ws.current_user.me().user_name, "managed_by": ws.current_user.me().user_name, "key_enabled": True, "key_version": key_version, "key_type": "ENCRYPTED_DEK", "key_name": f"{catalog}.{schema}", "key": encrypted_dek}]

spark.createDataFrame(keyvault_data, keyvault_schema).createOrReplaceTempView("tmp_keyvault")

In [0]:
%sql
INSERT INTO crypto.keyvault (created_date, created_time, last_modified_time, created_by, managed_by, key_name, key_enabled, key_version, key_type, key)
SELECT * FROM tmp_keyvault

In [0]:
%sql
SELECT * 
FROM crypto.keyvault 
WHERE key_name = concat(:catalog, '.', :schema) 
AND key_enabled = true
ORDER BY key_version DESC 

## Step 4
* Create an `unwrap_akv_key()` function that can be used to return a decrypted DEK 
* Test our `unwrap_akv_key()` function by encrypting/decrypting some data

In [0]:
%sql
USE SCHEMA crypto;
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE CREDENTIALS() SECTION TO REFERENCE YOUR uc_service_credential
CREATE OR REPLACE FUNCTION crypto.unwrap_akv_key(key_vault_url STRING, key_name STRING, encrypted_dek BINARY)
RETURNS BINARY
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
    -- IMPORTANT! REPLACE THIS WITH YOUR UC SERVICE CREDENTIAL!!!
    `production-akv` DEFAULT)
ENVIRONMENT (
  dependencies = '["azure-keyvault-keys==4.11.0", "azure-identity==1.23.1"]', environment_version = 'None')
AS $$
import hashlib
from functools import lru_cache
from azure.identity import DefaultAzureCredential
from azure.keyvault.keys.crypto import CryptographyClient, EncryptionAlgorithm
from typing import Iterator, Tuple
import pandas as pd

_cred = DefaultAzureCredential()

# In-memory blob cache per executor (to avoid storing large encrypted blobs multiple times)
_blob_cache = {}

@lru_cache(maxsize=128)
def _get_client(key_uri: str) -> CryptographyClient:
    """
    Returns a cached CryptographyClient for the given Azure Key Vault key URI.

    Args:
        key_uri (str): Full Azure Key Vault key ID (e.g., https://vault/keys/my-key/version).

    Returns:
        CryptographyClient: A client capable of wrap/unwrap operations.
    """
    return CryptographyClient(key_uri, credential=_cred)

@lru_cache(maxsize=2048)
def _unwrap(key_uri: str, blob_hash: bytes) -> bytes:
    """
    Internal cached unwrap using a blob hash to deduplicate wrapped DEK calls.

    Args:
        key_uri (str): Full URI of the KEK in Azure Key Vault.
        blob_hash (bytes): SHA-256 digest of the wrapped DEK.

    Returns:
        bytes: Unwrapped 32-byte DEK.
    """
    return _get_client(key_uri).unwrap_key(EncryptionAlgorithm.rsa_oaep_256, _blob_cache[blob_hash]).key

def batchhandler(batch_iter: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.DataFrame]:

    """
    Securely unwraps a DEK using Azure Key Vault and executor-local cache.

    Args:
        key_vault_url (str): Base URI of the Key Vault (e.g., https://myvault-kv.vault.azure.net).
        key_name (str): Name of the RSA KEK in Key Vault.
        encrypted_dek (bytes): RSA-wrapped DEK to be unwrapped.

    Returns:
        bytes: Unwrapped 32-byte DEK.
    """

    for url, name, dek in batch_iter:

        key_vault_url = url[0]
        key_name = name[0]
        encrypted_dek = dek[0]

        try:
            h = hashlib.sha256(encrypted_dek).digest()
            _blob_cache.setdefault(h, encrypted_dek)
            key_uri = f"{key_vault_url.rstrip('/')}/keys/{key_name}"
            unwrapped = _unwrap(key_uri, h)

        except Exception as e:
            e.add_note("""
              ___ _ __ _ __ ___  _ __ 
             / _ \ '__| '__/ _ \| '__|
            |  __/ |  | | | (_) | |   
             \___|_|  |_|  \___/|_|
            
            Failed to unwrap key! Please check:

            1. The user is a member of the <catalog>.<schema>.crypto.user account group
            2. The UC service credential has the right permissions to use the AKV key
            3. That the network you're connecting from is allowed to access the AKV
            """)
            raise e

        yield pd.Series(unwrapped)
$$;

In [0]:
key_name

In [0]:
%sql
SELECT CAST(try_aes_decrypt(
  unbase64(base64(
  aes_encrypt(
    randstr(30),
    (SELECT crypto.unwrap_akv_key(
      key_vault_url=>:key_vault_url,
      key_name=>CONCAT('unity-catalog-', :metastore_id, '-', REPLACE(:catalog, '_', '-'), '-kek'),
      encrypted_dek=>(SELECT key FROM crypto.keyvault 
      WHERE key_name = CONCAT(:catalog, '.', :schema) AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
      'GCM',  
      'DEFAULT'))),
  (SELECT crypto.unwrap_akv_key(
      key_vault_url=>:key_vault_url, 
      key_name=>CONCAT('unity-catalog-', :metastore_id, '-', REPLACE(:catalog, '_', '-'), '-kek'),
    encrypted_dek=>(SELECT key FROM crypto.keyvault 
    WHERE key_name = CONCAT(:catalog, '.', :schema) AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
    'GCM',
    'DEFAULT') AS STRING) AS decrypted_random_string

## Step 5
* Create `encrypt()` and `decrypt()` functions that can be used to encrypt/decrypt data within our catalog. 
* These functions will call our more privileged `unwrap_akv_key()` function in order to unwrap DEKs and encrypt or decrypt the data

> ### IMPORTANT: 
Before running this section please create the account level group `<catalog>.<schema>.crypto.user` and add your user as a member