# Data Quality Checks - Enterprise Data Platform

## Overview
This notebook performs comprehensive data quality checks on Gold layer tables.

**Checks Performed:**
- Referential integrity (FK validation)
- Business rule compliance
- Data distribution analysis
- Anomaly detection

**Prerequisites:**
- Gold star schema created (run 03_build_gold_star_schema.ipynb first)

In [None]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from datetime import datetime
import json

print(f"Data Quality Checks Started: {datetime.now()}")

## Check 1: Referential Integrity

In [None]:
print("\n" + "="*80)
print("CHECK 1: Referential Integrity Validation")
print("="*80)

# Check if lakehouse is attached
try:
    current_db = spark.catalog.currentDatabase()
    print(f"\n‚úÖ Using database/lakehouse: {current_db}")
except Exception as e:
    print("\n‚ùå ERROR: No lakehouse attached!")
    print("Please attach a lakehouse to this notebook before running.")
    print("Steps:")
    print("  1. Click on the lakehouse icon in the left panel")
    print("  2. Select 'Add' and choose your lakehouse")
    print("  3. Re-run this cell")
    raise Exception("No lakehouse attached. Please attach a lakehouse and try again.") from None

# Get all Gold fact and dimension tables dynamically
try:
    gold_facts = [t.name for t in spark.catalog.listTables() 
                  if t.name.startswith("gold_fact")]
    gold_dims = [t.name for t in spark.catalog.listTables() 
                 if t.name.startswith("gold_dim") or t.name.startswith("dim")]
except Exception as e:
    print(f"\n‚ùå ERROR: Cannot list tables: {str(e)}")
    print("Make sure you have attached a lakehouse with Gold layer tables.")
    raise

print(f"\nüìä Found {len(gold_facts)} fact tables and {len(gold_dims)} dimension tables to validate")

if len(gold_facts) == 0:
    print("\n‚ö†Ô∏è  WARNING: No Gold fact tables found!")
    print("Please run notebook 03_build_gold_star_schema.ipynb first to create Gold tables.")
    integrity_results = []
