In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from delta.tables import DeltaTable
from pyspark.sql import Window

#### Create Watermark Table

In [0]:
if not spark.catalog.tableExists("project.sch.watermark_tbl"):
    schema = StructType([
        StructField("id", IntegerType(), True),
        StructField("last_processed_timestamp", TimestampType(), True)
    ])
    
    initial_data = [(1, None)]
    df_init = spark.createDataFrame(initial_data, schema)
    df_init.write.format("delta").saveAsTable("project.sch.watermark_tbl")

#### Read new records from silver layer

In [0]:
# ADLS configuration 
spark.conf.set(
  "fs.azure.account.key.2adls.dfs.core.windows.net",
  "<<Access_key>>"
)

# Paths
silver_path = "abfss://silver@2adls.core.windows.net/"
dim_patient_path = "abfss://gold@2adls.core.windows.net/dim_patient/"
dim_department_path = "abfss://gold@2adls.core.windows.net/dim_department/"
fact_tbl_path = "abfss://gold@2adls.core.windows.net/fact_tbl/"

# Read new data from silver layer
watermark_df = spark.table("project.sch.watermark_tbl").filter(col("id") == 1)
last_processed_val = watermark_df.collect()[0]["last_processed_timestamp"]

silver_df = (
     spark.read.format("delta").load(silver_path) 
    .filter((col("ingestion_time") > last_processed_val) | (lit(last_processed_val).isNull()))
)

# Stop execution if no. of record is less than 10
if silver_df.count() < 10:
    dbutils.notebook.exit("Stopping execution")

#### Patient Dimension Table

In [0]:
incoming_patient = ( silver_df.select("patient_id", "gender", "age")
                         .withColumn("effective_from", current_timestamp())
                   )

# Initial Run
if not DeltaTable.isDeltaTable(spark, dim_patient_path):
    incoming_patient.withColumn("surrogate_key", monotonically_increasing_id()) \
                    .withColumn("effective_to", lit(None).cast("timestamp")) \
                    .withColumn("is_current", lit(True)) \
                    .write.format("delta").mode("overwrite").save(dim_patient_path)

# Find & Update changed records
incoming_patient = incoming_patient.withColumn(
    "_hash",
    F.sha2(F.concat_ws("||", F.coalesce(col("gender"), lit("NA")), F.coalesce(col("age").cast("string"), lit("NA"))), 256)
)

target_patient = spark.read.format("delta").load(dim_patient_path).withColumn(
    "_target_hash",
    F.sha2(F.concat_ws("||", F.coalesce(col("gender"), lit("NA")), F.coalesce(col("age").cast("string"), lit("NA"))), 256)
)


changes_df = (
        target_patient.alias("t").join(
            incoming_patient.alias("i"),
            on="patient_id", 
            how="inner")
        .filter((col("t.is_current") == True) & (col("t._target_hash") != col("i._hash")))
        .select(col("t.surrogate_key"),col("i.effective_from"))
)


target_patient_tbl = DeltaTable.forPath(spark, dim_patient_path)

target_patient_tbl.alias("target").merge(
    source = changes_df.alias("updates"),
    condition = "target.surrogate_key = updates.surrogate_key"
).whenMatchedUpdate(set = {
    "is_current": "false",
    "effective_to": "updates.effective_from"
}).execute()


# Insert new and changed records
inserts_df = incoming_patient.alias("i").join(
                target_patient.alias("t"),
                on="patient_id",
                how="left")
            .filter((col("t.patient_id").isNull()) |  ((col("t.is_current") == True) & (col("i._hash") != col("t._target_hash"))))
            .select(col("i.patient_id"),col("i.gender"),col("i.age"),col("i.effective_from"))


max_surrogate_key = target_patient.select(max("surrogate_key")).collect()[0][0]


