In [None]:
# Unity Catalog-aware serverless tokenization notebook  
# Uses dbutils.secrets.get() + UC HTTP connections for serverless compatibility
import json
import os
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils

# Initialize Spark session optimized for serverless compute
spark = SparkSession.builder \
    .appName("SkyflowTokenization") \
    .config("spark.databricks.cluster.profile", "serverless") \
    .config("spark.databricks.delta.autoCompact.enabled", "true") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()
    
dbutils = DBUtils(spark)

print(f"✓ Running on Databricks serverless compute")
print(f"✓ Spark version: {spark.version}")

# Performance configuration
MAX_MERGE_BATCH_SIZE = 10000  # Maximum records per MERGE operation
COLLECT_BATCH_SIZE = 1000     # Maximum records to collect() from Spark at once

# Define widgets to receive input parameters
dbutils.widgets.text("table_name", "")
dbutils.widgets.text("pii_columns", "")
dbutils.widgets.text("batch_size", "")  # Skyflow API batch size

# Read widget values
table_name = dbutils.widgets.get("table_name")
pii_columns = dbutils.widgets.get("pii_columns").split(",")
SKYFLOW_BATCH_SIZE = int(dbutils.widgets.get("batch_size"))

if not table_name or not pii_columns:
    raise ValueError("Both 'table_name' and 'pii_columns' must be provided.")

print(f"Tokenizing table: {table_name}")
print(f"PII columns: {', '.join(pii_columns)}")
print(f"Skyflow API batch size: {SKYFLOW_BATCH_SIZE}")
print(f"MERGE batch size limit: {MAX_MERGE_BATCH_SIZE:,} records")
print(f"Collect batch size: {COLLECT_BATCH_SIZE:,} records")

# Extract catalog and schema from table name if fully qualified
if '.' in table_name:
    parts = table_name.split('.')
    if len(parts) == 3:  # catalog.schema.table
        catalog_name = parts[0]
        schema_name = parts[1]
        table_name_only = parts[2]
        
        # Set the catalog and schema context for this session
        print(f"Setting catalog context to: {catalog_name}")
        spark.sql(f"USE CATALOG {catalog_name}")
        spark.sql(f"USE SCHEMA {schema_name}")
        
        # Use the simple table name for queries since context is set
        table_name = table_name_only
        print(f"✓ Catalog context set, using table name: {table_name}")

print("✓ Using dbutils.secrets.get() + UC HTTP connections for serverless compatibility")

def tokenize_column_values(column_name, values):
    """
    Tokenize a list of PII values using Unity Catalog HTTP connection.
    Uses dbutils.secrets.get() and http_request() for serverless compatibility.
    Returns list of tokens in same order as input values.
    """
    if not values:
        return []
    
    # Get secrets using dbutils (works in serverless)
    table_column = dbutils.secrets.get("skyflow-secrets", "skyflow_table_column")
    skyflow_table = dbutils.secrets.get("skyflow-secrets", "skyflow_table")
    
    # Create records for each value
    skyflow_records = [{
        "fields": {table_column: str(value)}
    } for value in values if value is not None]

    # Create Skyflow tokenization payload
    payload = {
        "records": skyflow_records,
        "tokenization": True
    }

    print(f"  Tokenizing {len(skyflow_records)} values for {column_name}")
    
    # Use Unity Catalog HTTP connection via SQL http_request function
    # Connection base_path is /v1/vaults/{vault_id}, so we add /{table_name}
    json_payload = json.dumps(payload).replace("'", "''")
    tokenize_path = f"/{skyflow_table}"
    
    # Execute tokenization via UC connection
    result_df = spark.sql(f"""
        SELECT http_request(
            conn => 'skyflow_conn',
            method => 'POST',
            path => '{tokenize_path}',
            headers => map(
                'Content-Type', 'application/json',
                'Accept', 'application/json'
            ),
            json => '{json_payload}'
        ) as full_response
    """)
    
    # Parse response
    full_response = result_df.collect()[0]['full_response']
    result = json.loads(full_response.text)
    
    # Fail fast if API response indicates error
    if "error" in result:
        raise RuntimeError(f"Skyflow API error: {result['error']}")
    
    if "records" not in result:
        raise RuntimeError(f"Invalid Skyflow API response - missing 'records': {result}")
    
    # Extract tokens in order
    tokens = []
    for i, record in enumerate(result.get("records", [])):
        if "tokens" in record and table_column in record["tokens"]:
            token = record["tokens"][table_column]
            tokens.append(token)
        else:
            raise RuntimeError(f"Tokenization failed for value {i+1} in {column_name}. Record: {record}")
    
    successful_tokens = len([t for i, t in enumerate(tokens) if t and str(t) != str(values[i])])
    print(f"    Successfully tokenized {successful_tokens}/{len(values)} values")
    
    return tokens