else:
    integrity_results = []
    total_checks = 0
    passed_checks = 0
    failed_checks = 0

    # For each fact table, verify all foreign key relationships
    for fact_table_name in sorted(gold_facts):
        try:
            fact_df = spark.table(fact_table_name)
            fact_columns = fact_df.columns
            
            print(f"\nüîç Validating {fact_table_name}...")
            
            # Find all potential foreign key columns (ending with _id or _key)
            fk_columns = [col for col in fact_columns 
                         if col.endswith('_id') or col.endswith('_key')]
            
            if not fk_columns:
                print(f"   ‚ÑπÔ∏è  No foreign key columns found")
                continue
            
            # For each foreign key column, try to find matching dimension table
            for fk_col in fk_columns:
                # Try to infer dimension table name
                possible_dim_names = []
                
                # Extract base name from FK column
                if fk_col.endswith('_date_id'):
                    possible_dim_names = ['dimdate', 'gold_dimdate']
                elif fk_col.endswith('_id'):
                    base_name = fk_col.replace('_id', '')
                    possible_dim_names = [
                        f'gold_dim{base_name}',
                        f'dim{base_name}',
                        f'gold_dim{base_name.replace("_", "")}',
                        f'dim{base_name.replace("_", "")}'
                    ]
                elif fk_col.endswith('_key'):
                    base_name = fk_col.replace('_key', '')
                    possible_dim_names = [
                        f'gold_dim{base_name}',
                        f'dim{base_name}',
                        f'gold_dim{base_name.replace("_", "")}',
                        f'dim{base_name.replace("_", "")}'
                    ]
                
                # Find matching dimension table
                matching_dim = None
                for dim_name in possible_dim_names:
                    if dim_name in [d.lower() for d in gold_dims]:
                        matching_dim = [d for d in gold_dims if d.lower() == dim_name][0]
                        break
                
                if not matching_dim:
                    # Check if the FK might reference a dimension with different naming
                    for dim_table in gold_dims:
                        dim_df = spark.table(dim_table)
                        if fk_col in dim_df.columns:
                            matching_dim = dim_table
                            break
                
                if matching_dim:
                    total_checks += 1
                    try:
                        # Get dimension table
                        dim_df = spark.table(matching_dim)
                        
                        # Find the primary key column in dimension
                        pk_candidates = [fk_col]
                        if '_' in fk_col:
                            # For order_date_id, try date_id
                            parts = fk_col.split('_')
                            if len(parts) > 2:
                                pk_candidates.append('_'.join(parts[-2:]))
                        
                        pk_col = None
                        for candidate in pk_candidates:
                            if candidate in dim_df.columns:
                                pk_col = candidate
                                break
                        
                        if not pk_col:
                            print(f"   ‚ö†Ô∏è  {fk_col} ‚Üí {matching_dim}: Cannot find PK column")
                            failed_checks += 1
                            continue
                        
                        # Find orphaned FKs (excluding nulls)
                        orphaned = fact_df.select(fk_col).distinct() \
                            .join(dim_df.select(pk_col), 
                                  fact_df[fk_col] == dim_df[pk_col], 
                                  "left_anti") \
                            .filter(col(fk_col).isNotNull())
                        
                        orphan_count = orphaned.count()
                        total_distinct = fact_df.select(fk_col).filter(col(fk_col).isNotNull()).distinct().count()
                        
                        status = "‚úÖ PASS" if orphan_count == 0 else "‚ùå FAIL"
                        
                        print(f"   {status} {fk_col} ‚Üí {matching_dim}.{pk_col}")
                        print(f"        Orphaned: {orphan_count:,} / {total_distinct:,} distinct values")
                        
                        if orphan_count == 0:
                            passed_checks += 1
                        else:
                            failed_checks += 1
                            # Show sample orphaned values
                            if orphan_count <= 5:
                                print(f"        Orphaned values:")
                                orphaned.show(orphan_count, truncate=False)
                            else:
                                print(f"        Sample orphaned values:")
                                orphaned.show(5, truncate=False)
                        
                        integrity_results.append({
                            "fact_table": fact_table_name,
                            "fk_column": fk_col,
                            "dim_table": matching_dim,
                            "pk_column": pk_col,
                            "orphan_count": orphan_count,
                            "total_distinct": total_distinct,
                            "passed": orphan_count == 0
                        })
                    
                    except Exception as e:
                        print(f"   ‚ùå Error checking {fk_col} ‚Üí {matching_dim}: {str(e)}")
                        failed_checks += 1
                else:
                    # FK column doesn't match any dimension table (might be a measure or non-FK column)
                    print(f"   ‚ÑπÔ∏è  {fk_col}: No matching dimension table found (might not be a FK)")
        
        except Exception as e:
            print(f"‚ùå Error processing {fact_table_name}: {str(e)}")

    # Summary
    print(f"\n{'='*80}")
    print("REFERENTIAL INTEGRITY VALIDATION SUMMARY")
    print(f"{'='*80}")
    print(f"Total relationship checks: {total_checks}")
    print(f"‚úÖ Passed: {passed_checks}")
    print(f"‚ùå Failed: {failed_checks}")

    if failed_checks == 0 and total_checks > 0:
        print("\nüéâ All referential integrity checks passed!")
    elif total_checks == 0:
        print("\n‚ö†Ô∏è  No relationships could be validated")
    else:
        print(f"\n‚ö†Ô∏è  {failed_checks} relationship(s) have orphaned records")
        print("   Review the failed checks above and consider cleaning data or updating dimension tables")

print(f"{'='*80}")

## Check 2: Null Value Analysis

In [None]:
print("\n" + "="*80)
print("CHECK 2: Null Value Analysis")
print("="*80)

# Get all Gold tables dynamically
all_gold_tables = spark.catalog.listTables()
tables_to_check = [t.name for t in all_gold_tables 
                   if t.name.startswith("gold_dim") or t.name.startswith("gold_fact") ]

if not tables_to_check:
    print("\n‚ö†Ô∏è  No Gold tables found to analyze")
else:
    print(f"\nüìä Analyzing {len(tables_to_check)} tables for null values\n")
    
    for table_name in sorted(tables_to_check):
        try:
            print(f"\n{table_name}:")
            print("-" * 80)
            
            df = spark.table(table_name)
            total_rows = df.count()
            
            # Calculate null percentage for each column
            null_stats = []
            for col_name in df.columns:
                null_count = df.filter(col(col_name).isNull()).count()
                null_pct = (null_count / total_rows * 100) if total_rows > 0 else 0
                
                if null_count > 0:
                    null_stats.append({
                        "column": col_name,
                        "null_count": null_count,
                        "null_percentage": null_pct
                    })
            
            if null_stats:
                print(f"  {'Column':<30s} | {'Null Count':>12} | {'Percentage':>10}")
                print("  " + "-" * 76)
                for stat in sorted(null_stats, key=lambda x: x["null_percentage"], reverse=True):
                    status = "‚ö†Ô∏è" if stat["null_percentage"] > 10 else "‚ÑπÔ∏è"
                    print(f"  {status} {stat['column']:28s} | {stat['null_count']:>12,} | {stat['null_percentage']:>9.2f}%")
            else:
                print("  ‚úÖ No null values found in any column")
                
        except Exception as e:
            print(f"  ‚è≠Ô∏è  Skipping: {str(e)}")

