In [None]:
from pyspark.sql.functions import col, explode_outer, lit, when, row_number, trim
from pyspark.sql.types import *
import os
from datetime import datetime
from delta.tables import *

##### dim_client as a SCD Type 2 using PySpark

In [None]:
# Select & rename columns, specify non string column types, create dataframe 
dim_client = df \
    .select('ClientCode', 'ClientName') \
    .dropDuplicates()

# Force schema, create delta table if not exists
DeltaTable.createIfNotExists(spark) \
    .tableName('dim_client') \
    .addColumn('ClientCode', StringType()) \
    .addColumn('ClientName', StringType()) \
    .addColumn('StartDate', TimestampType()) \
    .addColumn('EndDate', TimestampType()) \
    .addColumn('CurrentFlag', BooleanType()) \
    .execute()

# Metadata columns for SCD Type 2
source_table = dim_client \
    .withColumn("StartDate", lit(current_datetime).cast("timestamp")) \
    .withColumn("EndDate", lit(None).cast("timestamp")) \
    .withColumn("CurrentFlag", lit("Y").cast("boolean"))

target_table = DeltaTable.forPath(spark, 'Tables/dim_client')

In [None]:
# Insert / Update with SCD Type 2 logic
target_table.alias("target") \
    .merge(
        source_table.alias("source"),
        "target.ClientCode = source.ClientCode AND target.CurrentFlag = 'Y'"
    ) \
    .whenMatchedUpdate(condition = "target.ClientName != source.ClientName",

                       set = {
                            "EndDate": lit(current_datetime),
                            "CurrentFlag": lit("N")
                        }
    ) \
    .whenNotMatchedInsert(values = {
        "ClientCode": "source.ClientCode",
        "ClientName": "source.ClientName",
        "StartDate": "source.StartDate",
        "EndDate": "source.EndDate",
        "CurrentFlag": lit("Y")
    }
    ) \
    .execute()

# Insert the new versions of the records that have changed
changed_records = source_table.alias("source").join(
    target_table.toDF().alias("target"),
    (col("source.ClientCode") == col("target.ClientCode")) & (col("target.CurrentFlag") == "N"),
    "inner"
).filter(
    (col("source.ClientName") != col("target.ClientName"))
).select(
    col("source.ClientCode"),
    col("source.ClientName"),
    lit(current_datetime).alias("StartDate"),
    lit(None).cast("timestamp").alias("EndDate"),
    lit("Y").cast("boolean").alias("CurrentFlag")
)

changed_records.write.format("delta").mode("append").save("Tables/dim_client")

##### Special Duplicates: removing rows with less information

In [None]:
from pyspark.sql.window import Window # Provides tools for creating window specifications that define how to partition and order data for window functions
from functools import reduce

# Function to count non-NULL values per row
def count_non_nulls(*cols):
    return reduce(lambda a, b: a + b, [when(col(c).isNotNull(), 1).otherwise(0) for c in cols])

In [None]:
# Select & rename columns, specify non string column types, create dataframe 
dim_collection = df \
    .select('CollectionCompany', 'CollectionAddressLine1', 'CollectionAddressLine2', 'CollectionAddressLine3', 'CollectionAddressLine4', 'CollectionPostcode', 'CollectionAlternativePlace', 'CollectionCountryCode', 'CollectionCountryName') \
    .dropDuplicates()

# Drop duplicate rows based on contact and address, and keep ones with more information in the remaining columns
# Create a column to count non-null values in each row
columns_to_check = dim_collection.columns
df_with_non_null_count = dim_collection.withColumn("NonNullCount", count_non_nulls(*columns_to_check))

# Define a window specification to partition by the columns we consider for duplicates
window_spec = Window.partitionBy("CollectionCompany", "CollectionAddressLine1", "CollectionPostcode").orderBy(col("NonNullCount").desc())

# Add a row number within each partition to identify the row with the most non-null values
df_with_row_number = df_with_non_null_count.withColumn("row_number", row_number().over(window_spec))

# Filter to keep only the rows with row number 1 (the row with the most non-null values in each group)
dim_collection = df_with_row_number.filter(col("row_number") == 1).drop("NonNullCount", "row_number")

##### dim_collection as a SCD Type 2 using PySpark

