In [0]:

dbutils.widgets.text("metastore_id", defaultValue=sql("SELECT element_at(split(current_metastore(), ':'), -1) AS metastore").first().metastore)
dbutils.widgets.text("catalog", defaultValue="production")
dbutils.widgets.text("new_schema", defaultValue="finance")
dbutils.widgets.text("region", defaultValue="eu-west-1")
dbutils.widgets.text("uc_service_credential", defaultValue="production-aws-kms")

metastore_id = dbutils.widgets.get("metastore_id")
catalog = dbutils.widgets.get("catalog")
region = dbutils.widgets.get("region")
new_schema = dbutils.widgets.get("new_schema")
service_credential = dbutils.widgets.get("uc_service_credential")

## Step 1
* Download the titanic dataset and store it in a UC volume for raw files. 
* We'll use this to simulate a dataset that contains PII (Name, Age, Sex)

In [0]:
%sql
USE CATALOG IDENTIFIER(:catalog);
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :new_schema));
CREATE VOLUME IF NOT EXISTS IDENTIFIER(concat(:catalog, '.', :new_schema, '.raw_files'));

In [0]:
import subprocess

file_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
volume_path = f"/Volumes/{catalog}/{new_schema}/raw_files/titanic.csv"

subprocess.run(["wget", file_url, "-O", volume_path], check=True)
display(dbutils.fs.ls(f"/Volumes/{catalog}/{new_schema}/raw_files/"))

## Step 2
* Generate a new encrypted DEK for our schema
* Store the encrypted DEK in the `crypto.keyvault` table

In [0]:
import boto3

crypto_functions = dbutils.import_notebook("notebooks.envelope_encryption_v2.common.aws_crypto_functions")

session = boto3.Session(botocore_session=dbutils.credentials.getServiceCredentialsProvider(service_credential), region_name=region)

key_alias = f"alias/unity_catalog/{metastore_id}/{catalog}/cmk"

In [0]:
from databricks.sdk import WorkspaceClient
from pyspark.sql.types import StructType, StructField, IntegerType, DateType, TimestampType, StringType, BooleanType, BinaryType
from datetime import datetime
from datetime import date

ws = WorkspaceClient()

dek = crypto_functions.generate_data_key(session=session, key_alias=key_alias, encryption_context={"metastore/catalog": f"{metastore_id}/{catalog}"}).get("CiphertextBlob")

alias = f"{catalog}.{new_schema}"

key_version = sql(f"SELECT MAX(key_version) AS max_version FROM crypto.keyvault WHERE key_alias = '{alias}'").first().max_version

if not key_version:
  key_version = 1
else:
  key_version += 1

keyvault_schema = StructType([
    StructField("created_date", DateType(), False),
    StructField("created_time", TimestampType(), False),
    StructField("last_modified_time", TimestampType(), False),
    StructField("created_by", StringType(), False),
    StructField("managed_by", StringType(), False),
    StructField("key_alias", StringType(), False),
    StructField("key_enabled", BooleanType(), False),
    StructField("key_version", IntegerType(), True),
    StructField("key_type", StringType(), False),
    StructField("key", BinaryType(), False)
])

keyvault_data = [{"created_date": date.today(), "created_time": datetime.now(), "last_modified_time": datetime.now(), "created_by": ws.current_user.me().user_name, "managed_by": ws.current_user.me().user_name, "key_enabled": True, "key_version": key_version, "key_type": "ENCRYPTED_DEK", "key_alias": alias, "key": dek}]

spark.createDataFrame(keyvault_data, keyvault_schema).createOrReplaceTempView("tmp_keyvault")

In [0]:
%sql
INSERT INTO crypto.keyvault (created_date, created_time, last_modified_time, created_by, managed_by, key_alias, key_enabled, key_version, key_type, key)
SELECT * FROM tmp_keyvault

In [0]:
%sql
SELECT 
* 
FROM crypto.keyvault 
WHERE key_alias = concat(:catalog, '.', :new_schema) 
AND key_enabled = true
ORDER BY key_version DESC 

## Step 3
* Create a table from the raw data we downloaded above, encrypting the columns that contain sensitive data
> ### IMPORTANT: 
Before running this section please create the account level group `<catalog>.<schema>.crypto.user` and add your user as a member

In [0]:
%sql
CREATE OR REPLACE TABLE IDENTIFIER(:catalog || '.' || :new_schema || '.titanic') AS (
SELECT 
PassengerId,
crypto.encrypt(Name, :catalog, :new_schema) AS Name,
crypto.encrypt(Age, :catalog, :new_schema) AS Age,
crypto.encrypt(Sex, :catalog, :new_schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM read_files(
  concat('/Volumes/', :catalog, '/', :new_schema, '/raw_files/titanic.csv'),
  format => 'csv',
  header => true,
  mode => 'FAILFAST'));
  SELECT * FROM IDENTIFIER(:catalog || '.' || :new_schema || '.titanic')

## Step 4
* Check that the decrypt functions work as expected...

In [0]:
%sql
SELECT
PassengerId,
crypto.decrypt(Name, :catalog, :new_schema) AS Name,
crypto.decrypt(Age, :catalog, :new_schema) AS Age,
crypto.decrypt(Sex, :catalog, :new_schema) AS Sex,
* EXCEPT(PassengerId, Name, Age, Sex)
FROM IDENTIFIER(:catalog || '.' || :new_schema || '.titanic')