## Check 3: Business Rule Validation

In [None]:
print("\n" + "="*80)
print("CHECK 3: Business Rule Validation")
print("="*80)

# Get all Gold fact tables
gold_facts = [t.name for t in spark.catalog.listTables() 
              if t.name.startswith("gold_fact")]

if not gold_facts:
    print("\n‚ö†Ô∏è  No Gold fact tables found for business rule validation")
else:
    print(f"\nüìä Validating business rules for {len(gold_facts)} fact table(s)\n")
    
    business_rule_violations = 0
    business_rule_checks = 0
    
    for fact_table in sorted(gold_facts):
        try:
            df = spark.table(fact_table)
            columns = df.columns
            total_rows = df.count()
            
            print(f"\nüîç Validating {fact_table}:")
            print("-" * 80)
            
            table_has_rules = False
            
            # Rule 1: Numeric amount/value columns should be non-negative
            amount_cols = [c for c in columns 
                          if any(keyword in c.lower() for keyword in ['amount', 'value', 'price', 'cost', 'revenue', 'total'])
                          and df.schema[c].dataType.simpleString() in ['double', 'decimal', 'float', 'int', 'bigint']]
            
            if amount_cols:
                table_has_rules = True
                for amt_col in amount_cols:
                    try:
                        business_rule_checks += 1
                        negative_amounts = df.filter(col(amt_col) < 0).count()
                        status = "‚úÖ PASS" if negative_amounts == 0 else "‚ùå FAIL"
                        
                        print(f"  {status} Non-negative values in '{amt_col}'")
                        print(f"       Violations: {negative_amounts:,} / {total_rows:,} rows")
                        
                        if negative_amounts > 0:
                            business_rule_violations += 1
                    except Exception as e:
                        print(f"  ‚ö†Ô∏è  Error checking {amt_col}: {str(e)}")
            
            # Rule 2: Quantity columns should be positive
            qty_cols = [c for c in columns 
                       if 'quantity' in c.lower() or 'qty' in c.lower()
                       and df.schema[c].dataType.simpleString() in ['double', 'decimal', 'float', 'int', 'bigint']]
            
            if qty_cols:
                table_has_rules = True
                for qty_col in qty_cols:
                    try:
                        business_rule_checks += 1
                        invalid_qty = df.filter(col(qty_col) <= 0).count()
                        status = "‚úÖ PASS" if invalid_qty == 0 else "‚ùå FAIL"
                        
                        print(f"  {status} Positive values in '{qty_col}'")
                        print(f"       Violations: {invalid_qty:,} / {total_rows:,} rows")
                        
                        if invalid_qty > 0:
                            business_rule_violations += 1
                    except Exception as e:
                        print(f"  ‚ö†Ô∏è  Error checking {qty_col}: {str(e)}")
            
            # Rule 3: Date logic - ship_date >= order_date, end_date >= start_date, etc.
            date_pairs = []
            
            # Check for common date pair patterns
            if 'ship_date_id' in columns and 'order_date_id' in columns:
                date_pairs.append(('ship_date_id', 'order_date_id', 'Ship date should be >= Order date'))
            if 'delivery_date_id' in columns and 'ship_date_id' in columns:
                date_pairs.append(('delivery_date_id', 'ship_date_id', 'Delivery date should be >= Ship date'))
            if 'end_date_id' in columns and 'start_date_id' in columns:
                date_pairs.append(('end_date_id', 'start_date_id', 'End date should be >= Start date'))
            if 'close_date_id' in columns and 'create_date_id' in columns:
                date_pairs.append(('close_date_id', 'create_date_id', 'Close date should be >= Create date'))
            
            if date_pairs:
                table_has_rules = True
                for later_date, earlier_date, rule_desc in date_pairs:
                    try:
                        business_rule_checks += 1
                        invalid_dates = df.filter(
                            (col(later_date).isNotNull()) & 
                            (col(earlier_date).isNotNull()) & 
                            (col(later_date) < col(earlier_date))
                        ).count()
                        
                        status = "‚úÖ PASS" if invalid_dates == 0 else "‚ùå FAIL"
                        
                        print(f"  {status} {rule_desc}")
                        print(f"       Violations: {invalid_dates:,} / {total_rows:,} rows")
                        
                        if invalid_dates > 0:
                            business_rule_violations += 1
                    except Exception as e:
                        print(f"  ‚ö†Ô∏è  Error checking date rule: {str(e)}")
            
            if not table_has_rules:
                print(f"  ‚ÑπÔ∏è  No standard business rules applicable to this table")
        
        except Exception as e:
            print(f"\n‚ùå Error processing {fact_table}: {str(e)}")
    
    # Summary
    print("\n" + "="*80)
    print("BUSINESS RULE VALIDATION SUMMARY")
    print("="*80)
    print(f"Total business rule checks: {business_rule_checks}")
    print(f"‚úÖ Passed: {business_rule_checks - business_rule_violations}")
    print(f"‚ùå Failed: {business_rule_violations}")
    
    if business_rule_violations == 0 and business_rule_checks > 0:
        print("\nüéâ All business rules validated successfully!")
    elif business_rule_checks == 0:
        print("\n‚ö†Ô∏è  No business rules were validated")
    else:
        print(f"\n‚ö†Ô∏è  {business_rule_violations} business rule violation(s) found")
    
    print("="*80)

