In [0]:
dbutils.widgets.dropdown("aws_cmk_creation", defaultValue="True", choices=["True", "False"])
dbutils.widgets.text("schema", defaultValue="marketing")
dbutils.widgets.text("region", defaultValue="eu-west-1")
dbutils.widgets.text("catalog", defaultValue="production")
dbutils.widgets.text("uc_service_credential", defaultValue="production-aws-kms")

metastore_id = sql("SELECT element_at(split(current_metastore(), ':'), -1) AS metastore").first().metastore
catalog = dbutils.widgets.get("catalog")
region = dbutils.widgets.get("region")
schema = dbutils.widgets.get("schema")
service_credential = dbutils.widgets.get("uc_service_credential")

## Step 1
* Download the titanic dataset and store it in a UC volume for raw files. 
* We'll use this to simulate a dataset that contains PII (Name, Age, Sex)

In [0]:
%sql
USE CATALOG IDENTIFIER(:catalog);
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema));
CREATE VOLUME IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema, '.raw_files'));

In [0]:
import subprocess

file_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
volume_path = f"/Volumes/{catalog}/{schema}/raw_files/titanic.csv"

subprocess.run(["wget", file_url, "-O", volume_path], check=True)
display(dbutils.fs.ls(f"/Volumes/{catalog}/{schema}/raw_files/"))

## Step 2
* Create an AWS KMS key to generate our DEKs
* In order for this to work your UC service credential will need the following privileges: `"kms:CreateKey", "kms:CreateAlias"`
* You can optionally create your KMS key manually, via your own scripts or via IaaC such as Terraform

> ### IMPORTANT: 
The key created via this function is largely intended to be used for demonstration purposes. Please review the Key policy and other security related configurations when creating your actual keys

In [0]:
import boto3

crypto_functions = dbutils.import_notebook("notebooks.envelope_encryption_v2.common.crypto_functions")

session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider(service_credential), region_name=region)

key_alias = f"alias/unity_catalog/{metastore_id}/{catalog}/4"

if eval(dbutils.widgets.get("aws_cmk_creation")):
    kms = crypto_functions.create_kms_key(session=session, alias=key_alias, description=f"CMK for AES encryption of UC catalog {catalog} in metastore {metastore_id}", tags=[] )
key_alias

## Step 3
* Generate a DEK for the schema 
* Create a `crypto.key_vault` table to store our encrypted DEKs
* Store the encrypted DEK in the `crypto.key_vault` table

In [0]:
from databricks.sdk import WorkspaceClient
from pyspark.sql.types import StructType, StructField, IntegerType, DateType, TimestampType, StringType, BooleanType, BinaryType
from datetime import datetime
from datetime import date

ws = WorkspaceClient()

dek = crypto_functions.generate_data_key(session=session, key_alias=key_alias, encryption_context={"metastore/catalog": f"{metastore_id}/{catalog}"}).get("CiphertextBlob")

keyvault_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("created_date", DateType(), False),
    StructField("created_time", TimestampType(), False),
    StructField("last_modified_time", TimestampType(), False),
    StructField("created_by", StringType(), False),
    StructField("managed_by", StringType(), False),
    StructField("key_alias", StringType(), False),
    StructField("key_enabled", BooleanType(), False),
    StructField("key_version", IntegerType(), True),
    StructField("key_type", StringType(), False),
    StructField("key", BinaryType(), False)
])

keyvault_data = [
    {"id": 1, "created_date": date.today(), "created_time": datetime.now(), "last_modified_time": datetime.now(), "created_by": ws.current_user.me().user_name, "managed_by": ws.current_user.me().user_name, "key_enabled": True, "key_version": 1, "key_type": "ENCRYPTED_DEK", "key_alias": f"{catalog}.{schema}", "key": dek}]

spark.createDataFrame(keyvault_data, keyvault_schema).write.mode("overwrite").saveAsTable("crypto.keyvault")

