# 04 Data Quality and Validation

This notebook runs lightweight data quality checks across Bronze, Silver, and Gold layers.

Checks included:
- Row counts (and optional drop threshold)
- Schema comparison (column set + type mismatches)
- Null checks on required columns
- Duplicate checks on primary keys
- Optional distinct-PK reconciliation between layers

In [None]:
DEFAULT_FORMAT = "delta"  # delta, parquet, csv

# A conservative threshold to avoid false failures.
# Example: if silver count drops more than 30% vs bronze, flag it.
MAX_ROW_DROP_PCT = 0.30

# If true, the notebook raises an exception when failures are found.
# (In Databricks jobs, this can fail the task.)
FAIL_PIPELINE_ON_ERROR = False

# Datasets to validate.
# Fill in paths that match your project.
# If Gold is an aggregated table without a stable PK, set pk_cols=[] and compare_pk=False.

DATASETS = [
    {
        "name": "customers",
        "format": DEFAULT_FORMAT,
        "bronze_path": "/mnt/your_mount/bronze/customers",
        "silver_path": "/mnt/your_mount/silver/customers",
        "gold_path": "/mnt/your_mount/gold/customers",
        "pk_cols": ["CustomerID"],
        "not_null_cols": ["CustomerID"],
        "compare_pk": True
    },
    {
        "name": "sales",
        "format": DEFAULT_FORMAT,
        "bronze_path": "/mnt/your_mount/bronze/sales",
        "silver_path": "/mnt/your_mount/silver/sales",
        "gold_path": "/mnt/your_mount/gold/sales",
        "pk_cols": ["SalesOrderID", "SalesOrderDetailID"],
        "not_null_cols": ["SalesOrderID"],
        "compare_pk": True
    },
    {
        "name": "products",
        "format": DEFAULT_FORMAT,
        "bronze_path": "/mnt/your_mount/bronze/products",
        "silver_path": "/mnt/your_mount/silver/products",
        "gold_path": "/mnt/your_mount/gold/products",
        "pk_cols": ["ProductID"],
        "not_null_cols": ["ProductID"],
        "compare_pk": True
    }
]


In [None]:
# =========================
# HELPERS
# =========================

from typing import Dict, List, Any

def _safe_display(obj):
    """Use Databricks display() if available, else fallback to print."""
    try:
        display(obj)  # type: ignore
    except Exception:
        print(obj)

def load_df(path: str, fmt: str):
    """Load a Spark DataFrame from a path."""
    if fmt.lower() == "delta":
        return spark.read.format("delta").load(path)
    if fmt.lower() == "parquet":
        return spark.read.parquet(path)
    if fmt.lower() == "csv":
        return spark.read.option("header", "true").csv(path)
    raise ValueError(f"Unsupported format: {fmt}")

def exists_path(path: str) -> bool:
    """Check whether a path exists in Databricks FS / mounted storage."""
    try:
        dbutils.fs.ls(path)  # type: ignore
        return True
    except Exception:
        return False

def schema_map(df) -> Dict[str, str]:
    return {f.name: f.dataType.simpleString() for f in df.schema.fields}

def count_nulls(df, cols: List[str]) -> Dict[str, int]:
    from pyspark.sql import functions as F
    out = {}
    for c in cols:
        if c not in df.columns:
            out[c] = -1  # missing column
        else:
            out[c] = df.filter(F.col(c).isNull()).count()
    return out

def count_duplicates(df, pk_cols: List[str]) -> int:
    if not pk_cols:
        return 0
    from pyspark.sql import functions as F
    for c in pk_cols:
        if c not in df.columns:
            return -1  # missing pk col
    dup = (
        df.groupBy(*pk_cols)
          .agg(F.count("*").alias("n"))
          .filter(F.col("n") > 1)
          .count()
    )
    return dup

def distinct_pk_count(df, pk_cols: List[str]) -> int:
    if not pk_cols:
        return 0
    for c in pk_cols:
        if c not in df.columns:
            return -1
    return df.select(*pk_cols).distinct().count()

def pct_drop(a: int, b: int) -> float:
    """Percent drop from a -> b. If a=0, return 0."""
    if a <= 0:
        return 0.0
    return max(0.0, (a - b) / float(a))


In [None]:
# =========================
# RUN VALIDATION
# =========================

results: List[Dict[str, Any]] = []
failures: List[str] = []

for ds in DATASETS:
    name = ds["name"]
    fmt = ds.get("format", DEFAULT_FORMAT)
    pk_cols = ds.get("pk_cols", [])
    not_null_cols = ds.get("not_null_cols", [])
    compare_pk = bool(ds.get("compare_pk", False))

    layer_info = {
        "bronze": ds.get("bronze_path"),
        "silver": ds.get("silver_path"),
        "gold": ds.get("gold_path"),
    }

