In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable

## Get last run time

In [0]:
last_ingest_df = spark.sql("SELECT last_timestamp FROM control.ctl.control_dates WHERE stage_name = 'silver_transformation'")
last_ingest_time = last_ingest_df.collect()[0]['last_timestamp']
spark.conf.set("last_ingest_time", str(last_ingest_time))

## Reading Data

In [0]:
%sql
USE CATALOG bronze;

In [0]:
df = spark.read.table("products.products").filter(col("ingest_timestamp") > spark.conf.get("last_ingest_time"))

## Clean data

In [0]:
df = df.withColumn("price", abs(col("price")))\
    .withColumn("category", when((col("category") == "") | col("category").isNull(), "UNKNOWN").otherwise(col("category")))\
    .withColumn("ingest_timestamp", date_format('ingest_timestamp', 'yyyy-MM-dd HH:mm:ss'))

### Split data for normalization

In [0]:
df_categories = df.select("category", "ingest_timestamp").dropDuplicates(["category"]).withColumn("active_flag", lit(True))
df_product_categories = df.select("product_id", "category", "ingest_timestamp")
df_products = df.select("product_id", "product_name", "price", "active_flag", "ingest_timestamp")

### Handle Products Table

Left join the newly arrived data to existing data to find new records vs updated records <br>
If there is no existing table, create an empty dataframe to simulate the existing data with 0 rows <br>
Split new records and updated records to handle separately <br>
Union them back together for later merging <br>
Dedupe as needed

In [0]:
if spark.catalog.tableExists("silver.products.products"):
    # Get existing products (only relevant columns)
    df_products_existing = spark.sql("SELECT product_key, product_id, date_created, last_updated FROM silver.products.products")
    # Find current max value for surrogate key
    max_prod_key = spark.sql("SELECT MAX(product_key) AS max_prod_key FROM silver.products.products").collect()[0]["max_prod_key"]
    # Set boolean to mark table exists for later merge
    prod_table_exists = True
else:
    # Create empty df with relevant columns to match above
    df_products_existing = df_products.select("product_id")\
        .withColumn("product_key", lit(None)).withColumn("product_key", col("product_key").cast(IntegerType()))\
        .withColumn("date_created", lit('1900-01-01 00:00:00')).withColumn("date_created", col("date_created").cast(TimestampType()))\
        .withColumn("last_updated", lit('1900-01-01 00:00:00')).withColumn("last_updated", col("last_updated").cast(TimestampType()))\
        .filter(col("product_key").isNotNull())
    max_prod_key = 0
    prod_table_exists = False

# Renaming columns before join to avoid any conflicts
df_products_existing = df_products_existing.withColumnRenamed("product_key", "existing_product_key")\
    .withColumnRenamed("product_id", "existing_product_id")\
    .withColumnRenamed("date_created", "existing_date_created")\
    .withColumnRenamed("last_updated", "existing_last_updated")

# Left join to split new products and existing products
df_products_joined = df_products.join(df_products_existing, on=df_products.product_id == df_products_existing.existing_product_id, how="left")
df_products_new = df_products_joined.filter(col("existing_product_key").isNull())
df_products_existing = df_products_joined.filter(col("existing_product_key").isNotNull())

# Existing products already have a product_key, date_created, and last_updated
# So we remove the extra product_id column
# Then remove the "existing" from the column names
# And set last_updated to ingest_timestamp as this will be when the record has most recently been updated
df_products_existing = df_products_existing.drop("existing_product_id", "existing_last_updated")\
    .withColumnRenamed("ingest_timestamp", "last_updated")\
    .withColumnRenamed("existing_product_key", "product_key")\
    .withColumnRenamed("existing_date_created", "date_created")

# For new products, we need to assign a product_key
# Which we do using row_number() over ordering by product_id
# Adding on the max product_key that already exists to ensure uniqueness
window_spec_prod = Window.orderBy("product_id")
# The new products get their product_key, date_created, and last_updated columns created
# Before dropping all the other columns
df_products_new = df_products_new\
    .withColumn("date_created", col("ingest_timestamp"))\
    .withColumn("last_updated", col("ingest_timestamp"))\
    .withColumn("product_key", row_number().over(window_spec_prod) + lit(max_prod_key))\
    .drop("existing_product_id", "existing_last_updated", "existing_product_key", "existing_date_created", "ingest_timestamp")

# Union the new products and existing products to have a single dataset to merge into the silver layer
df_products_final = df_products_new.unionByName(df_products_existing)

prod_dedupe_w_spec = Window.partitionBy("product_id").orderBy(desc("last_updated"))

df_products_final = df_products_final\
    .withColumn("rn", row_number().over(prod_dedupe_w_spec))\
    .filter(col("rn") == 1)\
    .drop("rn")

### Handle Categories Lookup

Same as above, but no deduping necessary as dupes were removed when we split up the dataframes for normalisation