In [None]:
# Force schema, create delta table if not exists
DeltaTable.createIfNotExists(spark) \
    .tableName('dim_collection') \
    .addColumn('CollectionCompany', StringType()) \
    .addColumn('CollectionAddressLine1', StringType()) \
    .addColumn('CollectionAddressLine2', StringType()) \
    .addColumn('CollectionAddressLine3', StringType()) \
    .addColumn('CollectionAddressLine4', StringType()) \
    .addColumn('CollectionPostcode', StringType()) \
    .addColumn('CollectionAlternativePlace', StringType()) \
    .addColumn('CollectionCountryCode', StringType()) \
    .addColumn('CollectionCountryName', StringType()) \
    .addColumn('StartDate', TimestampType()) \
    .addColumn('EndDate', TimestampType()) \
    .addColumn('CurrentFlag', BooleanType()) \
    .execute()

# Metadata columns
source_table = dim_collection \
    .withColumn("StartDate", lit(current_datetime).cast("timestamp")) \
    .withColumn("EndDate", lit(None).cast("timestamp")) \
    .withColumn("CurrentFlag", lit("Y").cast("boolean"))

target_table = DeltaTable.forPath(spark, 'Tables/dim_collection')

In [None]:
# Insert / Update with SCD Type 2 logic
target_table.alias("target") \
    .merge(
        source_table.alias("source"),
        "target.CollectionCompany = source.CollectionCompany AND target.CollectionAddressLine1 = source.CollectionAddressLine1 AND target.CurrentFlag = 'Y'"
    ) \
    .whenMatchedUpdate(condition = "target.CollectionAddressLine2 != source.CollectionAddressLine2 OR "
                                   "target.CollectionAddressLine3 != source.CollectionAddressLine3 OR "
                                   "target.CollectionAddressLine4 != source.CollectionAddressLine4 OR "
                                   "target.CollectionPostcode != source.CollectionPostcode OR "
                                   "target.CollectionAlternativePlace != source.CollectionAlternativePlace OR "
                                   "target.CollectionCountryCode != source.CollectionCountryCode OR "
                                   "target.CollectionCountryName != source.CollectionCountryName",

                       set = {
                            "EndDate": lit(current_datetime),
                            "CurrentFlag": lit("N")
                        }
    ) \
    .whenNotMatchedInsert(values = {
        "CollectionCompany": "source.CollectionCompany",
        "CollectionAddressLine1": "source.CollectionAddressLine1",
        "CollectionAddressLine2": "source.CollectionAddressLine2",
        "CollectionAddressLine3": "source.CollectionAddressLine3",
        "CollectionAddressLine4": "source.CollectionAddressLine4",
        "CollectionPostcode": "source.CollectionPostcode",
        "CollectionAlternativePlace": "source.CollectionAlternativePlace",
        "CollectionCountryCode": "source.CollectionCountryCode",
        "CollectionCountryName": "source.CollectionCountryName",
        "StartDate": "source.StartDate",
        "EndDate": "source.EndDate",
        "CurrentFlag": lit("Y")
    }
    ) \
    .execute()

# Insert the new versions of the records that have changed
changed_records = source_table.alias("source").join(
    target_table.toDF().alias("target"),
    (col("source.CollectionCompany") == col("target.CollectionCompany")) & (col("source.CollectionAddressLine1") == col("target.CollectionAddressLine1")) & (col("target.CurrentFlag") == "N"),
    "inner"
).filter(
    (col("source.CollectionAddressLine2") != col("target.CollectionAddressLine2")) |
    (col("source.CollectionAddressLine3") != col("target.CollectionAddressLine3")) |
    (col("source.CollectionAddressLine4") != col("target.CollectionAddressLine4")) |
    (col("source.CollectionPostcode") != col("target.CollectionPostcode")) |
    (col("source.CollectionAlternativePlace") != col("target.CollectionAlternativePlace")) |
    (col("source.CollectionCountryCode") != col("target.CollectionCountryCode")) |
    (col("source.CollectionCountryName") != col("target.CollectionCountryName"))
).select(
    col("source.CollectionCompany"),
    col("source.CollectionAddressLine1"),
    col("source.CollectionAddressLine2"),
    col("source.CollectionAddressLine3"),
    col("source.CollectionAddressLine4"),
    col("source.CollectionPostcode"),
    col("source.CollectionAlternativePlace"),
    col("source.CollectionCountryCode"),
    col("source.CollectionCountryName"),
    lit(current_datetime).alias("StartDate"),
    lit(None).cast("timestamp").alias("EndDate"),
    lit("Y").cast("boolean").alias("CurrentFlag")
)

changed_records.write.format("delta").mode("append").save("Tables/met_dim_collection")