In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable

## Get last run time

In [0]:
last_ingest_df = spark.sql("SELECT last_timestamp FROM control.ctl.control_dates WHERE stage_name = 'silver_transformation'")
last_ingest_time = last_ingest_df.collect()[0]['last_timestamp']
spark.conf.set("last_ingest_time", str(last_ingest_time))

## Reading Data

In [0]:
%sql
USE CATALOG bronze;

In [0]:
df_s1 = spark.read.table("sales.transactions_s001").filter(col("ingest_timestamp") > spark.conf.get("last_ingest_time"))
df_s2 = spark.read.table("sales.transactions_s002").filter(col("ingest_timestamp") > spark.conf.get("last_ingest_time"))

Cleansing to be done:
- combine into 1 dataframe
- dedupe on transaction_id, keeping better data
- check for valid prices
- remove records with invalid customer or product

In [0]:
df = df_s1.unionByName(df_s2).withColumn("ingest_timestamp", date_format(col("ingest_timestamp"), "yyyy-MM-dd HH:mm:ss"))

In [0]:
df = df.withColumn(
    "priority",
    when(col("customer_id").isNotNull(), 1)
    .when(col("product_id") != "INVALID", 2)
    .when(col("quantity") > 0, 3)
    .otherwise(4)
)

In [0]:
window_spec = Window.partitionBy("transaction_id").orderBy("priority")

df = df\
    .withColumn("customer_id", first("customer_id", ignorenulls=True).over(window_spec))\
    .withColumn("product_id", first(when(col("product_id") != "INVALID", col("product_id")).otherwise(None)).over(window_spec))\
    .withColumn("quantity", first(when(col("quantity") > 0, col("quantity")).otherwise(None)).over(window_spec))\
    .drop("priority")

In [0]:
df = df.withColumn(
        "rn",
        row_number().over(Window.partitionBy("transaction_id").orderBy(desc("ingest_timestamp")))
    ).filter(col("rn") == 1).drop("rn")

In [0]:
if spark.catalog.tableExists("silver.sales.store_transactions"):
    max_surr_key = spark.sql("SELECT MAX(transaction_key) AS max_tran_key FROM silver.sales.store_transactions").collect()[0]['max_tran_key']
    df_existing_records = spark.sql("SELECT transaction_id, transaction_key, date_created, last_updated FROM silver.sales.store_transactions")
    trans_table_exists = True
else:
    df_existing_records = df.select("transaction_id")\
        .withColumn("transaction_key", lit(None)).withColumn("transaction_key", col("transaction_key").cast(IntegerType()))\
        .withColumn("date_created", lit('1900-01-01 00:00:00')).withColumn("date_created", col("date_created").cast(TimestampType()))\
        .withColumn("last_updated", lit('1900-01-01 00:00:00')).withColumn("last_updated", col("last_updated").cast(TimestampType()))\
        .filter(col("transaction_key").isNotNull())
    max_surr_key = 0
    trans_table_exists = False

df_existing_records = df_existing_records\
    .withColumnRenamed("transaction_key", "existing_transaction_key")\
    .withColumnRenamed("transaction_id", "existing_transaction_id")\
    .withColumnRenamed("date_created", "existing_date_created")\
    .withColumnRenamed("last_updated", "existing_last_updated")

df_joined = df.join(df_existing_records, df.transaction_id == df_existing_records.existing_transaction_id, "left")
df_existing_records = df_joined.filter(col("existing_transaction_key").isNotNull())
df_new_records = df_joined.filter(col("existing_transaction_key").isNull())

df_existing_records = df_existing_records\
    .withColumnRenamed("existing_transaction_key", "transaction_key")\
    .withColumnRenamed("existing_date_created", "date_created")\
    .withColumnRenamed("ingest_timestamp", "last_updated")\
    .drop("existing_transaction_id", "existing_last_updated")

new_surr_key_window = Window.orderBy("transaction_id")

df_new_records = df_new_records\
    .withColumn("transaction_key", row_number().over(new_surr_key_window) + lit(max_surr_key))\
    .withColumn("date_created", col("ingest_timestamp"))\
    .withColumn("last_updated", col("ingest_timestamp"))\
    .drop("ingest_timestamp", "existing_transaction_key", "existing_transaction_id", "existing_date_created", "existing_last_updated")

df_final = df_existing_records.unionByName(df_new_records)

In [0]:
if trans_table_exists:
    dlt_trans = DeltaTable.forName("silver.sales.store_transactions")
    dlt_trans.alias("t").merge(
        df_final.alias("s"),
        "t.transaction_id = s.transaction_id",
    ).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
else:
    df_final.write.mode("overwrite").saveAsTable("silver.sales.store_transactions")