In [52]:
import seaborn as sns
import numpy as np
import mlflow
import mlflow
import mlflow.xgboost
import shap
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    f1_score, recall_score, brier_score_loss,
    confusion_matrix, RocCurveDisplay,
    PrecisionRecallDisplay
)

In [54]:
mlflow.set_tracking_uri("http://localhost:5001")
mlflow.set_experiment("brfss_heart_attack_risk")

2025/05/09 19:18:03 INFO mlflow.tracking.fluent: Experiment with name 'brfss_heart_attack_risk' does not exist. Creating a new experiment.


<Experiment: artifact_location='/Users/rev/IUB/Projects/HeartAttackRiskPrediction/mlruns/1', creation_time=1746832683767, experiment_id='1', last_update_time=1746832683767, lifecycle_stage='active', name='brfss_heart_attack_risk', tags={}>

In [56]:
from pathlib import Path
import joblib

model_dir = Path("..") / "models"

xgb_model = joblib.load(model_dir / "xgb_top25_shap.joblib")
X_test = joblib.load(model_dir / "X_test25.joblib")
y_test = joblib.load(model_dir / "y_test25.joblib")


In [58]:
import os
print(os.getcwd())


/Users/rev/IUB/Projects/HeartAttackRiskPrediction/notebooks


In [92]:
from itertools import chain
from mlflow.models.signature import infer_signature

# Gather params & metrics
params = xgb_model.get_params()

y_proba = xgb_model.predict_proba(X_test)[:, 1]
y_pred  = (y_proba >= 0.5).astype(int)

metrics = {
    "auroc"       : roc_auc_score(y_test, y_proba),
    "auprc"       : average_precision_score(y_test, y_proba),
    "f1"          : f1_score(y_test, y_pred),
    "recall"      : recall_score(y_test, y_pred),
    "brier_score" : brier_score_loss(y_test, y_proba)
}

# SHAP: Extract model and transformed input 
raw_xgb = xgb_model.named_steps["clf"]
preprocessor = xgb_model.named_steps["pre"]
X_transformed = preprocessor.transform(X_test)

# Get original categorical and numerical column names
cat_cols = preprocessor.transformers_[0][2]
num_cols = preprocessor.transformers_[1][2]

# Get OneHotEncoder feature names
ohe = preprocessor.named_transformers_["cat"]
ohe_feature_names = ohe.get_feature_names_out(cat_cols)

# Final SHAP feature names
feature_names = list(chain(ohe_feature_names, num_cols))

# SHAP explainer
explainer = shap.TreeExplainer(raw_xgb)
shap_values = explainer.shap_values(X_transformed)

# Logging everything in one run 
from datetime import datetime
run_name = f"xgb_top25_shap__{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
with mlflow.start_run(run_name=run_name) as run:
    mlflow.set_tag("model_status", "best_candidate")
    mlflow.set_tag("note", "final filtered SHAP, log model, good CM plot")
    mlflow.log_params(params)
    mlflow.log_metrics(metrics)

    # SHAP summary plot
    exclude_features = ["prev_chd_or_mi"]
    
    filtered_indices = [i for i, name in enumerate(feature_names) if name not in exclude_features]
    
    filtered_shap_values = shap_values[:, filtered_indices]
    filtered_feature_names = [feature_names[i] for i in filtered_indices]
    
    # Updated SHAP plot
    shap.summary_plot(filtered_shap_values, X_transformed[:, filtered_indices], feature_names=filtered_feature_names, show=False)
    plt.tight_layout()
    plt.savefig("shap_summary_filtered.png")
    mlflow.log_artifact("shap_summary_filtered.png", artifact_path="plots")
    plt.close()
    # ROC curve
    RocCurveDisplay.from_predictions(y_test, y_proba)
    plt.title("ROC Curve")
    plt.savefig("roc_curve.png")
    mlflow.log_artifact("roc_curve.png", artifact_path="plots")
    plt.close()

    # PR curve
    PrecisionRecallDisplay.from_predictions(y_test, y_proba)
    plt.title("Precision-Recall Curve")
    plt.savefig("pr_curve.png")
    mlflow.log_artifact("pr_curve.png", artifact_path="plots")
    plt.close()

    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    labels = np.array([[f"{value}" for value in row] for row in cm])
    
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=labels, fmt='', cmap="viridis", cbar=True, square=True,
                xticklabels=["Predicted 0", "Predicted 1"],
                yticklabels=["Actual 0", "Actual 1"],
                annot_kws={"fontsize": 12, "color": "orange"})
    
    plt.title("XGBoost - Confusion Matrix")
    plt.xlabel("Predicted label")
    plt.ylabel("True label")
    plt.tight_layout()
    plt.savefig("confusion_matrix_annotated.png")
    mlflow.log_artifact("confusion_matrix_annotated.png", artifact_path="plots")
    plt.close()
    signature = infer_signature(X_test, y_proba)

    mlflow.sklearn.log_model(
    sk_model=xgb_model,
    artifact_path="model",
    registered_model_name="HeartAttackRiskModel",
    signature=signature       
    )
    # tag as best
    mlflow.set_tag("model_status", "best_candidate")

Registered model 'HeartAttackRiskModel' already exists. Creating a new version of this model...
2025/05/10 21:38:14 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: HeartAttackRiskModel, version 3


🏃 View run xgb_top25_shap__2025-05-10_21-37-57 at: http://localhost:5001/#/experiments/1/runs/c7f68751b1c34adab829067a5be5d756
🧪 View experiment at: http://localhost:5001/#/experiments/1


Created version '3' of model 'HeartAttackRiskModel'.


In [62]:
from mlflow import register_model

model_uri = f"runs:/{run.info.run_id}/model"
register_model(model_uri=model_uri, name="HeartAttackRiskModel")

Successfully registered model 'HeartAttackRiskModel'.
2025/05/09 19:18:38 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: HeartAttackRiskModel, version 1
Created version '1' of model 'HeartAttackRiskModel'.


<ModelVersion: aliases=[], creation_timestamp=1746832718446, current_stage='None', description='', last_updated_timestamp=1746832718446, name='HeartAttackRiskModel', run_id='d3e5ff757a1b444693264a92ae732793', run_link='', source='/Users/rev/IUB/Projects/HeartAttackRiskPrediction/mlruns/1/d3e5ff757a1b444693264a92ae732793/artifacts/model', status='READY', status_message=None, tags={}, user_id='', version='1'>

## When we find a better model

In [None]:
# from mlflow.tracking import MlflowClient
# client = MlflowClient("http://localhost:5001")

# client.transition_model_version_stage(
#     name="HeartAttackRiskModel",
#     version="2",                 # version that shows the schema
#     stage="Production",
#     archive_existing_versions=True
# )

In [None]:
import json
import os

model_version = result.version

metadata = {
    "model_version": model_version
}

os.makedirs("/opt/airflow/out", exist_ok=True)
with open("/opt/airflow/out/notebook_output_metadata.json", "w") as f:
    json.dump(metadata, f)

print(f"Registered MLflow model version {model_version} written to metadata.")