## Check 4: Data Distribution Analysis

In [None]:
print("\n" + "="*80)
print("CHECK 4: Data Distribution Analysis")
print("="*80)

# Get all Gold fact tables
gold_facts = [t.name for t in spark.catalog.listTables() 
              if t.name.startswith("gold_fact")]

if not gold_facts:
    print("\n‚ö†Ô∏è  No Gold fact tables found for distribution analysis")
else:
    print(f"\nüìä Analyzing data distribution for {len(gold_facts)} fact table(s)\n")
    
    # Analyze each fact table
    for fact_table in sorted(gold_facts):
        try:
            print(f"\n{fact_table} - Record Count by Date:")
            print("-" * 80)
            
            df = spark.table(fact_table)
            
            # Find date columns
            date_cols = [c for c in df.columns if c.endswith('_date_id')]
            
            if date_cols:
                # Use first date column found
                date_col = date_cols[0]
                
                # Show distribution by year-month
                monthly_dist = df \
                    .withColumn("year_month", substring(col(date_col).cast("string"), 1, 6)) \
                    .groupBy("year_month") \
                    .agg(count("*").alias("record_count")) \
                    .orderBy("year_month")
                
                print(f"  Based on column: {date_col}")
                monthly_dist.show(12, truncate=False)
            else:
                # Just show total count
                total = df.count()
                print(f"  Total records: {total:,}")
                print(f"  (No date columns found for temporal distribution)")
            
            # If there's a status column, show distribution by status
            if "status" in df.columns:
                print(f"\n{fact_table} - Distribution by Status:")
                print("-" * 80)
                
                status_dist = df.groupBy("status") \
                    .agg(count("*").alias("count")) \
                    .withColumn("percentage", 
                                round(col("count") / df.count() * 100, 2)) \
                    .orderBy(desc("count"))
                
                status_dist.show(truncate=False)
            
        except Exception as e:
            print(f"  ‚è≠Ô∏è  Skipping {fact_table}: {str(e)}")

# Analyze dimension tables for categorical distributions
gold_dims = [t.name for t in spark.catalog.listTables() 
             if t.name.startswith("gold_dim") or 
                (t.name.startswith("dim") and not t.name.startswith("gold_"))]

            print(f"  ‚è≠Ô∏è  Skipping {dim_table}: {str(e)}")