In [0]:
%sql
SELECT 
* 
FROM crypto.keyvault 
WHERE key_alias = concat(:catalog, '.', :schema) 
AND key_enabled = true
ORDER BY key_version DESC 
LIMIT 1

## Step 4
* Create an `unwrap_kms_key()` function that can be used to return a decrypted DEK 
* Test our `unwrap_kms_key()` function by encrypting/decryption some data

In [0]:
%sql
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE CREDENTIALS() SECTION TO REFERENCE YOUR uc_service_credential
CREATE OR REPLACE FUNCTION crypto.unwrap_kms_key(context STRING, encrypted_dek BINARY) 
RETURNS BINARY
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
  `production-aws-kms` DEFAULT -- IMPORTANT! REPLACE THIS WITH YOUR UC SERVICE CREDENTIAL!!!
  -- service credential should align with the catalog!
)
ENVIRONMENT (
  dependencies = '["pycryptodome==3.22.0"]',
  environment_version = 'None'
)
AS $$
import boto3
from pyspark.taskcontext import TaskContext
from botocore.exceptions import ClientError
from base64 import b64encode
from base64 import b64decode
from typing import Iterator, Tuple
import pandas as pd

def setup_session():

  session = boto3.Session()
  region = TaskContext.get().getLocalProperty("spark.databricks.clusterUsageTags.region")
  client = session.client("kms", region_name=region)
  return client

client = setup_session()

def batchhandler(batch_iter: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.DataFrame]:

    for c, d in batch_iter:

        context = c[0]
        alias = f"alias/unity_catalog/{context}/4"

        try:
            response = client.decrypt(CiphertextBlob=d[0], KeyId=alias, EncryptionContext={'metastore/catalog': context})
            unwrapped = response.get("Plaintext")
        except ClientError as e:
            raise e
        yield pd.Series(unwrapped)
$$;

In [0]:
%sql
SELECT base64(
    aes_encrypt("I love Nelson so much he is my world", 
    crypto.unwrap_kms_key(
        context=>'9ae0a551-8fd1-423d-9265-dcf04b25cb09/production', 
        encrypted_dek=>(SELECT key FROM crypto.keyvault WHERE key_alias = concat(:catalog, '.', :schema)  AND key_enabled = true ORDER BY key_version DESC LIMIT 1)), 
        'GCM', 'DEFAULT')) AS test

