# Process Bronze-to-Silver

In [0]:
from pyspark.sql.functions import explode
from pyspark.sql import SparkSession
from functools import reduce
from delta.tables import DeltaTable
import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql import Window, Column, DataFrame, SparkSession
import logging
import sys

In [None]:
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Add console handler for Databricks output
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)

# Ensure no duplicate handlers
if not logger.hasHandlers():
    logger.addHandler(console_handler)

## Define aux functions

In [None]:
def create_constraints_table(schema: str, table_name: str):
    salesConstraints = f"""{{
        "tableName": "silver_sales",
        "constraints": [
            "ALTER TABLE {schema}.silver_sales ALTER COLUMN sales_id SET NOT NULL;",
            "ALTER TABLE {schema}.silver_sales ALTER COLUMN client_id SET NOT NULL;",
            "ALTER TABLE {schema}.silver_sales ALTER COLUMN product_id SET NOT NULL;",
            "ALTER TABLE {schema}.silver_sales ADD CONSTRAINT dateWithinRange  CHECK (date between '2020-01-01' and '2025-01-01');",
            "ALTER TABLE {schema}.silver_sales ADD CONSTRAINT validQuantity  CHECK (quantity > 0);"
        ]
    }}"""

    df = spark.read.json(spark.sparkContext.parallelize([salesConstraints])).select("tableName", explode("constraints"))
    df.write.mode("overwrite").saveAsTable(f"{schema}.{table_name}")

    return salesConstraints 

In [None]:
def apply_table_constaints(source_constraints_tablename: str, table_name: str) -> None:
    """
    Apply table-specific SQL constraints to the given Delta table.

    Args:
        spark (SparkSession): The Spark session.
        source_constraints_tablename (str): The table containing the constraints definitions.
        table_name (str): The table to which constraints are being applied.

    Returns:
        None: This function applies constraints directly to the table using SQL statements.
    """
    logger.info("Applying table constraints to table: %s", table_name)

    # Read the constraints from the source constraints table
    constraints_df = spark.read.table(source_constraints_tablename)
    
    # Filter constraints relevant to the target table
    table_constraints = constraints_df.filter(f"tableName == '{table_name}'")

    # Apply each constraint via SQL execution
    for row in table_constraints.collect():
        try:
            logger.info("Applying constraint: %s", row[1])
            spark.sql(row[1])
        except Exception as e:
            logger.error("Error applying constraint %s to table %s: %s", row[1], table_name, str(e))
            raise

In [None]:
def get_table_constraints_conditions(schema: str, table_name: str) -> Column:
    """
    Generate a composite filter condition for validating table constraints, including non-nullable columns and custom constraints.

    Args:
        spark (SparkSession): The Spark session.
        schema (str): The schema of the table.
        table_name (str): The name of the table for which constraints are generated.

    Returns:
        Column: A PySpark Column object representing the composite filter condition for the constraints.
    """
    logger.info("Generating table constraint conditions for table: %s.%s", schema, table_name)

    # Identify non-nullable fields and generate a filter for them
    nullable_filters = [
        f"{x.name} IS NOT NULL" 
        for x in spark.table(f"{schema}.{table_name}").schema.fields if not x.nullable
    ]
    logger.info("Non-nullable field filters: %s", nullable_filters)

    # Retrieve custom table constraints (from Delta properties)
    constraints_df = (
        spark.sql(f"SHOW TBLPROPERTIES {schema}.{table_name}")
        .filter(F.col("key").startswith("delta.cons"))
        .select("value")
    )

    # Collect the constraints from the table properties
    constraints_filter = [c[0] for c in constraints_df.collect()]
    logger.info("Custom constraints from TBLPROPERTIES: %s", constraints_filter)

    # Combine nullable filters and custom constraints into a single condition
    constraints = nullable_filters + constraints_filter

    # Reduce the list of constraints to a composite filter condition using logical OR
    # If any constraint fails (is false), we want to quarantine the record
    combined_condition = reduce(lambda x, y: x | ~F.expr(y), constraints, F.lit(False))
    logger.info("Combined constraint condition generated.")

    return combined_condition

In [None]:
display(table("data_quality_demo.tableconstraints"))

