In [0]:
import mlflow
from pyspark.sql import functions as F

In [0]:
# Set experiment
mlflow.set_experiment("/Users/evansavo@gmail.com/phr_risk_exp")

# At the top of your script
mlflow.set_registry_uri("databricks-uc")

# Data Preprocessing

In [0]:
# Load ml data
ml_df = spark.read.table('phr.`03_gold`.ml_readmission_features')

In [0]:
# Fill missing with 0
ml_df = ml_df.fillna(0)

In [0]:
# Change to int
ml_df = (
    ml_df
    .withColumn("BENE_SEX_IDENT_CD", F.col("BENE_SEX_IDENT_CD").cast("int"))
    .withColumn("BENE_RACE_CD", F.col("BENE_RACE_CD").cast("int"))
    .withColumn("SP_ALZHDMTA", F.col("SP_ALZHDMTA").cast("int"))
    .withColumn("SP_CHF", F.col("SP_CHF").cast("int"))
    .withColumn("SP_CHRNKIDN", F.col("SP_CHRNKIDN").cast("int"))
    .withColumn("SP_CNCR", F.col("SP_CNCR").cast("int"))
    .withColumn("SP_COPD", F.col("SP_COPD").cast("int"))
    .withColumn("SP_DEPRESSN", F.col("SP_DEPRESSN").cast("int"))
    .withColumn("SP_DIABETES", F.col("SP_DIABETES").cast("int"))
    .withColumn("SP_ISCHMCHT", F.col("SP_ISCHMCHT").cast("int"))
    .withColumn("had_er_visit_7days_prior", F.col("had_er_visit_7days_prior").cast("int"))
)

In [0]:
# SPLIT DATA
patients = ml_df.select("DESYNPUF_ID").distinct()

train_ids, test_ids = patients.randomSplit([0.8, 0.2], seed=42)

df_train = ml_df.join(train_ids, on="DESYNPUF_ID", how="inner")
df_test  = ml_df.join(test_ids, on="DESYNPUF_ID", how="inner")

# drop ids
df_train = df_train.drop('CLM_ID', 'DESYNPUF_ID', 'SP_STATE_CODE', "CLM_DRG_CD")
df_test = df_test.drop('CLM_ID', 'DESYNPUF_ID', 'SP_STATE_CODE', 'CLM_DRG_CD')

In [0]:
%skip
# # fix CLM_DRG_CD [train]
# global_mean = df_train.agg({"target": "mean"}).first()[0]

# drg_stats = (
#     df_train
#     .groupBy("CLM_DRG_CD")
#     .agg(
#         F.mean("target").alias("drg_mean"),
#         F.count("*").alias("drg_count")
#     )
# )


# SMOOTHING = 20  # tune: 10â€“100 typical

# drg_stats = drg_stats.withColumn(
#     "drg_target_enc",
#     (
#         F.col("drg_count") * F.col("drg_mean") +
#         SMOOTHING * F.lit(global_mean)
#     ) / (F.col("drg_count") + SMOOTHING)
# )


# df_train_enc = (
#     df_train
#     .join(
#         drg_stats.select("CLM_DRG_CD", "drg_target_enc"),
#         on="CLM_DRG_CD",
#         how="left"
#     )
# )

# df_test_enc = (
#     df_test
#     .join(
#         drg_stats.select("CLM_DRG_CD", "drg_target_enc"),
#         on="CLM_DRG_CD",
#         how="left"
#     )
#     .fillna({"drg_target_enc": global_mean})  # unseen DRGs
# )


# df_train_enc = df_train_enc.drop("CLM_DRG_CD")
# df_test_enc = df_test_enc.drop("CLM_DRG_CD")

In [0]:
# To pandas
df_train_pd = df_train.toPandas()
df_test_pd = df_test.toPandas()

# Train: Features and target
y_train = df_train_pd['target']
X_train = df_train_pd.drop(columns=['target'])

# Test: Features and target
y_test = df_test_pd['target']
X_test = df_test_pd.drop(columns=['target'])

# Xgboost

In [0]:
import xgboost as xgb
import mlflow
import mlflow.pyfunc
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    confusion_matrix,
    classification_report,
    roc_curve,
    precision_recall_curve,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings

# 1. Configuration & Setup
# ---------------------------------------------------------
# Suppress warnings for cleaner logs
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Use non-interactive backend for plots (prevents crashes in Jobs)
plt.switch_backend("Agg")

# Set MLflow registry to Unity Catalog
mlflow.set_registry_uri("databricks-uc")
UC_MODEL_NAME = "phr.03_gold.xgb_readmission_model"

# 2. Data Preparation
# ---------------------------------------------------------
# (Assumes X_train, X_test, y_train, y_test are already loaded in memory)

# Calculate class imbalance
n_positive = y_train.sum()
n_negative = len(y_train) - n_positive
scale_pos_weight = n_negative / n_positive

print(f"ðŸ“Š Class Distribution Analysis:")
print(f"  â€¢ Imbalance ratio: 1:{n_negative/n_positive:.1f}")
print(f"  â€¢ Scale pos weight: {scale_pos_weight:.2f}\n")

# Cast to float32 for XGBoost stability
X_train_float = X_train.astype("float32")
X_test_float = X_test.astype("float32")


# 3. Define Custom Wrapper for Optimal Threshold
# ---------------------------------------------------------
class ThresholdXGBModel(mlflow.pyfunc.PythonModel):
    """
    Custom wrapper to enforce the optimal decision threshold
    during inference instead of the default 0.5.
    """

    def __init__(self, model, threshold):
        self.model = model
        self.threshold = threshold

    def predict(self, context, model_input):
        # Predict probabilities
        probs = self.model.predict_proba(model_input)[:, 1]
        # Apply optimal threshold
        return (probs >= self.threshold).astype(int)

    def predict_proba(self, context, model_input):
        # Expose probabilities if needed
        return self.model.predict_proba(model_input)


