In [0]:
import os
import mlflow

from pyspark.ml.functions import vector_to_array
from pyspark.sql.functions import col, when
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [0]:
gold_path = "/Volumes/workspace/default/pujitha_patient-readmission/gold"
gold_df = spark.read.format("delta").load(gold_path)

gold_df.limit(5).show()
gold_df.groupBy("readmitted_binary").count().show()

+---------------+------+-------+--------------+----------------+-----------------+----------------+----------------+------------------+--------------+---------------+----------------+---------+-------+------+-----------+------+-----------------+
|           race|gender|    age|admission_type|days_in_hospital|number_outpatient|number_emergency|number_inpatient|num_lab_procedures|num_procedures|num_medications|number_diagnoses|metformin|insulin|change|diabetesMed|diag_1|readmitted_binary|
+---------------+------+-------+--------------+----------------+-----------------+----------------+----------------+------------------+--------------+---------------+----------------+---------+-------+------+-----------+------+-----------------+
|      Caucasian|Female|[10-20)|     Emergency|               3|                0|               0|               0|                59|             0|             18|               9|       No|     Up|    Ch|        Yes|   276|                0|
|AfricanAmerican

In [0]:
categorical_cols = ["race", "gender", "age", "admission_type"]
numeric_cols = [
    "days_in_hospital",
    "number_outpatient", "number_emergency", "number_inpatient",
    "num_lab_procedures", "num_procedures", "num_medications", "number_diagnoses"
]

raw_df = gold_df.select(*(categorical_cols + numeric_cols + ["readmitted_binary"])) \
                .withColumnRenamed("readmitted_binary", "label")

In [0]:
from pyspark.sql.functions import col

for c in numeric_cols:
    raw_df = raw_df.withColumn(c, col(c).cast("double"))


train_raw, test_raw = raw_df.randomSplit([0.8, 0.2], seed=42)

print("Train rows:", train_raw.count())
print("Test rows:", test_raw.count())
train_raw.groupBy("label").count().show()
test_raw.groupBy("label").count().show()

Train rows: 78277
Test rows: 19776
+-----+-----+
|label|count|
+-----+-----+
|    1| 8824|
|    0|69453|
+-----+-----+

+-----+-----+
|label|count|
+-----+-----+
|    1| 2242|
|    0|17534|
+-----+-----+



In [0]:
indexers = [
    StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
    for c in categorical_cols
]

feature_cols = [f"{c}_idx" for c in categorical_cols] + numeric_cols
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

lr = LogisticRegression(featuresCol="features", labelCol="label")

pipeline = Pipeline(stages=indexers + [assembler, lr])
pipeline_model = pipeline.fit(train_raw)
lr_test_pred = pipeline_model.transform(test_raw)

THRESHOLD = 0.2  

lr_test_pred_thr = lr_test_pred.withColumn(
    "pred_thr",
    when(vector_to_array(col("probability"))[1] >= THRESHOLD, 1.0).otherwise(0.0)
)
# Positive class metrics (label=1)
lr_tp = lr_test_pred_thr.filter((col("label") == 1) & (col("pred_thr") == 1)).count()
lr_fn = lr_test_pred_thr.filter((col("label") == 1) & (col("pred_thr") == 0)).count()
lr_fp = lr_test_pred_thr.filter((col("label") == 0) & (col("pred_thr") == 1)).count()

lr_pos_recall = lr_tp / (lr_tp + lr_fn) if (lr_tp + lr_fn) > 0 else 0.0
lr_pos_precision = lr_tp / (lr_tp + lr_fp) if (lr_tp + lr_fp) > 0 else 0.0

print("LR threshold:", THRESHOLD)
print("LR pos_recall:", lr_pos_recall)
print("LR pos_precision:", lr_pos_precision)

lr_test_pred.select("label", "prediction", "probability").show(5, truncate=False)


LR threshold: 0.2
LR pos_recall: 0.1128456735057984
LR pos_precision: 0.2786343612334802
+-----+----------+----------------------------------------+
|label|prediction|probability                             |
+-----+----------+----------------------------------------+
|0    |0.0       |[0.9406335491392352,0.05936645086076475]|
|0    |0.0       |[0.9346979801585779,0.06530201984142214]|
|0    |0.0       |[0.919577729965166,0.08042227003483404] |
|0    |0.0       |[0.9289042039714294,0.07109579602857063]|
|0    |0.0       |[0.8753602432966225,0.12463975670337746]|
+-----+----------+----------------------------------------+
only showing top 5 rows


