In [14]:
import os
from dotenv import load_dotenv, find_dotenv
from pyspark.sql import SparkSession, functions as F, types as T

## Environment setup and Helper to load CSVs

In [15]:

# ---- env (DATA_DIR) ----
load_dotenv(find_dotenv(usecwd=True), override=True)

DATA_DIR = os.getenv("DATA_DIR")
assert DATA_DIR and os.path.isdir(DATA_DIR), f"DATA_DIR not set or invalid: {DATA_DIR}"

# ---- Spark ----
spark = (SparkSession.builder
    .appName("HealthClaims_LabelFeatures")
    .config("spark.driver.memory", "4g")  # Increase if you have more RAM
    .config("spark.executor.memory", "4g")
    .config("spark.sql.shuffle.partitions", "20")  # Reduce from default 200 for small data
    .config("spark.driver.maxResultSize", "2g")
    .config("spark.sql.adaptive.enabled", "true")  # Enable adaptive query execution
    .getOrCreate())

# helper
def load_csv(name: str):
    return spark.read.csv(os.path.join(DATA_DIR, f"{name}.csv"), header=True, inferSchema=True)

## Load patient summary and raw tables

In [16]:
# Load patient_summary Parquet
ps_dir = os.path.abspath(os.path.join(DATA_DIR, "..", "processed", "patient_summary_parquet"))
patient_summary = spark.read.parquet(ps_dir)
patient_summary.createOrReplaceTempView("patient_summary")

def load_csv(name: str):
    return spark.read.csv(os.path.join(DATA_DIR, f"{name}.csv"), header=True, inferSchema=True)
# Load CSVs
patients     = load_csv("patients");     patients.createOrReplaceTempView("patients")
encounters   = load_csv("encounters");   encounters.createOrReplaceTempView("encounters")
conditions   = load_csv("conditions");   conditions.createOrReplaceTempView("conditions")
procedures   = load_csv("procedures");   procedures.createOrReplaceTempView("procedures")
medications  = load_csv("medications");  medications.createOrReplaceTempView("medications")
observations = load_csv("observations"); observations.createOrReplaceTempView("observations")
claims       = load_csv("claims");       claims.createOrReplaceTempView("claims")


##  Standardize encounter costs

In [17]:
enc_cols_lower = {c.lower(): c for c in encounters.columns}
for candidate in ["total_claim_cost", "base_encounter_cost", "encounter_cost", "cost"]:
    if candidate in enc_cols_lower:
        cost_col_real = enc_cols_lower[candidate]
        break
else:
    raise ValueError(f"No cost column found in encounters: {encounters.columns}")

enc_costs = (
    encounters
    .select(
        F.col("PATIENT").alias("patient_id"),
        F.to_date("START").alias("svc_date"),
        F.col(cost_col_real).cast("double").alias("cost")
    )
    .where(F.col("svc_date").isNotNull() & F.col("cost").isNotNull())
    .cache()  
)
enc_costs.count()  
enc_costs.createOrReplaceTempView("enc_costs")

## Compute dataset end and parameterize lookahead window

In [18]:
DATA_END_ENC = spark.sql("SELECT MAX(svc_date) AS max_dt FROM enc_costs").first().max_dt
DATA_END_CLAIM = (
    claims.select(F.to_date("servicedate").alias("svc_date"))
          .agg(F.max("svc_date").alias("max_dt")).first().max_dt
)
DATA_END = max(DATA_END_ENC, DATA_END_CLAIM)

LOOKAHEAD_DAYS = 365   # change freely: e.g., 90, 180, 365

valid_index_max_expr = F.date_sub(F.lit(DATA_END), LOOKAHEAD_DAYS)

print(f"DATA_END        : {DATA_END}")
print(f"LOOKAHEAD_DAYS  : {LOOKAHEAD_DAYS}")
print("valid_index_max :", spark.range(1).select(valid_index_max_expr.alias("valid_index_max")).first().valid_index_max)

DATA_END        : 2021-11-19
LOOKAHEAD_DAYS  : 365
valid_index_max : 2020-11-19


## Create index_date (anchor per patient) and filter fully observed ones

