In [0]:
from pyspark.sql.functions import *

In [0]:
source_schema = 'silver'
source_object = 'silver_bookings'
target_schema = 'gold'
target_object = 'fact_bookings'
cdc_column = 'modified_date'
backdated_refresh = ""
source_fact_table = f"workspace.{source_schema}.{source_object}"
target_fact_table = f"workspace.{target_schema}.{target_object}"

fact_key_cols = ['dim_flights_key', 'dim_passengers_key', 'dim_airports_key']



In [0]:
dimensions = [
    {
        "table": f"workspace.{target_schema}.dim_passengers",
        "alias": "dim_passengers",
        "join_keys": [("passenger_id", "passenger_id")]  # (fact_col, dim_col)
    },
    {
        "table": f"workspace.{target_schema}.dim_flights",
        "alias": "dim_flights",
        "join_keys": [("flight_id", "flight_id")]
    },
    {
        "table": f"workspace.{target_schema}.dim_airports",
        "alias": "dim_airports",
        "join_keys": [("airport_id", "airport_id")] 
    },
]

# Columns to keep in Fact table (besides the surrogate keys)
fact_columns = ["amount","booking_date"]

In [0]:
if not backdated_refresh:

    if spark.catalog.tableExists(target_fact_table):
        last_load = spark.sql(f"SELECT max({cdc_column}) as last_load FROM {target_fact_table}").first()[0]
    
    else:
        last_load = "1900-01-01 00:00:00"
else:
    last_load = backdated_refresh

print(f"Last load: {last_load}")


In [0]:
# dynamic fact query


def generate_fact_query_incremental(fact_table, dimensions, fact_columns, last_load):

    fact_alias = "f"

    surrogate_key = []
    select_clause = []
    joins = []
    facts = []
    for dim in dimensions:
        surrogate_key.append(f"{dim['alias']}.{dim['alias']}_key")
        on_condition= [f"{fact_alias}.{fk} = {dim['alias']}.{dk}" for fk, dk in dim['join_keys']]
        joins.append(f"LEFT JOIN {dim['table']} {dim['alias']} ON {' AND '.join(on_condition)}")
    
    
    select_clause = ",\n\t".join([f"{fact_alias}.{col}" for col in fact_columns] + surrogate_key)

    join_condition = "\n".join(joins)

    where_clause = f"{fact_alias}.{cdc_column} >= '{last_load}'"

    query = f"""
SELECT
    {select_clause}
FROM {fact_table} {fact_alias}
{join_condition}
WHERE {where_clause}
""".strip()

    return query


In [0]:
query = generate_fact_query_incremental(source_fact_table, dimensions, fact_columns, last_load)
print(query)

In [0]:
df_fact = spark.sql(query)
df_fact.display()

In [0]:
# upsert

fact_key_cols_str = " AND ".join([f"trg.{col} = src.{col} " for col in fact_key_cols])
fact_key_cols_str

In [0]:
from delta.tables import DeltaTable

# MERGE INTO command

if spark.catalog.tableExists(target_fact_table):
    DeltaTable.forName(target_fact_table).alias("trg").merge(
        df_fact.alias("src"),
        fact_key_cols_str,
    ).whenMatchedUpdateAll(condition=f"src.{cdc_column} >= trg.{cdc_column}").whenNotMatchedInsertAll().execute()
else:
    df_fact.write.format("delta").mode("append").saveAsTable(target_fact_table)


In [0]:
%sql
SELECT * FROM workspace.gold.fact_bookings