In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, avg, sum as spark_sum, countDistinct, round

spark = SparkSession.builder.appName("GoldPatientReadmission").getOrCreate()

# -------------------------
# Paths
# -------------------------
silver_fact_path = "/mnt/silver/fact_visit"
silver_patient_path = "/mnt/silver/dim_patient"
silver_hospital_path = "/mnt/silver/dim_hospital"
silver_diag_path = "/mnt/silver/dim_diagnosis"
gold_path = "/mnt/gold/gold_patient_readmission"
checkpoint_path = "/mnt/chk/gold_patient_readmission"

# -------------------------
# Read Silver Tables
# -------------------------
df_fact = spark.read.format("delta").load(silver_fact_path)
df_patient = spark.read.format("delta").load(silver_patient_path)
df_hospital = spark.read.format("delta").load(silver_hospital_path)
df_diag = spark.read.format("delta").load(silver_diag_path)

# -------------------------
# Join Fact with Dimensions
# -------------------------
df_gold_base = (
    df_fact
        .join(df_patient.select("patient_sk", "patient_id_masked", "age_group", "gender"), "patient_sk", "left")
        .join(df_hospital.select("hospital_sk", "hospital_name", "city", "state"), "hospital_sk", "left")
        .join(df_diag.select("diagnosis_sk", "diagnosis_name", "risk_category"), "diagnosis_sk", "left")
)

# -------------------------
# Derive KPIs
# -------------------------
df_gold = (
    df_gold_base
        .withColumn("was_readmitted", when(col("is_readmission_30d") == 1, "Yes").otherwise("No"))
        # Simple risk score example
        .withColumn("risk_score",
            (when(col("age_group") == "65+", 2).otherwise(0) +
             when(col("days_since_last_discharge") <= 30, 3).otherwise(0) +
             when(col("risk_category") == "High", 3).otherwise(0))
        )
        .withColumn("load_timestamp", col("load_timestamp"))
        .select(
            "patient_id_masked",
            "age_group",
            "gender",
            "hospital_name",
            "city",
            "state",
            "diagnosis_name",
            "admission_date",
            "discharge_date",
            "prev_discharge",
            "days_since_last_discharge",
            "was_readmitted",
            "risk_score",
            "cost",
            "load_timestamp"
        )
)

# -------------------------
# Merge into Gold Table Incrementally
# -------------------------
from delta.tables import DeltaTable

def merge_gold_patient(batch_df, batch_id):
    target = gold_path

    if not DeltaTable.isDeltaTable(spark, target):
        batch_df.write.format("delta").mode("overwrite").save(target)
        return

    gold = DeltaTable.forPath(spark, target)

    gold.alias("t").merge(
        batch_df.alias("s"),
        "t.patient_id_masked = s.patient_id_masked AND t.admission_date = s.admission_date"
    ) \
    .whenMatchedUpdateAll() \
    .whenNotMatchedInsertAll() \
    .execute()

# -------------------------
# Streaming Write (availableNow)
# -------------------------
(
    df_gold.writeStream
        .foreachBatch(merge_gold_patient)
        .outputMode("update")
        .trigger(availableNow=True)
        .option("checkpointLocation", checkpoint_path)
        .start()
)


## Optimized gold layer

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from delta.tables import DeltaTable
from pyspark.sql.streaming import DataStreamWriter

spark = SparkSession.builder.appName("GoldSimple").getOrCreate()

silver_fact_path = "/mnt/silver/fact_visit"
gold_path = "/mnt/gold/gold_patient_readmission"
checkpoint = "/mnt/chk/gold_patient_readmission"

# Read silver fact
df_fact = spark.read.format("delta").load(silver_fact_path)

# Create Gold table with KPIs
df_gold = (
    df_fact
    .withColumn("was_readmitted", when(col("is_readmission_30d") == 1, "Yes").otherwise("No"))
    .withColumn("risk_score",
        (when(col("age_group") == "65+", 2).otherwise(0) +
         when(col("days_since_last_discharge") <= 30, 3).otherwise(0) +
         when(col("risk_category") == "High", 3).otherwise(0))
    )
    .select(
        "patient_id_masked",
        "age_group",
        "gender",
        "hospital_name",
        "city",
        "diagnosis_name",
        "admission_date",
        "discharge_date",
        "prev_discharge",
        "days_since_last_discharge",
        "was_readmitted",
        "risk_score",
        "cost"
    )
)

# Incremental MERGE logic
def merge_gold(batch_df, batch_id):
    if not DeltaTable.isDeltaTable(spark, gold_path):
        batch_df.write.format("delta").mode("overwrite").save(gold_path)
    else:
        gold = DeltaTable.forPath(spark, gold_path)
        gold.alias("t").merge(
            batch_df.alias("s"),
            "t.patient_id_masked = s.patient_id_masked AND t.admission_date = s.admission_date"
        ).whenMatchedUpdateAll()\
         .whenNotMatchedInsertAll()\
         .execute()

# Streaming write (available now, single/multiple files incremental)
df_gold.writeStream.foreachBatch(merge_gold)\
    .outputMode("update")\
    .trigger(availableNow=True)\
    .option("checkpointLocation", checkpoint)\
    .start()