In [0]:
%sql
SELECT CAST(
    aes_decrypt(
        unbase64("//n8QLeXlPI6A+wSFIvK8fSI2TOVFbYmdd+2++9446Yx6FtLaecTERc9zeprwuvDF0nRXI6epRQ+
OejSs3YGWg=="), 
crypto.unwrap_kms_key(
    context=>'9ae0a551-8fd1-423d-9265-dcf04b25cb09/production', 
    encrypted_dek=>(SELECT key FROM crypto.keyvault WHERE key_alias = concat(:catalog, '.', :schema)  AND key_enabled = true ORDER BY key_version DESC LIMIT 1)), 
    'GCM', 'DEFAULT') AS STRING) AS test

## Step 5
* Create `encrypt_kms()` and `decrypt_kms()` functions that can be used to encrypt/decrypt data within our catalog. 
* These functions will call our more privileged `unwrap_kms_key()` function in order to unwrap DEKs and encrypt or decrypt the data

In [0]:
# Todo - this fails with an error I think because the UC credential is not passed from query -> function -> nested query. Need to check if this is expected and/or there's a fix planned

sql(f"""
    CREATE OR REPLACE FUNCTION crypto.encrypt_kms(col STRING, key_alias STRING) 
    RETURNS STRING
    RETURN base64(aes_encrypt(
            col, 
            (SELECT crypto.unwrap_kms_key(
                context=>'9ae0a551-8fd1-423d-9265-dcf04b25cb09/production', 
                encrypted_dek=>(SELECT key FROM crypto.keyvault 
                WHERE key_alias = key_alias AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
            'GCM',  
            'DEFAULT'
        ))
    """)

In [0]:
# This works but isn't as nice

sql(f"""
    CREATE OR REPLACE FUNCTION crypto.encrypt_kms(col STRING, key BINARY) 
    RETURNS STRING
    RETURN base64(aes_encrypt(
            col, 
            key,
            'GCM',  
            'DEFAULT'
        ))
    """)

In [0]:
sql(f"""
    CREATE OR REPLACE FUNCTION crypto.decrypt_kms(col STRING, key BINARY) 
    RETURNS STRING
    RETURN 
        CASE WHEN TRUE
        --is_account_group_member(CONCAT('{catalog}_', schema_name, '_decrypt')) 
        THEN 
        nvl(CAST(try_aes_decrypt(unbase64(col), 
        key,
        'GCM',  
        'DEFAULT') AS STRING), 
        col)
        ELSE col END
    """)

In [0]:
%sql
--SELECT crypto.encrypt_kms("Hello my name is Weaver", "production.schema") AS result

In [0]:
%sql
SELECT crypto.encrypt_kms("Hello my name is Weaver", 
(crypto.unwrap_kms_key(
    context=>'9ae0a551-8fd1-423d-9265-dcf04b25cb09/production', 
    encrypted_dek=>
    (SELECT key FROM crypto.keyvault WHERE key_alias = key_alias AND key_enabled = true ORDER BY key_version DESC LIMIT 1)))) AS result

In [0]:
%sql
SELECT crypto.decrypt_kms("pp2lq50yTptgiwPd9YcdDu51EhCgE/XAC8pt/TNQfDANaILZ/0H1haQ1Wn88yRCIfJnW", 
(crypto.unwrap_kms_key(
    context=>'9ae0a551-8fd1-423d-9265-dcf04b25cb09/production', 
    encrypted_dek=>
    (SELECT key FROM crypto.keyvault WHERE key_alias = key_alias AND key_enabled = true ORDER BY key_version DESC LIMIT 1)))) AS result

## Step 
* Create a table from the raw data we downloaded above, encrypting the columns that contain sensitive data

In [0]:
%sql
USE SCHEMA IDENTIFIER(:schema);
CREATE OR REPLACE TABLE IDENTIFIER(:schema || '.titanic') AS (
SELECT 
PassengerId,
crypto.encrypt_kms(Name, :schema) AS Name,
crypto.encrypt_kms(Age, :schema) AS Age,
crypto.encrypt_kms(Sex, :schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM read_files(
  concat('/Volumes/', :catalog, '/', :schema, '/raw_files/titanic.csv'),
  format => 'csv',
  header => true,
  mode => 'FAILFAST')
);
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic');

## Step 
* Check that the decrypt functions work as expected...

In [0]:
%sql
SELECT
PassengerId, 
crypto.decrypt_kms(Name, :schema) AS Name,
crypto.decrypt_kms(Age, :schema) AS Age,
crypto.decrypt_kms(Sex, :schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic');

## Step  
* You can also add a column mask to the encrypted table
* A column mask serves the following purposes:
  * The calling users don't even need permissions to the `encrypt_kms()` and `decrypt_kms()` functions or the `crypto` schema
  * The whole process of encryption/decryption is abstracted away from them

> ### NOTE: 
Adding a column mask is likely to cause calling the `decrypt_kms()` function directly to fail, since the column mask will try to decrypt the data automatically and you'll be trying to decrypt the already decrypted results!

> ### IMPORTANT: 
Please update the `USING COLUMNS('customers')` section below with your `schema` name

In [0]:
%sql
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE USING COLUMNS() SECTION TO REFERENCE YOUR UC SCHEMA NAME
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Name SET MASK crypto.decrypt USING COLUMNS('marketing');
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Age SET MASK crypto.decrypt USING COLUMNS ('marketing');
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Sex SET MASK crypto.decrypt USING COLUMNS ('marketing');

In [0]:
%sql
SELECT 
*
FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic');