In [0]:
dbutils.widgets.dropdown("aws_cmk_creation", defaultValue="True", choices=["True", "False"])
dbutils.widgets.text("schema", defaultValue="human_resources")
dbutils.widgets.text("region", defaultValue="eu-west-1")
dbutils.widgets.text("metastore_id", defaultValue=sql("SELECT element_at(split(current_metastore(), ':'), -1) AS metastore").first().metastore)
dbutils.widgets.text("catalog", defaultValue="production")
dbutils.widgets.text("uc_service_credential", defaultValue="production-aws-kms")

metastore_id = dbutils.widgets.get("metastore_id")
catalog = dbutils.widgets.get("catalog")
region = dbutils.widgets.get("region")
schema = dbutils.widgets.get("schema")
service_credential = dbutils.widgets.get("uc_service_credential")

## Step 1
* Download the titanic dataset and store it in a UC volume for raw files. 
* We'll use this to simulate a dataset that contains PII (Name, Age, Sex)

In [0]:
%sql
USE CATALOG IDENTIFIER(:catalog);
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema));
CREATE VOLUME IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :schema, '.raw_files'));

In [0]:
import subprocess

file_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
volume_path = f"/Volumes/{catalog}/{schema}/raw_files/titanic.csv"

subprocess.run(["wget", file_url, "-O", volume_path], check=True)
display(dbutils.fs.ls(f"/Volumes/{catalog}/{schema}/raw_files/"))

## Step 2
* Create an AWS KMS key to generate our DEKs
* In order for this to work your UC service credential will need the following privileges: `"kms:CreateKey", "kms:CreateAlias"`
* You can optionally create your KMS key manually, via your own scripts or via IaaC such as Terraform

> ### IMPORTANT: 
The key created via this function is largely intended to be used for demonstration purposes. Please review the Key policy and other security related configurations when creating your actual keys

In [0]:
import boto3

crypto_functions = dbutils.import_notebook("notebooks.envelope_encryption_v2.common.aws_crypto_functions")

session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider(service_credential), region_name=region)

key_alias = f"alias/unity_catalog/{metastore_id}/{catalog}/cmk"

if eval(dbutils.widgets.get("aws_cmk_creation")):
    print(f"Creating CMK '{key_alias}'...")
    kms = crypto_functions.create_kms_key(session=session, alias=key_alias, description=f"CMK for AES encryption of UC catalog {catalog} in metastore {metastore_id}", tags=[] )

## Step 3
* Create a `crypto.keyvault` table to store our encrypted DEKs
* Generate an encrypted DEK for our schema
* Store the encrypted DEK in the `crypto.keyvault` table

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.crypto'));
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.crypto.keyvault') (
  id BIGINT GENERATED BY DEFAULT AS IDENTITY,
  created_date DATE, 
  created_time TIMESTAMP,
  last_modified_time TIMESTAMP,
  created_by STRING,
  managed_by STRING,
  key_alias STRING,
  key_enabled BOOLEAN,
  key_version INT,
  key_type STRING,
  key BINARY);

In [0]:
from databricks.sdk import WorkspaceClient
from pyspark.sql.types import StructType, StructField, IntegerType, DateType, TimestampType, StringType, BooleanType, BinaryType
from datetime import datetime
from datetime import date

ws = WorkspaceClient()

dek = crypto_functions.generate_data_key(session=session, key_alias=key_alias, encryption_context={"metastore/catalog": f"{metastore_id}/{catalog}"}).get("CiphertextBlob")

alias = f"{catalog}.{schema}"

key_version = sql(f"SELECT MAX(key_version) AS max_version FROM crypto.keyvault WHERE key_alias = '{alias}'").first().max_version

if not key_version:
  key_version = 1
else:
  key_version += 1

keyvault_schema = StructType([
    StructField("created_date", DateType(), False),
    StructField("created_time", TimestampType(), False),
    StructField("last_modified_time", TimestampType(), False),
    StructField("created_by", StringType(), False),
    StructField("managed_by", StringType(), False),
    StructField("key_alias", StringType(), False),
    StructField("key_enabled", BooleanType(), False),
    StructField("key_version", IntegerType(), True),
    StructField("key_type", StringType(), False),
    StructField("key", BinaryType(), False)
])

keyvault_data = [{"created_date": date.today(), "created_time": datetime.now(), "last_modified_time": datetime.now(), "created_by": ws.current_user.me().user_name, "managed_by": ws.current_user.me().user_name, "key_enabled": True, "key_version": key_version, "key_type": "ENCRYPTED_DEK", "key_alias": f"{catalog}.{schema}", "key": dek}]

spark.createDataFrame(keyvault_data, keyvault_schema).createOrReplaceTempView("tmp_keyvault")

In [0]:
%sql
INSERT INTO crypto.keyvault (created_date, created_time, last_modified_time, created_by, managed_by, key_alias, key_enabled, key_version, key_type, key)
SELECT * FROM tmp_keyvault

In [0]:
%sql
SELECT 
* 
FROM crypto.keyvault 
WHERE key_alias = concat(:catalog, '.', :schema) 
AND key_enabled = true
ORDER BY key_version DESC 

## Step 4
* Create an `unwrap_kms_key()` function that can be used to return a decrypted DEK 
* Test our `unwrap_kms_key()` function by encrypting/decrypting some data

In [0]:
%sql
USE SCHEMA crypto;
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE CREDENTIALS() SECTION TO REFERENCE YOUR uc_service_credential
CREATE OR REPLACE FUNCTION crypto.unwrap_kms_key(context STRING, encrypted_dek BINARY) 
RETURNS BINARY
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
  `production-aws-kms` DEFAULT -- IMPORTANT! REPLACE THIS WITH YOUR UC SERVICE CREDENTIAL!!!
)
AS $$
import boto3
from pyspark.taskcontext import TaskContext
from botocore.exceptions import ClientError
from base64 import b64encode
from base64 import b64decode
from typing import Iterator, Tuple
import pandas as pd

def setup_session():

  session = boto3.Session()
  region = TaskContext.get().getLocalProperty("spark.databricks.clusterUsageTags.region")
  client = session.client("kms", region_name=region)
  return client

client = setup_session()

def batchhandler(batch_iter: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.DataFrame]:

    for c, d in batch_iter:

        context = c[0]
        alias = f"alias/unity_catalog/{context}/cmk"

        try:
            response = client.decrypt(CiphertextBlob=d[0], KeyId=alias, EncryptionContext={'metastore/catalog': context})
            unwrapped = response.get("Plaintext")
        except ClientError as e:
            e.add_note("""
              ___ _ __ _ __ ___  _ __ 
             / _ \ '__| '__/ _ \| '__|
            |  __/ |  | | | (_) | |   
             \___|_|  |_|  \___/|_|
            
            Failed to unwrap key! Please check:

            1. The user is a member of the <catalog>.<schema>.crypto.user account group
            2. The UC service credential has the right permissions to use the CMK (on both the IAM and key policy)
            3. That the network you're connecting from is allowed in the IAM or key policy
            """)
            raise e
        yield pd.Series(unwrapped)
$$;

