In [0]:
%sql
USE CATALOG samples;
USE SCHEMA tpch;

USE CATALOG catalog_cp;
CREATE SCHEMA IF NOT EXISTS bronze;

CREATE OR REPLACE TABLE catalog_cp.bronze.customer AS SELECT * FROM samples.tpch.customer;
CREATE OR REPLACE TABLE catalog_cp.bronze.orders AS SELECT * FROM samples.tpch.orders;
CREATE OR REPLACE TABLE catalog_cp.bronze.lineitem AS SELECT * FROM samples.tpch.lineitem;
CREATE OR REPLACE TABLE catalog_cp.bronze.nation AS SELECT * FROM samples.tpch.nation;
CREATE OR REPLACE TABLE catalog_cp.bronze.part AS SELECT * FROM samples.tpch.part;
CREATE OR REPLACE TABLE catalog_cp.bronze.partsupp AS SELECT * FROM samples.tpch.partsupp;
CREATE OR REPLACE TABLE catalog_cp.bronze.supplier AS SELECT * FROM samples.tpch.supplier;
CREATE OR REPLACE TABLE catalog_cp.bronze.region AS SELECT * FROM samples.tpch.region;
SELECT COUNT(*) FROM catalog_cp.bronze.customer;



In [0]:
from pyspark.sql.functions import md5, concat_ws, col, collect_list, lit, when
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

tables = ['customer', 'orders', 'lineitem', 'nation', 'part', 'partsupp', 'supplier', 'region']

source_catalog = "samples"
source_schema = "tpch"
target_catalog = "catalog_cp"
target_schema = "bronze"

def get_table_stats(catalog, schema, table):
    df = spark.table(f"{catalog}.{schema}.{table}")
    row_hashes = df.select(md5(concat_ws("||", *df.columns)).alias("row_hash"))
    checksum_df = row_hashes.agg(
        md5(concat_ws("||", collect_list("row_hash"))).alias("checksum")
    ).withColumn("row_count", lit(df.count())) \
     .withColumn("table_name", lit(table))
    return checksum_df.select("table_name", "row_count", "checksum")

validation_dfs = {}

all_match = True

for table in tables:
    source_stats = get_table_stats(source_catalog, source_schema, table) \
        .withColumnRenamed("row_count", "source_count") \
        .withColumnRenamed("checksum", "source_checksum")
    target_stats = get_table_stats(target_catalog, target_schema, table) \
        .withColumnRenamed("row_count", "target_count") \
        .withColumnRenamed("checksum", "target_checksum")

    validation_df = source_stats.join(target_stats, on="table_name", how="outer")

    validation_df = validation_df.withColumn(
        "validation_result",
        when((col("source_count") == col("target_count")) & (col("source_checksum") == col("target_checksum")), "MATCH")
        .when((col("source_count").isNull()) | (col("target_count").isNull()), "MISSING_TABLE")
        .otherwise("MISMATCH")
    )
    
    validation_dfs[table] = validation_df
    
    result = validation_df.select("validation_result").collect()[0][0]
    print(f"Table '{table}': {result}")
    
    if result != "MATCH":
        all_match = False

if all_match:
    print("\nAll tables match.")
else:
    print("\nOne or more tables have mismatches or are missing.")