In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# AUC
auc = BinaryClassificationEvaluator(labelCol="label").evaluate(lr_test_pred)

# Accuracy
total = lr_test_pred.count()
correct = lr_test_pred.filter(lr_test_pred.label == lr_test_pred.prediction).count()
accuracy = correct / total

# Recall for label=1
actual_ones = lr_test_pred_thr.filter(col("label") == 1).count()
caught_ones = lr_test_pred_thr.filter((col("label") == 1) & (col("pred_thr") == 1)).count()
recall = caught_ones / actual_ones if actual_ones > 0 else 0


print("Logistic Regression Metrics (Test Data)")
print("AUC:", auc)
print("Accuracy:", accuracy)
print("Recall:", recall)


Logistic Regression Metrics (Test Data)
AUC: 0.6340763127521732
Accuracy: 0.8872370550161812
Recall: 0.1128456735057984


In [0]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline

rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="label",
    numTrees=30,
    maxDepth=6,
    seed=42
)

# IMPORTANT: use the SAME indexers + assembler you used for LR
rf_pipeline = Pipeline(stages=indexers + [assembler, rf])

rf_model = rf_pipeline.fit(train_raw)

rf_test_pred = rf_model.transform(test_raw)
rf_test_pred_thr = rf_test_pred.withColumn(
    "pred_thr",
    when(vector_to_array(col("probability"))[1] >= THRESHOLD, 1.0).otherwise(0.0)
)

rf_tp = rf_test_pred_thr.filter((col("label") == 1) & (col("pred_thr") == 1)).count()
rf_fn = rf_test_pred_thr.filter((col("label") == 1) & (col("pred_thr") == 0)).count()
rf_fp = rf_test_pred_thr.filter((col("label") == 0) & (col("pred_thr") == 1)).count()

rf_pos_recall = rf_tp / (rf_tp + rf_fn) if (rf_tp + rf_fn) > 0 else 0.0
rf_pos_precision = rf_tp / (rf_tp + rf_fp) if (rf_tp + rf_fp) > 0 else 0.0

print("RF threshold:", THRESHOLD)
print("RF pos_recall:", rf_pos_recall)
print("RF pos_precision:", rf_pos_precision)

rf_test_pred.select("label", "prediction", "probability").show(5, truncate=False)


RF threshold: 0.2
RF pos_recall: 0.07136485280999108
RF pos_precision: 0.3611738148984199
+-----+----------+----------------------------------------+
|label|prediction|probability                             |
+-----+----------+----------------------------------------+
|0    |0.0       |[0.9226309453469873,0.07736905465301275]|
|0    |0.0       |[0.916900004238408,0.08309999576159209] |
|0    |0.0       |[0.8925765472072239,0.10742345279277614]|
|0    |0.0       |[0.9047617670402586,0.0952382329597414] |
|0    |0.0       |[0.8361455345285335,0.16385446547146656]|
+-----+----------+----------------------------------------+
only showing top 5 rows


In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# AUC
rf_auc = BinaryClassificationEvaluator(labelCol="label").evaluate(rf_test_pred)
# Accuracy
rf_total = rf_test_pred.count()
rf_correct = rf_test_pred.filter(rf_test_pred.label == rf_test_pred.prediction).count()
rf_accuracy = rf_correct / rf_total
# Recall (label=1)
rf_actual_ones = rf_test_pred.filter(rf_test_pred.label == 1).count()
rf_caught_ones = rf_test_pred.filter((rf_test_pred.label == 1) & (rf_test_pred.prediction == 1)).count()
rf_recall = rf_caught_ones / rf_actual_ones if rf_actual_ones > 0 else 0
print("Random Forest Metrics (Test Data)")
print("AUC:", rf_auc)
print("Accuracy:", rf_accuracy)
print("Recall:", rf_recall)

Random Forest Metrics (Test Data)
AUC: 0.6353271513166668
Accuracy: 0.8866302588996764
Recall: 0.0008920606601248885


