In [0]:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year, month, sum as _sum, avg, count, desc, current_date, date_sub
from delta.tables import DeltaTable

# Initialize Spark
spark = SparkSession.builder.appName("MediPulseCapstone").getOrCreate()
spark.conf.set("spark.sql.legacy.timeParserPolicy","LEGACY")


# Create Sample Files (DBFS)

base_path = "/tmp/medi_pulse/"
dbutils.fs.rm(base_path, True)
dbutils.fs.mkdirs(base_path)

# Patients CSV
patients_csv = """patient_id,name,age,gender,region
P001,Arjun Mehta,34,M,North
P002,Neha Sharma,29,F,South
P003,Rahul Gupta,40,M,East
P004,Sneha Nair,25,F,West
"""
dbutils.fs.put(base_path + "patients.csv", patients_csv, overwrite=True)

# Hospitals JSON
hospitals_json = '\n'.join([
    '{"hospital_id":"H001","hospital_name":"City Care","region":"North"}',
    '{"hospital_id":"H002","hospital_name":"LifePlus","region":"South"}',
    '{"hospital_id":"H003","hospital_name":"MediHope","region":"East"}',
    '{"hospital_id":"H004","hospital_name":"CureWell","region":"West"}'
])
dbutils.fs.put(base_path + "hospitals.json", hospitals_json, overwrite=True)

# Appointments Day 1 CSV
appointments_day1 = """appointment_id,patient_id,hospital_id,appointment_date,diagnosis,cost,status
A1001,P001,H001,2024-01-10,Diabetes,400,Completed
A1002,P002,H002,2024-01-11,Flu,250,Completed
A1003,P003,H003,2024-01-11,Heart Disease,1000,Pending
A1004,P004,H004,2024-01-12,Allergy,300,Completed
"""
dbutils.fs.put(base_path + "appointments_day1.csv", appointments_day1, overwrite=True)

# Appointments Day 2 CSV
appointments_day2 = """appointment_id,patient_id,hospital_id,appointment_date,diagnosis,cost,status
A1005,P001,H001,2024-02-05,Diabetes,450,Completed
A1006,P003,H003,2024-02-06,Cardiology,1500,Completed
A1003,P003,H003,2024-01-11,Heart Disease,1000,Completed
"""
dbutils.fs.put(base_path + "appointments_day2.csv", appointments_day2, overwrite=True)

# File paths
raw_patients = base_path + "patients.csv"
raw_hospitals = base_path + "hospitals.json"
raw_appointments = base_path + "appointments_day1.csv"
raw_appointments_day2 = base_path + "appointments_day2.csv"


# Bronze Layer: Raw Ingestion
bronze_patients = spark.read.csv(raw_patients, header=True, inferSchema=True)
bronze_hospitals = spark.read.json(raw_hospitals)
bronze_appointments = spark.read.csv(raw_appointments, header=True, inferSchema=True)

bronze_patients.write.format("delta").mode("overwrite").save("/tmp/bronze_patients")
bronze_hospitals.write.format("delta").mode("overwrite").save("/tmp/bronze_hospitals")
bronze_appointments.write.format("delta").mode("overwrite").save("/tmp/bronze_appointments")


# Silver Layer: Data Cleansing & Transformation

# Rename columns to avoid duplicates
patients_df = bronze_patients.withColumnRenamed("region", "patient_region")
hospitals_df = bronze_hospitals.withColumnRenamed("region", "hospital_region")

# Alias tables to avoid ambiguity after join
patients_alias = patients_df.alias("p")
hospitals_alias = hospitals_df.alias("h")

silver_appointments = bronze_appointments.filter(col("status") == "Completed") \
    .join(patients_alias, "patient_id") \
    .join(hospitals_alias, "hospital_id") \
    .withColumn("year", year(col("appointment_date"))) \
    .withColumn("month", month(col("appointment_date"))) \
    .select(
        col("appointment_id"),
        col("patient_id"),
        col("hospital_id"),
        col("appointment_date"),
        col("diagnosis"),
        col("cost"),
        col("status"),
        col("p.patient_region").alias("patient_region"),
        col("h.hospital_region").alias("hospital_region"),
        col("h.hospital_name").alias("hospital_name"),
        col("year"),
        col("month")
    )

silver_appointments.write.format("delta").mode("overwrite").save("/tmp/silver_appointments")

# Gold Layer: Analytical Aggregations

gold_revenue = silver_appointments.groupBy("hospital_name").agg(_sum("cost").alias("total_revenue"))
gold_revenue.show()

gold_patients_region = silver_appointments.groupBy("hospital_region").agg(count("patient_id").alias("total_patients"))
gold_patients_region.show()

