In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
from concurrent.futures import ThreadPoolExecutor

# Initialize Spark session
spark = SparkSession.builder.appName("AutomatedValidation").getOrCreate()

# Catalog and database details
source_catalog = "core_tst_sys9"
source_db = "rr_source"
target_catalog = "core_tst_std001"
target_db = "ods"

# Table names to process
table_names = ["ptable1", "btable2", "ctable3"]

# Date filter
filter_start = '2025-05-12T00:00:01.0025+00:00'
filter_end = '2025-05-13T00:00:01.0025+00:00'


In [None]:
def convert_boolean_columns(df):
    converted_cols = []
    for column in df.columns:
        unique_vals = df.select(column).distinct().limit(100).rdd.flatMap(lambda x: x).collect()
        string_vals = set(map(lambda x: str(x).lower() if x is not None else "null", unique_vals))
        if string_vals.issubset({"t", "f", "null"}):
            df = df.withColumn(column, when(col(column).isNull(), 0)
                                       .when(col(column) == "t", 1)
                                       .when(col(column) == "f", 0)
                                       .otherwise(0))
            converted_cols.append(column)
    return df, converted_cols

def to_lowercase_columns(df):
    return df.select([col(c).alias(c.lower()) for c in df.columns])

def cast_columns_to_common_type(source_df, target_df, columns):
    for col_name in columns:
        source_type = dict(source_df.dtypes).get(col_name)
        target_type = dict(target_df.dtypes).get(col_name)
        common_type = target_type if source_type != target_type else source_type
        if common_type:
            source_df = source_df.withColumn(col_name, col(col_name).cast(common_type))
            target_df = target_df.withColumn(col_name, col(col_name).cast(common_type))
    return source_df.select(columns), target_df.select(columns)

def normalize_target_decimal_to_double(target_df):
    for column, dtype in target_df.dtypes:
        if dtype.startswith("decimal"):
            target_df = target_df.withColumn(column, col(column).cast("double"))
    return target_df

def get_mismatched_columns(df1, df2, columns):
    mismatched = []
    for col_name in columns:
        diff_count = df1.select(col_name).subtract(df2.select(col_name)).count()
        reverse_diff_count = df2.select(col_name).subtract(df1.select(col_name)).count()
        if diff_count > 0 or reverse_diff_count > 0:
            mismatched.append(col_name)
    return mismatched

def compare_dataframes(source_comp_df, target_comp_df, common_columns):
    source_only = source_comp_df.subtract(target_comp_df).count()
    target_only = target_comp_df.subtract(source_comp_df).count()
    mismatches = get_mismatched_columns(source_comp_df, target_comp_df, common_columns)
    mismatches_target = get_mismatched_columns(target_comp_df, source_comp_df, common_columns)
    return source_only, target_only, mismatches, mismatches_target
