## Setup (env, Spark, load features)

In [1]:
import os
from dotenv import load_dotenv, find_dotenv
from pyspark.sql import SparkSession

In [2]:
# 1) Env & paths
load_dotenv(find_dotenv(usecwd=True), override=True)
DATA_DIR = os.getenv("DATA_DIR")
assert DATA_DIR and os.path.isdir(DATA_DIR), f"DATA_DIR invalid: {DATA_DIR}"

# 2) Spark (used only to load parquet and convert to pandas)
spark = SparkSession.builder.appName("HealthClaims_TrainXGB").getOrCreate()

# 3) Load features saved by 02_label_features
feat_dir = os.path.abspath(os.path.join(DATA_DIR, "..", "processed", "features_parquet"))
features = spark.read.parquet(feat_dir).cache()
features.createOrReplaceTempView("features_v0")

print("Features rows:", features.count())
features.printSchema()
features.limit(5).show(truncate=False)


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/10/28 16:48:08 WARN Utils: Your hostname, JINUTSA, resolves to a loopback address: 127.0.1.1; using 10.4.8.103 instead (on interface enp37s0f0)
25/10/28 16:48:08 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/28 16:48:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/28 16:48:09 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/10/28 16:48:09 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/10/28 16:48:09 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
25/10/28 16:48:09 WARN Utils: Ser

Features rows: 1163
root
 |-- patient_id: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- race: string (nullable = true)
 |-- ethnicity: string (nullable = true)
 |-- age_at_index: decimal(13,0) (nullable = true)
 |-- index_date: date (nullable = true)
 |-- last_enc_date: date (nullable = true)
 |-- n_encounters: long (nullable = true)
 |-- n_conditions: long (nullable = true)
 |-- n_procedures: long (nullable = true)
 |-- n_medications: long (nullable = true)
 |-- n_observations: long (nullable = true)
 |-- n_claims: long (nullable = true)
 |-- hist_total_cost: double (nullable = true)
 |-- n_unique_providers: long (nullable = true)
 |-- n_unique_departments: long (nullable = true)
 |-- n_claims_with_diag: long (nullable = true)
 |-- claim_span_days: integer (nullable = true)
 |-- cost_next_window: double (nullable = true)
 |-- label: integer (nullable = true)

+------------------------------------+------+-----+-----------+------------+----------+-------------+----

## Time-based split (index_date) & pandas conversion

In [3]:
import pandas as pd
from pyspark.sql import functions as F


# Time split (mimics production): train on earlier index dates, test on later ones
CUT = "2020-10-15"
train_df = spark.sql(f"SELECT * FROM features_v0 WHERE index_date < DATE('{CUT}')")
test_df  = spark.sql(f"SELECT * FROM features_v0 WHERE index_date >= DATE('{CUT}')")

print("Train rows:", train_df.count(), " Test rows:", test_df.count())

# Columns we’ll use (v0 features)
CAT = ["gender","race","ethnicity"]
NUM = [
    "age_at_index", "n_encounters", "n_conditions", "n_procedures", "n_medications",
    "n_observations", "n_claims", "hist_total_cost",
    "n_unique_providers", "n_unique_departments", "claim_span_days"
]
TARGET = "label"

train_pd = train_df.select(*(CAT + NUM + [TARGET])).toPandas()
test_pd  = test_df.select(*(CAT + NUM + [TARGET])).toPandas()

# Sanity on class balance (train/test)
print("Train label counts:\n", train_pd[TARGET].value_counts(dropna=False))
print("Test  label counts:\n", test_pd[TARGET].value_counts(dropna=False))


Train rows: 914  Test rows: 249
Train label counts:
 label
0    662
1    252
Name: count, dtype: int64
Test  label counts:
 label
0    152
1     97
Name: count, dtype: int64


## One-hot encoding, alignment, and class weight

In [4]:
# Categoricals → category dtype
for c in CAT:
    train_pd[c] = train_pd[c].astype("category")
    test_pd[c]  = test_pd[c].astype("category")

# One-hot (drop_first avoids perfect multicollinearity)
X_train = pd.get_dummies(train_pd[CAT + NUM], drop_first=True)
X_test  = pd.get_dummies(test_pd[CAT + NUM], drop_first=True)