In [0]:
print("MODEL COMPARISON (Test Data)")
print("---------------------------")
print("Logistic Regression - AUC:", auc)
print("Logistic Regression - Accuracy:", accuracy)
print("Logistic Regression - Recall:", recall)
print()
print("Random Forest       - AUC:", rf_auc)
print("Random Forest       - Accuracy:", rf_accuracy)
print("Random Forest       - Recall:", rf_recall)

MODEL COMPARISON (Test Data)
---------------------------
Logistic Regression - AUC: 0.6340763127521732
Logistic Regression - Accuracy: 0.8872370550161812
Logistic Regression - Recall: 0.1128456735057984

Random Forest       - AUC: 0.6353271513166668
Random Forest       - Accuracy: 0.8866302588996764
Random Forest       - Recall: 0.0008920606601248885


In [0]:
import mlflow
import mlflow.spark
import os
os.environ["MLFLOW_DFS_TMP"] = "/Volumes/workspace/default/pujitha_patient-readmission/mlflow_tmp"
input_example = train_raw.select(*(categorical_cols + numeric_cols)).limit(5).toPandas()
dfs_tmp = "/Volumes/workspace/default/pujitha_patient-readmission/mlflow_tmp"
dbutils.fs.mkdirs(dfs_tmp)
with mlflow.start_run(run_name="LogisticRegression") as run:
    mlflow.log_param("model_type", "LogisticRegression")
    mlflow.log_param("max_iter", lr.getMaxIter())
    mlflow.log_param("reg_param", lr.getRegParam())
    mlflow.log_param("elastic_net", lr.getElasticNetParam())
    mlflow.log_param("threshold", THRESHOLD)
    mlflow.log_metric("auc", float(auc))
    mlflow.log_metric("accuracy", float(accuracy))
    mlflow.log_metric("pos_recall", float(lr_pos_recall))
    mlflow.log_metric("pos_precision", float(lr_pos_precision))
    # log model
    mlflow.spark.log_model(
        pipeline_model,
        artifact_path="model",
        input_example=input_example,
        dfs_tmpdir=os.environ["MLFLOW_DFS_TMP"],
        pip_requirements=[
            "pyspark==4.0.0",
            "mlflow>=2.2.0"
        ]
    )
    print("LR run_id:", run.info.run_id)
with mlflow.start_run(run_name="RandomForest") as run:
    mlflow.log_param("model_type", "RandomForest")
    mlflow.log_param("num_trees", rf.getNumTrees)
    mlflow.log_param("max_depth", rf.getMaxDepth())
    mlflow.log_param("threshold", THRESHOLD)
    mlflow.log_metric("auc", float(rf_auc))
    mlflow.log_metric("accuracy", float(rf_accuracy))
    mlflow.log_metric("pos_recall", float(rf_pos_recall))
    mlflow.log_metric("pos_precision", float(rf_pos_precision))
    # log model
    mlflow.spark.log_model(
        rf_model,
        artifact_path="model",
        input_example=input_example,
        dfs_tmpdir=os.environ["MLFLOW_DFS_TMP"],
        pip_requirements=[
            "pyspark==4.0.0",
            "mlflow>=2.2.0"
        ]
    )
    print("RF run_id:", run.info.run_id)
best_model_name = "RandomForest" if rf_pos_recall > lr_pos_recall else "LogisticRegression"
best_model = rf_model if rf_pos_recall > lr_pos_recall else pipeline_model
print("Best model based on Positive Recall:", best_model_name)
best_predictions = best_model.transform(raw_df)   # raw_df is your full dataset


LR run_id: fcdc3a27af464ba29b0bb87b955db5b2
RF run_id: a5a52810936e42d1896adcbdeebe67cc
Best model based on Positive Recall: LogisticRegression


In [0]:
best_predictions.select(
    "label", "prediction", "probability"
).show(10, truncate=False)


+-----+----------+----------------------------------------+
|label|prediction|probability                             |
+-----+----------+----------------------------------------+
|0    |0.0       |[0.91509710493452,0.08490289506547999]  |
|0    |0.0       |[0.9225457850583596,0.0774542149416404] |
|0    |0.0       |[0.9195601718717679,0.08043982812823214]|
|0    |0.0       |[0.9276526080684361,0.07234739193156392]|
|0    |0.0       |[0.9204335756866239,0.07956642431337613]|
|0    |0.0       |[0.9033464230600552,0.09665357693994481]|
|0    |0.0       |[0.8979732133703908,0.10202678662960918]|
|0    |0.0       |[0.8841304481856026,0.1158695518143974] |
|0    |0.0       |[0.9080935280776313,0.09190647192236867]|
|0    |0.0       |[0.9018711443468335,0.09812885565316654]|
+-----+----------+----------------------------------------+
only showing top 10 rows


