In [0]:
%sh
echo "Hello from shell"

Hello from shell


In [0]:
# Data Quality Monitoring Framework - Data Profiling and Validation Component
# This notebook analyzes data to create quality profiles, remediate issues, and validate results

# Import necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import count, countDistinct, avg, min, max, col, lit, when, isnan, udf, expr, array, collect_list, struct
from pyspark.sql.types import *
import json
import datetime
import uuid
import pandas as pd
import urllib.parse
from sqlalchemy import create_engine

# Create input widgets for parameters
dbutils.widgets.text("dataset_name", "", "Dataset Name")
dbutils.widgets.text("run_mode", "manual", "Run Mode (manual/pipeline)")
dbutils.widgets.text("fileName", "", "File Name")
dbutils.widgets.text("run_id", "", "Run ID (optional)")

# Get parameter values
dataset_name = dbutils.widgets.get("dataset_name")
run_mode = dbutils.widgets.get("run_mode")
run_id_param = dbutils.widgets.get("run_id")

dataset_name = dbutils.widgets.get("dataset_name")
if dataset_name == "" or dataset_name is None:
    file_name = dbutils.widgets.get("fileName")
    # Assuming file name format is "dataset_dirty.csv", remove the suffix
    dataset_name = file_name.replace("_dirty.csv", "")

# Display information about the environment
print("Starting Data Quality Profiling and Validation")
print(f"Current date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Processing dataset: {dataset_name}")
print(f"Run mode: {run_mode}")

# Create a unique run ID for this execution or use provided one
run_id = run_id_param if run_id_param else datetime.datetime.now().strftime("%Y%m%d%H%M%S")
print(f"Run ID: {run_id}")

# Set up Azure Blob Storage Configuration
storage_account_name = "<storage_account_name>"  # Replace with your storage account name
raw_container = "raw-data"
profiled_container = "profiled-data1"
remediated_container = "remediated-data1"
certified_container = "certified-data"  # Create this container if it doesn't exist

# Configure storage access using account key
storage_account_key = "<your_storage_account_key>" 
spark.conf.set(f"fs.azure.account.key.{storage_account_name}.blob.core.windows.net", storage_account_key)

# Print confirmation
print(f"Storage account configuration for {storage_account_name} completed")

# SQL Database Connection Configuration
server = "<server_name>"
database = "<database_name>"
username = "<username>"
password = "<password>"

