# Decision Curve Analysis

In [None]:
import sys
import warnings

sys.path.append("../")
from src.data_utils import get_data, get_models
from src.config import BASE_PATH

import matplotlib.pyplot as plt
from statkit.decision import NetBenefitDisplay
import pandas as pd

Set Globals

In [None]:
RESULT_PATH = BASE_PATH / "results" / "figures" / "DCA"
## Data
OUTCOME_DICT = {
    "surg": get_data("outcome_surg"),
    "bleed": get_data("outcome_bleed"),
    "asp": get_data("outcome_asp"),
    "mort": get_data("outcome_mort"),
}

## Models
model_dir = BASE_PATH / "cal_models"
model_prefix_list = ["lr", "lgbm", "svc", "stack", "nn"]
MODEL_DICT = {}
for outcome in OUTCOME_DICT.keys():
    MODEL_DICT[outcome] = get_models(model_prefix_list, outcome, model_dir)

COLOR_LIST = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:purple",
    "tab:red",
    "tab:olive",
    "tab:pink",
    "tab:gray",
    "tab:olive",
]

In [None]:
for outcome_name, outcome_data in OUTCOME_DICT.items():
    X_test = outcome_data["X_test"]
    y_test = outcome_data["y_test"].values.ravel()
    cur_model_dict = MODEL_DICT[outcome_name]

    model_names = list(cur_model_dict.keys())
    models = list(cur_model_dict.values())

    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    for (model_name, model), color in zip(cur_model_dict.items(), COLOR_LIST):
        y_proba = model.predict_proba(X_test)[:, 1]

        NetBenefitDisplay.from_predictions(
            y_true=y_test,
            y_pred=y_proba,
            name=model_name,
            ax=ax,
        )
    lines = [line for line in ax.get_lines() if line.get_label() in model_names]
    for line, color in zip(lines, COLOR_LIST):
        line.set_color(color)

    # Get all handles and labels
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))  # Remove duplicates

    # # Separate baselines from models
    baseline_labels = ["Always act", "Never act", "Oracle"]
    model_labels = [name for name in model_names if name in by_label]

    # # Models first, then baselines
    ordered_labels = model_labels + baseline_labels

    # # Create ordered handles and labels
    ordered_handles = [by_label[label] for label in ordered_labels if label in by_label]
    final_labels = [label for label in ordered_labels if label in by_label]

    ax.legend(ordered_handles, final_labels, loc="upper right")

    plt.title(f"DCA: Model Comparison for {outcome_name}")
    plt.ylim(-0.05, 0.125)
    save_path = RESULT_PATH / f"{outcome_name}.pdf"
    if save_path.exists():
        warnings.warn(
            f"Over-writing DCA for outcome {outcome_name} at path {save_path}"
        )
    save_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()