In [0]:
from pyspark.ml.functions import vector_to_array
from pyspark.sql.functions import col, when

best_predictions = best_predictions.withColumn(
    "readmission_risk",
    vector_to_array(col("probability"))[1]
)

best_predictions = best_predictions.withColumn(
    "risk_bucket",
    when(col("readmission_risk") >= 0.70, "HIGH")
    .when(col("readmission_risk") >= 0.40, "MEDIUM")
    .otherwise("LOW")
)

In [0]:
pred_path = "/Volumes/workspace/default/pujitha_patient-readmission/gold_predictions_best"

best_predictions.select(
    "race","gender","age","admission_type","days_in_hospital",
    "number_outpatient","number_emergency","number_inpatient",
    "num_lab_procedures","num_procedures","num_medications","number_diagnoses",
    "label","prediction","readmission_risk","risk_bucket"
).write.format("delta").mode("overwrite").option("overwriteSchema","true").save(pred_path)

print("Saved best predictions to:", pred_path)

Saved best predictions to: /Volumes/workspace/default/pujitha_patient-readmission/gold_predictions_best


In [0]:
from pyspark.sql.functions import col
from pyspark.ml.functions import vector_to_array

# Convert vector probability -> array, then take index 1 (class=1)
final_predictions = lr_test_pred.withColumn(
    "readmission_risk",
    vector_to_array(col("probability"))[1]
)

final_predictions.select("readmission_risk", "label", "prediction").show(5, truncate=False)

+-------------------+-----+----------+
|readmission_risk   |label|prediction|
+-------------------+-----+----------+
|0.05936645086076475|0    |0.0       |
|0.06530201984142214|0    |0.0       |
|0.08042227003483404|0    |0.0       |
|0.07109579602857063|0    |0.0       |
|0.12463975670337746|0    |0.0       |
+-------------------+-----+----------+
only showing top 5 rows


In [0]:
from pyspark.sql.functions import col, when
from pyspark.ml.functions import vector_to_array

THRESHOLD = 0.2
pred_path = "/Volumes/workspace/default/pujitha_patient-readmission/gold_predictions"

pred_scored = pipeline_model.transform(gold_df).withColumn(
    "risk_score", vector_to_array(col("probability"))[1]
).withColumn(
    "pred_thr", when(col("risk_score") >= THRESHOLD, 1).otherwise(0)
)

pred_out = pred_scored.select(
    "race",
    "gender",
    "age",
    "admission_type",
    col("days_in_hospital").cast("int").alias("days_in_hospital"),
    "risk_score",
    "pred_thr"
)

dbutils.fs.rm(pred_path, True)
pred_out.write.format("delta").mode("overwrite").save(pred_path)

In [0]:
pred_out.groupBy("pred_thr").count().show()
pred_out.orderBy(col("risk_score").desc()).show(5,truncate=False)

+--------+-----+
|pred_thr|count|
+--------+-----+
|       1| 4628|
|       0|93425|
+--------+-----+

+---------+------+-------+--------------+----------------+------------------+--------+
|race     |gender|age    |admission_type|days_in_hospital|risk_score        |pred_thr|
+---------+------+-------+--------------+----------------+------------------+--------+
|Caucasian|Female|[20-30)|Emergency     |3               |0.9550251889352136|1       |
|Caucasian|Female|[20-30)|Emergency     |1               |0.9482106964451944|1       |
|Caucasian|Female|[20-30)|Emergency     |7               |0.932753366093359 |1       |
|Caucasian|Male  |[40-50)|Unknown       |8               |0.9226187728442351|1       |
|Caucasian|Female|[30-40)|Urgent        |14              |0.9108307912991878|1       |
+---------+------+-------+--------------+----------------+------------------+--------+
only showing top 5 rows


In [0]:
pred_out.orderBy(col("risk_score").desc()).show(50, truncate=False)

