01_ingest_eda.ipynb → inspect raw data, build patient_summary.

02_label_features.ipynb → compute labels & features.

03_train_xgb.ipynb → train/evaluate/log to MLflow, export ONNX.

04_monitoring_drift.ipynb → Evidently reports & gating logic.

# Bootstrap (env, Java, Spark)

In [42]:
import os
from dotenv import load_dotenv, find_dotenv
from pyspark.sql import SparkSession

In [43]:
# Load .env (find it even if notebook is under notebooks/)
load_dotenv(find_dotenv(usecwd=True), override=True)

# Ensure Java 17 is visible to this kernel (adjust if your path differs)
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["PATH"] = f'{os.environ["JAVA_HOME"]}/bin:' + os.environ["PATH"]

In [44]:
# Data directory
DATA_DIR = os.getenv("DATA_DIR")
assert DATA_DIR and os.path.isdir(DATA_DIR), f"DATA_DIR not set or invalid: {DATA_DIR}"
print("DATA_DIR =", DATA_DIR)

DATA_DIR = /home/utsajinlab/health_claims_ml/data/raw/synthea_1k/csv


In [45]:
# Start Spark (local)
spark = SparkSession.builder.appName("HealthClaimsEDA").getOrCreate()
print("Spark version:", spark.version)

Spark version: 4.0.1


# Load core tables

In [46]:
import os

def load_csv(name: str):
    return spark.read.csv(os.path.join(DATA_DIR, f"{name}.csv"), header=True, inferSchema=True)

# List of all CSVs
csv_files = [
    "claims", "observations", "medications", "careplans", "organizations", "claims_transactions",
    "supplies", "imaging_studies", "devices", "immunizations", "allergies", "payers",
    "conditions", "payer_transitions", "procedures", "encounters", "patients", "providers"
]

# Register them
all_tables = {}
for name in csv_files:
    all_tables[name] = load_csv(name)



In [47]:
# Show sample rows and schema for a few important ones
all_tables["patients"].printSchema()
all_tables["patients"].show(5)

all_tables["encounters"].printSchema()
all_tables["encounters"].show(5)

all_tables["conditions"].printSchema()
all_tables["conditions"].show(5)

all_tables["claims"].printSchema()
all_tables["claims"].show(5)

all_tables["procedures"].printSchema()
all_tables["procedures"].show(5)

all_tables["medications"].printSchema()
all_tables["medications"].show(5)

all_tables["observations"].printSchema()
all_tables["observations"].show(5)


root
 |-- Id: string (nullable = true)
 |-- BIRTHDATE: date (nullable = true)
 |-- DEATHDATE: date (nullable = true)
 |-- SSN: string (nullable = true)
 |-- DRIVERS: string (nullable = true)
 |-- PASSPORT: string (nullable = true)
 |-- PREFIX: string (nullable = true)
 |-- FIRST: string (nullable = true)
 |-- LAST: string (nullable = true)
 |-- SUFFIX: string (nullable = true)
 |-- MAIDEN: string (nullable = true)
 |-- MARITAL: string (nullable = true)
 |-- RACE: string (nullable = true)
 |-- ETHNICITY: string (nullable = true)
 |-- GENDER: string (nullable = true)
 |-- BIRTHPLACE: string (nullable = true)
 |-- ADDRESS: string (nullable = true)
 |-- CITY: string (nullable = true)
 |-- STATE: string (nullable = true)
 |-- COUNTY: string (nullable = true)
 |-- ZIP: integer (nullable = true)
 |-- LAT: double (nullable = true)
 |-- LON: double (nullable = true)
 |-- HEALTHCARE_EXPENSES: double (nullable = true)
 |-- HEALTHCARE_COVERAGE: double (nullable = true)

+--------------------+-----

# Register SQL views & quick sanity

In [48]:
for name in ["patients", "encounters", "conditions", "claims", "procedures", "medications", "observations"]:
    all_tables[name].createOrReplaceTempView(name)