dfs: Dict[str, Any] = {}
counts: Dict[str, int] = {}
schemas: Dict[str, Dict[str, str]] = {}
nulls: Dict[str, Dict[str, int]] = {}
dups: Dict[str, int] = {}
distinct_pk: Dict[str, int] = {}

    # Load each layer if it exists
    for layer, path in layer_info.items():
        if not path:
            continue
        if not exists_path(path):
            results.append({
                "dataset": name,
                "layer": layer,
                "status": "MISSING_PATH",
                "path": path
            })
            continue
        try:
            df = load_df(path, fmt)
            dfs[layer] = df
            counts[layer] = df.count()
            schemas[layer] = schema_map(df)
            nulls[layer] = count_nulls(df, not_null_cols)
            dups[layer] = count_duplicates(df, pk_cols)
            distinct_pk[layer] = distinct_pk_count(df, pk_cols) if compare_pk else 0
        except Exception as e:
            results.append({
                "dataset": name,
                "layer": layer,
                "status": "LOAD_ERROR",
                "path": path,
                "error": str(e)[:250]
            })

    # Summary rows per layer
    for layer in ["bronze", "silver", "gold"]:
        if layer not in dfs:
            continue
        row = {
            "dataset": name,
            "layer": layer,
            "status": "LOADED",
            "row_count": counts.get(layer, None),
            "dup_pk_groups": dups.get(layer, None),
            "distinct_pk": distinct_pk.get(layer, None),
            "missing_not_null_cols": ",".join([c for c, v in nulls.get(layer, {}).items() if v == -1]),
            "not_null_violations": ",".join([f"{c}:{v}" for c, v in nulls.get(layer, {}).items() if v not in (-1, 0)]),
        }
        results.append(row)

    # Cross-layer checks
    if "bronze" in dfs and "silver" in dfs:
        drop = pct_drop(counts["bronze"], counts["silver"])
        if drop > MAX_ROW_DROP_PCT:
            failures.append(
                f"{name}: bronze->silver row drop {drop:.1%} exceeds threshold {MAX_ROW_DROP_PCT:.1%}"
            )

        bronze_cols = set(schemas["bronze"].keys())
        silver_cols = set(schemas["silver"].keys())

        missing_in_silver = sorted(list(bronze_cols - silver_cols))
        if missing_in_silver:
            results.append({
                "dataset": name,
                "layer": "bronze_vs_silver",
                "status": "SCHEMA_WARN",
                "detail": f"Missing in silver (sample): {missing_in_silver[:20]}" + (" ..." if len(missing_in_silver) > 20 else "")
            })

        common = sorted(list(bronze_cols & silver_cols))
        type_mismatches = [
            c for c in common
            if schemas["bronze"].get(c) != schemas["silver"].get(c)
        ]
        if type_mismatches:
            results.append({
                "dataset": name,
                "layer": "bronze_vs_silver",
                "status": "TYPE_WARN",
                "detail": f"Type mismatches (sample): {type_mismatches[:20]}" + (" ..." if len(type_mismatches) > 20 else "")
            })

        if compare_pk and pk_cols:
            b_pk = distinct_pk.get("bronze", 0)
            s_pk = distinct_pk.get("silver", 0)
            if b_pk != -1 and s_pk != -1:
                pk_drop = pct_drop(b_pk, s_pk)
                if pk_drop > MAX_ROW_DROP_PCT:
                    failures.append(
                        f"{name}: bronze->silver distinct PK drop {pk_drop:.1%} exceeds threshold {MAX_ROW_DROP_PCT:.1%}"
                    )

    if "silver" in dfs and "gold" in dfs:
        # Gold may be aggregated; only apply row-drop check when compare_pk is True
        if compare_pk and pk_cols:
            drop = pct_drop(counts["silver"], counts["gold"])
            if drop > MAX_ROW_DROP_PCT:
                failures.append(
                    f"{name}: silver->gold row drop {drop:.1%} exceeds threshold {MAX_ROW_DROP_PCT:.1%}"
                )

        silver_cols = set(schemas["silver"].keys())
        gold_cols = set(schemas["gold"].keys())

        missing_in_gold = sorted(list(silver_cols - gold_cols))
        if missing_in_gold:
            results.append({
                "dataset": name,
                "layer": "silver_vs_gold",
                "status": "SCHEMA_WARN",
                "detail": f"Missing in gold (sample): {missing_in_gold[:20]}" + (" ..." if len(missing_in_gold) > 20 else "")
            })

    # Hard failures based on nulls/dups
    for layer in ["bronze", "silver", "gold"]:
        if layer not in dfs:
            continue

        d = dups.get(layer, 0)
        if d == -1:
            failures.append(f"{name}:{layer}: missing PK columns for duplicate check")
        elif d > 0:
            failures.append(f"{name}:{layer}: duplicate PK groups={d}")

        for c, v in nulls.get(layer, {}).items():
            if v == -1:
                failures.append(f"{name}:{layer}: missing required column {c}")
            elif v > 0:
                failures.append(f"{name}:{layer}: null violations {c}={v}")

print("Validation complete.")
print(f"Datasets checked: {len(DATASETS)}")
print(f"Issues found: {len(failures)}")


In [None]:
# =========================
# RESULTS SUMMARY
# =========================

from pyspark.sql import Row
from pyspark.sql import functions as F

rows = [Row(**r) for r in results]
df_results = spark.createDataFrame(rows) if rows else spark.createDataFrame([], "dataset string")

_safe_display(df_results.orderBy(F.col("dataset"), F.col("layer")))

if failures:
    print("\nFAILURES:")
    for f in failures:
        print(f"- {f}")
else:
    print("\nNo failures detected.")


In [None]:
# =========================
# OPTIONAL: FAIL THE JOB
# =========================

if FAIL_PIPELINE_ON_ERROR and failures:
    raise Exception("Data quality validation failed. See failures list above.")

# If you want ADF to read notebook output, you can use:
# dbutils.notebook.exit("OK")