In [19]:
# Create index_date (anchor per patient) and filter fully observed ones
idx = (
    patient_summary
    .select(
        "patient_id",
        "last_enc_date",
        F.date_sub(F.col("last_enc_date"), LOOKAHEAD_DAYS).alias("index_date")
    )
)
idx_valid = idx.where(F.col("index_date") <= valid_index_max_expr).cache()
idx_valid.count()  
idx_valid.createOrReplaceTempView("v_index_date")
# Demographics + index join
v_demo_valid = (
    patient_summary.alias("ps")
    .join(idx_valid.alias("idx"), on="patient_id", how="inner")
    .select(
        F.col("ps.patient_id"),
        F.col("ps.gender"),
        F.col("ps.race"),
        F.col("ps.ethnicity"),
        F.col("ps.age_years").alias("age_at_index"),
        F.col("idx.index_date"),
        F.col("ps.last_enc_date")
    )
    .cache()  
)
v_demo_valid.count()  
v_demo_valid.createOrReplaceTempView("v_demo_valid")

## Build label = total cost in next LOOKAHEAD_DAYS after index_date

In [20]:
label = (
    enc_costs.join(v_demo_valid.select("patient_id", "index_date"), "patient_id")
    .where(
        (F.col("svc_date") > F.col("index_date")) &
        (F.col("svc_date") <= F.date_add(F.col("index_date"), LOOKAHEAD_DAYS))
    )
    .groupBy("patient_id")
    .agg(F.round(F.sum("cost"), 2).alias("cost_next_window"))
    .cache()  
)
label.count()  
label.createOrReplaceTempView("v_label")


## Historical features (same as before, just restricted to valid cohort)

In [21]:
hist_enc = (
    enc_costs.join(v_demo_valid.select("patient_id","index_date"), "patient_id")
             .where(F.col("svc_date") <= F.col("index_date"))
             .groupBy("patient_id")
             .agg(F.count("*").alias("n_encounters"),
                  F.round(F.sum("cost"), 2).alias("hist_total_cost"))
             .cache()
)
hist_enc.count()
hist_enc.createOrReplaceTempView("v_hist_enc")

# Conditions
conditions_lc = conditions.withColumnRenamed("PATIENT", "patient_id").withColumn("start_dt", F.to_date("START"))
hist_cond = (
    conditions_lc
    .join(v_demo_valid.select("patient_id","index_date"), "patient_id")
    .where(F.col("start_dt") <= F.col("index_date"))
    .groupBy("patient_id")
    .agg(F.count("*").alias("n_conditions"))
    .cache()
)
hist_cond.count()
hist_cond.createOrReplaceTempView("v_hist_cond")

# Procedures
hist_proc = (
    procedures.withColumnRenamed("PATIENT", "patient_id")
              .withColumn("start_dt", F.to_date("START"))
              .join(v_demo_valid.select("patient_id","index_date"), "patient_id")
              .where(F.col("start_dt") <= F.col("index_date"))
              .groupBy("patient_id").agg(F.count("*").alias("n_procedures"))
              .cache()
)
hist_proc.count()
hist_proc.createOrReplaceTempView("v_hist_proc")

# Medications
hist_med = (
    medications.withColumnRenamed("PATIENT", "patient_id")
               .withColumn("start_dt", F.to_date("START"))
               .join(v_demo_valid.select("patient_id","index_date"), "patient_id")
               .where(F.col("start_dt") <= F.col("index_date"))
               .groupBy("patient_id").agg(F.count("*").alias("n_medications"))
               .cache()
)
hist_med.count()
hist_med.createOrReplaceTempView("v_hist_med")

# Observations
hist_obs = (
    observations.withColumnRenamed("PATIENT", "patient_id")
                .withColumn("start_dt", F.to_date("DATE"))
                .join(v_demo_valid.select("patient_id","index_date"), "patient_id")
                .where(F.col("start_dt") <= F.col("index_date"))
                .groupBy("patient_id").agg(F.count("*").alias("n_observations"))
                .cache()
)
hist_obs.count()
hist_obs.createOrReplaceTempView("v_hist_obs")

# Claims
claims_hist = (
    claims.withColumnRenamed("PATIENTID", "patient_id")
          .withColumn("start_dt", F.to_date("servicedate"))
          .join(v_demo_valid.select("patient_id","index_date"), "patient_id")
          .where(F.col("start_dt") <= F.col("index_date"))
          .cache()
)
claims_hist.count()
claims_hist.createOrReplaceTempView("claims_hist")

hist_claims = claims_hist.groupBy("patient_id").agg(F.count("*").alias("n_claims")).cache()
hist_claims.count()
hist_claims.createOrReplaceTempView("v_hist_claims")