In [None]:
def stream_merge_into_delta(
    batch_df: DataFrame, 
    batch_id: int, 
    table_schema: str, 
    table_name: str, 
    ids: list, 
    max_col: str = None, 
    quarantine_schema: str = "quarantine", 
    table_constraints_name: str = "tableConstraints"
) -> None:
    logger.info("Starting stream_merge_into_delta for table: %s.%s", table_schema, table_name)
    
    # Step 1: Dedupe input batch, use max_col if provided
    if max_col:
        logger.info("Deduplication using max column: %s", max_col)
        w = Window.partitionBy(ids).orderBy(col(max_col).desc())
        df_dedupe = (
            batch_df
            .withColumn("rank", F.rank().over(w))
            .filter(col("rank") == 1)
            .drop("rank")
        )
    else:
        logger.info("Deduplication using primary keys: %s", ids)
        df_dedupe = batch_df.dropDuplicates(ids)

    # Step 2: Check if Delta table exists
    if spark.catalog.tableExists(f"{table_schema}.{table_name}"):
        logger.info("Delta table exists: %s.%s", table_schema, table_name)
        delta = DeltaTable.forName(spark, f"{table_schema}.{table_name}")
    else:
        # Step 3: Create Delta table if it doesn't exist
        logger.info("Creating Delta table: %s.%s", table_schema, table_name)
        delta = (DeltaTable.create(spark)
                 .tableName(f"{table_schema}.{table_name}")
                 .addColumns(df_dedupe.schema)
                 .execute())
        
        # Apply table constraints after table creation
        apply_table_constaints(f"{table_schema}.{table_constraints_name}", table_name)

    # Step 5: Quarantine invalid records based on table constraints
    constraints_conditions = get_table_constraints_conditions(table_schema, table_name)
    
    quarantine_records = df_dedupe.filter(constraints_conditions)
    valid_records = df_dedupe.filter(~constraints_conditions)

    # Write quarantine records to the specified quarantine schema
    quarantine_records.write.mode("append").saveAsTable(f"{table_schema}_{quarantine_schema}.{table_name}")
    
    logger.info("Total Records: %d", df_dedupe.count())
    logger.info("Records moved to quarantine: %d", quarantine_records.count())
    logger.info("Valid Records: %d", valid_records.count())

    # Step 6: Create merge condition based on primary keys (ids)
    condition = " AND ".join(f"l.{c} = r.{c}" for c in ids)
    logger.info("Merge condition: %s", condition)

    # Step 7: Perform the merge operation into Delta table
    merge = (
        delta.alias("l")
        .merge(valid_records.alias("r"), condition)
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
    )
    
    merge.execute()
    logger.info("Merge operation completed for table: %s.%s", table_schema, table_name)

## Define parameters

In [None]:
bronze_table ='bronze_data_quality'
schema = 'data_quality_demo'

silver_schema = 'demo'
silver_table ='silver_sales'
checkpoint_path = 'tmp/testa/_checkpoints'
cols = ['date','sales_id', 'client_id', 'product_id', 'quantity','sale_amount']
ids = ['sales_id']


## Create Schemas & Constraint table

In [0]:
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {silver_schema}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {silver_schema}_quarantine")

table_constraints_name = 'tableConstraints'
create_constraints_table(silver_schema, f"{table_constraints_name}")

# Read Bronze table & Process Data

In [None]:
df = spark.readStream.table(f"{schema}.{bronze_table}")

In [None]:
display(df)

In [None]:
query = (
            df
            .select(cols)
            .writeStream
            .format('delta')
            .option("checkpointLocation", f"{checkpoint_path}/{silver_table}/")
            .foreachBatch(lambda batch_df, batch_id: stream_merge_into_delta(batch_df=batch_df, batch_id=batch_id, table_schema=silver_schema, table_name=silver_table, ids=ids, max_col=None))
            .trigger(availableNow=True)
            .start()
        )

query.awaitTermination()

In [None]:
display(table(f"{silver_schema}.{silver_table}"))

In [None]:
display(table(f"{silver_schema}_quarantine.{silver_table}"))

## Clean environment

In [None]:
dbutils.fs.rm(checkpoint_path, True)

spark.sql(f"DROP SCHEMA  IF EXISTS {silver_schema} CASCADE")
spark.sql(f"DROP SCHEMA  IF EXISTS {silver_schema}_quarantine CASCADE")