In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit

# Initialize Spark session with increased memory settings
spark = SparkSession.builder \
    .appName("DataQualityChecks") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Load the fact and dimension tables
dq_fact_immigration_table = spark.table("fact_immigration_table")
dq_dimension_arrival_date_table = spark.table("dimension_arrival_date_table")
dq_dimension_airports_table = spark.table("dimension_airports_table")
dq_dimension_demographics_table = spark.table("dimension_demographics_table")
#dq_dimension_countries_temperature_table = spark.table("dimension_countries_temperature_table")

# Data Quality Checks

# 1. Check for Missing Values
missing_values_check = lambda df, table_name: df.select([col(c).isNull().alias(c) for c in df.columns]) \
    .withColumn('table', lit(table_name)).limit(100)  # Adjust the limit as needed

# 2. Check Data Types
data_types_check = lambda df, table_name: df.select([col(c).cast("string").alias(c) for c in df.columns]) \
    .withColumn('table', lit(table_name)).limit(100)  # Adjust the limit as needed

# 3. Duplicate Records
duplicate_records_check = lambda df, table_name: df.groupBy(df.columns).count().filter(col("count") > 1) \
    .withColumn('table', lit(table_name)).limit(100)  # Adjust the limit as needed

# 4. Referential Integrity
referential_integrity_check = dq_fact_immigration_table.join(dq_dimension_arrival_date_table,
                                                             dq_fact_immigration_table["arrive_date"] == dq_dimension_arrival_date_table["arrival_date"], "left") \
    .filter(dq_dimension_arrival_date_table["arrival_date"].isNull()) \
    .withColumn('table', lit("fact_immigration_table")).limit(100)  # Adjust the limit as needed

# 5. Data Distribution
data_distribution_check = lambda df, column: df.groupBy(column).count() \
    .withColumn('table', lit(df.columns[0])).limit(100)  # Adjust the limit as needed

# Example: Check distribution of ages in demographics table
age_distribution_check = data_distribution_check(dq_dimension_demographics_table, "median_age")

# Execute Data Quality Checks
data_quality_checks_results = []

# Perform checks on each table
for table_df, table_name in [
    (dq_fact_immigration_table, "fact_immigration_table"),
    (dq_dimension_arrival_date_table, "dimension_arrival_date_table"),
    (dq_dimension_airports_table, "dimension_airports_table"),
    (dq_dimension_demographics_table, "dimension_demographics_table")
    #,(dq_dimension_countries_temperature_table, "dimension_countries_temperature_table")
]:
    # 1. Check for Missing Values
    missing_values_result = missing_values_check(table_df, table_name)
    data_quality_checks_results.append(missing_values_result)

    # 2. Check Data Types
    data_types_result = data_types_check(table_df, table_name)
    data_quality_checks_results.append(data_types_result)

    # 3. Duplicate Records
    duplicate_records_result = duplicate_records_check(table_df, table_name)
    data_quality_checks_results.append(duplicate_records_result)

# Display Results
for result_df in data_quality_checks_results:
    # Check if there are any issues before displaying
    if result_df.count() > 0:
        table_name = result_df.select('table').distinct().collect()[0]['table']
        print(f"Data Quality Checks passed for {table_name}:")
        result_df.show(5)

# Stop Spark session
spark.stop()
