## Install the relevant libraries
If you have [serverless egress control](https://learn.microsoft.com/en-gb/azure/databricks/security/network/serverless-network-security/network-policies) configured (recommended) you'll either need to download the following Python wheels:

* [msal](https://pypi.org/project/msal/)
* [msal-extensions](https://pypi.org/project/msal-extensions/)
* [azure-identity](https://pypi.org/project/azure-identity/)
* [azure-keyvault-keys](https://pypi.org/project/azure-keyvault-keys/)

And install then via:

* [A UC volume](volumes)
* [Workspaces files](https://learn.microsoft.com/en-gb/azure/databricks/libraries/workspace-files-libraries)
* [Serverless environments](https://learn.microsoft.com/en-gb/azure/databricks/compute/serverless/dependencies)

Or use [Private Link](https://learn.microsoft.com/en-gb/azure/databricks/security/network/serverless-network-security/pl-to-internal-network) to connect to your own artifact repository!

NB - you can also use [Private Link](https://learn.microsoft.com/en-gb/azure/databricks/security/network/serverless-network-security/serverless-private-link) to connect to your Azure Key Vault!

In [0]:
%pip install -r ../../requirements.txt -q

In [0]:
dbutils.library.restartPython()

In [0]:
dbutils.widgets.dropdown("azure_key_creation", defaultValue="True", choices=["True", "False"])
dbutils.widgets.text("uc_service_credential", defaultValue="production-akv")
dbutils.widgets.text("key_vault_url", defaultValue="")
dbutils.widgets.text("metastore_id", defaultValue=sql("SELECT element_at(split(current_metastore(), ':'), -1) AS metastore").first().metastore)
dbutils.widgets.text("catalog", defaultValue="production")
dbutils.widgets.text("schema", defaultValue="finance")

service_credential = dbutils.widgets.get("uc_service_credential")
metastore_id = dbutils.widgets.get("metastore_id")
key_vault_url = dbutils.widgets.get("key_vault_url")
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")

## Step 1
* Download the titanic dataset and store it in a UC volume for raw files. 
* We'll use this to simulate a dataset that contains PII (Name, Age, Sex)

In [0]:
%sql
USE CATALOG IDENTIFIER(:catalog);
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema));
CREATE VOLUME IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema, '.raw_files'));

In [0]:
import subprocess

file_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
volume_path = f"/Volumes/{catalog}/{schema}/raw_files/titanic.csv"

subprocess.run(["wget", file_url, "-O", volume_path], check=True)
display(dbutils.fs.ls(f"/Volumes/{catalog}/{schema}/raw_files/"))

## Step 2

In [0]:
from azure.keyvault.keys import KeyClient, KeyType

credential = dbutils.credentials.getServiceCredentialsProvider(service_credential)
client = KeyClient(vault_url=key_vault_url, credential=credential)
kek_name = f"unity-catalog-{metastore_id}-{catalog.replace('_', '-')}-kek"

if eval(dbutils.widgets.get("azure_key_creation")):

    key = client.create_key(name=kek_name, key_type=KeyType.rsa) 
    print(f"Created kek '{key.name}': {key.id}")

In [0]:
import secrets
from azure.keyvault.keys.crypto import EncryptionAlgorithm

dek = secrets.token_bytes(32)
crypto_client = client.get_cryptography_client(key_name=kek_name)

encrypted_dek = crypto_client.wrap_key(EncryptionAlgorithm.rsa_oaep_256, dek).encrypted_key

## Step 3
* Create a `crypto.keyvault` table to store our encrypted DEKs
* Generate an encrypted DEK for our schema
* Store the encrypted DEK in the `crypto.keyvault` table

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.crypto'));
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.crypto.keyvault') (
  id BIGINT GENERATED BY DEFAULT AS IDENTITY,
  created_date DATE, 
  created_time TIMESTAMP,
  last_modified_time TIMESTAMP,
  created_by STRING,
  managed_by STRING,
  key_name STRING,
  key_enabled BOOLEAN,
  key_version INT,
  key_type STRING,
  key BINARY,
  CONSTRAINT pk_key_name_and_version
        PRIMARY KEY (key_name, key_version))
  TBLPROPERTIES (
  'delta.appendOnly' = 'true', -- we only INSERT new versions
  'delta.autoOptimize.optimizeWrite' = 'true',
  'delta.autoOptimize.autoCompact'  = 'true'
);

In [0]:
from databricks.sdk import WorkspaceClient
from pyspark.sql.types import StructType, StructField, IntegerType, DateType, TimestampType, StringType, BooleanType, BinaryType
from datetime import datetime
from datetime import date

ws = WorkspaceClient()

key_version = sql(f"SELECT MAX(key_version) AS max_version FROM crypto.keyvault WHERE key_name = concat('{catalog}', '.', '{schema}')").first().max_version

if not key_version:
  key_version = 1
else:
  key_version += 1

keyvault_schema = StructType([
    StructField("created_date", DateType(), False),
    StructField("created_time", TimestampType(), False),
    StructField("last_modified_time", TimestampType(), False),
    StructField("created_by", StringType(), False),
    StructField("managed_by", StringType(), False),
    StructField("key_name", StringType(), False),
    StructField("key_enabled", BooleanType(), False),
    StructField("key_version", IntegerType(), True),
    StructField("key_type", StringType(), False),
    StructField("key", BinaryType(), False)
])

keyvault_data = [{"created_date": date.today(), "created_time": datetime.now(), "last_modified_time": datetime.now(), "created_by": ws.current_user.me().user_name, "managed_by": ws.current_user.me().user_name, "key_enabled": True, "key_version": key_version, "key_type": "ENCRYPTED_DEK", "key_name": f"{catalog}.{schema}", "key": encrypted_dek}]

spark.createDataFrame(keyvault_data, keyvault_schema).createOrReplaceTempView("tmp_keyvault")

In [0]:
%sql
INSERT INTO crypto.keyvault (created_date, created_time, last_modified_time, created_by, managed_by, key_name, key_enabled, key_version, key_type, key)
SELECT * FROM tmp_keyvault

In [0]:
%sql
SELECT * 
FROM crypto.keyvault 
WHERE key_name = concat(:catalog, '.', :schema) 
AND key_enabled = true
ORDER BY key_version DESC 

## Step 4
* Create an `unwrap_akv_key()` function that can be used to return a decrypted DEK 
* Test our `unwrap_akv_key()` function by encrypting/decrypting some data

> ### IMPORTANT: 
Before running this section please:
1. Update the `CREDENTIALS()` section of the code below to reference your uc_service_credential
2. Update the `dependencies` section of the code below to reference your libraries, as installed above

In [0]:
%sql
USE SCHEMA crypto;
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE CREDENTIALS() SECTION TO REFERENCE YOUR uc_service_credential
CREATE OR REPLACE FUNCTION crypto.unwrap_akv_key(key_vault_url STRING, key_name STRING, encrypted_dek BINARY)
RETURNS BINARY
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
    -- IMPORTANT! REPLACE THIS WITH YOUR UC SERVICE CREDENTIAL!!!
    `production-akv` DEFAULT)
ENVIRONMENT (
    -- IMPORTANT! REPLACE THIS SECTION WITH YOUR LIBRARIES!
  dependencies = '["azure-keyvault-keys==4.11.0", "azure-identity==1.23.1"]', 
    --   dependencies = '[
    --     "/Volumes/production/default/packages/msal-1.32.3-py3-none-any.whl",
    --     "/Volumes/production/default/packages/msal_extensions-1.3.1-py3-none-any.whl", 
    --     "/Volumes/production/default/packages/azure_identity-1.23.1-py3-none-any.whl", 
    --     "/Volumes/production/default/packages/azure_keyvault_keys-4.11.0-py3-none-any.whl"
    --     ]', 
  environment_version = 'None')
AS $$
import hashlib
from functools import lru_cache
from azure.identity import DefaultAzureCredential
from azure.keyvault.keys.crypto import CryptographyClient, EncryptionAlgorithm
from typing import Iterator, Tuple
import pandas as pd

_cred = DefaultAzureCredential()

# In-memory blob cache per executor (to avoid storing large encrypted blobs multiple times)
_blob_cache = {}

@lru_cache(maxsize=128)
def _get_client(key_uri: str) -> CryptographyClient:
    """
    Returns a cached CryptographyClient for the given Azure Key Vault key URI.

    Args:
        key_uri (str): Full Azure Key Vault key ID (e.g., https://vault/keys/my-key/version).

    Returns:
        CryptographyClient: A client capable of wrap/unwrap operations.
    """
    return CryptographyClient(key_uri, credential=_cred)

@lru_cache(maxsize=2048)
def _unwrap(key_uri: str, blob_hash: bytes) -> bytes:
    """
    Internal cached unwrap using a blob hash to deduplicate wrapped DEK calls.

    Args:
        key_uri (str): Full URI of the KEK in Azure Key Vault.
        blob_hash (bytes): SHA-256 digest of the wrapped DEK.

    Returns:
        bytes: Unwrapped 32-byte DEK.
    """
    return _get_client(key_uri).unwrap_key(EncryptionAlgorithm.rsa_oaep_256, _blob_cache[blob_hash]).key

def batchhandler(batch_iter: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.DataFrame]:

    """
    Securely unwraps a DEK using Azure Key Vault and executor-local cache.

    Args:
        key_vault_url (str): Base URI of the Key Vault (e.g., https://myvault-kv.vault.azure.net).
        key_name (str): Name of the RSA KEK in Key Vault.
        encrypted_dek (bytes): RSA-wrapped DEK to be unwrapped.

    Returns:
        bytes: Unwrapped 32-byte DEK.
    """

    for url, name, dek in batch_iter:

        key_vault_url = url[0]
        key_name = name[0]
        encrypted_dek = dek[0]

        try:
            h = hashlib.sha256(encrypted_dek).digest()
            _blob_cache.setdefault(h, encrypted_dek)
            key_uri = f"{key_vault_url.rstrip('/')}/keys/{key_name}"
            unwrapped = _unwrap(key_uri, h)

        except Exception as e:
            e.add_note("""
              ___ _ __ _ __ ___  _ __ 
             / _ \ '__| '__/ _ \| '__|
            |  __/ |  | | | (_) | |   
             \___|_|  |_|  \___/|_|
            
            Failed to unwrap key! Please check:

            1. The user is a member of the <catalog>.<schema>.crypto.user account group
            2. The UC service credential has the right permissions to use the AKV key
            3. That the network you're connecting from is allowed to access the AKV
            """)
            raise e

        yield pd.Series(unwrapped)
$$;

In [0]:
%sql
SELECT CAST(try_aes_decrypt(
  unbase64(base64(
  aes_encrypt(
    randstr(30),
    (SELECT crypto.unwrap_akv_key(
      key_vault_url=>:key_vault_url,
      key_name=>CONCAT('unity-catalog-', :metastore_id, '-', REPLACE(:catalog, '_', '-'), '-kek'),
      encrypted_dek=>(SELECT key FROM crypto.keyvault 
      WHERE key_name = CONCAT(:catalog, '.', :schema) AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
      'GCM',  
      'DEFAULT'))),
  (SELECT crypto.unwrap_akv_key(
      key_vault_url=>:key_vault_url, 
      key_name=>CONCAT('unity-catalog-', :metastore_id, '-', REPLACE(:catalog, '_', '-'), '-kek'),
    encrypted_dek=>(SELECT key FROM crypto.keyvault 
    WHERE key_name = CONCAT(:catalog, '.', :schema) AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
    'GCM',
    'DEFAULT') AS STRING) AS decrypted_random_string

## Step 5
* Create `encrypt()` and `decrypt()` functions that can be used to encrypt/decrypt data within our catalog. 
* These functions will call our more privileged `unwrap_akv_key()` function in order to unwrap DEKs and encrypt or decrypt the data

> ### IMPORTANT: 
Before running this section please create the account level group `<catalog>.<schema>.crypto.user` and add your user as a member

In [0]:
is_account_group_member = sql(f"SELECT is_account_group_member('{catalog}.{schema}.crypto.user') AS is_account_group_member").first().is_account_group_member
if is_account_group_member != True:
    raise Exception(f"Please add your user to the '{catalog}.{schema}.crypto.user' account group before proceeding")

In [0]:
%sql
CREATE OR REPLACE FUNCTION crypto.get_latest_dek_version(catalog STRING, schema STRING)
RETURNS INT
RETURN (SELECT MAX(key_version) FROM crypto.keyvault 
    WHERE key_name = CONCAT(catalog, '.', schema) 
         AND key_enabled = true)  

In [0]:
%sql
CREATE OR REPLACE FUNCTION crypto.get_encrypted_dek(catalog STRING, schema STRING, version INT)
RETURNS BINARY
RETURN (SELECT FIRST(key)
         FROM crypto.keyvault 
         WHERE key_name = CONCAT(catalog, '.', schema) 
         AND key_enabled = true
         AND key_version = version)

In [0]:
sql(f"""
CREATE OR REPLACE FUNCTION crypto.encrypt_akv(col STRING, key BINARY)
RETURNS STRING
RETURN base64(aes_encrypt(
            col, 
            (SELECT * FROM (SELECT crypto.unwrap_akv_key(
                key_vault_url=>"{key_vault_url}",
                key_name=>"unity-catalog-{metastore_id}-{catalog.replace('_', '-')}-kek",
                encrypted_dek=>key))),
            'GCM',  
            'DEFAULT'
        ))""")

In [0]:
sql(f"""
    CREATE OR REPLACE FUNCTION crypto.decrypt_akv(col STRING, catalog STRING, schema STRING, key BINARY)
    RETURNS STRING
    RETURN CASE WHEN is_account_group_member(CONCAT(catalog, '.', schema, '.crypto.user')) THEN
    nvl(CAST(try_aes_decrypt(unbase64(col), 
        (SELECT * FROM (SELECT crypto.unwrap_akv_key(
                key_vault_url=>"{key_vault_url}",
                key_name=>"unity-catalog-{metastore_id}-{catalog.replace('_', '-')}-kek",
                encrypted_dek=>key))),
        'GCM',  
        'DEFAULT') AS STRING), 
        col)
        ELSE col END
    """)

In [0]:
%sql
CREATE OR REPLACE FUNCTION crypto.encrypt(col STRING, catalog STRING, schema STRING)
RETURNS STRING
RETURN crypto.encrypt_akv(col, crypto.get_encrypted_dek(catalog, schema, crypto.get_latest_dek_version(catalog, schema)));

CREATE OR REPLACE FUNCTION crypto.decrypt(col STRING, catalog STRING, schema STRING)
RETURNS STRING
RETURN crypto.decrypt_akv(col, catalog, schema, crypto.get_encrypted_dek(catalog, schema, crypto.get_latest_dek_version(catalog, schema)))

## Step 6
* Create a table from the raw data we downloaded above, encrypting the columns that contain sensitive data

In [0]:
%sql
CREATE OR REPLACE TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') AS (
SELECT 
PassengerId,
crypto.encrypt(Name, :catalog, :schema) AS Name,
crypto.encrypt(Age, :catalog, :schema) AS Age,
crypto.encrypt(Sex, :catalog, :schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM read_files(
  concat('/Volumes/', :catalog, '/', :schema, '/raw_files/titanic.csv'),
  format => 'csv',
  header => true,
  mode => 'FAILFAST'));
  SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic')

## Step 7 
* Check that the decrypt functions work as expected...

In [0]:
%sql
SELECT
PassengerId,
crypto.decrypt(Name, :catalog, :schema) AS Name,
crypto.decrypt(Age, :catalog, :schema) AS Age,
crypto.decrypt(Sex, :catalog, :schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic')

## Step 8
* You can also add a column mask to the encrypted table
* A column mask serves the following purposes:
  * The calling users don't even need permissions to the `encrypt_akv()` and `decrypt_akv()` functions or the `crypto` schema
  * The whole process of encryption/decryption is abstracted away from them

> ### NOTE: 
Adding a column mask is likely to cause calling the `decrypt_akv()` function directly to fail, since the column mask will try to decrypt the data automatically and you'll be trying to decrypt the already decrypted results!

> ### IMPORTANT: 
Please update the `USING COLUMNS('customers')` section below with your `schema` name

In [0]:
%sql
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE USING COLUMNS() SECTION TO REFERENCE YOUR UC SCHEMA NAME
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Name SET MASK crypto.decrypt USING COLUMNS('production', 'finance');
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Age SET MASK crypto.decrypt USING COLUMNS ('production', 'finance');
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Sex SET MASK crypto.decrypt USING COLUMNS ('production', 'finance');

In [0]:
%sql
SELECT 
*
FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic');

## Idea:
Why not use [Unity Catalog attribute-based access control (ABAC)](https://learn.microsoft.com/en-gb/azure/databricks/data-governance/unity-catalog/abac/) to automatically scale you encryption/decryption UDFs across an entire catalog? You could even use [Data Classification](https://learn.microsoft.com/en-gb/azure/databricks/lakehouse-monitoring/data-classification) to automatically detect and tag sensitive data!