In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, current_date
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("Spark SCD2 Pipeline") \
    .config("spark.jars.packages", "org.postgresql:postgresql:42.6.0") \
    .getOrCreate()
spark

In [6]:
jdbc_url = "jdbc:postgresql://postgres_container:5432/dbt_db"
db_props = {
    "user": "dbtuser",
    "password": "test123",
    "driver": "org.postgresql.Driver"}

incoming = spark.createDataFrame([
    (1, "Alice", "Berlin"),     # Alice changed city (was London)
    (2, "Bob", "Rome"),         # Bob unchanged
    (3, "Charlie", "New York"), # Charlie unchanged
    (4, "Daisy", "Tokyo"),      # New record
], ["customer_id", "name", "city"])

# -------------------------
# 4Ô∏è‚É£ Read existing dim table from PostgreSQL
# -------------------------
dim_df = spark.read.jdbc(jdbc_url, "public.dim_customers", properties=db_props)

print("======= EXISTING DIM TABLE =======")
dim_df.show()

+-----------+-------+--------+--------------+-----------+----------+
|customer_id|   name|    city|effective_date|expiry_date|is_current|
+-----------+-------+--------+--------------+-----------+----------+
|          1|  Alice|  London|    2024-01-01|       NULL|      true|
|          2|    Bob|   Paris|    2024-01-01|       NULL|      true|
|          3|Charlie|New York|    2024-01-01|       NULL|      true|
|          1|  Alice|  London|    2023-01-01| 2024-01-01|     false|
|          2|    Bob|   Paris|    2023-01-01| 2024-01-01|     false|
+-----------+-------+--------+--------------+-----------+----------+



In [7]:
joined = incoming.alias("stg").join(
    dim_df.filter(col("is_current") == True).alias("dim"),
    on="customer_id",
    how="left"
)

# -------------------------
# 6Ô∏è‚É£ Identify new and changed records
# -------------------------
changed = joined.filter(
    (col("dim.customer_id").isNotNull()) &
    (
        (col("stg.name") != col("dim.name")) |
        (col("stg.city") != col("dim.city"))
    )
)

new_records = joined.filter(col("dim.customer_id").isNull())

# -------------------------
# 7Ô∏è‚É£ Prepare expired records
# -------------------------
expired = changed.select("dim.*") \
    .withColumn("expiry_date", current_date()) \
    .withColumn("is_current", lit(False))

In [8]:
new_versions = changed.select("stg.*").union(new_records.select("stg.*")) \
    .withColumn("effective_date", current_date()) \
    .withColumn("expiry_date", lit(None).cast("date")) \
    .withColumn("is_current", lit(True))

# -------------------------
# 9Ô∏è‚É£ Keep unchanged records
# -------------------------
unchanged = dim_df.join(changed.select("dim.customer_id"), "customer_id", "left_anti")

# -------------------------
# üîü Combine all together
# -------------------------
final_df = unchanged.union(expired).union(new_versions)

print("======= FINAL DIM TABLE =======")
final_df.show()

# -------------------------
# 1Ô∏è‚É£1Ô∏è‚É£ Write back to PostgreSQL
# -------------------------
final_df.write.jdbc(jdbc_url, "public.dim_customers", mode="overwrite", properties=db_props)

print("‚úÖ SCD2 pipeline successfully updated dim_customers table.")

spark.stop()

+-----------+-------+--------+--------------+-----------+----------+
|customer_id|   name|    city|effective_date|expiry_date|is_current|
+-----------+-------+--------+--------------+-----------+----------+
|          3|Charlie|New York|    2024-01-01|       NULL|      true|
|          1|  Alice|  London|    2024-01-01| 2025-11-13|     false|
|          2|    Bob|   Paris|    2024-01-01| 2025-11-13|     false|
|          1|  Alice|  Berlin|    2025-11-13|       NULL|      true|
|          2|    Bob|    Rome|    2025-11-13|       NULL|      true|
|          4|  Daisy|   Tokyo|    2025-11-13|       NULL|      true|
+-----------+-------+--------+--------------+-----------+----------+

‚úÖ SCD2 pipeline successfully updated dim_customers table.