In [0]:
if spark.catalog.tableExists("silver.products.categories_lookup"):
    # Create df with existing categories
    df_categories_existing = spark.sql("SELECT category_key, category, date_created, last_updated FROM silver.products.categories")
    max_cat_key = spark.sql("SELECT MAX(category_key) AS max_cat_key FROM silver.products.category_lookup").collect()[0]["max_cat_key"]
    cat_lookup_exists = True
else:
    # Create empty df with relevant columns
    df_categories_existing = df_categories.select("category")\
        .withColumn("category_key", lit(None)).withColumn("category_key", col("category_key").cast(IntegerType()))\
        .withColumn("date_created", lit('1900-01-01 00:00:00')).withColumn("date_created", col("date_created").cast(TimestampType()))\
        .withColumn("last_updated", lit('1900-01-01 00:00:00')).withColumn("last_updated", col("last_updated").cast(TimestampType()))\
        .filter(col("category_key").isNotNull())
    max_cat_key = 0
    cat_lookup_exists = False

df_categories_existing = df_categories_existing.withColumnRenamed("category_key", "existing_category_key")\
    .withColumnRenamed("category", "existing_category")\
    .withColumnRenamed("date_created", "existing_date_created")\
    .withColumnRenamed("last_updated", "existing_last_updated")

df_categories_joined = df_categories.join(df_categories_existing, on=df_categories.category == df_categories_existing.existing_category, how="left")
df_categories_new = df_categories_joined.filter(col("existing_category_key").isNull())
df_categories_existing = df_categories_joined.filter(col("existing_category_key").isNotNull())

df_categories_existing = df_categories_existing.drop("existing_last_updated", "existing_category")\
    .withColumnRenamed("ingest_timestamp", "last_updated")\
    .withColumnRenamed("existing_category_key", "category_key")\
    .withColumnRenamed("existing_date_created", "date_created")

window_spec_cat = Window.orderBy("category")

df_categories_new = df_categories_new\
    .withColumn("last_updated", col("ingest_timestamp"))\
    .withColumn("date_created", col("ingest_timestamp"))\
    .withColumn("category_key", row_number().over(window_spec_cat) + lit(max_cat_key))\
    .drop("existing_category", "existing_category_key", "existing_date_created", "existing_last_updated", "ingest_timestamp")

df_categories_final = df_categories_new.unionByName(df_categories_existing)

### Handle joining table

Dedupe on product (each product can only have 1 category), preserving the most recently ingested record <br>
Use the final datasets created above to get the correct surrogate keys <br>
If any surrogate keys already existed, they would have been found above <br>
Only need to keep surrogate keys for easy joining

In [0]:
prod_cat_dedupe_w_spec = Window.partitionBy("product_id").orderBy(desc("ingest_timestamp"))

df_product_categories = df_product_categories\
    .withColumn("rn", row_number().over(prod_cat_dedupe_w_spec))\
    .filter(col("rn") == 1)\
    .drop("rn")

df_product_categories = df_product_categories.join(df_categories_final, on=df_product_categories.category == df_categories_final.category, how="inner")\
    .join(df_products_final, on=df_product_categories.product_id == df_products_final.product_id, how="inner")\
    .select(df_products_final.product_key, df_categories_final.category_key, df_product_categories.ingest_timestamp)\
    .withColumnRenamed("ingest_timestamp", "last_updated")

## Write Data

In [0]:
if prod_table_exists:
    dlt_prod = DeltaTable.forName(spark, "silver.products.products")
    dlt_prod.alias("t").merge(df_products_final.alias("s"), "t.product_key = s.product_key")\
        .whenMatchedUpdateAll()\
        .whenNotMatchedInsertAll()\
        .execute()
else:
    df_products_final.write.mode("overwrite").saveAsTable("silver.products.products")

In [0]:
if cat_lookup_exists:
    # Hold delta table object for existing table
    dlt_cat = DeltaTable.forName(spark, "silver.products.categories_lookup")
    # Merge into existing table
    dlt_cat.alias("t").merge(df_categories_final.alias("s"), "t.category_key = s.category_key")\
        .whenMatchedUpdateAll()\
        .whenNotMatchedInsertAll()\
        .execute()
else:
    # Create new table if one didn't already exist
    # Schema is already set up to save in correct cloud location
    df_categories_final.write.mode("overwrite").saveAsTable("silver.products.categories_lookup")

In [0]:
if spark.catalog.tableExists("silver.products.product_categories"):
    dlt_prod_cat = DeltaTable.forName(spark, "silver.products.product_categories")
    dlt_prod_cat.alias("t").merge(df_product_categories.alias("s"), "t.product_key = s.product_key")\
        .whenMatchedUpdateAll()\
        .whenNotMatchedInsertAll()\
        .execute()
else:
    df_product_categories.write.mode("overwrite").saveAsTable("silver.products.product_categories")

### Update Categories lookup

Make sure each category has up to date active_flag

In [0]:
%sql
UPDATE silver.products.categories_lookup
SET active_flag = CASE
    WHEN category_key IN (
        SELECT pc.category_key
        FROM silver.products.product_categories AS pc
        JOIN silver.products.products AS p
        ON pc.product_key = p.product_key
        WHERE p.active_flag = true
    ) THEN true
    ELSE false
END