# 4. Training & Logging
# ---------------------------------------------------------
with mlflow.start_run(run_name="xgb_readmission_balanced") as run:

    # --- Hyperparameters ---
    params = {
        "max_depth": 6,
        "learning_rate": 0.05,
        "n_estimators": 200,
        "objective": "binary:logistic",
        "eval_metric": "aucpr",
        "scale_pos_weight": scale_pos_weight,  # Handles imbalance
        "subsample": 0.8,
        "colsample_bytree": 0.8,
        "min_child_weight": 1,
        "gamma": 0.1,
        "random_state": 42,
        "tree_method": "hist",
    }

    mlflow.log_params(params)
    mlflow.log_param("class_imbalance_ratio", f"1:{n_negative/n_positive:.1f}")

    # --- Train ---
    print("ðŸ”„ Training XGBoost model...")
    model = xgb.XGBClassifier(**params)
    model.fit(X_train_float, y_train, eval_set=[(X_test_float, y_test)], verbose=False)

    # --- Threshold Optimization ---
    y_pred_proba = model.predict_proba(X_test_float)[:, 1]

    # Calculate F1 for all thresholds
    precisions, recalls, thresholds = precision_recall_curve(y_test, y_pred_proba)
    # Handle division by zero
    denominator = precisions + recalls
    f1_scores = np.divide(
        2 * (precisions * recalls),
        denominator,
        out=np.zeros_like(denominator),
        where=denominator != 0,
    )

    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = (
        thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5
    )

    # Generate predictions using optimal threshold
    y_pred_optimal = (y_pred_proba >= optimal_threshold).astype(int)

    # --- Metrics ---
    # Calculate standard metrics
    metrics = {
        "accuracy": accuracy_score(y_test, y_pred_optimal),
        "precision": precision_score(y_test, y_pred_optimal),
        "recall": recall_score(y_test, y_pred_optimal),
        "f1_score": f1_score(y_test, y_pred_optimal),
        "roc_auc": roc_auc_score(y_test, y_pred_proba),
        "optimal_threshold": optimal_threshold,
    }
    mlflow.log_metrics(metrics)

    # --- Plots ---
    fig = plt.figure(figsize=(10, 8))

    # ROC
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    plt.subplot(2, 2, 1)
    plt.plot(fpr, tpr, label=f"AUC = {metrics['roc_auc']:.3f}")
    plt.plot([0, 1], [0, 1], "k--")
    plt.title("ROC Curve")
    plt.legend()

    # PR Curve
    plt.subplot(2, 2, 2)
    plt.plot(recalls, precisions)
    plt.axvline(metrics["recall"], color="r", linestyle="--", label="Optimal")
    plt.title("PR Curve")
    plt.legend()

    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred_optimal)
    plt.subplot(2, 2, 3)
    plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
    plt.title(f"Conf. Matrix (Thresh={optimal_threshold:.2f})")

    # Feature Importance
    plt.subplot(2, 2, 4)
    importance = model.feature_importances_
    # Robust feature name extraction
    if hasattr(X_train, "columns"):
        feats = X_train.columns
    else:
        feats = [f"Feature {i}" for i in range(len(importance))]

    sorted_idx = np.argsort(importance)[-10:]
    plt.barh(range(len(sorted_idx)), importance[sorted_idx])
    plt.yticks(range(len(sorted_idx)), [feats[i] for i in sorted_idx])
    plt.title("Top 10 Features")

    plt.tight_layout()
    mlflow.log_figure(fig, "model_evaluation.png")
    plt.close()

    # --- Logging Artifacts ---

    # 1. Log the text report
    report = classification_report(y_test, y_pred_optimal)
    mlflow.log_text(report, "classification_report.txt")

    # 2. Log the Model (Wrapped with Threshold)
    # We infer signature from input (X) and output (y_pred class)
    signature = mlflow.models.infer_signature(
        X_train_float[:5],
        model.predict(X_train_float[:5]),  # Example output is class int
    )

    # Create the wrapper instance
    wrapped_model = ThresholdXGBModel(model, float(optimal_threshold))

    # Log the PyFunc model to Unity Catalog
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=wrapped_model,
        registered_model_name=UC_MODEL_NAME,
        signature=signature,
        input_example=X_train_float[:5],
        pip_requirements=["xgboost", "numpy", "pandas"],
    )

    print(f"âœ“ Run ID: {run.info.run_id}")
    print(f"âœ“ Model logged to: {UC_MODEL_NAME}")
    print(f"âœ“ Optimal Threshold baked into model: {optimal_threshold:.3f}")
    print(f"  (Inference will now automatically use this threshold)")
    print(f"\nðŸ“Š Final Performance:")
    print(f"  â€¢ F1 Score: {metrics['f1_score']:.4f}")
    print(f"  â€¢ Recall:   {metrics['recall']:.4f}")

# 5. Model Management (Set Alias)
# ---------------------------------------------------------
client = mlflow.tracking.MlflowClient()
latest_version = client.search_model_versions(f"name='{UC_MODEL_NAME}'")[0].version

client.set_registered_model_alias(
    name=UC_MODEL_NAME, alias="champion", version=latest_version
)

client.update_model_version(
    name=UC_MODEL_NAME,
    version=latest_version,
    description=f"Auto-threshold wrapper (Thresh={optimal_threshold:.3f}). F1={metrics['f1_score']:.4f}",
)

print(f"âœ“ Version {latest_version} set as 'champion'")