spark.sql("""
SELECT 'patients' AS table, COUNT(*) AS n FROM patients
UNION ALL SELECT 'encounters', COUNT(*) FROM encounters
UNION ALL SELECT 'conditions', COUNT(*) FROM conditions
UNION ALL SELECT 'claims', COUNT(*) FROM claims
UNION ALL SELECT 'procedures', COUNT(*) FROM procedures
UNION ALL SELECT 'medications', COUNT(*) FROM medications
UNION ALL SELECT 'observations', COUNT(*) FROM observations
""").show()

spark.sql("""
SELECT gender, COUNT(*) AS n
FROM patients
GROUP BY gender
ORDER BY n DESC
""").show()

+------------+------+
|       table|     n|
+------------+------+
|    patients|  1163|
|  encounters| 61459|
|  conditions| 38094|
|      claims|117889|
|  procedures| 83823|
| medications| 56430|
|observations|531144|
+------------+------+

+------+---+
|gender|  n|
+------+---+
|     F|616|
|     M|547|
+------+---+



# Age bands, top encounters, top conditions, monthly trend

In [53]:
# A) Age bands
spark.sql("""
SELECT
  CASE
    WHEN FLOOR(DATEDIFF(current_date(), to_date(BIRTHDATE)) / 365.25) < 18 THEN '<18'
    WHEN FLOOR(DATEDIFF(current_date(), to_date(BIRTHDATE)) / 365.25) BETWEEN 18 AND 34 THEN '18-34'
    WHEN FLOOR(DATEDIFF(current_date(), to_date(BIRTHDATE)) / 365.25) BETWEEN 35 AND 49 THEN '35-49'
    WHEN FLOOR(DATEDIFF(current_date(), to_date(BIRTHDATE)) / 365.25) BETWEEN 50 AND 64 THEN '50-64'
    ELSE '65+'
  END AS age_band,
  COUNT(*) AS n
FROM patients
GROUP BY age_band
ORDER BY n DESC
""").show()

# B) Top encounter descriptions
spark.sql("""
SELECT description, COUNT(*) AS n
FROM encounters
GROUP BY description
ORDER BY n DESC
LIMIT 12
""").show()

# C) Top conditions (code + description)
spark.sql("""
SELECT code, description, COUNT(*) AS n
FROM conditions
GROUP BY code, description
ORDER BY n DESC
LIMIT 12
""").show()

# D) Monthly encounters — last 10 years (ignore very old simulated dates)
spark.sql("""
SELECT date_format(to_date(START), 'yyyy-MM') AS ym, COUNT(*) AS n
FROM encounters
WHERE year(to_date(START)) >= year(current_date()) - 10
GROUP BY ym
ORDER BY ym
""").show(60, truncate=False)

# Top Procedures
spark.sql("""
SELECT code, description, COUNT(*) AS n
FROM procedures
GROUP BY code, description
ORDER BY n DESC
LIMIT 10
""").show()

# Top Medications
spark.sql("""
SELECT code, description, COUNT(*) AS n
FROM medications
GROUP BY code, description
ORDER BY n DESC
LIMIT 10
""").show()

# Top Observations
spark.sql("""
SELECT code, description, COUNT(*) AS n
FROM observations
GROUP BY code, description
ORDER BY n DESC
LIMIT 10
""").show()


+--------+---+
|age_band|  n|
+--------+---+
|     65+|322|
|   18-34|248|
|   50-64|226|
|   35-49|196|
|     <18|171|
+--------+---+

+--------------------+-----+
|         description|    n|
+--------------------+-----+
|General examinati...|19374|
|Encounter for che...| 4778|
|Well child visit ...| 4507|
|Encounter for sym...| 4154|
|      Prenatal visit| 2804|
|Encounter for pro...| 2611|
|Encounter for pro...| 2571|
|Urgent care clini...| 2564|
| Follow-up encounter| 2489|
|Outpatient procedure| 2010|
|Patient encounter...| 1783|
|Consultation for ...| 1344|
+--------------------+-----+

