In [0]:
# Create widgets for input
dbutils.widgets.text("table_reference", "", "Table Reference")
dbutils.widgets.text("rollback_version", "", "Rollback Version (Optional)")
dbutils.widgets.text("rollback_timestamp", "", "Rollback Timestamp (Optional)")


In [0]:
# Retrieve widget values
table_reference = dbutils.widgets.get("table_reference")
rollback_version = dbutils.widgets.get("rollback_version")
rollback_timestamp = dbutils.widgets.get("rollback_timestamp")

In [0]:
from datetime import datetime
import json
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

# Assuming Spark session is already created
spark = SparkSession.builder.getOrCreate()

def parse_table_reference(table_reference):
    """Parse the catalog, schema, and table from the input."""
    return table_reference.split('.')

def validate_rollback_params(rollback_version, rollback_timestamp):
    """Validate input parameters for rollback."""
    if not rollback_version and not rollback_timestamp:
        raise ValueError("Either rollback_version or rollback_timestamp must be provided.")

def generate_rollback_clause(rollback_version, rollback_timestamp):
    """Generate the rollback clause for the SQL command."""
    if rollback_version:
        return f"TO VERSION AS OF {rollback_version}"
    elif rollback_timestamp:
        return f"TO TIMESTAMP AS OF '{rollback_timestamp}'"

def restore_table(catalog_name, schema_name, table_name, rollback_clause):
    """Restore the table using Delta's RESTORE command."""
    rollback_query = f"RESTORE TABLE {catalog_name}.{schema_name}.{table_name} {rollback_clause}"
    spark.sql(rollback_query)

def get_latest_history_entry(catalog_name, schema_name, table_name):
    """Get the latest history entry after rollback."""
    describe_history_df = spark.sql(f"DESCRIBE HISTORY {catalog_name}.{schema_name}.{table_name} LIMIT 1")
    return describe_history_df.collect()[0]

def extract_operation_parameters(latest_history_entry):
    """Extract operation parameters from the latest history entry."""
    return latest_history_entry['operationParameters']

def get_created_by(catalog_name, schema_name, table_name):
    """Get created_by from information_schema.tables."""
    created_by_df = spark.sql(f"""
        SELECT created_by 
        FROM {catalog_name}.information_schema.tables 
        WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'
    """)
    return created_by_df.collect()[0]['created_by'] if created_by_df.count() > 0 else None

def get_schema_info(catalog_name, schema_name, table_name):
    """Get schema information from information_schema.columns."""
    schema_df = spark.sql(f"""
        SELECT COLUMN_NAME, DATA_TYPE 
        FROM {catalog_name}.information_schema.columns 
        WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'
    """)
    return {row['COLUMN_NAME']: row['DATA_TYPE'] for row in schema_df.collect()}

def get_new_schema_version(catalog_name, table_name):
    """Determine the next schema version."""
    max_version_df = spark.sql(f"""
        SELECT MAX(schema_version) AS max_version 
        FROM {catalog_name}.default.schema_registry 
        WHERE table_name = '{table_name}'
    """)
    return (max_version_df.collect()[0]['max_version'] + 1) if max_version_df.count() > 0 else 1

def mark_previous_schema_inactive(catalog_name, table_name):
    """Mark previous schema entries as inactive."""
    update_query = f"""
        UPDATE {catalog_name}.default.schema_registry 
        SET status = 'Inactive' 
        WHERE table_name = '{table_name}' AND status = 'Active'
    """
    spark.sql(update_query)

def create_rollback_entry(catalog_name, schema_name, table_name, new_schema_version, created_by, modified_by,
                          new_modified_timestamp, schema_json, change_type_value, new_table_version,
                          table_version_timestamp):
    """Prepare rollback entry as a Row object."""
    return Row(
        catalog_name=catalog_name,
        schema_name=schema_name,
        table_name=table_name,
        schema_version=new_schema_version,
        created_by=created_by,
        modified_by=modified_by,
        modified_timestamp=new_modified_timestamp,
        schema_json=json.dumps(schema_json),
        change_type=change_type_value,
        column_name=None,
        table_version=new_table_version,
        table_version_timestamp=table_version_timestamp,
        status="Active",
        check_timestamp=datetime.utcnow(),
        schema_change_alert_status="Pending",
        rollback_notification_status="Rollbacked"
    )