if gold_dims:        except Exception as e:

    print(f"\n\nüìä Analyzing categorical distributions for {len(gold_dims)} dimension table(s)\n")            

                    dist.show(10, truncate=False)

    for dim_table in sorted(gold_dims)[:3]:  # Show first 3 dimensions to avoid too much output                

        try:                    .orderBy(desc("count"))

            df = spark.table(dim_table)                                round(col("count") / df.count() * 100, 2)) \

                                .withColumn("percentage", 

            # Find categorical columns (non-ID, non-date columns)                    .agg(count("*").alias("count")) \

            categorical_cols = [c for c in df.columns                 dist = df.groupBy(col_to_analyze) \

                               if not c.endswith('_id') and not c.endswith('_key')                 

                               and not 'date' in c.lower()                 print("-" * 80)

                               and df.schema[c].dataType.simpleString() == 'string']                print(f"\n{dim_table} - Distribution by {col_to_analyze}:")

                            

            if categorical_cols:                col_to_analyze = categorical_cols[0]
                # Show distribution for first categorical column

## Check 5: Anomaly Detection

In [None]:
print("\n" + "="*80)
print("CHECK 5: Anomaly Detection")
print("="*80)

# Get all Gold fact tables
gold_facts = [t.name for t in spark.catalog.listTables() 
              if t.name.startswith("gold_fact")]

if not gold_facts:
    print("\n‚ö†Ô∏è  No Gold fact tables found for anomaly detection")
else:
    print(f"\nüìä Detecting outliers in {len(gold_facts)} fact table(s) (>3œÉ from mean)\n")
    
    # Analyze each fact table
    for fact_table in sorted(gold_facts):
        try:
            df = spark.table(fact_table)
            
            # Find numeric amount/value columns
            numeric_cols = [c for c in df.columns 
                           if any(keyword in c.lower() for keyword in ['amount', 'value', 'price', 'cost', 'revenue'])
                           and df.schema[c].dataType.simpleString() in ['double', 'decimal', 'float', 'int', 'bigint']]
            
            if not numeric_cols:
                print(f"\n{fact_table}:")
                print(f"  ‚ÑπÔ∏è  No numeric amount columns found for outlier detection\n")
                continue
            
            # Analyze first numeric column found
            amount_col = numeric_cols[0]
            
            print(f"\n{fact_table} - Outlier Detection on '{amount_col}':")
            print("-" * 80)
            
            # Calculate statistics
            stats = df.select(
                mean(amount_col).alias("mean_amount"),
                stddev(amount_col).alias("stddev_amount"),
                min(amount_col).alias("min_amount"),

                max(amount_col).alias("max_amount")            print(f"  ‚è≠Ô∏è  Skipping: {str(e)}\n")

            ).collect()[0]            print(f"\n{fact_table}:")

                    except Exception as e:

            mean_val = stats["mean_amount"]            

            stddev_val = stats["stddev_amount"]                print(f"  ‚úÖ No outliers detected\n")

            min_val = stats["min_amount"]            else:

            max_val = stats["max_amount"]                outliers.select(*cols_to_show).show(5, truncate=False)

                            

            if mean_val is None or stddev_val is None:                    cols_to_show = df.columns[:5]  # Show first 5 columns

                print(f"  ‚ö†Ô∏è  Cannot calculate statistics (null values)\n")                if not cols_to_show:

                continue                cols_to_show = [c for c in df.columns if c in ['order_id', 'transaction_id', 'id'] + [amount_col]]

                            # Show key columns if they exist

            threshold = mean_val + (3 * stddev_val)                print(f"\n  Top 5 outliers:")

                        if outlier_count > 0:

            # Find outliers            

            outliers = df.filter(col(amount_col) > threshold) \            print(f"  Outliers found: {outlier_count:,} / {total_count:,} records ({outlier_count/total_count*100:.2f}%)")

                .orderBy(desc(amount_col))            print(f"  Outlier threshold (>3œÉ): ${threshold:,.2f}")

                        print(f"  Range: ${min_val:,.2f} to ${max_val:,.2f}")

            outlier_count = outliers.count()            print(f"  Std dev: ${stddev_val:,.2f}")

            total_count = df.count()            print(f"  Mean: ${mean_val:,.2f}")
            

## Quality Report Summary

In [None]:
print("\n" + "="*80)
print("DATA QUALITY REPORT - SUMMARY")
print("="*80)

# Compile overall score
checks_passed = 0
total_checks = 0

# Count integrity checks
if integrity_results:
    total_checks += len(integrity_results)
    checks_passed += sum(1 for r in integrity_results if r["passed"])

print(f"\n‚úÖ Checks Passed: {checks_passed}")
print(f"‚ö†Ô∏è  Checks Failed: {total_checks - checks_passed}")
print(f"üìä Total Checks: {total_checks}")

if total_checks > 0:
    quality_score = (checks_passed / total_checks) * 100
    print(f"\nüéØ Data Quality Score: {quality_score:.1f}%")
    
    if quality_score >= 90:
        print("   ‚úÖ EXCELLENT - Data is production-ready")
    elif quality_score >= 75:
        print("   ‚ö†Ô∏è  GOOD - Minor issues to address")
    else:
        print("   ‚ùå NEEDS IMPROVEMENT - Review failed checks")

print(f"\nCompletion Time: {datetime.now()}")
print("="*80)