+---------+--------------------+-----+
|     code|         description|    n|
+---------+--------------------+-----+
|160903007|Full-time employm...|13805|
| 73595000|    Stress (finding)| 5137|
|160904001|Part-time employm...| 2426|
|422650009|Social isolation ...| 1243|
|444814009|Viral sinusitis (...| 1233|
|423315002|Limited social co...| 1200|
|741062008|Not in labor forc...| 1077|
|70689300

# Build patient_summary temp views

In [54]:
# Age from birthdate
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_patients_basic AS
SELECT
  Id AS patient_id,
  gender,
  RACE       AS race,
  ETHNICITY  AS ethnicity,
  BIRTHDATE,
  FLOOR(DATEDIFF(current_date(), to_date(BIRTHDATE)) / 365.25) AS age_years
FROM patients
""")

# Encounters per patient (+ include cost if a cost-like column exists)
enc_cols = [c.lower() for c in encounters.columns]
enc_cost_col = next((c for c in ["total_claim_cost","base_encounter_cost","cost","encounter_cost"] if c in enc_cols), None)

enc_sql = f"""
CREATE OR REPLACE TEMP VIEW v_patient_encounters AS
SELECT
  patient AS patient_id,
  COUNT(*)            AS n_encounters,
  MIN(to_date(START)) AS first_enc_date,
  MAX(to_date(START)) AS last_enc_date
  {", ROUND(SUM(" + enc_cost_col + "),2) AS enc_total_cost" if enc_cost_col else ""}
FROM encounters
GROUP BY patient
"""
spark.sql(enc_sql)

# Conditions per patient
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_patient_conditions AS
SELECT patient AS patient_id, COUNT(*) AS n_conditions
FROM conditions
GROUP BY patient
""")

# Number of Procedures per Patient
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_patient_procedures AS
SELECT patient AS patient_id, COUNT(*) AS n_procedures
FROM procedures
GROUP BY patient
""")

# Number of Medications per Patient
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_patient_medications AS
SELECT patient AS patient_id, COUNT(*) AS n_medications
FROM medications
GROUP BY patient
""")

# Number of Observations per Patient
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_patient_observations AS
SELECT patient AS patient_id, COUNT(*) AS n_observations
FROM observations
GROUP BY patient
""")

DataFrame[]

# Optional claims (if present) and final join

In [65]:
# Claims (optional) and patient_summary join
import os

claims_path = os.path.join(DATA_DIR, "claims.csv")
if os.path.exists(claims_path):
    claims_df = spark.read.csv(claims_path, header=True, inferSchema=True)
    for c in claims_df.columns:
        claims_df = claims_df.withColumnRenamed(c, c.lower())
    claims_df.createOrReplaceTempView("claims")

    # total_claim_cost available?
    cl_cols = [c for c in claims_df.columns]
    if "total_claim_cost" in cl_cols:
        spark.sql("""
        CREATE OR REPLACE TEMP VIEW v_patient_claims AS
        SELECT patient AS patient_id, ROUND(SUM(total_claim_cost),2) AS total_claim_cost
        FROM claims
        GROUP BY patient
        """)
    else:
        spark.sql("""
        CREATE OR REPLACE TEMP VIEW v_patient_claims AS
        SELECT NULL AS patient_id, NULL AS total_claim_cost
        WHERE false
        """)
else:
    spark.sql("""
    CREATE OR REPLACE TEMP VIEW v_patient_claims AS
    SELECT NULL AS patient_id, NULL AS total_claim_cost
    WHERE false
    """)

patient_summary = spark.sql("""
SELECT
  p.patient_id,
  p.gender, p.race, p.ethnicity,
  p.age_years,

  -- Encounters
  COALESCE(e.n_encounters, 0) AS n_encounters,
  e.first_enc_date, e.last_enc_date,
  COALESCE(c.n_conditions, 0) AS n_conditions,

  -- Additional features
  COALESCE(pr.n_procedures, 0) AS n_procedures,
  COALESCE(m.n_medications, 0) AS n_medications,
  COALESCE(o.n_observations, 0) AS n_observations,

  -- Total cost from claims or fallback to encounters
  COALESCE(cl.total_claim_cost, e.enc_total_cost) AS total_cost