# Align columns (add missing columns to test, in same order as train)
X_test = X_test.reindex(columns=X_train.columns, fill_value=0)

y_train = train_pd[TARGET].astype(int).values
y_test  = test_pd[TARGET].astype(int).values

# Optional: drop any constant columns (rare but possible)
const_cols = [c for c in X_train.columns if X_train[c].nunique() <= 1]
if const_cols:
    X_train = X_train.drop(columns=const_cols)
    X_test  = X_test.drop(columns=const_cols)
    print("Dropped constant columns:", const_cols)

# Compute scale_pos_weight to handle imbalance: (#neg / #pos) on train
import numpy as np
pos = np.sum(y_train == 1)
neg = np.sum(y_train == 0)
scale_pos_weight = (neg / pos) if pos > 0 else 1.0
print(f"scale_pos_weight (neg/pos): {scale_pos_weight:.2f} (neg={neg}, pos={pos})")

# Keep metadata for inference
feature_columns = X_train.columns.tolist()
cat_snapshot = {c: list(train_pd[c].cat.categories) for c in CAT}


scale_pos_weight (neg/pos): 2.63 (neg=662, pos=252)


## LightGBM with MLflow

In [5]:
from lightgbm import LGBMClassifier
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    precision_score, recall_score, f1_score,
    confusion_matrix, ConfusionMatrixDisplay
)
from mlflow.tracking import MlflowClient
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import mlflow
import os
from itertools import product

# -------------------- MLflow setup --------------------
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
mlruns_clean_path = os.path.join(project_root, "mlruns_clean")
os.makedirs(mlruns_clean_path, exist_ok=True)
mlflow.set_tracking_uri(f"file://{mlruns_clean_path}")
print("Tracking URI set to:", mlflow.get_tracking_uri())

experiment_name = "health_claims_highcost_lgbm"
client = MlflowClient()
if not client.get_experiment_by_name(experiment_name):
    client.create_experiment(name=experiment_name)
mlflow.set_experiment(experiment_name)

# -------------------- Cross-validation setup --------------------
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# -------------------- Parameter grid --------------------
param_grid = {
    "num_leaves": [15, 31],
    "learning_rate": [0.01, 0.05],
    "n_estimators": [300, 500],
    "subsample": [0.8],
    "colsample_bytree": [0.8],
}

# -------------------- Run all parameter combinations --------------------
for i, (nl, lr, ne, ss, csbt) in enumerate(product(
    param_grid["num_leaves"],
    param_grid["learning_rate"],
    param_grid["n_estimators"],
    param_grid["subsample"],
    param_grid["colsample_bytree"]
)):
    with mlflow.start_run(run_name=f"lgbm_run_cv_{i}"):
        mlflow.set_tag("mlflow.runName", f"lgbm_run_cv_{i}")
        mlflow.log_params({
            "num_leaves": nl,
            "learning_rate": lr,
            "n_estimators": ne,
            "subsample": ss,
            "colsample_bytree": csbt,
            "scale_pos_weight": scale_pos_weight
        })

        # Define model
        model = LGBMClassifier(
            num_leaves=nl,
            learning_rate=lr,
            n_estimators=ne,
            subsample=ss,
            colsample_bytree=csbt,
            scale_pos_weight=scale_pos_weight,
            objective="binary",
            random_state=42
        )

        # -------------------- Cross-validation with extra metrics --------------------
        scoring = {
            "roc_auc": "roc_auc",
            "pr_auc": "average_precision",
            "precision": "precision",
            "recall": "recall",
            "f1": "f1"
        }

        scores = cross_validate(
            model,
            X_train, y_train,
            scoring=scoring,
            cv=cv,
            return_train_score=False
        )

        for metric in scoring.keys():
            mlflow.log_metric(f"cv_{metric}", scores[f"test_{metric}"].mean())

        # -------------------- Train and evaluate on test set --------------------
        model.fit(X_train, y_train)
        proba_test = model.predict_proba(X_test)[:, 1]
        pred_test = (proba_test >= 0.5).astype(int)

        roc_auc = roc_auc_score(y_test, proba_test)
        pr_auc = average_precision_score(y_test, proba_test)
        precision = precision_score(y_test, pred_test)
        recall = recall_score(y_test, pred_test)
        f1 = f1_score(y_test, pred_test)

        mlflow.log_metric("test_roc_auc", roc_auc)
        mlflow.log_metric("test_pr_auc", pr_auc)
        mlflow.log_metric("test_precision", precision)
        mlflow.log_metric("test_recall", recall)
        mlflow.log_metric("test_f1", f1)

        print(f"\nRun {i} | leaves={nl}, lr={lr}, est={ne}")
        print(f"Precision: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f}")
        print(f"ROC-AUC: {roc_auc:.3f} | PR-AUC: {pr_auc:.3f}")

        # -------------------- Confusion Matrix --------------------
        # fig, ax = plt.subplots(figsize=(4, 4))
        # ConfusionMatrixDisplay.from_predictions(y_test, pred_test, ax=ax)
        # plt.title(f"LGBM Confusion Matrix (Run {i})")
        # cm_path = f"conf_matrix_lgbm_{i}.png"
        # plt.savefig(cm_path)
        # mlflow.log_artifact(cm_path)
        # plt.close()