gold_top_diagnosis = silver_appointments.groupBy("diagnosis") \
    .agg(_sum("cost").alias("total_cost")) \
    .orderBy(desc("total_cost")) \
    .limit(3)
gold_top_diagnosis.show()

gold_summary = silver_appointments.groupBy("hospital_name", "hospital_region", "diagnosis") \
    .agg(_sum("cost").alias("total_cost"), count("appointment_id").alias("num_appointments"))
gold_summary.write.format("delta").mode("overwrite").save("/tmp/gold_healthcare_summary")


# Incremental Load Simulation

appointments_day2_df = spark.read.csv(raw_appointments_day2, header=True, inferSchema=True)

silver_table = DeltaTable.forPath(spark, "/tmp/silver_appointments")
silver_table.alias("silver").merge(
    appointments_day2_df.alias("new"),
    "silver.appointment_id = new.appointment_id"
).whenMatchedUpdate(
    set={
        "patient_id": col("new.patient_id"),
        "hospital_id": col("new.hospital_id"),
        "appointment_date": col("new.appointment_date"),
        "diagnosis": col("new.diagnosis"),
        "cost": col("new.cost"),
        "status": col("new.status")
    }
).whenNotMatchedInsert(
    values={
        "appointment_id": col("new.appointment_id"),
        "patient_id": col("new.patient_id"),
        "hospital_id": col("new.hospital_id"),
        "appointment_date": col("new.appointment_date"),
        "diagnosis": col("new.diagnosis"),
        "cost": col("new.cost"),
        "status": col("new.status")
    }
).execute()

updated_silver = spark.read.format("delta").load("/tmp/silver_appointments") \
    .join(patients_alias, "patient_id") \
    .join(hospitals_alias, "hospital_id") \
    .withColumn("year", year(col("appointment_date"))) \
    .withColumn("month", month(col("appointment_date"))) \
    .select(
        col("appointment_id"),
        col("patient_id"),
        col("hospital_id"),
        col("appointment_date"),
        col("diagnosis"),
        col("cost"),
        col("status"),
        col("p.patient_region").alias("patient_region"),
        col("h.hospital_region").alias("hospital_region"),
        col("h.hospital_name").alias("hospital_name"),
        col("year"),
        col("month")
    )

updated_gold_summary = updated_silver.groupBy("hospital_name", "hospital_region", "diagnosis") \
    .agg(_sum("cost").alias("total_cost"), count("appointment_id").alias("num_appointments"))
updated_gold_summary.show()


# Delta Lake Features

gold_before_update = spark.read.format("delta").option("versionAsOf", 0).load("/tmp/gold_healthcare_summary")
gold_before_update.show()

delta_table = DeltaTable.forPath(spark, "/tmp/gold_healthcare_summary")
delta_table.vacuum(168)
delta_table.optimize().executeZOrderBy("hospital_name")

# Analytical Queries

# Total revenue per hospital
updated_gold_summary.groupBy("hospital_name").agg(_sum("total_cost").alias("revenue")).show()

# Average cost per diagnosis
updated_silver.groupBy("diagnosis").agg(avg("cost").alias("avg_cost")).show()

# Number of patients per hospital region
updated_silver.groupBy("hospital_region").agg(count("patient_id").alias("num_patients")).show()

# Trend of appointments month-over-month
updated_silver.groupBy("year","month").agg(count("appointment_id").alias("num_appointments")) \
    .orderBy("year","month").show()

# Top 5 most expensive treatments last 6 months
six_months_ago = date_sub(current_date(), 180)
updated_silver.filter(col("appointment_date") >= six_months_ago) \
    .groupBy("diagnosis") \
    .agg(_sum("cost").alias("total_cost")) \
    .orderBy(desc("total_cost")) \
    .limit(5).show()


Wrote 143 bytes.
Wrote 266 bytes.
Wrote 275 bytes.
Wrote 236 bytes.
+-------------+-------------+
|hospital_name|total_revenue|
+-------------+-------------+
|    City Care|          400|
|     CureWell|          300|
|     LifePlus|          250|
+-------------+-------------+

+---------------+--------------+
|hospital_region|total_patients|
+---------------+--------------+
|           West|             1|
|          North|             1|
|          South|             1|
+---------------+--------------+

+---------+----------+
|diagnosis|total_cost|
+---------+----------+
| Diabetes|       400|
|  Allergy|       300|
|      Flu|       250|
+---------+----------+

+-------------+---------------+-------------+----------+----------------+
|hospital_name|hospital_region|    diagnosis|total_cost|num_appointments|
+-------------+---------------+-------------+----------+----------------+
|     MediHope|           East|   Cardiology|      1500|               1|
|     MediHope|           East|