In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag, to_date, datediff, current_timestamp
from pyspark.sql.window import Window
from delta.tables import DeltaTable

spark = SparkSession.builder.appName("SilverFactVisit").getOrCreate()

# -------------------------
# Paths
# -------------------------
bronze_visit_path = "/mnt/bronze/visit_raw"
silver_patient_path = "/mnt/silver/dim_patient"
silver_hospital_path = "/mnt/silver/dim_hospital"
silver_diag_path = "/mnt/silver/dim_diagnosis"
silver_fact_path = "/mnt/silver/fact_visit"
checkpoint_path = "/mnt/chk/fact_visit"

# -------------------------
# Read Bronze Visit (Incremental)
# -------------------------
df_visit_new = spark.read.format("delta").load(bronze_visit_path)

# -------------------------
# Read Dimension Tables (Silver)
# -------------------------
df_patient = spark.read.format("delta").load(silver_patient_path).select("patient_id", "patient_sk")
df_hospital = spark.read.format("delta").load(silver_hospital_path).select("hospital_id", "hospital_sk")
df_diag = spark.read.format("delta").load(silver_diag_path).select("diagnosis_code", "diagnosis_sk")

# -------------------------
# Join Visit with Dimensions to get SKs
# -------------------------
df_fact_new = (
    df_visit_new
        .join(df_patient, "patient_id", "left")
        .join(df_hospital, "hospital_id", "left")
        .join(df_diag, "diagnosis_code", "left")
        .withColumn("admission_date", to_date("admission_date"))
        .withColumn("discharge_date", to_date("discharge_date"))
        .withColumn("load_timestamp", current_timestamp())
)

# -------------------------
# Combine with existing fact (to calculate prev_discharge)
# -------------------------
if DeltaTable.isDeltaTable(spark, silver_fact_path):
    df_fact_existing = spark.read.format("delta").load(silver_fact_path)
    df_all = df_fact_existing.select("visit_id","patient_id","admission_date","discharge_date","cost",
                                     "hospital_id","diagnosis_code","patient_sk","hospital_sk","diagnosis_sk") \
                             .unionByName(df_fact_new.select("visit_id","patient_id","admission_date","discharge_date","cost",
                                                             "hospital_id","diagnosis_code","patient_sk","hospital_sk","diagnosis_sk"))
else:
    df_all = df_fact_new

# -------------------------
# Compute prev_discharge and readmission info
# -------------------------
window_patient = Window.partitionBy("patient_id").orderBy("admission_date")

df_fact_processed = (
    df_all
        .withColumn("prev_discharge", lag("discharge_date").over(window_patient))
        .withColumn("days_since_last_discharge", datediff(col("admission_date"), col("prev_discharge")))
        .withColumn("is_readmission_30d", (col("days_since_last_discharge") <= 30).cast("int"))
        .withColumn("load_timestamp", current_timestamp())
)

# -------------------------
# Merge into Silver fact_visit
# -------------------------
def merge_fact_visit(batch_df, batch_id):
    target = silver_fact_path

    if not DeltaTable.isDeltaTable(spark, target):
        batch_df.write.format("delta").save(target)
        return

    fact = DeltaTable.forPath(spark, target)
    fact.alias("t").merge(
        batch_df.alias("s"),
        "t.visit_id = s.visit_id"
    ) \
    .whenMatchedUpdateAll() \
    .whenNotMatchedInsertAll() \
    .execute()

# -------------------------
# Run as availableNow incremental
# -------------------------
(
    df_fact_processed.writeStream
        .foreachBatch(merge_fact_visit)
        .outputMode("update")
        .trigger(availableNow=True)
        .option("checkpointLocation", checkpoint_path)
        .start()
)