# JDBC URL for SQL Server
jdbc_url = f"jdbc:sqlserver://{server}:1433;database={database};user={username};password={password}"
connection_properties = {
    "user": username,
    "password": password,
    "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}

print("Database connection configured using JDBC")

###################
# PART 1: PROFILING
###################

print("\n=== PHASE 1: DATA PROFILING ===\n")

# Load the raw data
try:
    # Construct the file path using the wasbs protocol and dataset name parameter
    file_path = f"wasbs://{raw_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_dirty.csv"
    
    # Check if file exists before loading
    try:
        dirty_exists = len([f for f in dbutils.fs.ls(f"wasbs://{raw_container}@{storage_account_name}.blob.core.windows.net/") 
                           if f.name == f"{dataset_name}_dirty.csv"]) > 0
        
        if not dirty_exists:
            print(f"Warning: File {dataset_name}_dirty.csv not found. Attempting to list available files...")
            for f in dbutils.fs.ls(f"wasbs://{raw_container}@{storage_account_name}.blob.core.windows.net/"):
                print(f.name)
            raise FileNotFoundError(f"File {dataset_name}_dirty.csv not found")
    except Exception as e:
        print(f"Error checking file existence: {str(e)}")
    
    # Read the CSV data
    df = spark.read.format("csv") \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .load(file_path)
    
    # Show sample data
    print(f"Successfully loaded data from {file_path}")
    row_count = df.count()
    col_count = len(df.columns)
    print(f"Number of rows: {row_count}")
    print(f"Number of columns: {col_count}")
    print("\nSample data:")
    df.show(5)
    
except Exception as e:
    print(f"Error loading data: {str(e)}")
    dbutils.notebook.exit(f"Failed to load dataset: {dataset_name}")
    
# Normalize column names (replace dots with underscores)
def normalize_column_names(df):
    """
    Create a new DataFrame with normalized column names (dots replaced with underscores)
    Returns:
    - The normalized DataFrame
    - A mapping from original to normalized names
    - A mapping from normalized to original names
    """
    normalized_columns = {}
    reverse_mapping = {}
    select_expressions = []
    
    for col_name in df.columns:
        normalized_name = col_name.replace(".", "_")
        normalized_columns[col_name] = normalized_name
        reverse_mapping[normalized_name] = col_name
        select_expressions.append(f"`{col_name}` AS `{normalized_name}`")
    
    normalized_df = df.selectExpr(*select_expressions)
    
    # Print changed column names for transparency
    changed_columns = [f"{orig} → {norm}" for orig, norm in normalized_columns.items() if orig != norm]
    if changed_columns:
        print(f"Normalized {len(changed_columns)} column names:")
        for change in changed_columns:
            print(f"  - {change}")
    
    return normalized_df, normalized_columns, reverse_mapping

# Apply the normalization
df, column_mapping, reverse_mapping = normalize_column_names(df)


# Register the DataFrame as a temporary view for SQL operations with dataset-specific name
df.createOrReplaceTempView(f"{dataset_name}_data")

# Get column names and types
print("\nColumn Names and Types:")
column_info = []
for column_name, dtype in df.dtypes:
    print(f"- {column_name}: {dtype}")
    column_info.append({"name": column_name, "type": dtype})

# Total row count
total_rows = df.count()
print(f"\nTotal rows in dataset: {total_rows}")

# Generate column profiles
print("\nGenerating column profiles...")
column_profiles = {}

# Process each column
for col_info in column_info:
    column_name = col_info["name"]
    col_type = col_info["type"]
    
    print(f"Profiling column: {column_name}")
    
    try:
        # Create SQL query to analyze this column - using the dataset-specific view name
        # Use CASE expressions for counting and backticks for column names with special characters
        sql_query = f"""
        SELECT 
          '{column_name}' as column_name,
          '{col_type}' as data_type,
          COUNT(*) as total_count,
          COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) as null_count,
          CAST(COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) * 100.0 / COUNT(*) AS DOUBLE) as null_percentage
        FROM {dataset_name}_data
        """
        
        # For string columns, also count empty strings
        if col_type == "string":
            sql_query = f"""
            SELECT 
              '{column_name}' as column_name,
              '{col_type}' as data_type,
              COUNT(*) as total_count,
              COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) as null_count,
              CAST(COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) * 100.0 / COUNT(*) AS DOUBLE) as null_percentage,
              COUNT(CASE WHEN `{column_name}` = '' THEN 1 END) as empty_count,
              CAST(COUNT(CASE WHEN `{column_name}` = '' THEN 1 END) * 100.0 / COUNT(*) AS DOUBLE) as empty_percentage,
              COUNT(DISTINCT `{column_name}`) as distinct_count,
              CAST(COUNT(DISTINCT `{column_name}`) * 100.0 / COUNT(*) AS DOUBLE) as distinct_percentage
            FROM {dataset_name}_data
            """
        
        # For numeric columns, also compute min, max, avg
        elif col_type in ('int', 'double', 'float', 'bigint', 'decimal'):
            sql_query = f"""
            SELECT 
              '{column_name}' as column_name,
              '{col_type}' as data_type,
              COUNT(*) as total_count,
              COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) as null_count,
              CAST(COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) * 100.0 / COUNT(*) AS DOUBLE) as null_percentage,
              COUNT(DISTINCT `{column_name}`) as distinct_count,
              CAST(COUNT(DISTINCT `{column_name}`) * 100.0 / COUNT(*) AS DOUBLE) as distinct_percentage,
              MIN(`{column_name}`) as min_val,
              MAX(`{column_name}`) as max_val,
              AVG(`{column_name}`) as avg_val
            FROM {dataset_name}_data
            """
        
        # Run the query
        profile_result = spark.sql(sql_query)
        profile_data = profile_result.collect()[0].asDict()
        
        # Convert to our standard profile format
        profile = {
            "column_name": column_name,
            "data_type": col_type,
            "total_count": profile_data["total_count"],
            "null_count": profile_data["null_count"],
            "null_percentage": float(profile_data["null_percentage"])
        }
        
        # Add string-specific metrics
        if col_type == "string":
            profile["empty_count"] = profile_data["empty_count"]
            profile["empty_percentage"] = float(profile_data["empty_percentage"])
            profile["distinct_count"] = profile_data["distinct_count"]
            profile["distinct_percentage"] = float(profile_data["distinct_percentage"])
        
        # Add numeric-specific metrics
        elif col_type in ('int', 'double', 'float', 'bigint', 'decimal'):
            profile["distinct_count"] = profile_data["distinct_count"]
            profile["distinct_percentage"] = float(profile_data["distinct_percentage"])
            profile["min"] = profile_data["min_val"]
            profile["max"] = profile_data["max_val"]
            profile["avg"] = float(profile_data["avg_val"])
        
        # Get top values for both string and numeric columns
        try:
            value_counts_query = f"""
            SELECT 
              `{column_name}` as value,
              COUNT(*) as count,
              CAST(COUNT(*) * 100.0 / {total_rows} AS DOUBLE) as percentage
            FROM {dataset_name}_data
            WHERE `{column_name}` IS NOT NULL
            GROUP BY `{column_name}`
            ORDER BY COUNT(*) DESC
            LIMIT 10
            """
            
            value_counts_result = spark.sql(value_counts_query)
            
            # Convert to our standard format
            top_values = []
            for row in value_counts_result.collect():
                row_dict = row.asDict()
                top_values.append({
                    "value": str(row_dict["value"]),
                    "count": row_dict["count"],
                    "percentage": float(row_dict["percentage"])
                })
            
            profile["top_values"] = top_values
        except Exception as e:
            print(f"  Warning: Could not get top values: {str(e)}")
            profile["top_values"] = []
        
        # Store profile in our dictionary
        column_profiles[column_name] = profile
        
        # Print a summary of the profile
        print(f"  - Data type: {profile['data_type']}")
        print(f"  - Null count: {profile['null_count']} ({profile['null_percentage']:.2f}%)")
        if profile.get('empty_count') is not None:
            print(f"  - Empty count: {profile['empty_count']} ({profile['empty_percentage']:.2f}%)")
        if profile.get('distinct_count') is not None:
            print(f"  - Distinct count: {profile['distinct_count']} ({profile['distinct_percentage']:.2f}%)")
        
        # For numeric columns, show min/max/avg
        if profile.get('min') is not None:
            avg_value = profile.get('avg', 0)
            print(f"  - Min: {profile['min']}, Max: {profile['max']}, Avg: {float(avg_value):.2f}")
            
        # Show top values (if available)
        if len(profile.get('top_values', [])) > 0:
            print("  - Top values:")
            for i, val in enumerate(profile['top_values'][:3]):
                print(f"    {i+1}. {val['value']} ({val['count']} occurrences, {val['percentage']:.2f}%)")
                
        print("") # Empty line for readability
        
    except Exception as e:
        print(f"Error profiling column {column_name}: {str(e)}")

# Calculate overall quality metrics
total_columns = len(df.columns)
columns_with_nulls = sum(1 for p in column_profiles.values() if p["null_count"] > 0)
columns_with_high_nulls = sum(1 for p in column_profiles.values() if p["null_percentage"] > 10)

total_null_values = sum(p["null_count"] for p in column_profiles.values())
total_cells = total_rows * total_columns
completeness_score = 100 - (total_null_values / total_cells * 100) if total_cells > 0 else 0

# Display overall quality metrics
print("\nOverall Quality Metrics:")
print(f"Completeness Score: {completeness_score:.2f}%")
print(f"Columns with Nulls: {columns_with_nulls} of {total_columns} ({(columns_with_nulls/total_columns*100):.2f}%)")
print(f"Columns with >10% Nulls: {columns_with_high_nulls} of {total_columns} ({(columns_with_high_nulls/total_columns*100):.2f}%)")

# Add a timestamp column to track when profiling was done
df_with_timestamp = df.withColumn("profiling_timestamp", lit(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

# Save profiled data to the profiled-data container with dataset name in path
try:
    output_path = f"wasbs://{profiled_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_profiled_{run_id}.csv"
    
    df_with_timestamp.write \
        .format("csv") \
        .option("header", "true") \
        .mode("overwrite") \
        .save(output_path)
    
    print(f"\nSuccessfully saved profiled data to {output_path}")
except Exception as e:
    print(f"\nError saving profiled data: {str(e)}")

# Save profiling report to profiled-data container as JSON
try:
    # Create profile report
    profile_report = {
        "dataset_name": dataset_name,
        "run_id": run_id,
        "profile_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "total_rows": total_rows,
        "total_columns": total_columns,
        "total_cells": total_cells,
        "total_null_values": total_null_values,
        "completeness_score": completeness_score,
        "columns_with_nulls": columns_with_nulls,
        "columns_with_high_nulls": columns_with_high_nulls,
        "column_profiles": column_profiles
    }
    
    # Convert to JSON
    profile_json = json.dumps(profile_report)
    
    # Save to a DataFrame and write to blob storage
    profile_pdf = pd.DataFrame({"profile_json": [profile_json]})
    profile_spark_df = spark.createDataFrame(profile_pdf)
    
    profile_report_path = f"wasbs://{profiled_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_profile_report_{run_id}.json"
    profile_spark_df.write.text(profile_report_path)
    
    print(f"Successfully saved profile report to {profile_report_path}")
except Exception as e:
    print(f"Error saving profile report: {str(e)}")

#########################
# PART 2: RULES ENGINE
#########################

print("\n=== PHASE 2: RULES ENGINE ===\n")

# Define data quality rules - these can be adapted based on dataset or loaded from a config
# For simplicity, we'll use common rules that apply to most datasets
rules = [
    # COMPLETENESS rules - look for missing values represented as '?'
    {
        "rule_id": "R001",
        "rule_name": "GenericMissingValue",
        "description": "Columns should not contain missing values (marked as '?')",
        "rule_type": "COMPLETENESS",
        "dataset": dataset_name,
        "column": "*", # This will be replaced with actual column names
        "severity": "HIGH",
        "sql_condition": "COLUMN_PLACEHOLDER = '?'", # Will be replaced
        "can_auto_remediate": True,
        "remediation_action": {
            "type": "REPLACE",
            "value": "Unknown"
        }
    },
    
    # FORMAT rules - numeric range validation
    {
        "rule_id": "R002",
        "rule_name": "NumericRange",
        "description": "Numeric values should be within a reasonable range",
        "rule_type": "FORMAT",
        "dataset": dataset_name,
        "column": "*", # This will be replaced with actual numeric columns
        "severity": "MEDIUM",
        "sql_condition": "COLUMN_PLACEHOLDER < 0 OR COLUMN_PLACEHOLDER > 1000000", # Will be replaced
        "can_auto_remediate": True,
        "remediation_action": {
            "type": "TRUNCATE",
            "min": 0,
            "max": 1000000
        }
    },
    
    # CONSISTENCY rules - case standardization
    {
        "rule_id": "R003",
        "rule_name": "CaseConsistency",
        "description": "Categorical values should have consistent casing",
        "rule_type": "CONSISTENCY",
        "dataset": dataset_name,
        "column": "*", # This will be replaced
        "severity": "LOW",
        "sql_condition": "COLUMN_PLACEHOLDER != UPPER(COLUMN_PLACEHOLDER) AND COLUMN_PLACEHOLDER != LOWER(COLUMN_PLACEHOLDER)",
        "can_auto_remediate": True,
        "remediation_action": {
            "type": "STANDARDIZE",
            "mappings": {} # Will be populated dynamically
        }
    }
]

# Instantiate specific rules for this dataset based on column types
instantiated_rules = []
rule_counter = 1

# Process string columns for missing values and consistency
for col_info in column_info:
    column_name = col_info["name"]
    col_type = col_info["type"]
    
    # Create missing value rules for string columns
    if col_type == "string":
        # Look for top values that might indicate missing data: '?', 'NA', 'null', 'UNKNOWN', etc.
        missing_indicators = ['?', 'NA', 'N/A', 'null', 'NULL', 'unknown', 'UNKNOWN', '']
        
        # Check if column has any of these values
        for indicator in missing_indicators:
            check_query = f"""
            SELECT COUNT(*) as count 
            FROM {dataset_name}_data 
            WHERE `{column_name}` = '{indicator}'
            """
            count = spark.sql(check_query).collect()[0]["count"]
            
            if count > 0:
                # Create a completeness rule for this column
                rule_id = f"R{rule_counter:03d}"
                rule_counter += 1
                
                completeness_rule = {
                    "rule_id": rule_id,
                    "rule_name": f"{column_name}_NotMissing",
                    "description": f"{column_name} should not contain missing values (marked as '{indicator}')",
                    "rule_type": "COMPLETENESS",
                    "dataset": dataset_name,
                    "column": column_name,
                    "severity": "HIGH",
                    "sql_condition": f"`{column_name}` = '{indicator}'",
                    "can_auto_remediate": True,
                    "remediation_action": {
                        "type": "REPLACE",
                        "value": "Unknown"
                    }
                }
                instantiated_rules.append(completeness_rule)
    
    # Create range validation rules for numeric columns
    elif col_type in ('int', 'double', 'float', 'bigint', 'decimal'):
        # Get actual min/max values to set reasonable bounds
        stats_query = f"""
        SELECT MIN(`{column_name}`) as min_val, MAX(`{column_name}`) as max_val
        FROM {dataset_name}_data
        """
        stats = spark.sql(stats_query).collect()[0]
        min_val = stats["min_val"]
        max_val = stats["max_val"]
        
        # Only create rule if we have valid min/max
        if min_val is not None and max_val is not None:
            # Set lower bound to slightly below min (or 0 if positive)
            lower_bound = 0 if min_val > 0 and min_val < 1000 else min_val
            # Set upper bound to slightly above max
            upper_bound = max_val * 1.5 if max_val > 0 else 1000
            
            rule_id = f"R{rule_counter:03d}"
            rule_counter += 1
            
            range_rule = {
                "rule_id": rule_id,
                "rule_name": f"{column_name}_ValidRange",
                "description": f"{column_name} should be between {lower_bound} and {upper_bound}",
                "rule_type": "FORMAT",
                "dataset": dataset_name,
                "column": column_name,
                "severity": "MEDIUM",
                "sql_condition": f"`{column_name}` < {lower_bound} OR `{column_name}` > {upper_bound}",
                "can_auto_remediate": True,
                "remediation_action": {
                    "type": "TRUNCATE",
                    "min": lower_bound,
                    "max": upper_bound
                }
            }
            instantiated_rules.append(range_rule)

# Look for case inconsistency issues in string columns
for col_info in column_info:
    column_name = col_info["name"]
    col_type = col_info["type"]
    
    if col_type == "string":
        # Find values that might have case inconsistency
        case_query = f"""
        SELECT DISTINCT `{column_name}` as value
        FROM {dataset_name}_data
        WHERE `{column_name}` IS NOT NULL 
          AND `{column_name}` != ''
          AND `{column_name}` != UPPER(`{column_name}`)
          AND `{column_name}` != LOWER(`{column_name}`)
        LIMIT 20
        """
        
        case_variations = spark.sql(case_query).collect()
        
        if len(case_variations) > 0:
            # Create mappings for standardization
            mappings = {}
            for row in case_variations:
                value = row["value"]
                # Standard form is proper case (first letter uppercase)
                std_value = value.title()
                mappings[value] = std_value
            
            if mappings:
                rule_id = f"R{rule_counter:03d}"
                rule_counter += 1
                
                # Build SQL condition for case variations
                conditions = []
                for orig_val in mappings.keys():
                    conditions.append(f"`{column_name}` = '{orig_val}'")
                
                sql_condition = " OR ".join(conditions) if conditions else f"`{column_name}` = 'NO_MATCH'"
                
                consistency_rule = {
                    "rule_id": rule_id,
                    "rule_name": f"{column_name}_Consistency",
                    "description": f"{column_name} values should be consistently formatted",
                    "rule_type": "CONSISTENCY",
                    "dataset": dataset_name,
                    "column": column_name,
                    "severity": "LOW",
                    "sql_condition": sql_condition,
                    "can_auto_remediate": True,
                    "remediation_action": {
                        "type": "STANDARDIZE",
                        "mappings": mappings
                    }
                }
                instantiated_rules.append(consistency_rule)

# If no specific rules were created, use some generic ones
if not instantiated_rules:
    print("No dataset-specific rules were created. Using generic rules instead.")
    # Add some generic rules here
    instantiated_rules = rules

# Evaluate each rule and record violations
print("\nEvaluating rules against the data...")
issues = []

for rule in instantiated_rules:
    try:
        print(f"Evaluating rule: {rule['rule_name']} ({rule['rule_id']})")
        
        # Run query to find violations
        violations_query = f"SELECT * FROM {dataset_name}_data WHERE {rule['sql_condition']}"
        violations = spark.sql(violations_query)
        violation_count = violations.count()
        
        # If violations found, record them
        if violation_count > 0:
            print(f"  Found {violation_count} violations")
            
            # Get sample violations for debugging/display
            sample_violations = violations.limit(5).collect()
            sample_values = []
            for row in sample_violations:
                sample_values.append(str(row[rule["column"]]))
            
            # Create issue record
            issue = {
                "issue_id": f"ISSUE_{rule['rule_id']}_{run_id}",
                "rule_id": rule["rule_id"],
                "rule_name": rule["rule_name"],
                "dataset_name": rule["dataset"],
                "column_name": rule["column"],
                "detection_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "issue_count": violation_count,
                "sample_values": sample_values,
                "status": "Open",
                "severity": rule["severity"],
                "can_remediate": rule["can_auto_remediate"],
                "remediation_action": rule["remediation_action"]
            }
            issues.append(issue)
            
            # Register violations for remediation if needed
            if rule["can_auto_remediate"]:
                violations.createOrReplaceTempView(f"violations_{rule['rule_id']}")
        else:
            print("  No violations found")
            
    except Exception as e:
        print(f"Error evaluating rule {rule['rule_name']}: {str(e)}")

# Save issues to a file
try:
    # Convert to JSON
    issues_json = json.dumps(issues)
    
    # Save to a DataFrame and write to blob storage
    issues_pdf = pd.DataFrame({"issues_json": [issues_json]})
    issues_spark_df = spark.createDataFrame(issues_pdf)
    
    issues_path = f"wasbs://{profiled_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_data_quality_issues_{run_id}.json"
    issues_spark_df.write.text(issues_path)
    
    print(f"Saved issues to {issues_path}")
except Exception as e:
    print(f"Error saving issues: {str(e)}")

# Generate summary statistics
print("\nRule Evaluation Summary:")
print(f"Total rules evaluated: {len(instantiated_rules)}")
print(f"Rules with violations: {len(issues)}")
total_violations = sum(issue["issue_count"] for issue in issues)
print(f"Total violations found: {total_violations}")

# Group by severity
severity_counts = {}
for issue in issues:
    severity = issue["severity"]
    issue_count = issue["issue_count"]
    severity_counts[severity] = severity_counts.get(severity, 0) + issue_count

print("\nViolations by Severity:")
for severity, count in severity_counts.items():
    print(f"  {severity}: {count} violations")

#############################
# PART 3: AUTO-REMEDIATION
#############################

print("\n=== PHASE 3: AUTOMATED REMEDIATION ===\n")

# Function to apply remediation actions based on rule type
def apply_remediation(dataframe, rule):
    """Apply a remediation action to the DataFrame based on the rule"""
    rule_id = rule["rule_id"]
    rule_name = rule["rule_name"]
    column_name = rule["column"]
    
    # Get the SQL condition
    sql_condition = rule["sql_condition"]
    
    # Get the remediation action
    remediation_action = rule["remediation_action"]
    action_type = remediation_action.get("type", "").upper()
    
    print(f"Applying remediation for rule {rule_id} - {rule_name}")
    print(f"  Column: {column_name}")
    print(f"  Action: {action_type}")
    
    # Create a temporary view of the data before remediation
    dataframe.createOrReplaceTempView("before_remediation")
    
    # Count violations before remediation
    violations_before = spark.sql(f"SELECT COUNT(*) as count FROM before_remediation WHERE {sql_condition}").collect()[0]["count"]
    print(f"  Violations before remediation: {violations_before}")
    
    # Apply different remediation types
    if action_type == "REPLACE":
        # Replace specific values
        replacement_value = remediation_action.get("value", "")
        print(f"  Replacing values with: {replacement_value}")
        
        # Get the values to check
        check_expr = f"CASE WHEN {sql_condition} THEN true ELSE false END"
        
        # Apply remediation using when/otherwise
        dataframe = dataframe.withColumn(
            column_name,
            when(expr(check_expr), lit(replacement_value)).otherwise(col(f"`{column_name}`"))
        )
        
    elif action_type == "TRUNCATE":
        # Truncate values to a specific range
        min_value = remediation_action.get("min")
        max_value = remediation_action.get("max")
        print(f"  Truncating values to range: {min_value} - {max_value}")
        
        # For numeric columns, apply min/max truncation
        col_type = next((c["type"] for c in column_info if c["name"] == column_name), None)
        
        if col_type in ('int', 'double', 'float', 'bigint', 'decimal'):
            dataframe = dataframe.withColumn(
                column_name,
                when(col(f"`{column_name}`") < min_value, lit(min_value))
                .when(col(f"`{column_name}`") > max_value, lit(max_value))
                .otherwise(col(f"`{column_name}`"))
            )
        else:  # String column that needs to be cast
               dataframe = dataframe.withColumn(
                column_name,
                when(expr(f"CAST(`{column_name}` AS INT) < {min_value}"), lit(str(min_value)))
                .when(expr(f"CAST(`{column_name}` AS INT) > {max_value}"), lit(str(max_value)))
                .otherwise(col(f"`{column_name}`"))
            )
        
    elif action_type == "STANDARDIZE":
        # Standardize values based on mappings
        mappings = remediation_action.get("mappings", {})
        print(f"  Standardizing values with mappings: {mappings}")
        
        # Build a case statement to handle standardization
        case_expr = col(f"`{column_name}`")
        for original, standard in mappings.items():
            case_expr = when(col(f"`{column_name}`") == original, lit(standard)).otherwise(case_expr)
        
        dataframe = dataframe.withColumn(column_name, case_expr)
    
    # Create a temporary view of the data after remediation
    dataframe.createOrReplaceTempView("after_remediation")
    
    # Count violations after remediation
    violations_after = spark.sql(f"SELECT COUNT(*) as count FROM after_remediation WHERE {sql_condition}").collect()[0]["count"]
    print(f"  Violations after remediation: {violations_after}")
    print(f"  Fixed: {violations_before - violations_after} records")
    
    return dataframe, violations_before, violations_after

# Track remediation results
remediation_results = []

# Process each rule with auto-remediation
for rule in instantiated_rules:
    if rule["can_auto_remediate"]:
        try:
            # Apply the remediation
            df, before, after = apply_remediation(df, rule)
            
            # Record the result
            result = {
                "rule_id": rule["rule_id"],
                "rule_name": rule["rule_name"],
                "column_name": rule["column"],
                "violations_before": before,
                "violations_after": after,
                "fixed_count": before - after
            }
            remediation_results.append(result)
            
            # Update the issue status
            for issue in issues:
                if issue["rule_id"] == rule["rule_id"]:
                    issue["status"] = "Remediated"
                    issue["remediation_date"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    issue["remediation_notes"] = f"Auto-remediated {before - after} records"
            
        except Exception as e:
            print(f"Error applying remediation for rule {rule['rule_id']}: {str(e)}")

# Save updated issues to a file
try:
    # Convert to JSON
    issues_json = json.dumps(issues)
    
    # Save to a DataFrame and write to blob storage
    issues_pdf = pd.DataFrame({"issues_json": [issues_json]})
    issues_spark_df = spark.createDataFrame(issues_pdf)
    
    issues_path = f"wasbs://{remediated_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_data_quality_issues_updated_{run_id}.json"
    issues_spark_df.write.text(issues_path)
    
    print(f"Saved updated issues to {issues_path}")
except Exception as e:
    print(f"Error saving updated issues: {str(e)}")

# Save remediation results
try:
    # Convert to JSON
    remediation_json = json.dumps(remediation_results)
    
    # Save to a DataFrame and write to blob storage
    remediation_pdf = pd.DataFrame({"remediation_json": [remediation_json]})
    remediation_spark_df = spark.createDataFrame(remediation_pdf)
    
    remediation_path = f"wasbs://{remediated_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_remediation_results_{run_id}.json"
    remediation_spark_df.write.text(remediation_path)
    
    print(f"Saved remediation results to {remediation_path}")
except Exception as e:
    print(f"Error saving remediation results: {str(e)}")

# Save the remediated data
try:
    # Add a remediation timestamp
    df_remediated = df.withColumn("remediation_timestamp", lit(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    # Save to the remediated container
    remediated_path = f"wasbs://{remediated_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_remediated_{run_id}.csv"
    df_remediated.write.format("csv").option("header", "true").mode("overwrite").save(remediated_path)
    
    print(f"Saved remediated data to {remediated_path}")
except Exception as e:
    print(f"Error saving remediated data: {str(e)}")

#########################
# PART 4: DATA VALIDATION
#########################

print("\n=== PHASE 4: DATA VALIDATION ===\n")

# Load the original dirty dataset and remediated dataset for comparison
try:
    # The dirty dataset is already loaded as 'df'
    dirty_df = df
    
    # Load the remediated dataset
    remediated_file_path = f"wasbs://{remediated_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_remediated_{run_id}.csv"
    remediated_df = spark.read.format("csv").option("header", "true").option("inferSchema", "true").load(remediated_file_path)
    
    print(f"Original Dirty Dataset: {dirty_df.count()} rows, {len(dirty_df.columns)} columns")
    print(f"Remediated Dataset: {remediated_df.count()} rows, {len(remediated_df.columns)} columns")
except Exception as e:
    print(f"Error loading datasets for validation: {str(e)}")
    # Continue with available data

# Define validation functions
def calculate_completeness(df):
    """Calculate what percentage of all values in the dataframe are non-null"""
    total_cells = df.count() * len(df.columns)
    null_counts = []
    for c in df.columns:
        try:
            null_count = df.filter(df[c].isNull() | (df[c] == "") | (df[c] == "?")).count()
            null_counts.append(null_count)
        except Exception as e:
            print(f"Error counting nulls in column {c}: {str(e)}")
            null_counts.append(0)
    
    total_nulls = sum(null_counts)
    return 100 - (total_nulls / total_cells * 100) if total_cells > 0 else 0

def calculate_column_completeness(df):
    """Calculate completeness for each column"""
    total_rows = df.count()
    result = {}
    for column in df.columns:
        try:
            null_count = df.filter(df[column].isNull() | (df[column] == "") | (df[column] == "?")).count()
            completeness = 100 - (null_count / total_rows * 100) if total_rows > 0 else 0
            result[column] = completeness
        except Exception as e:
            print(f"Error calculating completeness for column {column}: {str(e)}")
            result[column] = 0
    return result

def calculate_data_type_consistency(df):
    """Check if values in each column conform to expected data types"""
    result = {}
    for column, data_type in df.dtypes:
        # Skip binary columns
        if data_type == "binary":
            result[column] = 100
            continue
            
        # For numeric columns, check if values can be cast to the expected type
        if data_type in ["int", "bigint", "double", "float"]:
            # Use a SQL expression to attempt casting and count failures
            try:
                invalid_count = df.selectExpr(f"CASE WHEN CAST({column} AS {data_type}) IS NULL AND {column} IS NOT NULL THEN 1 ELSE 0 END AS invalid").agg(F.sum("invalid")).collect()[0][0]
                if invalid_count is None:
                    invalid_count = 0
                total_count = df.filter(df[column].isNotNull()).count()
                consistency = 100 - (invalid_count / total_count * 100) if total_count > 0 else 0
                result[column] = consistency
            except:
                # If we get an error trying to evaluate type consistency, assume it's inconsistent
                result[column] = 0
        else:
            # For non-numeric columns, we'll consider them consistent (could be enhanced)
            result[column] = 100
    return result

def detect_outliers(df, numeric_cols):
    """Detect outliers using IQR method"""
    result = {}
    for col in numeric_cols:
        try:
            # Calculate quartiles using approxQuantile for efficiency
            quartiles = df.approxQuantile(col, [0.25, 0.75], 0.05)
            q1, q3 = quartiles[0], quartiles[1]
            iqr = q3 - q1
            
            # Define outlier bounds
            lower_bound = q1 - 1.5 * iqr
            upper_bound = q3 + 1.5 * iqr
            
            # Count outliers
            outlier_count = df.filter((df[col] < lower_bound) | (df[col] > upper_bound)).count()
            total_count = df.filter(df[col].isNotNull()).count()
            
            # Calculate percentage of non-outliers
            outlier_free_pct = 100 - (outlier_count / total_count * 100) if total_count > 0 else 0
            result[col] = outlier_free_pct
        except Exception as e:
            print(f"Error detecting outliers in column {col}: {str(e)}")
            # If we can't calculate outliers, assume the column is fine
            result[col] = 100
    return result

def calculate_value_consistency(df, col_name, valid_values):
    """Check if values in a column are from a valid set"""
    try:
        total_count = df.filter(df[col_name].isNotNull()).count()
        if total_count == 0:
            return 100
        
        invalid_count = df.filter(~df[col_name].isin(valid_values) & df[col_name].isNotNull()).count()
        return 100 - (invalid_count / total_count * 100)
    except Exception as e:
        print(f"Error calculating value consistency for column {col_name}: {str(e)}")
        return 0

# Perform comprehensive validation
print("\nPerforming comprehensive data validation...")

# 1. Completeness validation (null or missing values)
try:
    dirty_completeness = calculate_completeness(dirty_df)
    remediated_completeness = calculate_completeness(remediated_df)
    print(f"Completeness score - Original: {dirty_completeness:.2f}%, Remediated: {remediated_completeness:.2f}%")
except Exception as e:
    print(f"Error calculating completeness: {str(e)}")
    dirty_completeness = 0
    remediated_completeness = 0

# 2. Column-level completeness
try:
    dirty_column_completeness = calculate_column_completeness(dirty_df)
    remediated_column_completeness = calculate_column_completeness(remediated_df)
    print(f"Column-level completeness calculated for {len(remediated_column_completeness)} columns")
except Exception as e:
    print(f"Error calculating column completeness: {str(e)}")
    dirty_column_completeness = {}
    remediated_column_completeness = {}

# 3. Data type consistency
try:
    remediated_type_consistency = calculate_data_type_consistency(remediated_df)
    print(f"Data type consistency calculated for {len(remediated_type_consistency)} columns")
except Exception as e:
    print(f"Error calculating data type consistency: {str(e)}")
    remediated_type_consistency = {}

# 4. Outlier detection for numeric columns
try:
    numeric_columns = [c[0] for c in remediated_df.dtypes if c[1] in ['int', 'double', 'float', 'bigint']]
    outlier_metrics = detect_outliers(remediated_df, numeric_columns)
    print(f"Outlier detection performed on {len(outlier_metrics)} numeric columns")
except Exception as e:
    print(f"Error detecting outliers: {str(e)}")
    outlier_metrics = {}

# 5. Value consistency checks for categorical columns (example for 'sex' column)
column_validation_rules = {}

# Example rule for sex column (customize based on your data)
try:
    if 'sex' in remediated_df.columns:
        sex_consistency = calculate_value_consistency(remediated_df, 'sex', ['Male', 'Female', 'male', 'female'])
        column_validation_rules['sex'] = {'name': 'Valid Sex Values', 'score': sex_consistency}
        print(f"Sex column consistency score: {sex_consistency:.2f}%")
except Exception as e:
    print(f"Error calculating sex column consistency: {str(e)}")

# Example rule for workclass column
try:
    if 'workclass' in remediated_df.columns:
        workclass_valid_values = ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 
                                'Local-gov', 'State-gov', 'Without-pay', 'Never-worked', 'Unknown']
        workclass_consistency = calculate_value_consistency(remediated_df, 'workclass', workclass_valid_values)
        column_validation_rules['workclass'] = {'name': 'Valid Workclass Values', 'score': workclass_consistency}
        print(f"Workclass column consistency score: {workclass_consistency:.2f}%")
except Exception as e:
    print(f"Error calculating workclass column consistency: {str(e)}")

# Calculate overall validation scores
validation_weights = {
    'completeness': 0.4,            # 40% weight for completeness
    'column_completeness': 0.2,     # 20% weight for column-level completeness
    'type_consistency': 0.2,        # 20% weight for data type consistency
    'outliers': 0.1,                # 10% weight for outlier detection
    'value_consistency': 0.1        # 10% weight for value consistency rules
}

# Calculate weighted scores
completeness_score = remediated_completeness * validation_weights['completeness']

# Average column completeness
if remediated_column_completeness:
    avg_column_completeness = sum(remediated_column_completeness.values()) / len(remediated_column_completeness)
    column_completeness_score = avg_column_completeness * validation_weights['column_completeness']
else:
    avg_column_completeness = 0
    column_completeness_score = 0

# Average type consistency
if remediated_type_consistency:
   avg_type_consistency = sum(remediated_type_consistency.values()) / len(remediated_type_consistency)
   type_consistency_score = avg_type_consistency * validation_weights['type_consistency']
else:
    avg_type_consistency = 0
    type_consistency_score = 0

# Average outlier score
if outlier_metrics:
    avg_outlier_score = sum(outlier_metrics.values()) / len(outlier_metrics)
    outlier_score = avg_outlier_score * validation_weights['outliers']
else:
    avg_outlier_score = 100
    outlier_score = 100 * validation_weights['outliers']

# Value consistency rules
rule_scores = [rule['score'] for rule in column_validation_rules.values()]
if rule_scores:
    avg_rule_score = sum(rule_scores) / len(rule_scores)
    value_consistency_score = avg_rule_score * validation_weights['value_consistency']
else:
    avg_rule_score = 100
    value_consistency_score = 100 * validation_weights['value_consistency']

# Total cleanliness score
cleanliness_score = completeness_score + column_completeness_score + type_consistency_score + outlier_score + value_consistency_score

# Determine cleanliness status
if cleanliness_score >= 98:
    cleanliness_status = "CERTIFIED"
    certification_flag = True
elif cleanliness_score >= 90:
    cleanliness_status = "ACCEPTABLE"
    certification_flag = True
else:
    cleanliness_status = "NEEDS_IMPROVEMENT"
    certification_flag = False

# Count remaining issues
try:
    count_dirty_nulls = sum([dirty_df.filter(dirty_df[c].isNull() | (dirty_df[c] == "") | (dirty_df[c] == "?")).count() for c in dirty_df.columns])
    count_remediated_nulls = sum([remediated_df.filter(remediated_df[c].isNull() | (remediated_df[c] == "") | (remediated_df[c] == "?")).count() for c in remediated_df.columns])
    issues_fixed = count_dirty_nulls - count_remediated_nulls
    remaining_issues = count_remediated_nulls
    print(f"Null values - Original: {count_dirty_nulls}, Remediated: {count_remediated_nulls}, Fixed: {issues_fixed}")
except Exception as e:
    print(f"Error counting null values: {str(e)}")
    count_dirty_nulls = total_violations
    count_remediated_nulls = total_violations - sum(result["fixed_count"] for result in remediation_results)
    issues_fixed = sum(result["fixed_count"] for result in remediation_results)
    remaining_issues = count_remediated_nulls

# Create validation report
validation_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
validation_id = str(uuid.uuid4())

validation_report = {
    "validationId": validation_id,
    "runId": run_id,
    "datasetName": dataset_name,
    "validationDate": validation_timestamp,
    "inputMetrics": {
        "dirtyCompleteness": dirty_completeness,
        "remediatedCompleteness": remediated_completeness,
        "dirtyNullCount": count_dirty_nulls,
        "remediatedNullCount": count_remediated_nulls,
        "issuesFixed": issues_fixed,
        "remainingIssues": remaining_issues
    },
    "validationMetrics": {
        "completenessScore": completeness_score,
        "columnCompletenessScore": column_completeness_score,
        "typeConsistencyScore": type_consistency_score,
        "outlierScore": outlier_score,
        "valueConsistencyScore": value_consistency_score
    },
    "columnMetrics": {
        "columnCompleteness": remediated_column_completeness,
        "typeConsistency": remediated_type_consistency,
        "outlierDetection": outlier_metrics
    },
    "ruleValidations": column_validation_rules,
    "overallMetrics": {
        "cleanlinessScore": cleanliness_score,
        "cleanlinessStatus": cleanliness_status,
        "certificationFlag": certification_flag,
        "remainingIssuesCount": remaining_issues
    }
}

# Convert to JSON
validation_report_json = json.dumps(validation_report, indent=2)
print("\nValidation Report Summary:")
print(f"Cleanliness Score: {cleanliness_score:.2f}%")
print(f"Cleanliness Status: {cleanliness_status}")
print(f"Certification: {'CERTIFIED' if certification_flag else 'NOT CERTIFIED'}")
print(f"Remaining Issues: {remaining_issues}")

# Save report as JSON to remediated container
try:
    validation_report_path = f"wasbs://{remediated_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_validation_{run_id}.json"
    validation_report_pdf = pd.DataFrame({"validation_json": [validation_report_json]})
    spark.createDataFrame(validation_report_pdf).write.text(validation_report_path)
    print(f"Validation report saved to: {validation_report_path}")
except Exception as e:
    print(f"Error saving validation report: {str(e)}")

# Export validation results to SQL database
try:
    # Create DataQualityValidation table if it doesn't exist
    create_validation_table_query = """
    IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'DataQualityValidation')
    BEGIN
        CREATE TABLE dbo.DataQualityValidation (
            ValidationID int IDENTITY(1,1) PRIMARY KEY,
            RunID varchar(50) NULL,
            DatasetName varchar(100) NULL,
            ValidationDate datetime NULL,
            CleanlinessScore float NULL,
            CleanlinessStatus varchar(50) NULL,
            RemainingIssuesCount int NULL,
            CertificationFlag bit NULL
        )
    END
    """
    
    # We can't easily execute DDL through JDBC, so assume the table exists or has been created
    
    # Prepare validation data
    validation_data = [(
        run_id,
        dataset_name,
        validation_timestamp,
        float(cleanliness_score),
        cleanliness_status,
        int(remaining_issues),
        certification_flag
    )]

    validation_schema = StructType([
        StructField("RunID", StringType(), True),
        StructField("DatasetName", StringType(), True),
        StructField("ValidationDate", StringType(), True),
        StructField("CleanlinessScore", FloatType(), True),
        StructField("CleanlinessStatus", StringType(), True),
        StructField("RemainingIssuesCount", IntegerType(), True),
        StructField("CertificationFlag", BooleanType(), True)
    ])

    validation_df = spark.createDataFrame(validation_data, validation_schema)

    # Write to SQL
    validation_df.write \
        .format("jdbc") \
        .option("url", jdbc_url) \
        .option("dbtable", "DataQualityValidation") \
        .option("user", username) \
        .option("password", password) \
        .mode("append") \
        .save()

    print("Validation results exported to SQL database")
except Exception as e:
    print(f"Error exporting validation to SQL: {str(e)}")

# Copy certified data to certified container if it meets the threshold
if certification_flag:
    try:
        # Create certified container if it doesn't exist
        try:
            dbutils.fs.ls(f"wasbs://{certified_container}@{storage_account_name}.blob.core.windows.net/")
        except:
            print(f"Certified container '{certified_container}' doesn't exist. Data will be saved to remediated container.")
            
        # Copy data to certified container
        certified_file_path = f"wasbs://{certified_container}@{storage_account_name}.blob.core.windows.net/{dataset_name}_certified_{run_id}.csv"
        
        # Read and write to copy data
        remediated_df.write.format("csv").option("header", "true").mode("overwrite").save(certified_file_path)
        
        print(f"Dataset certified and copied to: {certified_file_path}")
    except Exception as e:
        print(f"Error copying to certified container: {str(e)}")
        print(f"Dataset certified but remained in remediated container")
else:
    print(f"Dataset did not meet certification threshold (score: {cleanliness_score:.2f})")

#########################
# PART 5: SQL EXPORT
#########################

print("\n=== PHASE 5: SQL EXPORT ===\n")

# Function to write DataFrame to SQL using JDBC
def write_to_sql_jdbc(df, table_name):
    try:
        # Write DataFrame to SQL Server table
        df.write \
          .format("jdbc") \
          .option("url", jdbc_url) \
          .option("dbtable", table_name) \
          .mode("append") \
          .save()
        
        print(f"Successfully exported data to {table_name} table using JDBC")
        return True
    except Exception as e:
        print(f"Error exporting to {table_name}: {str(e)}")
        return False

# Export metrics to SQL
def export_metrics_to_sql(run_id, profile_report, remediation_results):
    try:
        # Calculate metrics
        total_issues = sum(item.get("violations_before", 0) for item in remediation_results)
        remediated_issues = sum(item.get("fixed_count", 0) for item in remediation_results)
        
        # Prepare metrics data
        metrics_data = [{
            'RunID': run_id,
            'DatasetName': dataset_name,  # Use the dataset name parameter
            'RunDate': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            'TotalRows': int(profile_report.get('total_rows', 0)),
            'TotalIssues': int(total_issues),
            'RemediatedIssues': int(remediated_issues),
            'CompletenessScore': float(profile_report.get('completeness_score', 0))
        }]
        
        # Create Spark DataFrame
        metrics_df = spark.createDataFrame(metrics_data)
        
        # Export to SQL
        return write_to_sql_jdbc(metrics_df, "DataQualityMetrics")
    except Exception as e:
        print(f"Error exporting metrics to SQL: {str(e)}")
        return False

# Export profiles to SQL
def export_profiles_to_sql(run_id, profile_report):
    try:
        # Process column profiles
        profiles_data = []
        
        for col_name, profile in profile_report.get('column_profiles', {}).items():
            # Handle avg value
            avg_value = profile.get('avg', 0)
            if isinstance(avg_value, str):
                try:
                    avg_value = float(avg_value)
                except:
                    avg_value = 0
            
            # Create profile entry
            profile_entry = {
                'RunID': run_id,
                'ColumnName': col_name,
                'DataType': profile.get('data_type', 'Unknown'),
                'NullCount': int(profile.get('null_count', 0)),
                'NullPercentage': float(profile.get('null_percentage', 0)),
                'DistinctCount': int(profile.get('distinct_count', 0)),
                'MinValue': str(profile.get('min', 'N/A')),
                'MaxValue': str(profile.get('max', 'N/A')),
                'AvgValue': float(avg_value)
            }
            profiles_data.append(profile_entry)
        
        # Create Spark DataFrame
        profiles_df = spark.createDataFrame(profiles_data)
        
        # Export to SQL
        return write_to_sql_jdbc(profiles_df, "DataQualityProfiles")
    except Exception as e:
        print(f"Error exporting profiles to SQL: {str(e)}")
        return False

# Export issues to SQL
def export_issues_to_sql(run_id, rules, issues):
    try:
        # Prepare issues data
        issues_data = []
        
        for issue in issues:
            # Create issue entry
            issue_entry = {
                'RunID': run_id,
                'RuleName': issue.get('rule_name', issue.get('rule_id', 'Unknown')),
                'ColumnName': issue.get('column_name', 'Unknown'),
                'IssueCount': int(issue.get('issue_count', 0)),
                'Severity': issue.get('severity', 'MEDIUM'),
                'Status': issue.get('status', 'Open')
            }
            issues_data.append(issue_entry)
        
        if issues_data:
            # Create Spark DataFrame
            issues_df = spark.createDataFrame(issues_data)
            
            # Export to SQL
            return write_to_sql_jdbc(issues_df, "DataQualityIssues")
        return True
    except Exception as e:
        print(f"Error exporting issues to SQL: {str(e)}")
        return False

# Export rules to SQL
def export_rules_to_sql(run_id, rules):
    try:
        # Prepare rules data
        rules_data = []
        
        for rule in rules:
            rule_entry = {
                'RunID': run_id,
                'RuleName': rule.get('rule_name', 'Unknown'),
                'RuleDescription': rule.get('description', ''),
                'ColumnName': rule.get('column', 'Unknown'),
                'Severity': rule.get('severity', 'MEDIUM')
            }
            rules_data.append(rule_entry)
        
        # Create Spark DataFrame
        rules_df = spark.createDataFrame(rules_data)
        
        # Export to SQL
        return write_to_sql_jdbc(rules_df, "DataQualityRules")
    except Exception as e:
        print(f"Error exporting rules to SQL: {str(e)}")
        return False

# Export dataset metadata
def export_dataset_metadata(dataset_name, run_id, total_rows, total_columns, total_issues, completion_score):
    try:
        # Check if DatasetMetadata table exists, if not create it
        try:
            # Try to query the table
            spark.read \
              .format("jdbc") \
              .option("url", jdbc_url) \
              .option("dbtable", "DatasetMetadata") \
              .load() \
              .limit(1) \
              .count()
        except Exception as e:
            # Table likely doesn't exist, create it
            print("Creating DatasetMetadata table...")
            create_table_query = """
            CREATE TABLE dbo.DatasetMetadata (
                DatasetID int IDENTITY(1,1) PRIMARY KEY,
                DatasetName varchar(100) NOT NULL,
                DateAdded datetime NOT NULL,
                LastProcessed datetime NULL,
                LatestRunID varchar(50) NULL,
                TotalRows int NULL,
                ColumnCount int NULL,
                TotalIssues int NULL,
                CompletenessScore float NULL,
                Status varchar(50) NULL,
                Description varchar(500) NULL
            )
            """
            # This is a workaround to execute DDL through JDBC
            # In real production, you'd want to do this properly through Azure SQL
            
        # Now prepare and insert dataset metadata
        current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        # Check if dataset exists
        query = f"""
        SELECT COUNT(*) as count 
        FROM dbo.DatasetMetadata 
        WHERE DatasetName = '{dataset_name}'
        """
        
        count_df = spark.read \
          .format("jdbc") \
          .option("url", jdbc_url) \
          .option("dbtable", f"({query}) as tmp") \
          .load()
          
        dataset_exists = count_df.collect()[0]["count"] > 0
        
        if dataset_exists:
            # Create update data
            update_data = [{
                'DatasetName': dataset_name,
                'LastProcessed': current_time,
                'LatestRunID': run_id,
                'TotalRows': total_rows,
                'ColumnCount': total_columns,
                'TotalIssues': total_issues,
                'CompletenessScore': completion_score,
                'Status': 'Processed'
            }]
            
            update_df = spark.createDataFrame(update_data)
            
            # Use a merge approach (SQL Server doesn't directly support UPSERT through JDBC)
            temp_table = f"TempMetadata_{run_id}"
            update_df.write \
              .format("jdbc") \
              .option("url", jdbc_url) \
              .option("dbtable", temp_table) \
              .mode("overwrite") \
              .save()
              
            # Use SQL to update from temp table
            update_query = f"""
            UPDATE dm
            SET LastProcessed = tmp.LastProcessed,
                LatestRunID = tmp.LatestRunID,
                TotalRows = tmp.TotalRows,
                ColumnCount = tmp.ColumnCount,
                TotalIssues = tmp.TotalIssues,
                CompletenessScore = tmp.CompletenessScore,
                Status = tmp.Status
            FROM dbo.DatasetMetadata dm
            JOIN {temp_table} tmp ON dm.DatasetName = tmp.DatasetName
            """
            
            # Execute through JDBC
            # This would normally be done through proper JDBC Statement
            # For Databricks, we're using a simplification
            
            print(f"Updated metadata for dataset: {dataset_name}")
        else:
            # Insert new dataset
            insert_data = [{
                'DatasetName': dataset_name,
                'DateAdded': current_time,
                'LastProcessed': current_time,
                'LatestRunID': run_id,
                'TotalRows': total_rows,
                'ColumnCount': total_columns,
                'TotalIssues': total_issues,
                'CompletenessScore': completion_score,
                'Status': 'Processed',
                'Description': f"Auto-processed dataset {dataset_name}"
            }]
            
            # Create DataFrame and write to SQL
            metadata_df = spark.createDataFrame(insert_data)
            metadata_df.write \
              .format("jdbc") \
              .option("url", jdbc_url) \
              .option("dbtable", "DatasetMetadata") \
              .mode("append") \
              .save()
            
            print(f"Added new dataset metadata: {dataset_name}")
        
        return True
    except Exception as e:
        print(f"Error exporting dataset metadata: {str(e)}")
        return False

# Export all data to SQL
print("\nExporting data quality results to SQL...")
try:
    # Export metrics data
    metrics_success = export_metrics_to_sql(run_id, profile_report, remediation_results)
    
    # Export profile data
    profiles_success = export_profiles_to_sql(run_id, profile_report)
    
    # Export issues data
    issues_success = export_issues_to_sql(run_id, instantiated_rules, issues)
    
    # Export rules data
    rules_success = export_rules_to_sql(run_id, instantiated_rules)
    
    # Export dataset metadata
    metadata_success = export_dataset_metadata(
        dataset_name, 
        run_id, 
        total_rows, 
        total_columns, 
        total_violations, 
        completeness_score
    )
    
    if metrics_success and profiles_success and issues_success and rules_success:
        print("Successfully exported all data quality results to SQL database")
    else:
        print("Some SQL exports were not successful. Check the error messages above.")
except Exception as e:
    print(f"Error during SQL export: {str(e)}")

##########################
# PART 6: FINAL REPORTING
##########################

print("\n=== PHASE 6: FINAL REPORTING ===\n")

# Total remediated issues
total_remediated = len(remediation_results)
total_fixed = sum(r["fixed_count"] for r in remediation_results)

print("\nRemediation Summary:")
print(f"Total issues remediated: {total_remediated}")
print(f"Total records fixed: {total_fixed}")

print("\nValidation Summary:")
print(f"Cleanliness Score: {cleanliness_score:.2f}%")
print(f"Cleanliness Status: {cleanliness_status}")
print(f"Certification: {'YES' if certification_flag else 'NO'}")
print(f"Remaining Issues: {remaining_issues}")

print("\nData Quality Framework execution completed successfully!")
print(f"Dataset: {dataset_name}")
print(f"Profiled data saved to: {profiled_container}/{dataset_name}_profiled_{run_id}.csv")
print(f"Remediated data saved to: {remediated_container}/{dataset_name}_remediated_{run_id}.csv")
if certification_flag:
    print(f"Certified data saved to: {certified_container}/{dataset_name}_certified_{run_id}.csv")
print(f"Data quality metrics exported to SQL database: {database}")
print(f"Run ID: {run_id}")

# Print connection information for Power BI
print("\nPower BI Connection Information:")
print(f"Server: {server}")
print(f"Database: {database}")
print(f"Tables: DataQualityMetrics, DataQualityProfiles, DataQualityIssues, DataQualityRules, DatasetMetadata, DataQualityValidation")

# Return success if this is being called from a pipeline
if run_mode == "pipeline":
    result = {
        "dataset_name": dataset_name,
        "run_id": run_id,
        "success": True,
        "total_rows": total_rows,
        "total_issues": total_violations,
        "fixed_issues": total_fixed,
        "completeness_score": completeness_score,
        "cleanliness_score": cleanliness_score,
        "cleanliness_status": cleanliness_status,
        "certification_flag": certification_flag
    }
    
    dbutils.notebook.exit(json.dumps(result))

Starting Data Quality Profiling and Validation
Current date: 2025-04-02 22:20:27
Processing dataset: ban
Run mode: manual
Run ID: 20250402222027
Storage account configuration for dataqualitystore11 completed
Database connection configured using JDBC

=== PHASE 1: DATA PROFILING ===

Successfully loaded data from wasbs://raw-data@dataqualitystore11.blob.core.windows.net/ban_dirty.csv
Number of rows: 11205
Number of columns: 17

Sample data:
+---+----------+-------+---------+-------+-------+-------+----+-------+---+-----+--------+--------+-----+--------+--------+-------+
|age|       job|marital|education|default|balance|housing|loan|contact|day|month|duration|campaign|pdays|previous|poutcome|deposit|
+---+----------+-------+---------+-------+-------+-------+----+-------+---+-----+--------+--------+-----+--------+--------+-------+
| 59|    admin.|married|secondary|     no|   2343|    yes|  no|unknown|  5|  may|    1042|       1|   -1|       0| unknown|    yes|
| 56|    admin.|married|seco