In [0]:
from pyspark.sql.functions import *

In [0]:
# CDC column
cdc_column = "modifiedDate"

# Back-dated refresh
backdated_refresh = ""

# Catalog name
catalog = "flightsproject"

# Source object
source_object = "silver_bookings"

# Source schema
source_schema = "silver"

# Source fact table
fact_table = f"{catalog}.{source_schema}.{source_object}"

# Target schema
target_schema = "gold"

# Target object
target_object = "FactBookings"

# Fact key columns
fact_key_cols = ["DimPassengersKey", "DimFlightsKey", "DimAirportsKey", "booking_date"]

In [0]:
dimensions = [
    {
        "table": f"{catalog}.{target_schema}.DimPassengers",
        "alias": "DimPassengers",
        "join_keys": [("passenger_id", "passenger_id")] # fact_col, dim_col
    },
    {
        "table": f"{catalog}.{target_schema}.DimFlights",
        "alias": "DimFlights",
        "join_keys": [("flight_id", "flight_id")]
    },
    {
        "table": f"{catalog}.{target_schema}.DimAirports",
        "alias": "DimAirports",
        "join_keys": [("airport_id", "airport_id")]
    },
]

# columns you want to keep from fact table (besides the surrogate keys)
fact_columns = ["amount", "booking_date", "modifiedDate"]

### Last load date

In [0]:
# no back dated refresh
if len(backdated_refresh) == 0:
    # if table exists in destination
    if spark.catalog.tableExists(f"{target_schema}.{target_object}"):
        last_load = spark.sql(f"select max({cdc_col}) from {catalog}.{target_schema}.{target_object}").collect()[0][0]
    else:
        last_load = "1900-01-01 00:00:00" # initial load, so we can load everything
# yes back dated refresh
else:
    last_load = backdated_refresh

# test the last load
last_load

### Dynamic fact query [bring keys]

In [0]:
def generate_fact_query_incremental(fact_table, dimensions, fact_columns, cdc_column, processing_date):
    fact_alias = "f"

    # base columns to select
    select_cols = [f'{fact_alias}.{col}' for col in fact_columns]

    # build joins dynamically
    join_clauses = []

    for dim in dimensions:
        table_full = dim["table"]
        alias = dim["alias"]
        table_name = table_full.split(".")[-1]
        surrogate_key = f"{alias}.{table_name}Key"
        select_cols.append(surrogate_key)

        # build on clause
        on_conditions = [
            f"{fact_alias}.{fk} = {alias}.{dk}" for fk, dk in dim["join_keys"]
        ]
        join_clause = f"left join {table_full} {alias} on " + " and ".join(on_conditions)
        join_clauses.append(join_clause)

    # final select and join clauses
    select_clause = ",\n".join(select_cols)
    joins = "\n".join(join_clauses)

    # where clause for incremental filtering
    where_clause = f"{fact_alias}.{cdc_column} >= date('{last_load}')"  

    # final query
    query = f"""
        select
            {select_clause}
        from
            {fact_table} {fact_alias}
            {joins}
        where
            {where_clause}
    """.strip()
    return query

In [0]:
# showing query
query = generate_fact_query_incremental(fact_table, dimensions, fact_columns, cdc_column, last_load)
print(query)

### df_fact

In [0]:
df_fact = spark.sql(query)
display(df_fact)

### Upsert

In [0]:
# Fact key columns merge condition
fact_key_cols_str = " and ".join([f"src.{col} = tgt.{col}" for col in fact_key_cols])
fact_key_cols_str

In [0]:
from delta.tables import DeltaTable

if spark.catalog.tableExists(f"{catalog}.{target_schema}.{target_object}"):
    dlt_obj = DeltaTable.forName(spark, f"{catalog}.{target_schema}.{target_object}")
    
    dlt_obj.alias("tgt").merge(df_fact.alias("src"), fact_key_cols_str) \
        .whenMatchedUpdateAll(condition = f"src.{cdc_column} >= tgt.{cdc_column}") \
        .whenNotMatchedInsertAll() \
        .execute()
else:
    df_fact.write.format("delta") \
        .mode("append") \
        .saveAsTable(f"{catalog}.{target_schema}.{target_object}")

In [0]:
%sql
select * from flightsproject.gold.factbookings

In [0]:
# checking duplicates
df = spark.sql("select * from flightsproject.gold.dimairports").groupBy("DimAirportsKey").count().filter(col("count") > 1).display()

df = spark.sql("select * from flightsproject.gold.dimflights").groupBy("DimFlightsKey").count().filter(col("count") > 1).display()

df = spark.sql("select * from flightsproject.gold.dimpassengers").groupBy("DimPassengersKey").count().filter(col("count") > 1).display()