## Data Quality and Quarantine Framework : Validation Engine Logic

Data source configuration (Catalog, Schema, Tables, PK)

In [0]:
from pyspark.sql.functions import expr
from pyspark.sql.functions import split, trim, col

In [0]:
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# Widgets & Setup
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("schema", "")
dbutils.widgets.text("tables", "")
dbutils.widgets.text("Input_file_location_pk", "")
#dbutils.widgets.text("pk", "")
def _get_widget(name: str) -> str:
    return (dbutils.widgets.get(name) or "").strip()

catalog    = _get_widget("catalog")
schema     = _get_widget("schema")
tables_raw = _get_widget("tables")
input_pk = _get_widget("Input_file_location_pk")

file_path = f"/Volumes/{catalog}/{schema}/{input_pk}"
pk_df = spark.read.text(file_path)

if not catalog or not schema:
    raise RuntimeError("Please fill 'catalog' and 'schema' widgets.")
if not tables_raw:
    raise RuntimeError("Please fill 'tables' widget.")
# if not pk_raw:
#     raise RuntimeError("Please fill 'pk' widget.")

meta = f"{catalog}.{schema}"

import re
from typing import Dict, List

def normalize_logical_name(table_name: str) -> str:
    base = table_name.strip().split(".")[-1]
    return re.sub(r'^(bronze)_', '', base, flags=re.IGNORECASE)

def source_table_name(meta: str, source_hint: str) -> str:
    return source_hint if "." in source_hint else f"{meta}.{source_hint}"

def silver_table_name(meta: str, logical_name: str) -> str:
    return f"{meta}.silver_{logical_name}_valid"

def get_pk_dict(pk_df):
    parsed_df = (
        pk_df
        .withColumn("table_name", trim(split(col("value"), ":")[0]))
        .withColumn("pk_columns_raw", trim(split(col("value"), ":")[1]))
        .withColumn("primary_keys", split(col("pk_columns_raw"), ","))
        .drop("value", "pk_columns_raw")
    )
    parsed_df = parsed_df.withColumn(
        "primary_keys",
        expr("transform(primary_keys, x -> trim(x))")
    )
    mapping = {
        row["table_name"]: row["primary_keys"]
        for row in parsed_df.collect()
    }
    return mapping


def get_pk_for_table_from_widget(source_input_name: str, src_cols: List[str]) -> List[str]:
    pk_map = get_pk_dict(pk_df)
    logical = normalize_logical_name(source_input_name)

    for key in (source_input_name, logical):
        if key in pk_map:
            pk_cols = pk_map[key]
            break
    else:
        raise RuntimeError(f"PK not defined for table {source_input_name}")

    missing = [c for c in pk_cols if c not in src_cols]
    if missing:
        raise RuntimeError(f"Missing PK columns {missing} in {source_input_name}")

    return pk_cols

def _fmt_line(label: str, value: str, width: int = 12):
    print(f"{label:<{width}}  : {value}")

tables_list = [t.strip() for t in tables_raw.split(",") if t.strip()]



#Validation Engine Logic

In [0]:
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# FULL DATA QUALITY PIPELINE (FINAL ‚Äì STABLE, METADATA-DRIVEN)
# PKs from PK file ONLY
# dq_column_group ‚Üí merge logic ONLY
# ALL source columns preserved in Silver
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from functools import reduce
import time
from datetime import datetime

# =============================================================
# PIPELINE METADATA
# =============================================================
PIPELINE_RUN_ID = int(time.time() * 1000)
PIPELINE_TS     = datetime.now()

# =============================================================
# FRAMEWORK TABLES
# =============================================================
RULES_DEF_TBL    = f"{meta}.rule_defination"
COLUMN_MAP_TBL   = f"{meta}.column_map"
COLUMN_GROUP_TBL = f"{meta}.dq_column_group"
QUARANTINE_TBL   = f"{meta}.bronze_dq_quarantine"
OBS_TBL          = f"{meta}.dq_observability_summary"

rules_def  = spark.table(RULES_DEF_TBL).select("rule_id","rule_type","threshold")
column_map = spark.table(COLUMN_MAP_TBL)

# =============================================================
# HELPERS
# =============================================================
def _fmt_line(label, value, width=22):
    print(f"{label:<{width}} : {value}")