print("\n All LightGBM runs complete. Check MLflow UI or printouts for best F1/precision.")


Tracking URI set to: file:///home/utsajinlab/health_claims_ml/notebooks/mlruns_clean
[LightGBM] [Info] Number of positive: 202, number of negative: 529
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000258 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1257
[LightGBM] [Info] Number of data points in the train set: 731, number of used features: 14
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.276334 -> initscore=-0.962721
[LightGBM] [Info] Start training from score -0.962721
[LightGBM] [Info] Number of positive: 202, number of negative: 529
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000183 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1262
[LightGBM] [Info] Number of data points in the train set: 731, number of used features: 14
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.276334 -> initscore=-0.

In [6]:
from mlflow.tracking import MlflowClient
import pandas as pd

experiment_name = "health_claims_highcost_lgbm"
client = MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)

runs = client.search_runs(
    experiment_ids=[experiment.experiment_id],
    order_by=["metrics.test_f1 DESC"]
)

records = []
for run in runs:
    metrics = run.data.metrics
    params = run.data.params

    record = {
        "run_name": run.info.run_name,
        "num_leaves": params.get("num_leaves"),
        "learning_rate": params.get("learning_rate"),
        "n_estimators": params.get("n_estimators"),
        "subsample": params.get("subsample"),
        "colsample_bytree": params.get("colsample_bytree"),
        "cv_f1": metrics.get("cv_f1"),
        "test_precision": metrics.get("test_precision"),
        "test_recall": metrics.get("test_recall"),
        "test_f1": metrics.get("test_f1"),
        "test_roc_auc": metrics.get("test_roc_auc"),
        "test_pr_auc": metrics.get("test_pr_auc")
    }
    records.append(record)

df_lgbm = pd.DataFrame(records).sort_values(by="test_f1", ascending=False).reset_index(drop=True)
df_lgbm = df_lgbm.round(3)

print("\n===== LightGBM MLflow Results (sorted by test F1) =====")
display(df_lgbm)



===== LightGBM MLflow Results (sorted by test F1) =====


Unnamed: 0,run_name,num_leaves,learning_rate,n_estimators,subsample,colsample_bytree,cv_f1,test_precision,test_recall,test_f1,test_roc_auc,test_pr_auc
0,lgbm_run_cv_1,15,0.01,500,0.8,0.8,0.532,0.588,0.691,0.635,0.758,0.607
1,lgbm_run_cv_0,15,0.01,300,0.8,0.8,0.542,0.589,0.68,0.632,0.767,0.589
2,lgbm_run_cv_4,31,0.01,300,0.8,0.8,0.532,0.606,0.619,0.612,0.763,0.594
3,lgbm_run_cv_6,31,0.05,300,0.8,0.8,0.533,0.651,0.577,0.612,0.736,0.574
4,lgbm_run_cv_5,31,0.01,500,0.8,0.8,0.524,0.62,0.588,0.603,0.752,0.596
5,lgbm_run_cv_2,15,0.05,300,0.8,0.8,0.5,0.615,0.577,0.596,0.74,0.582
6,lgbm_run_cv_7,31,0.05,500,0.8,0.8,0.534,0.631,0.546,0.586,0.735,0.573
7,lgbm_run_cv_3,15,0.05,500,0.8,0.8,0.494,0.624,0.546,0.582,0.732,0.568