def perform_chunked_merge(table_name, column, update_data):
    """
    Perform MERGE operations in chunks to avoid memory/timeout issues.
    Returns total number of rows updated.
    """
    if not update_data:
        return 0
    
    total_updated = 0
    chunk_size = MAX_MERGE_BATCH_SIZE
    total_chunks = (len(update_data) + chunk_size - 1) // chunk_size
    
    print(f"  Splitting {len(update_data):,} updates into {total_chunks} MERGE operations (max {chunk_size:,} per chunk)")
    
    for chunk_idx in range(0, len(update_data), chunk_size):
        chunk_data = update_data[chunk_idx:chunk_idx + chunk_size]
        chunk_num = (chunk_idx // chunk_size) + 1
        
        # Create temporary view for this chunk
        temp_df = spark.createDataFrame(chunk_data, ["customer_id", f"new_{column}"])
        temp_view_name = f"temp_{column}_chunk_{chunk_num}_{hash(column) % 1000}"
        temp_df.createOrReplaceTempView(temp_view_name)
        
        # Perform MERGE operation for this chunk
        merge_sql = f"""
            MERGE INTO `{table_name}` AS target
            USING {temp_view_name} AS source
            ON target.customer_id = source.customer_id
            WHEN MATCHED THEN 
                UPDATE SET `{column}` = source.new_{column}
        """
        
        spark.sql(merge_sql)
        chunk_updated = len(chunk_data)
        total_updated += chunk_updated
        
        print(f"    Chunk {chunk_num}/{total_chunks}: Updated {chunk_updated:,} rows")
        
        # Clean up temp view
        spark.catalog.dropTempView(temp_view_name)
    
    return total_updated

# Process each column individually (streaming approach)
print("Starting column-by-column tokenization with streaming chunked processing...")

for column in pii_columns:
    print(f"\nProcessing column: {column}")
    
    # Get total count first for progress tracking
    total_count = spark.sql(f"""
        SELECT COUNT(*) as count 
        FROM `{table_name}` 
        WHERE `{column}` IS NOT NULL
    """).collect()[0]['count']
    
    if total_count == 0:
        print(f"  No data found in column {column}")
        continue
        
    print(f"  Found {total_count:,} total values to tokenize")
    
    # Process in streaming chunks to avoid memory issues
    all_update_data = []  # Collect all updates before final MERGE
    processed_count = 0
    
    for offset in range(0, total_count, COLLECT_BATCH_SIZE):
        chunk_size = min(COLLECT_BATCH_SIZE, total_count - offset)
        print(f"  Processing chunk {offset//COLLECT_BATCH_SIZE + 1} ({chunk_size:,} records, offset {offset:,})...")
        
        # Get chunk of data from Spark
        chunk_df = spark.sql(f"""
            SELECT customer_id, `{column}` 
            FROM `{table_name}` 
            WHERE `{column}` IS NOT NULL 
            ORDER BY customer_id
            LIMIT {chunk_size} OFFSET {offset}
        """)
        
        chunk_rows = chunk_df.collect()
        if not chunk_rows:
            continue
            
        # Extract customer IDs and values for this chunk
        chunk_customer_ids = [row['customer_id'] for row in chunk_rows]
        chunk_column_values = [row[column] for row in chunk_rows]
        
        # Tokenize this chunk's values in Skyflow API batches
        chunk_tokens = []
        if len(chunk_column_values) <= SKYFLOW_BATCH_SIZE:  # Single API batch
            chunk_tokens = tokenize_column_values(f"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}", chunk_column_values)
        else:  # Multiple API batches within this chunk
            for i in range(0, len(chunk_column_values), SKYFLOW_BATCH_SIZE):
                api_batch_values = chunk_column_values[i:i + SKYFLOW_BATCH_SIZE]
                api_batch_tokens = tokenize_column_values(f"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}_api_{i//SKYFLOW_BATCH_SIZE + 1}", api_batch_values)
                chunk_tokens.extend(api_batch_tokens)
        
        # Verify token count matches input count (fail fast)
        if len(chunk_tokens) != len(chunk_customer_ids):
            raise RuntimeError(f"Token count mismatch: got {len(chunk_tokens)} tokens for {len(chunk_customer_ids)} input values")
        
        # Collect update data for rows that changed in this chunk
        chunk_original_map = {chunk_customer_ids[i]: chunk_column_values[i] for i in range(len(chunk_customer_ids))}
        
        for i, (customer_id, token) in enumerate(zip(chunk_customer_ids, chunk_tokens)):
            if token and str(token) != str(chunk_original_map[customer_id]):
                all_update_data.append((customer_id, token))
        
        processed_count += len(chunk_rows)
        print(f"    Processed {processed_count:,}/{total_count:,} records ({(processed_count/total_count)*100:.1f}%)")
    
    # Perform final chunked MERGE operations for all collected updates
    if all_update_data:
        print(f"  Performing final chunked MERGE of {len(all_update_data):,} changed rows...")
        total_updated = perform_chunked_merge(table_name, column, all_update_data)
        print(f"  ✓ Successfully updated {total_updated:,} rows in column {column}")
    else:
        print(f"  No updates needed - all tokens match original values")

print("\nOptimized streaming tokenization completed!")

# Verify results
print("\nFinal verification:")
for column in pii_columns:
    sample_df = spark.sql(f"""
        SELECT `{column}`, COUNT(*) as count 
        FROM `{table_name}` 
        GROUP BY `{column}` 
        LIMIT 3
    """)
    print(f"\nSample values in {column}:")
    sample_df.show(truncate=False)

total_rows = spark.sql(f"SELECT COUNT(*) as count FROM `{table_name}`").collect()[0]["count"]
print(f"\nTable size: {total_rows:,} total rows")

print(f"Optimized streaming tokenization completed for {len(pii_columns)} columns")