# Dimension Modeling

## Dim_model Sink - Initial and Incremental run

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import DateType, IntegerType, LongType
from pyspark.sql.window import Window

# Get source data
df_src = spark.sql('''
select distinct subcategory_key, subcategory_name
  from parquet.`abfss://silver@contosoprojectstorage.dfs.core.windows.net/contoso_sales`
''')

# Initialize target DataFrame
if spark.catalog.tableExists('contoso_catalog.gold.dim_subcategory'):
    df_tgt = spark.sql('''
        SELECT  subcategory_id, 
               subcategory_key, 
               subcategory_name,
                is_current,
                start_date,
                end_date
        FROM contoso_catalog.gold.dim_subcategory
    ''')
else:
    df_tgt = df_src.withColumn("subcategory_id", F.lit(0).cast(LongType()))\
                   .withColumn("is_current", F.lit(1).cast(IntegerType()))\
                   .withColumn("start_date", F.current_date())\
                   .withColumn("end_date", F.lit(None).cast(DateType()))\
                   .filter("1 = 0")  # Empty Schema

def generate_surrogate_key(df, start_value):
    w = Window.orderBy("subcategory_key")
    return df.withColumn("subcategory_id", F.row_number().over(w) + F.lit(start_value))

def apply_scd_type2_changes(df_src, df_tgt):
    # Define the final column schema for consistency
    final_columns = [
        "subcategory_id", 
        "subcategory_key", 
        "subcategory_name",
        "is_current",
        "start_date",
        "end_date"
    ]

    # Step 1: Handle empty target
    if df_tgt.rdd.isEmpty():
        print("Target DataFrame is empty. Initializing with source records.")
        initial_df = df_src.withColumn("subcategory_id", F.lit(0).cast(LongType()))\
                          .withColumn("start_date", F.current_date())\
                          .withColumn("end_date", F.lit(None).cast(DateType()))\
                          .withColumn("is_current", F.lit(1).cast(IntegerType()))
        return generate_surrogate_key(initial_df, 0).select(final_columns)

    # Step 2: Get max surrogate key
    max_surrogate_key = df_tgt.agg(F.max("subcategory_id")).collect()[0][0] or 0
    
    # Step 3: Join source and target
    src = df_src.alias("src")
    tgt = df_tgt.filter(F.col("is_current") == 1).alias("tgt")
    joined_df = src.join(tgt, "subcategory_key", "outer")

    # Step 4: Identify new records
    new_records = joined_df.filter(F.col("tgt.subcategory_id").isNull())\
                          .select("src.*")
    new_records_count = new_records.count()

    if new_records_count > 0:
        new_records = generate_surrogate_key(new_records, max_surrogate_key + 1)\
                     .withColumn("start_date", F.current_date())\
                     .withColumn("end_date", F.lit(None).cast(DateType()))\
                     .withColumn("is_current", F.lit(1).cast(IntegerType()))\
                     .select(final_columns)
    else:
        new_records = spark.createDataFrame([], df_tgt.schema)

    # Step 5: Identify changed records
    changed_records = joined_df.filter(
        (F.col("tgt.subcategory_id").isNotNull()) &
        (F.coalesce(F.col("src.subcategory_name") != F.col("tgt.subcategory_name"), F.lit(False)))
    )
    changed_records_count = changed_records.count()

    if changed_records_count > 0:
        # Step 5.1: Create new versions for changed records
        new_versions = changed_records.select("src.*")\
                                    .withColumn("start_date", F.current_date())\
                                    .withColumn("end_date", F.lit(None).cast(DateType()))\
                                    .withColumn("is_current", F.lit(1).cast(IntegerType()))
        start_key = max_surrogate_key + new_records_count + 1
        new_versions = generate_surrogate_key(new_versions, start_key).select(final_columns)

        # Step 5.2: Close old versions
        old_versions = df_tgt.join(
            changed_records.select("subcategory_key"), "subcategory_key", "inner"
        ).withColumn("end_date", F.when(
            F.col("is_current") == 1, F.date_sub(F.current_date(), 1)
        ).otherwise(F.col("end_date")))\
         .withColumn("is_current", F.lit(0))\
         .select(final_columns)
    else:
        new_versions = spark.createDataFrame([], df_tgt.schema)
        old_versions = spark.createDataFrame([], df_tgt.schema)

    # Step 6: Unchanged records
    unchanged_records = df_tgt.join(
        changed_records.select("subcategory_key"), "subcategory_key", "leftanti"
    ).select(final_columns)

    # Step 7: Combine results
    final_df = unchanged_records.union(new_versions).union(old_versions).union(new_records)

    return final_df.select(final_columns)

# Apply SCD Type 2 changes
result_df = apply_scd_type2_changes(df_src, df_tgt)
display(result_df)

subcategory_id,subcategory_key,subcategory_name,is_current,start_date,end_date
1,101,MP4&MP3,1,2025-01-16,
2,104,Recording Pen,1,2025-01-16,
3,106,Bluetooth Headphones,1,2025-01-16,
4,201,Televisions,1,2025-01-16,
5,202,VCD & DVD,1,2025-01-16,
6,203,Home Theater System,1,2025-01-16,
7,205,Car Video,1,2025-01-16,
9,303,Desktops,1,2025-01-16,
10,304,Monitors,1,2025-01-16,
11,305,Projectors & Screens,1,2025-01-16,


## Write data to catalog

In [0]:
result_df.write.format('delta')\
                .mode('overwrite')\
                    .option('path', 'abfss://gold@contosoprojectstorage.dfs.core.windows.net/dim_subcategory')\
                        .saveAsTable('contoso_catalog.gold.dim_subcategory')