FROM v_patients_basic p
LEFT JOIN v_patient_encounters    e  ON e.patient_id = p.patient_id
LEFT JOIN v_patient_conditions    c  ON c.patient_id = p.patient_id
LEFT JOIN v_patient_procedures    pr ON pr.patient_id = p.patient_id
LEFT JOIN v_patient_medications   m  ON m.patient_id = p.patient_id
LEFT JOIN v_patient_observations  o  ON o.patient_id = p.patient_id
LEFT JOIN v_patient_claims        cl ON cl.patient_id = p.patient_id
""")


from pyspark.sql.functions import coalesce, lit

# Aggregate claims features
claims_features = spark.sql("""
SELECT
  patientid AS patient_id,
  COUNT(*) AS n_claims,
  COUNT(DISTINCT providerid) AS n_unique_providers,
  COUNT(DISTINCT departmentid) AS n_unique_departments,
  COUNT(diagnosis1) AS n_claims_with_diag,
  ROUND(AVG(COALESCE(outstanding1, 0) + COALESCE(outstanding2, 0) + COALESCE(outstandingp, 0)), 2) AS avg_outstanding_total,
  DATEDIFF(MAX(servicedate), MIN(servicedate)) AS claim_span_days
FROM claims
GROUP BY patientid
""")

from pyspark.sql.functions import coalesce, lit

patient_summary = patient_summary.join(claims_features, on="patient_id", how="left")

# Fill missing claim stats with 0 or -1 where needed
patient_summary = (
    patient_summary
    .withColumn("n_claims", coalesce("n_claims", lit(0)))
    .withColumn("n_unique_providers", coalesce("n_unique_providers", lit(0)))
    .withColumn("n_unique_departments", coalesce("n_unique_departments", lit(0)))
    .withColumn("n_claims_with_diag", coalesce("n_claims_with_diag", lit(0)))
    .withColumn("avg_outstanding_total", coalesce("avg_outstanding_total", lit(0.0)))
    .withColumn("claim_span_days", coalesce("claim_span_days", lit(-1)))  # span can be negative if dates are missing
)

patient_summary.createOrReplaceTempView("patient_summary")
patient_summary.show(10, truncate=False)

+------------------------------------+------+-----+-----------+---------+------------+--------------+-------------+------------+------------+-------------+--------------+----------+--------+------------------+--------------------+------------------+---------------------+---------------+
|patient_id                          |gender|race |ethnicity  |age_years|n_encounters|first_enc_date|last_enc_date|n_conditions|n_procedures|n_medications|n_observations|total_cost|n_claims|n_unique_providers|n_unique_departments|n_claims_with_diag|avg_outstanding_total|claim_span_days|
+------------------------------------+------+-----+-----------+---------+------------+--------------+-------------+------------+------------+-------------+--------------+----------+--------+------------------+--------------------+------------------+---------------------+---------------+
|d488232e-bf14-4bed-08c0-a82f34b6a197|F     |white|nonhispanic|22       |31          |2011-05-31    |2021-09-27   |10          |86      

# Save Parquet + validation

In [70]:
# persist & validate
out_dir = os.path.abspath(os.path.join(DATA_DIR, "..", "processed", "patient_summary_parquet"))
print("Saving parquet to:", out_dir)
patient_summary.write.mode("overwrite").parquet(out_dir)

# Basic checks
spark.sql("""
SELECT
  COUNT(*) AS n_patients,
  ROUND(AVG(age_years),1)      AS avg_age,
  ROUND(AVG(n_encounters),2)   AS avg_encounters,
  ROUND(AVG(n_conditions),2)   AS avg_conditions,
  ROUND(AVG(n_procedures),2)   AS avg_procedures,
  ROUND(AVG(n_medications),2)   AS avg_medications,
  ROUND(AVG(n_observations),2)   AS avg_observations,
  ROUND(AVG(n_claims),2)   AS avg_claims,
  ROUND(AVG(n_unique_providers),2)   AS avg_unique_providers,
  ROUND(AVG(n_claims_with_diag),2)   AS avg_claims_with_diag,
  ROUND(AVG(avg_outstanding_total),2)   AS avg_outstanding_total,
  ROUND(AVG(claim_span_days),2)   AS avg_claim_span_days,
  ROUND(AVG(COALESCE(total_cost,0)),2) AS avg_cost_nonnull