hist_claims_extra = (
    claims_hist
    .withColumn("has_diag", F.expr("""
        diagnosis1 IS NOT NULL OR diagnosis2 IS NOT NULL OR diagnosis3 IS NOT NULL OR
        diagnosis4 IS NOT NULL OR diagnosis5 IS NOT NULL OR diagnosis6 IS NOT NULL OR
        diagnosis7 IS NOT NULL OR diagnosis8 IS NOT NULL
    """))
    .groupBy("patient_id")
    .agg(
        F.countDistinct("providerid").alias("n_unique_providers"),
        F.countDistinct("departmentid").alias("n_unique_departments"),
        F.sum(F.col("has_diag").cast("int")).alias("n_claims_with_diag"),
        F.datediff(F.max("start_dt"), F.min("start_dt")).alias("claim_span_days")
    )
    .cache()
)
hist_claims_extra.count()
hist_claims_extra.createOrReplaceTempView("v_hist_claims_extra")

25/10/28 11:18:15 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.
25/10/28 11:18:16 WARN CacheManager: Asked to cache already cached data.


##  Choose label threshold dynamically 

In [38]:
thresh_row = spark.sql("""
  SELECT percentile_approx(cost_next_window, 0.7) AS thr
  FROM v_label
  WHERE cost_next_window IS NOT NULL
""").first()
THRESH = float(thresh_row.thr)
print(f"Suggested THRESH for top-30% high cost (window={LOOKAHEAD_DAYS} days): {THRESH:,.0f}")


Suggested THRESH for top-30% high cost (window=365 days): 14,308


## Assemble labeled feature table

In [39]:
# Final labeled dataset - use DataFrame API instead of SQL for better control
ps_labeled = (
    v_demo_valid
    .join(hist_enc, "patient_id", "left")
    .join(hist_cond, "patient_id", "left")
    .join(hist_proc, "patient_id", "left")
    .join(hist_med, "patient_id", "left")
    .join(hist_obs, "patient_id", "left")
    .join(hist_claims, "patient_id", "left")
    .join(label, "patient_id", "left")
    .join(hist_claims_extra, "patient_id", "left")
    .select(
        "patient_id", "gender", "race", "ethnicity", "age_at_index", "index_date", "last_enc_date",
        F.coalesce("n_encounters", F.lit(0)).alias("n_encounters"),
        F.coalesce("n_conditions", F.lit(0)).alias("n_conditions"),
        F.coalesce("n_procedures", F.lit(0)).alias("n_procedures"),
        F.coalesce("n_medications", F.lit(0)).alias("n_medications"),
        F.coalesce("n_observations", F.lit(0)).alias("n_observations"),
        F.coalesce("n_claims", F.lit(0)).alias("n_claims"),
        F.coalesce("hist_total_cost", F.lit(0.0)).alias("hist_total_cost"),
        F.coalesce("n_unique_providers", F.lit(0)).alias("n_unique_providers"),
        F.coalesce("n_unique_departments", F.lit(0)).alias("n_unique_departments"),
        F.coalesce("n_claims_with_diag", F.lit(0)).alias("n_claims_with_diag"),
        F.coalesce("claim_span_days", F.lit(0)).alias("claim_span_days"),
        F.coalesce("cost_next_window", F.lit(0.0)).alias("cost_next_window"),
        F.when(F.coalesce("cost_next_window", F.lit(0.0)) >= THRESH, 1).otherwise(0).alias("label")
    )
    .cache()  # Cache final result before writing
)

print(f"Computing final dataset with {ps_labeled.count()} patients...")

Computing final dataset with 1163 patients...


25/10/28 16:03:41 WARN CacheManager: Asked to cache already cached data.


## Save processed features

In [40]:
# Write output
out_dir = os.path.abspath(os.path.join(DATA_DIR, "..", "processed", "features_parquet"))
ps_labeled.write.mode("overwrite").parquet(out_dir)
print("Saved features to:", out_dir)



Saved features to: /home/utsajinlab/health_claims_ml/data/raw/synthea_1k/processed/features_parquet


## Quick prevalence check

In [41]:
# Add this import at the top
from pyspark.sql.window import Window

# Then use it correctly:
print("\n📈 Label distribution:")
ps_labeled.groupBy("label").count().withColumn(
    "pct", F.round(100 * F.col("count") / F.sum("count").over(Window.partitionBy(F.lit(1))), 2)
).orderBy(F.desc("label")).show()

# Clean up
enc_costs.unpersist()
idx_valid.unpersist()
v_demo_valid.unpersist()
label.unpersist()


