In [0]:
bronze = spark.sql(''' describe external location `bronze` ''').select('url').collect()[0][0]
silver = spark.sql(''' describe external location `silver` ''').select('url').collect()[0][0]

In [0]:
from pyspark.sql.functions import col, when, trim, lower, length, lpad, sum as pyspark_sum
from pyspark.sql.types import StringType, IntegerType
from delta.tables import DeltaTable

In [0]:
df = spark.read.format("PARQUET")\
                    .option("inferSchema", "true")\
                    .load(f"{bronze}/dh_causal_lookup")
df.display()

### Remove Duplicates

In [0]:
def remove_duplicates(df):
    print(f"df before removing duplicates: {df.count()}")
    df_non_duplicates = df.dropDuplicates(["upc", "store", "week"])
    return df_non_duplicates

### Handle Missing Values

In [0]:
def handle_missing_values(df):
    df_missing = df.fillna({"feature_desc": "Not Specified", "display_desc": "Not Specified"})
    return df_missing

### Standardize Formats

In [0]:
def standardize_formats(df):
    df_standardized = df.withColumn("feature_desc", trim(lower(col("feature_desc")))) \
             .withColumn("display_desc", trim(lower(col("display_desc"))))
    return df_standardized

### Address Outliers

In [0]:
def filter_outliers(df, week_range=(1, 104), geography_range=(1, 2)):
    print(f"df before outliers check: {df.count()}")
    df_outliers_check = df.filter(col("week").between(week_range[0], week_range[1]) & 
                     col("geography").between(geography_range[0], geography_range[1]))
    return df_outliers_check

### Validate Data Types

In [0]:
def enforce_data_types(df):
    df_types = df.withColumn("upc", lpad(col("upc").cast(StringType()), 10, "0")) \
             .withColumn("store", col("store").cast(IntegerType())) \
             .withColumn("week", col("week").cast(IntegerType())) \
             .withColumn("geography", col("geography").cast(IntegerType())) \
             .withColumn("feature_desc", col("feature_desc").cast(StringType())) \
             .withColumn("display_desc", col("display_desc").cast(StringType()))
    return df_types

In [0]:
def main(df, delta_path=f"{silver}/dh_causal_lookup"):
    print("Starting dh_causal_lookup cleaning process...")
    print("Initial row count:", df.count())
    display(df.limit(5))
    
    # Step 1
    df_step1 = handle_missing_values(df)
    print("\nStep 1 - Handle Missing Values")
    null_counts = df_step1.select([col(c).isNull().cast("int").alias(c) for c in df_step1.columns]).agg(
        *[pyspark_sum(col(c)).alias(f"{c}_null_count") for c in df_step1.columns]
    ).collect()[0].asDict()
    print(f"Null counts after handling: {null_counts}")
    display(df_step1.limit(5))
    
    # Step 2
    df_step2 = remove_duplicates(df_step1)
    print("\nStep 2 - Remove Duplicates")
    initial_count = df_step1.count()
    final_count = df_step2.count()
    print(f"Initial row count: {initial_count}, Final row count: {final_count}, Duplicates removed: {initial_count - final_count}")
    display(df_step2.limit(5))
    
    # Step 3
    df_step3 = standardize_formats(df_step2)
    print("\nStep 3 - Standardize Formats")
    print("Distinct feature_desc values:")
    display(df_step3.select("feature_desc").distinct().limit(5))
    print("Distinct display_desc values:")
    display(df_step3.select("display_desc").distinct().limit(5))
    display(df_step3.limit(5))
    
    # Step 4
    df_step4 = filter_outliers(df_step3)
    print("\nStep 4 - Address Outliers")
    initial_count = df_step3.count()
    final_count = df_step4.count()
    week_range = (1, 104)
    geography_range = (1, 2)
    print(f"Rows before filtering: {initial_count}, Rows after filtering: {final_count}, Outliers removed: {initial_count - final_count}")
    print(f"Week range checked: {week_range}, Geography range checked: {geography_range}")
    display(df_step4.limit(5))
    
    # Step 5
    df_cleaned = enforce_data_types(df_step4)
    print("\nStep 5 - Validate Data Types")
    invalid_upc_count = df_cleaned.filter(length(col("upc")) != 10).count()
    print(f"UPCs with incorrect length: {invalid_upc_count}")
    print("Schema after type enforcement:")
    df_cleaned.printSchema()
    display(df_cleaned.limit(5))
    
    print("\nCleaning process completed!")
    print(f"Final row count: {df_cleaned.count()}")
    display(df_cleaned.limit(5))
    # Delta Lake Merge
    if DeltaTable.isDeltaTable(spark, delta_path):
        print("\nDelta table exists, performing MERGE...")
        delta_table = DeltaTable.forPath(spark, delta_path)
        delta_table.alias("target")\
            .merge(
                df_cleaned.alias("source"),
                "target.upc = source.upc AND target.store = source.store AND target.week = source.week"
            )\
            .whenMatchedUpdateAll()\
            .whenNotMatchedInsertAll()\
            .execute()
    else:
        print("\nDelta table does not exist, creating new table...")
        df_cleaned.write.format("delta")\
                    .mode("overwrite")\
                    .option("path", delta_path)\
                    .saveAsTable("carbo_catalog.silver.dh_causal_lookup")
    
    # Load and log final state
    final_df = spark.read.format("delta").load(delta_path)
    print("\nCleaning and merge completed!")
    print(f"Final row count in Delta table: {final_df.count()}")
    display(final_df.limit(5))
    
    return final_df

df = main(df)
display(df)

In [0]:
%sql
SELECT count(*) FROM `carbo_catalog`.`silver`.`dh_causal_lookup`