In [None]:
# Import necessary libraries
import requests
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils

# Initialize Spark session and Databricks utilities
spark = SparkSession.builder.appName("SkyflowTokenization").getOrCreate()
dbutils = DBUtils(spark)

# Define widgets to receive input parameters
dbutils.widgets.text("table_name", "")
dbutils.widgets.text("pii_columns", "")

# Read widget values
table_name = dbutils.widgets.get("table_name")
pii_columns = dbutils.widgets.get("pii_columns").split(",")

if not table_name or not pii_columns:
    raise ValueError("Both 'table_name' and 'pii_columns' must be provided.")

# Skyflow API details
SKYFLOW_API_URL = "<TODO: SKYFLOW_VAULT_URL>/v1/vaults/<TODO: SKYFLOW_VAULT_ID>/pii"
SKYFLOW_ACCOUNT_ID = "<TODO: SKYFLOW_ACCOUNT_ID>"
SKYFLOW_BEARER_TOKEN = "<TODO: SKYFLOW_BEARER_TOKEN>"

def tokenize_batch(values):
    """
    Function to tokenize a batch of PII values via Skyflow API.
    All values are already strings from the table schema.
    """
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json",
        "X-SKYFLOW-ACCOUNT-ID": SKYFLOW_ACCOUNT_ID,
        "Authorization": f"Bearer {SKYFLOW_BEARER_TOKEN}"
    }

    # Format records exactly like the successful example
    records = [{
        "fields": {
            "pii": value
        }
    } for value in values if value is not None]

    payload = {
        "records": records,
        "tokenization": True
    }

    try:
        response = requests.post(SKYFLOW_API_URL, headers=headers, json=payload)
        response.raise_for_status()
        return [record["tokens"]["pii"] for record in response.json()["records"]]
    except requests.exceptions.RequestException as e:
        print(f"Error tokenizing batch: {e}")
        if hasattr(e.response, 'text'):
            print(f"Response content: {e.response.text}")
        return ["ERROR" for _ in values]

for column in pii_columns:
    # Read distinct non-null PII values
    query = f"SELECT DISTINCT `{column}` FROM `{table_name}` WHERE `{column}` IS NOT NULL"
    df = spark.sql(query)
    values = [row[column] for row in df.collect()]

    if not values:
        print(f"No PII values found for column: {column}")
        continue

    # Tokenize data in batches
    batch_size = 25
    tokenized_values = []
    for i in range(0, len(values), batch_size):
        batch = values[i:i + batch_size]
        tokenized_values.extend(tokenize_batch(batch))

    # Generate and execute update statements
    update_statements = [
        f"UPDATE `{table_name}` SET `{column}` = '{token}' WHERE `{column}` = '{value}'"
        for value, token in zip(values, tokenized_values)
    ]

    for stmt in update_statements:
        spark.sql(stmt)

    print(f"Successfully tokenized column: {column}")

dbutils.notebook.exit(f"Tokenization completed for table `{table_name}` with columns {', '.join(pii_columns)}.")