In [0]:
%sql
SELECT CAST(try_aes_decrypt(
  unbase64(base64(
  aes_encrypt(
    randstr(30),
    (SELECT crypto.unwrap_kms_key(
      context=>CONCAT(:metastore_id, '/', :catalog),
      encrypted_dek=>(SELECT key FROM crypto.keyvault 
      WHERE key_alias = CONCAT(:catalog, '.', :schema) AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
      'GCM',  
      'DEFAULT'))),
  (SELECT crypto.unwrap_kms_key(
    context=>CONCAT(:metastore_id, '/', :catalog),
    encrypted_dek=>(SELECT key FROM crypto.keyvault 
    WHERE key_alias = CONCAT(:catalog, '.', :schema) AND key_enabled = true ORDER BY key_version DESC LIMIT 1))),
    'GCM',
    'DEFAULT') AS STRING) AS decrypted_random_string

## Step 5
* Create `encrypt()` and `decrypt()` functions that can be used to encrypt/decrypt data within our catalog. 
* These functions will call our more privileged `unwrap_kms_key()` function in order to unwrap DEKs and encrypt or decrypt the data

> ### IMPORTANT: 
Before running this section please create the account level group `<catalog>.<schema>.crypto.user` and add your user as a member

In [0]:
is_account_group_member = sql(f"SELECT is_account_group_member('{catalog}.{schema}.crypto.user') AS is_account_group_member").first().is_account_group_member
if is_account_group_member != True:
    raise Exception(f"Please add your user to the '{catalog}.{schema}.crypto.user' group before proceeding")

In [0]:
%sql
CREATE OR REPLACE FUNCTION crypto.get_latest_dek_version(catalog STRING, schema STRING)
RETURNS INT
RETURN (SELECT MAX(key_version) FROM crypto.keyvault 
    WHERE key_alias = CONCAT(catalog, '.', schema) 
         AND key_enabled = true)  

In [0]:
%sql
CREATE OR REPLACE FUNCTION crypto.get_encrypted_dek(catalog STRING, schema STRING, version INT)
RETURNS BINARY
RETURN (SELECT FIRST(key)
         FROM crypto.keyvault 
         WHERE key_alias = CONCAT(catalog, '.', schema) 
         AND key_enabled = true
         AND key_version = version)

In [0]:
sql(f"""
CREATE OR REPLACE FUNCTION crypto.encrypt_kms(col STRING, key BINARY)
RETURNS STRING
RETURN base64(aes_encrypt(
            col, 
            (SELECT * FROM (SELECT crypto.unwrap_kms_key(
                context=>"{metastore_id}/{catalog}",
                encrypted_dek=>key))),
            'GCM',  
            'DEFAULT'
        ))""")

In [0]:
sql(f"""
    CREATE OR REPLACE FUNCTION crypto.decrypt_kms(col STRING, catalog STRING, schema STRING, key BINARY)
    RETURNS STRING
    RETURN CASE WHEN is_account_group_member(CONCAT(catalog, '.', schema, '.crypto.user')) THEN
    nvl(CAST(try_aes_decrypt(unbase64(col), 
        (SELECT * FROM (SELECT crypto.unwrap_kms_key(
                context=>"{metastore_id}/{catalog}",
                encrypted_dek=>key))),
        'GCM',  
        'DEFAULT') AS STRING), 
        col)
        ELSE col END
    """)

In [0]:
%sql
CREATE OR REPLACE FUNCTION crypto.encrypt(col STRING, catalog STRING, schema STRING)
RETURNS STRING
RETURN crypto.encrypt_kms(col, crypto.get_encrypted_dek(catalog, schema, get_latest_dek_version(catalog, schema)));

CREATE OR REPLACE FUNCTION crypto.decrypt(col STRING, catalog STRING, schema STRING)
RETURNS STRING
RETURN crypto.decrypt_kms(col, catalog, schema, crypto.get_encrypted_dek(catalog, schema, get_latest_dek_version(catalog, schema)))

## Step 6
* Create a table from the raw data we downloaded above, encrypting the columns that contain sensitive data

In [0]:
%sql
CREATE OR REPLACE TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') AS (
SELECT 
PassengerId,
crypto.encrypt(Name, :catalog, :schema) AS Name,
crypto.encrypt(Age, :catalog, :schema) AS Age,
crypto.encrypt(Sex, :catalog, :schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM read_files(
  concat('/Volumes/', :catalog, '/', :schema, '/raw_files/titanic.csv'),
  format => 'csv',
  header => true,
  mode => 'FAILFAST'));
  SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic')

## Step 7 
* Check that the decrypt functions work as expected...

In [0]:
%sql
SELECT
PassengerId,
crypto.decrypt(Name, :catalog, :schema) AS Name,
crypto.decrypt(Age, :catalog, :schema) AS Age,
crypto.decrypt(Sex, :catalog, :schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic')

## Step 8
* You can also add a column mask to the encrypted table
* A column mask serves the following purposes:
  * The calling users don't even need permissions to the `encrypt_kms()` and `decrypt_kms()` functions or the `crypto` schema
  * The whole process of encryption/decryption is abstracted away from them

> ### NOTE: 
Adding a column mask is likely to cause calling the `decrypt_kms()` function directly to fail, since the column mask will try to decrypt the data automatically and you'll be trying to decrypt the already decrypted results!

> ### IMPORTANT: 
Please update the `USING COLUMNS('customers')` section below with your `schema` name

In [0]:
%sql
-- IMPORTANT!!! 
---> BEFORE RUNNING THIS STEP PLEASE UPDATE THE USING COLUMNS() SECTION TO REFERENCE YOUR UC SCHEMA NAME
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Name SET MASK crypto.decrypt USING COLUMNS('production', 'human_resources');
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Age SET MASK crypto.decrypt USING COLUMNS ('production', 'human_resources');
ALTER TABLE IDENTIFIER(:catalog || '.' || :schema || '.titanic') ALTER COLUMN Sex SET MASK crypto.decrypt USING COLUMNS ('production', 'human_resources');

In [0]:
%sql
SELECT 
*
FROM IDENTIFIER(:catalog || '.' || :schema || '.titanic');