def add_row_hash_and_id(df):
    return (
        df
        .withColumn(
            "_row_hash",
            F.sha2(
                F.concat_ws(
                    "||",
                    *[F.coalesce(F.col(c).cast("string"), F.lit("‚àÖ")) for c in df.columns]
                ),
                256
            )
        )
        .withColumn("_row_id", F.monotonically_increasing_id())
    )

# =============================================================
# CORE VALIDATION
# =============================================================
def run_validation_for_table(source_input_name: str, run_id: int):

    source_table = source_table_name(meta, source_input_name)
    logical_name = normalize_logical_name(source_input_name)
    silver_table = silver_table_name(meta, logical_name)

    df = spark.table(source_table)
    base_cols = df.columns

    # üîë PKs strictly from PK file
    pk_cols = get_pk_for_table_from_widget(source_input_name, base_cols)

    df_hashed = add_row_hash_and_id(df)

    # =========================================================
    # 1Ô∏è‚É£ FULL ROW DUPLICATES ‚Üí QUARANTINE
    # =========================================================
    dup_failures = (
        df_hashed
        .withColumn(
            "dup_rank",
            F.row_number().over(
                Window.partitionBy("_row_hash").orderBy("_row_id")
            )
        )
        .filter(F.col("dup_rank") > 1)
        .select(
            "_row_id",
            F.struct(*base_cols).alias("row_struct"),
            F.lit(2).alias("failed_rule_id"),
            F.lit("full_row_duplicate").alias("failed_rule_type"),
            F.lit("Exact duplicate row detected").alias("failure_reason"),
            F.lit("_all_columns").alias("column_name")
        )
    )

    # =========================================================
    # 2Ô∏è‚É£ COLUMN-LEVEL RULE FAILURES
    # =========================================================
    rule_map = (
        column_map
        .filter(F.col("table_name") == logical_name)
        .withColumn(
            "rule_id",
            F.explode(F.split(F.regexp_replace("rule_ids","\\s+",""), ","))
        )
        .withColumn("rule_id", F.col("rule_id").cast("int"))
        .join(rules_def, "rule_id")
        .filter(F.col("column_name").isin(base_cols))
    )

    def failing_rows_for_rule(df, r):
        col   = r["column_name"]
        rtype = r["rule_type"]
        thr   = (r.get("threshold") or "").strip()

        if rtype == "null_check":
            cond = F.col(col).isNull()
        elif rtype.startswith("regex_"):
            cond = F.col(col).isNotNull() & (~F.col(col).cast("string").rlike(thr))
        elif rtype.startswith("range_"):
            lo, hi = map(float, thr.split("-"))
            cond = (
                F.col(col).isNotNull() &
                (
                    F.col(col).cast("double").isNull() |
                    (F.col(col).cast("double") < lo) |
                    (F.col(col).cast("double") > hi)
                )
            )
        else:
            return None

        return (
            df.filter(cond)
              .select("_row_id", F.struct(*base_cols).alias("row_struct"))
              .withColumn("failed_rule_id", F.lit(r["rule_id"]))
              .withColumn("failed_rule_type", F.lit(rtype))
              .withColumn("failure_reason", F.lit(f"{col} failed {rtype}"))
              .withColumn("column_name", F.lit(col))
        )

    rule_failures = []
    for r in rule_map.collect():
        fr = failing_rows_for_rule(df_hashed, r.asDict())
        if fr is not None:
            rule_failures.append(fr)

    all_failures = (
        reduce(lambda a,b: a.unionByName(b), [dup_failures] + rule_failures)
        if rule_failures else dup_failures
    )

    invalid_rows = all_failures.select("_row_id").distinct()

    # =========================================================
    # 3Ô∏è‚É£ QUARANTINE WRITE
    # =========================================================
    (
        all_failures
        .withColumn("run_id", F.lit(run_id))
        .withColumn("source_table", F.lit(logical_name))
        .withColumn("invalid_data", F.to_json("row_struct"))
        .withColumn("processed_timestamp", F.current_timestamp())
        .withColumnRenamed("_row_id", "row_id")
        .select(
            "run_id","source_table","invalid_data",
            "failed_rule_id","failed_rule_type",
            "failure_reason","processed_timestamp","row_id"
        )
        .write.mode("append")
        .saveAsTable(QUARANTINE_TBL)
    )

    # =========================================================
    # 4Ô∏è‚É£ CLEAN DATA (PRE-AGG)
    # =========================================================
    clean_df = (
        df_hashed
        .join(invalid_rows, "_row_id", "left_anti")
        .drop("_row_hash", "_row_id")
    )

    clean_pre_agg_cnt = clean_df.count()

    # =========================================================
    # 5Ô∏è‚É£ SILVER ‚Äì MERGE COLUMNS ONLY, KEEP ALL OTHERS
    # =========================================================
    cg_rows = (
        spark.table(COLUMN_GROUP_TBL)
        .filter((F.col("table_name") == logical_name) & (F.col("active_flag") == "Y"))
        .select("column_name","merge_strategy","delimiter")
        .collect()
    )

    merge_cols = [r.column_name for r in cg_rows]

    agg_exprs = []
    for r in cg_rows:
        if r.merge_strategy == "concat":
            agg_exprs.append(
                F.concat_ws(
                    r.delimiter or ",",
                    F.sort_array(F.collect_set(r.column_name))
                ).alias(r.column_name)
            )

    pass_through_cols = [
        c for c in base_cols if c not in pk_cols and c not in merge_cols
    ]

    if agg_exprs:
        silver_df = (
            clean_df
            .groupBy(*pk_cols)
            .agg(
                *agg_exprs,
                *[F.first(c, ignorenulls=True).alias(c) for c in pass_through_cols]
            )
            .withColumn("load_timestamp", F.current_timestamp())
        )
    else:
        silver_df = (
            clean_df
            .dropDuplicates(pk_cols)
            .withColumn("load_timestamp", F.current_timestamp())
        )

    silver_df.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(silver_table)

    silver_cnt = silver_df.count()
    business_agg_cnt = clean_pre_agg_cnt - silver_cnt
    invalid_cnt = invalid_rows.count()
    total_src = df.count()

    # =========================================================
    # 6Ô∏è‚É£ PRINT BLOCK
    # =========================================================
    _fmt_line("SOURCE", source_table)
    _fmt_line("COLUMN MAP", COLUMN_MAP_TBL)
    _fmt_line("COLUMN GROUP", COLUMN_GROUP_TBL)
    _fmt_line("QUARANTINE", QUARANTINE_TBL)
    _fmt_line("SILVER VALID", silver_table)

    print(f"PK column(s)               : {', '.join(pk_cols)}")
    print(f"Validation Run ID          : {run_id}")
    print("‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
    print(f"Total source rows          : {total_src}")
    print(f"Quarantined rows           : {invalid_cnt}")
    print(f"Clean rows before agg      : {clean_pre_agg_cnt}")
    print(f"Business aggregated rows   : {business_agg_cnt}")
    print(f"Silver valid rows          : {silver_cnt}")
    print("‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
    print(f"Accounting check           : {silver_cnt} + {business_agg_cnt} + {invalid_cnt} = {total_src}")
    print("‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
    print(f"‚úÖ Run complete for table: {logical_name}")

    # =========================================================
    # 7Ô∏è‚É£ OBSERVABILITY
    # =========================================================
    metrics = [
        ("Volume","total_raw_records",str(total_src)),
        ("Volume","total_valid_records",str(silver_cnt)),
        ("Volume","total_failed_records",str(invalid_cnt)),
        ("Quality","business_aggregated_records",str(business_agg_cnt)),
        ("Quality","quarantined_percentage",
         str(round((invalid_cnt * 100.0) / total_src, 2) if total_src else 0))
    ]

    (
        spark.createDataFrame(metrics, ["metric_category","metric_name","metric_value"])
        .withColumn("run_id", F.lit(run_id))
        .withColumn("source_table", F.lit(logical_name))
        .withColumn("metric_ts", F.lit(PIPELINE_TS))
        .withColumn("metric_date", F.lit(PIPELINE_TS.date()))
        .withColumn("created_ts", F.current_timestamp())
        .write.mode("append")
        .saveAsTable(OBS_TBL)
    )

# =============================================================
# EXECUTION
# =============================================================
print(f"üöÄ PIPELINE RUN ID : {PIPELINE_RUN_ID}")
print(f"‚è±Ô∏è PIPELINE TS     : {PIPELINE_TS}")

for src in tables_list:
    run_validation_for_table(src, PIPELINE_RUN_ID)

print("üéØ VALIDATION + OBSERVABILITY COMPLETED SUCCESSFULLY")