+---------------+------+-------+--------------+----------------+------------------+--------+
|race           |gender|age    |admission_type|days_in_hospital|risk_score        |pred_thr|
+---------------+------+-------+--------------+----------------+------------------+--------+
|Caucasian      |Female|[20-30)|Emergency     |3               |0.9550251889352136|1       |
|Caucasian      |Female|[20-30)|Emergency     |1               |0.9482106964451944|1       |
|Caucasian      |Female|[20-30)|Emergency     |7               |0.932753366093359 |1       |
|Caucasian      |Male  |[40-50)|Unknown       |8               |0.9226187728442351|1       |
|Caucasian      |Female|[30-40)|Urgent        |14              |0.9108307912991878|1       |
|Caucasian      |Female|[30-40)|Emergency     |1               |0.8836308592994626|1       |
|Caucasian      |Female|[20-30)|Unknown       |5               |0.8730545054716601|1       |
|Caucasian      |Female|[30-40)|Urgent        |3               |0.8724

In [0]:
pred_out.sample(fraction=0.01, seed=42).show(50, truncate=False)

+---------------+------+-------+--------------+----------------+-------------------+--------+
|race           |gender|age    |admission_type|days_in_hospital|risk_score         |pred_thr|
+---------------+------+-------+--------------+----------------+-------------------+--------+
|Caucasian      |Female|[60-70)|Unknown       |1               |0.07299697420897999|0       |
|Caucasian      |Male  |[30-40)|Unknown       |12              |0.1338644210794987 |0       |
|Caucasian      |Female|[80-90)|Unknown       |3               |0.07206325205031472|0       |
|Caucasian      |Female|[40-50)|Urgent        |9               |0.09358459812704656|0       |
|Caucasian      |Female|[60-70)|Unknown       |6               |0.09449935680269983|0       |
|Caucasian      |Female|[70-80)|Emergency     |3               |0.10991040679572828|0       |
|Caucasian      |Female|[80-90)|Emergency     |6               |0.10012435818042953|0       |
|AfricanAmerican|Female|[40-50)|Emergency     |4            

In [0]:
%sql
SELECT * FROM delta.`/Volumes/workspace/default/pujitha_patient-readmission/gold_predictions`;

race,gender,age,admission_type,days_in_hospital,risk_score,pred_thr
Caucasian,Female,[10-20),Emergency,3,0.0849028950654799,0
AfricanAmerican,Female,[20-30),Emergency,2,0.0774542149416404,0
Caucasian,Male,[30-40),Emergency,2,0.0804398281282321,0
Caucasian,Male,[40-50),Emergency,1,0.0723473919315639,0
Caucasian,Male,[50-60),Urgent,3,0.0795664243133761,0
Caucasian,Male,[60-70),Elective,4,0.0966535769399448,0
Caucasian,Male,[70-80),Emergency,5,0.1020267866296091,0
Caucasian,Female,[80-90),Urgent,13,0.1158695518143974,0
Caucasian,Female,[90-100),Elective,12,0.0919064719223686,0
AfricanAmerican,Female,[40-50),Emergency,9,0.0981288556531665,0


In [0]:
pred_out.filter(col("pred_thr") == 1).limit(10) \
    .union(pred_out.filter(col("pred_thr") == 0).limit(10)) \
    .show(20, truncate=False)

+---------------+------+--------+--------------+----------------+-------------------+--------+
|race           |gender|age     |admission_type|days_in_hospital|risk_score         |pred_thr|
+---------------+------+--------+--------------+----------------+-------------------+--------+
|Caucasian      |Female|[70-80) |Emergency     |11              |0.3953779724053218 |1       |
|AfricanAmerican|Female|[40-50) |Emergency     |2               |0.23164990881366931|1       |
|AfricanAmerican|Female|[70-80) |Emergency     |2               |0.28015818272219606|1       |
|Caucasian      |Female|[50-60) |Emergency     |5               |0.22593995819928492|1       |
|AfricanAmerican|Male  |[20-30) |Emergency     |6               |0.2661277728558631 |1       |
|AfricanAmerican|Male  |[20-30) |Emergency     |3               |0.28969776764586697|1       |
|Caucasian      |Male  |[40-50) |Emergency     |7               |0.20153878762384592|1       |
|AfricanAmerican|Female|[80-90) |Emergency     |9 