def save_rollback_entry_to_registry(rollback_entry, catalog_name):
    """Save the rollback entry to the schema registry."""
    schema = StructType([
        StructField("catalog_name", StringType(), True),
        StructField("schema_name", StringType(), True),
        StructField("table_name", StringType(), True),
        StructField("schema_version", IntegerType(), True),
        StructField("created_by", StringType(), True),
        StructField("modified_by", StringType(), True),
        StructField("modified_timestamp", TimestampType(), True),
        StructField("schema_json", StringType(), True),
        StructField("change_type", StringType(), True),
        StructField("column_name", StringType(), True),
        StructField("table_version", IntegerType(), True),
        StructField("table_version_timestamp", TimestampType(), True),
        StructField("status", StringType(), True),
        StructField("check_timestamp", TimestampType(), True),
        StructField("schema_change_alert_status", StringType(), True),
        StructField("rollback_notification_status", StringType(), True)
    ])
    rollback_df = spark.createDataFrame([rollback_entry], schema=schema)
    display(rollback_df)
    rollback_df.write.format("delta").mode("append").saveAsTable(f"{catalog_name}.default.schema_registry")

def rollback_table_with_confirmation(table_reference, rollback_version=None, rollback_timestamp=None):
    """Main function to perform the rollback with confirmation."""
    # Parse the catalog, schema, and table
    catalog_name, schema_name, table_name = parse_table_reference(table_reference)
    
    # Validate input parameters
    validate_rollback_params(rollback_version, rollback_timestamp)
    
    # Generate rollback clause
    rollback_clause = generate_rollback_clause(rollback_version, rollback_timestamp)
    
    # Restore the table
    restore_table(catalog_name, schema_name, table_name, rollback_clause)
    
    # Get latest history entry
    latest_history_entry = get_latest_history_entry(catalog_name, schema_name, table_name)
    
    # Extract information from history
    operation = latest_history_entry['operation']
    if operation != "RESTORE":
        raise ValueError("The operation recorded in history is not a RESTORE operation.")
    
    # Extract rollback target version or timestamp from operation parameters
    operational_parameters = extract_operation_parameters(latest_history_entry)
    restored_version = operational_parameters.get("version")
    restored_timestamp = operational_parameters.get("timestamp")
    
    target_version = rollback_version if rollback_version else restored_version
    target_timestamp = rollback_timestamp if rollback_timestamp else restored_timestamp
    
    # Get created_by
    created_by = get_created_by(catalog_name, schema_name, table_name)
    
    # Get schema information
    schema_json = get_schema_info(catalog_name, schema_name, table_name)
    
    # Determine the next schema version
    new_schema_version = get_new_schema_version(catalog_name, table_name)
    
    # Mark previous schema entries as inactive
    mark_previous_schema_inactive(catalog_name, table_name)
    
    # Prepare rollback entry
    change_type_value = f"ROLLBACK TO {'VERSION' if target_version else 'TIMESTAMP'} {target_version if target_version else target_timestamp}"
    change_type_value = change_type_value[:50]  # Truncate if needed
    rollback_entry = create_rollback_entry(
        catalog_name=catalog_name,
        schema_name=schema_name,
        table_name=table_name,
        new_schema_version=new_schema_version,
        created_by=created_by,
        modified_by=latest_history_entry['userName'],
        new_modified_timestamp=latest_history_entry['timestamp'],
        schema_json=schema_json,
        change_type_value=change_type_value,
        new_table_version=latest_history_entry['version'],
        table_version_timestamp=latest_history_entry['timestamp']
    )

    # Save the rollback entry to the schema registry
    save_rollback_entry_to_registry(rollback_entry, catalog_name)
    
    print(f"Table {table_name} has been rolled back to {'VERSION' if target_version else 'TIMESTAMP'} {target_version if target_version else target_timestamp} and recorded in the schema registry.")



In [0]:
# Call the function with widget values
rollback_table_with_confirmation(
    table_reference=table_reference, 
    rollback_version=rollback_version if rollback_version else None, 
    rollback_timestamp=rollback_timestamp if rollback_timestamp else None
)

catalog_name,schema_name,table_name,schema_version,created_by,modified_by,modified_timestamp,schema_json,change_type,column_name,table_version,table_version_timestamp,status,check_timestamp,schema_change_alert_status,rollback_notification_status
ds_training_1,ds_silver,customer_silver_vishal,17,vishal.kokkula@latentviewo365.onmicrosoft.com,brindavivek.kotha@latentviewo365.onmicrosoft.com,2024-10-23T10:34:35Z,"{""customer_id"": ""INT"", ""name"": ""STRING"", ""age"": ""INT"", ""gender"": ""STRING"", ""phone_number"": ""STRING"", ""email"": ""STRING"", ""account_id"": ""INT"", ""account_type"": ""STRING"", ""balance"": ""INT"", ""opened_date"": ""DATE"", ""status"": ""STRING"", ""business_date"": ""STRING"", ""test_column"": ""DATE"", ""test_column_18_10_24"": ""INT"", ""test_column_23_10"": ""INT""}",ROLLBACK TO VERSION 20,,25,2024-10-23T10:34:35Z,Active,2024-10-23T10:34:46.169252Z,Pending,Rollbacked


Table customer_silver_vishal has been rolled back to VERSION 20 and recorded in the schema registry.