📈 Label distribution:
+-----+-----+-----+
|label|count|  pct|
+-----+-----+-----+
|    1|  349|30.01|
|    0|  814|69.99|
+-----+-----+-----+



25/10/28 16:03:44 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/10/28 16:03:44 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/10/28 16:03:44 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/10/28 16:03:44 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/10/28 16:03:44 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/10/28 16:03:44 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/10/28 1

DataFrame[patient_id: string, cost_next_window: double]

## Code to View the Final Table

In [42]:
# 1. Show schema (column names and types)
print(" Schema:")
ps_labeled.printSchema()

# 2. Show first 10 rows
print("\n First 10 rows:")
ps_labeled.show(10, truncate=False)

# 3. Show summary statistics for numeric columns
print("\n Summary statistics:")
ps_labeled.select(
    "age_at_index",
    "n_encounters", 
    "n_conditions",
    "n_procedures",
    "n_medications",
    "hist_total_cost",
    "cost_next_window",
    "label"
).summary().show()

# 4. Compare high-cost vs normal-cost patients (means by group)
print("\n💰 High-cost vs Normal-cost comparison:")
ps_labeled.groupBy("label").agg(
    F.count("*").alias("n_patients"),
    F.round(F.mean("age_at_index"), 1).alias("avg_age"),
    F.round(F.mean("n_encounters"), 1).alias("avg_encounters"),
    F.round(F.mean("n_conditions"), 1).alias("avg_conditions"),
    F.round(F.mean("n_procedures"), 1).alias("avg_procedures"),
    F.round(F.mean("hist_total_cost"), 0).alias("avg_hist_cost"),
    F.round(F.mean("cost_next_window"), 0).alias("avg_future_cost")
).orderBy(F.desc("label")).show()

# 5. Show a few high-cost patients in detail
print("\n Sample of high-cost patients (label=1):")
ps_labeled.filter(F.col("label") == 1).select(
    "patient_id",
    "age_at_index",
    "gender",
    "n_encounters",
    "n_conditions",
    "hist_total_cost",
    "cost_next_window"
).show(5, truncate=False)

# 6. Show a few normal-cost patients
print("\n Sample of normal-cost patients (label=0):")
ps_labeled.filter(F.col("label") == 0).select(
    "patient_id",
    "age_at_index",
    "gender",
    "n_encounters",
    "n_conditions",
    "hist_total_cost",
    "cost_next_window"
).show(5, truncate=False)

# 7. Check for any missing values
print("\n Missing value counts:")
from pyspark.sql.functions import col, sum as spark_sum, when, count
ps_labeled.select([
    spark_sum(when(col(c).isNull(), 1).otherwise(0)).alias(c)
    for c in ps_labeled.columns
]).show(vertical=True)

# 8. Export a sample to pandas for quick viewing (if dataset is small)
print("\n Converting sample to Pandas for easier viewing:")
sample_df = ps_labeled.limit(20).toPandas()
print(sample_df.to_string())

# 9. Check correlation between features and label (which features matter most?)
print("\n Feature correlation with label:")
feature_cols = [
    "age_at_index", "n_encounters", "n_conditions", "n_procedures",
    "n_medications", "n_observations", "n_claims", "hist_total_cost",
    "n_unique_providers", "n_unique_departments", "claim_span_days"
]

for col_name in feature_cols:
    corr = ps_labeled.stat.corr(col_name, "label")
    print(f"  {col_name:25s}: {corr:+.3f}")

 Schema:
root
 |-- patient_id: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- race: string (nullable = true)
 |-- ethnicity: string (nullable = true)
 |-- age_at_index: decimal(13,0) (nullable = true)
 |-- index_date: date (nullable = true)
 |-- last_enc_date: date (nullable = true)
 |-- n_encounters: long (nullable = false)
 |-- n_conditions: long (nullable = false)
 |-- n_procedures: long (nullable = false)
 |-- n_medications: long (nullable = false)
 |-- n_observations: long (nullable = false)
 |-- n_claims: long (nullable = false)
 |-- hist_total_cost: double (nullable = false)
 |-- n_unique_providers: long (nullable = false)
 |-- n_unique_departments: long (nullable = false)
 |-- n_claims_with_diag: long (nullable = false)
 |-- claim_span_days: integer (nullable = false)
 |-- cost_next_window: double (nullable = false)
 |-- label: integer (nullable = false)


 First 10 rows:
+------------------------------------+------+-----+-----------+------------+----------