In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable

In [0]:
init_load_flag = int(dbutils.widgets.get("init_load_flag"))

In [0]:
# read data from silver
df = spark.sql("""
SELECT * FROM databricks_dev.silver.silver_customers
""")

In [0]:
# remove any duplicates
df = df.dropDuplicates(subset=['customer_id'])

In [0]:
# separating old from new records
if init_load_flag == 0:
    df_old = spark.sql("""
    SELECT dim_customer_key, customer_id, create_date, update_date 
    FROM databricks_dev.gold.dim_customers
    """)
else:
    df_old = spark.sql("""
    SELECT 0 dim_customer_key, 0 customer_id, 0 create_date, 0 update_date 
    FROM databricks_dev.silver.silver_customers
    WHERE 1=0
    """)

In [0]:
df_old = df_old.withColumnRenamed("dim_customer_key", "old_dim_customer_key") \
                .withColumnRenamed("customer_id", "old_customer_id") \
                .withColumnRenamed("create_date", "old_create_date") \
                .withColumnRenamed("update_date", "old_update_date")

In [0]:
# join new data to the old records
df_join = df.join(df_old, df.customer_id == df_old.old_customer_id, how="left")

In [0]:
df_join.limit(10).display()

In [0]:
df_new = df_join.filter(df_join["old_dim_customer_key"].isNull())


In [0]:
df_old = df_join.filter(df_join["old_dim_customer_key"].isNotNull())

In [0]:
# drop unrequired columns
df_old = df_old.drop("old_customer_id","old_update_date")

# rename the old_create_date column
df_old = df_old.withColumnRenamed("old_dim_customer_key", "dim_customer_key")
df_old = df_old.withColumnRenamed("old_create_date", "create_date")
df_old = df_old.withColumn("create_date", to_timestamp(col("create_date")))

# recreate update_date with current_timestamp
df_old = df_old.withColumn("update_date", current_timestamp())

In [0]:
df_old.limit(10).display()

In [0]:
# drop unrequired columns
df_new = df_new.drop("old_dim_customer_key", "old_customer_id", "old_update_date", "old_create_date")

# recreate update_date, create_date with current_timestamp
df_new = df_new.withColumn("update_date", current_timestamp())
df_new = df_new.withColumn("create_date", current_timestamp())

In [0]:
df_new.display()

In [0]:
# Generate surrogate keys for new records
window_spec = Window.orderBy("customer_id").partitionBy(lit(1))
df_new = df_new.withColumn("row_num", row_number().over(window_spec))

In [0]:
# Get max existing surrogate key
if init_load_flag == 1:
    max_surrogate_key = 0
else:
    df_maxsur = spark.sql("SELECT MAX(dim_customer_key) AS max_surrogate_key FROM databricks_dev.gold.dim_customers")
    max_surrogate_key = df_maxsur.collect()[0]['max_surrogate_key']

# Assign new surrogate keys
df_new = df_new.withColumn("dim_customer_key", col("row_num") + lit(max_surrogate_key)) \
               .drop("row_num")

In [0]:
df_final = df_new.unionByName(df_old)

In [0]:
df_final.display()

In [0]:
if spark.catalog.tableExists("databricks_dev.gold.dim_customers"):
    delta_tbl = DeltaTable.forName(spark, "databricks_dev.gold.dim_customers")
    delta_tbl.alias("trg").merge(
        df_final.alias("src"),
        "trg.customer_id = src.customer_id"
    ).whenMatchedUpdateAll() \
     .whenNotMatchedInsertAll() \
     .execute()
else:
    df_final.write.mode("overwrite") \
        .format("delta") \
        .saveAsTable("databricks_dev.gold.dim_customers")

In [0]:
%sql
SELECT * FROM databricks_dev.gold.dim_customers;