Setting the things up

In [0]:
import pyspark.sql.functions as F
from pyspark.sql import Window as W
import pyspark.sql.types as T


In [0]:
# Install the H3 library (only once per cluster)
%pip install h3

import h3
import pyspark.sql.functions as F
from pyspark.sql.types import StringType

In [0]:
# Configuration parameters for `Step 4`
targetDB = "gold"
targetTableName = "dim_franchise"
sourceView = "vw_source_franchise"
naturalKey = "franchiseID"
businessColumns = "franchiseID,name,city,district,zipcode,country,size,longitude,latitude,supplierID,location_index"
#spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true") # normally I would enable schema evolution, but this configuration isn't available in Free Edition

Loading Data

In [0]:
df_franchises = spark.read.table("samples.bakehouse.sales_franchises")
df_franchises.printSchema()
df_franchises.limit(3).toPandas()

In [0]:
(
    df_franchises
        .groupBy('size')
        .agg(
            F.min('longitude').alias('min_longitude'),
            F.max('longitude').alias('max_longitude'),
            F.min('latitude').alias('min_latitude'),
            F.max('latitude').alias('max_latitude')
        ).distinct()
        .orderBy('size')
        .show(50)
)

`latitude/longitude` extreme values + `size` values look good

In [0]:
df_franchises.select('country', 'city').distinct().orderBy('country', 'city').show(50)

cities look good

In [0]:
( # Count NULL values per each column of a dataframe
    df_franchises
        .select(
            [F.count(F.when(F.col(c).isNull(),  c)).alias(c) for c in df_franchises.columns]
        ).toPandas()
)

In [0]:
( # Count distinct values per each column of a dataframe
    df_franchises
        .select(
            [F.countDistinct(F.col(c)).alias(c) for c in df_franchises.columns]
        ).toPandas()
)

Transformations

In [0]:
# Standardizing some fields
def clean_str(colname):
    return F.initcap(F.trim(colname))

df_franchises_standardized = (
    df_franchises
        .withColumn("name", clean_str("name")) 
        .withColumn("city", clean_str("city")) 
        .withColumn("district", clean_str("district")) 
        .withColumn("country", F.when(F.col("country") == "US", "United States").otherwise(F.col("country")))
        .withColumn("country", clean_str("country")) 
)

df_franchises_standardized.limit(10).toPandas()

add a geohash or location_h3 index for spatial joins/analytics

In [0]:
# h3.latlng_to_cell(lat, lon, resolution)
h3_udf = F.udf(lambda lat, lon: h3.latlng_to_cell(lat, lon, 9), StringType())

df_franchises_standardized_h3 = (
    df_franchises_standardized
        .withColumn("location_index", h3_udf(F.col("latitude"), F.col("longitude")))
)

df_franchises_standardized_h3.limit(10).toPandas()

Referential Integrity w/ Suppliers

In [0]:
df_franchises_standardized_h3.createOrReplaceTempView("franchises")

query = """
SELECT f.*
FROM franchises f
LEFT JOIN samples.bakehouse.sales_suppliers s
  ON f.supplierID = s.supplierID
WHERE
  s.supplierID IS NULL
"""

df_orphan_franchises = spark.sql(query)
df_orphan_franchises.limit(10).toPandas()
#df_orphan_franchises.count()

In [0]:
# Create the schema if it does not exist
spark.sql("CREATE SCHEMA IF NOT EXISTS dq_logs")

dq_table = "dq_logs.orphan_franchises"

(df_orphan_franchises
    .withColumn("logged_at", F.current_timestamp())
    .write
    .mode("append")
    .format("delta")
    .saveAsTable(dq_table)
)


Surrogate Key Assignment

**Why**
- Business data does not come with a stable surrogate key.
- We need a monotonically increasing customer_key that stays unique across pipeline runs.

**How**
- Get the current max customer_key from the target table.
  - If table doesn’t exist → start from 0
  - If table exists but empty → start from 0
  - Otherwise → continue from last max key
- Assign new keys using row_number() + offset = max_key

**Result**
- Keys are consistent and do not restart from 1 on every run
- Idempotent: safe whether table is empty, existing, or absent

In [0]:
# Get current max customer_key from target table
try:
    max_key = (
        spark.table(f"{targetDB}.{targetTableName}") 
            .agg(F.max("franchise_key").alias("max_key")) 
            .collect()[0]["max_key"] 
    )
    if max_key is None:
        max_key = 0
except Exception as e:
    max_key = 0  # Table does not exist yet

df_franchises_standardized_h3_PK = (
    df_franchises_standardized_h3
        .withColumn(
            "franchise_key", 
            F.row_number().over(W.orderBy("franchiseID")) + max_key)
        .select("franchise_key", *df_franchises_standardized_h3.columns)
)

df_franchises_standardized_h3_PK.orderBy("franchise_key").limit(10).toPandas()

Techincal Columns

In [0]:
df_franchises_standardized_h3_PK.createOrReplaceTempView(sourceView)

# Creating a target table if it doesn't exist
tableDDL = f"""
CREATE TABLE IF NOT EXISTS {targetDB}.{targetTableName} 
USING DELTA 
PARTITIONED BY (size) 
TBLPROPERTIES ('delta.enableChangeDataFeed'='true')
AS 
SELECT
    row_number() OVER (ORDER BY {naturalKey}) AS franchise_key, -- surrogate key
    {businessColumns},
    1 as is_active,                     -- SCD Type2 Fields
    current_timestamp() as valid_from,  -- SCD Type2 Fields
    current_timestamp() as valid_to     -- SCD Type2 Fields
FROM {sourceView}
WHERE 1=0"""

#print(tableDDL)
spark.sql(tableDDL)

SCD Type 2

In [0]:
target_table = f"{targetDB}.{targetTableName}"

try:
    df_target = spark.table(target_table)
except:
    df_target = None

df_source = df_franchises_standardized_h3_PK


In [0]:
df_target.head()

In [0]:
join_condition = "src.franchiseID = tgt.franchiseID"

df_joined = (
    df_source.alias("src")
    .join(df_target.alias("tgt"), F.expr(join_condition), "left")
    .select(
        "src.*",
    #    "tgt.franchiseID as tgt_franchiseID",
    #    "tgt.hash_diff as tgt_hash_diff",
        F.lit(1).alias("tgt_is_current"),
        F.current_timestamp().alias("tgt_valid_from"),
        F.lit('9999-12-31').alias("tgt_valid_to")
    )
)

df_joined.limit(10).toPandas()


In [0]:
df_changes = (
    df_joined
    .withColumn("record_flag", 
                    F.when(F.col("tgt_hash_diff").isNull(), "INSERT")  # new
                .when(F.col("src.hash_diff") != F.col("tgt_hash_diff"), "UPDATE") # changed
                .otherwise("UNCHANGED"))                                          # same
)