inserts_df.withColumn("surrogate_key", max_surrogate_key + monotonically_increasing_id() + 1) \
  .withColumn("effective_to", lit(None).cast("timestamp")) \
  .withColumn("is_current", lit(True)) \
  .select("surrogate_key", "patient_id", "gender", "age", "effective_from", "effective_to", "is_current")


if inserts_df.count() > 0:
    inserts_df.write.format("delta").mode("append").save(dim_patient_path)
    



#### Department Dimension Table

In [0]:
incoming_dept = silver_df.select("department", "hospital_id")
                
# Remove duplicates
incoming_dept = incoming_dept.dropDuplicates(["department", "hospital_id"]) 
    
# Initial Run
if not DeltaTable.isDeltaTable(spark, dim_department_path):
    incoming_dept = incoming_dept.withColumn("surrogate_key", monotonically_increasing_id())
    incoming_dept.select("surrogate_key", "department", "hospital_id") \
        .write.format("delta").mode("overwrite").save(dim_department_path)

# Incremental run
target_dept = spark.read.format("delta").load(dim_department_path)

new_records = incoming_dept.join(
    target_dept,
    on=["department", "hospital_id"],
    how="left_anti"
)

max_sk = target_dept.select(max("surrogate_key")).collect()[0][0]

incoming_dept = incoming_dept.withColumn("surrogate_key", max_sk + monotonically_increasing_id() + 1)

if incoming_dept.count() > 0:
    incoming_dept.select("surrogate_key", "department", "hospital_id") \
        .write.format("delta").mode("append").save(dim_department_path)

#### Fact Table

In [0]:
# Read current dimension tables
dim_patient_df = (spark.read.format("delta").load(dim_patient_path)
                  .filter(col("is_current") == True)
                  .select(col("surrogate_key").alias("surrogate_key_patient"), "patient_id", "gender", "age"))

dim_dept_df = (spark.read.format("delta").load(dim_department_path)
               .select(col("surrogate_key").alias("surrogate_key_dept"), "department", "hospital_id"))

# Build fact_df from silver_df
fact_df = (silver_df
             .select("patient_id", "department", "hospital_id", "admission_time", "discharge_time", "bed_id")
             .withColumn("admission_date", to_date("admission_time"))
            )

# Join to get surrogate keys
fact_df = (fact_df
                 .join(dim_patient_df, on="patient_id", how="left")
                 .join(dim_dept_df, on=["department", "hospital_id"], how="left")
                )

# Compute metrics
fact_df = fact_df.withColumn("length_of_stay_hours",
                        (unix_timestamp(col("discharge_time")) - unix_timestamp(col("admission_time"))) / 3600.0) \
                .withColumn("is_currently_admitted", when(col("discharge_time") > current_timestamp(), lit(True)).  otherwise(lit(False))) \
                             .withColumn("event_ingestion_time", current_timestamp())

# insert fact_id column
if not DeltaTable.isDeltaTable(spark, fact_tbl_path):
    fact_df = fact_df.withColumn("fact_id", monotonically_increasing_id())
else:
    max_fact_id = spark.read.format("delta").load(fact_tbl_path).select(max("fact_id")).collect()[0][0]
    fact_df = fact_df.withColumn("fact_id", max_fact_id + monotonically_increasing_id() + 1)

# Select columns
fact_df = fact_df.select("fact_id",
    col("surrogate_key_patient").alias("patient_sk"),
    col("surrogate_key_dept").alias("department_sk"),
    "admission_time",
    "discharge_time",
    "admission_date",
    "length_of_stay_hours",
    "is_currently_admitted",
    "bed_id",
    "event_ingestion_time"
)

# Write to fact table
fact_df.write.format("delta").mode("append").save(fact_tbl_path)

#### Update Watermark table

In [0]:
# Update Watermark Table
new_max_timestamp = silver_df.select(max("ingestion_time")).collect()[0][0]
spark.sql(f"""
            UPDATE project.sch.watermark_tbl 
            SET last_processed_timestamp = '{new_max_timestamp}' 
            WHERE id = 1
        """)