FROM patient_summary
""").show()

# Nulls overview
spark.sql("""
SELECT
  SUM(CASE WHEN patient_id   IS NULL THEN 1 ELSE 0 END) AS null_patient_id,
  SUM(CASE WHEN gender       IS NULL THEN 1 ELSE 0 END) AS null_gender,
  SUM(CASE WHEN age_years    IS NULL THEN 1 ELSE 0 END) AS null_age_years,
  SUM(CASE WHEN n_encounters IS NULL THEN 1 ELSE 0 END) AS null_n_encounters,
  SUM(CASE WHEN n_conditions IS NULL THEN 1 ELSE 0 END) AS null_n_conditions,
  SUM(CASE WHEN n_procedures IS NULL THEN 1 ELSE 0 END) AS null_n_procedures,
  SUM(CASE WHEN n_medications IS NULL THEN 1 ELSE 0 END) AS null_n_medications,
  SUM(CASE WHEN n_observations IS NULL THEN 1 ELSE 0 END) AS null_n_observations,
  SUM(CASE WHEN n_claims IS NULL THEN 1 ELSE 0 END) AS null_n_claims,
  SUM(CASE WHEN n_unique_providers IS NULL THEN 1 ELSE 0 END) AS null_n_unique_providers,
  SUM(CASE WHEN n_claims_with_diag IS NULL THEN 1 ELSE 0 END) AS null_n_claims_with_diag,
  SUM(CASE WHEN avg_outstanding_total IS NULL THEN 1 ELSE 0 END) AS null_avg_outstanding_total,
  SUM(CASE WHEN claim_span_days IS NULL THEN 1 ELSE 0 END) AS null_claim_span_days
FROM patient_summary
""").show()

Saving parquet to: /home/utsajinlab/health_claims_ml/data/raw/synthea_1k/processed/patient_summary_parquet
+----------+-------+--------------+--------------+--------------+---------------+----------------+----------+--------------------+--------------------+---------------------+-------------------+----------------+
|n_patients|avg_age|avg_encounters|avg_conditions|avg_procedures|avg_medications|avg_observations|avg_claims|avg_unique_providers|avg_claims_with_diag|avg_outstanding_total|avg_claim_span_days|avg_cost_nonnull|
+----------+-------+--------------+--------------+--------------+---------------+----------------+----------+--------------------+--------------------+---------------------+-------------------+----------------+
|      1163|   48.4|         52.85|         32.75|         72.07|          48.52|           456.7|    101.37|                 2.6|              101.37|                  0.0|           11636.87|       219289.62|
+----------+-------+--------------+--------------

25/10/27 11:36:33 WARN HostPort: Bad Authority: [jinutsa.utsarr.net/..]
25/10/27 11:36:33 WARN HttpChannel: /..\..\..\..\..\..\..\..\..\..\windows\win.ini
java.net.URISyntaxException: Illegal character in path at index 33: http://jinutsa.utsarr.net:4040/..\..\..\..\..\..\..\..\..\..\windows\win.ini
	at java.base/java.net.URI$Parser.fail(URI.java:2976)
	at java.base/java.net.URI$Parser.checkChars(URI.java:3147)
	at java.base/java.net.URI$Parser.parseHierarchical(URI.java:3229)
	at java.base/java.net.URI$Parser.parse(URI.java:3177)
	at java.base/java.net.URI.<init>(URI.java:623)
	at org.apache.spark.ui.JettyUtils$$anon$2.doRequest(JettyUtils.scala:153)
	at org.apache.spark.ui.JettyUtils$$anon$2.doGet(JettyUtils.scala:138)
	at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:500)
	at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:587)
	at org.sparkproject.jetty.servlet.ServletHolder.handle(ServletHolder.java:764)
	at org.sparkproject.jetty.servlet.ServletHandler$ChainE