In [None]:
import pandas as pd
from aequitas.group import Group
from aequitas.bias import Bias
from aequitas.fairness import Fairness
from aequitas.plotting import Plot
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

FIGDST = Path("../figures_out/temp_trash")
SAVE_FMT = {"format": "tiff", "dpi": 300}
AEQDST = FIGDST / "aequitas"
plt.style.use("seaborn-v0_8")

In [None]:
main_cohort = pd.read_parquet(COHORT_DATA_FILE)
mask_both = (
    main_cohort["mayo_score"].notna()
    & main_cohort["ultromics_prediction"].notna()
    & main_cohort["echonet_prediction"].notna()
    & (main_cohort["ultromics_classification"] != "Uncertain")
)
echonet_matched_cohort = main_cohort.loc[mask_both]

## Setup

In [None]:
aequitas_cohort__echonet = echonet_matched_cohort.copy(deep=True)
aequitas_cohort__echonet.rename(
    {
        "echonet_prediction": "score",
        "true_label": "label_value",
    },
    inplace=True,
    axis=1,
)
aequitas_cohort__echonet["model_id"] = "EchoNet-LVH"

aequitas_cohort__mayo = echonet_matched_cohort.copy(deep=True)
aequitas_cohort__echonet.rename(
    {
        "mayo_score": "score",
        "true_label": "label_value",
    },
    inplace=True,
    axis=1,
)
aequitas_cohort__mayo["model_id"] = "Mayo ATTR-CM Score"

aequitas_cohort__ult = echonet_matched_cohort.copy(deep=True)
aequitas_cohort__echonet.rename(
    {
        "ultromics_prediction": "score",
        "true_label": "label_value",
    },
    inplace=True,
    axis=1,
)
aequitas_cohort__ult["model_id"] = "EchoGo Amyloidosis"

aequitas_cohort = pd.concat(
    [aequitas_cohort__mayo, aequitas_cohort__echonet, aequitas_cohort__ult]
)
aequitas_cohort["SDI"] = pd.cut(
    aequitas_cohort.SDI_score,
    [0, 25, 50, 75, 100],
    include_lowest=True,
    ordered=False,
    labels=[0, 1, 2, 3],
).astype(str)
aequitas_cohort = aequitas_cohort.loc[
    aequitas_cohort.Sex.notna()
    & aequitas_cohort.SDI.notna()
    & aequitas_cohort.Race.notna()
]

In [None]:
group = Group()
xtab_raw, temp = group.get_multimodel_crosstabs(
    aequitas_cohort,
    attr_cols=["Sex", "Race"],
    score_thresholds={"score_val": [6, 0.8, 0.06]},
)

# Dropping unnecessary (model_id, score_threshold) pairs
xtab = xtab_raw.loc[
    (
        (xtab_raw.model_id == "Mayo ATTR-CM Score")
        & (xtab_raw.score_threshold == "6_val")
    )
    | ((xtab_raw.model_id == "EchoNet-LVH") & (xtab_raw.score_threshold == "0.8_val"))
    | (
        (xtab_raw.model_id == "EchoGo Amyloidosis")
        & (xtab_raw.score_threshold == "0.06_val")
    ),
    :,
]

## Absolute Metrics

In [None]:
aequitas_plot = Plot()

In [None]:
for model_id in ["Mayo ATTR-CM Score", "EchoNet-LVH", "EchoGo Amyloidosis"]:
    temp_fig = aequitas_plot.plot_group_metric_all(
        xtab[xtab.model_id == model_id],
        metrics=[
            "pprev",
            "ppr",
            "fdr",
            "for",
            "fpr",
            "fnr",
            "tpr",
            "tnr",
            "npv",
            "precision",
        ],
    )
    temp_fig.tight_layout()
    temp_fig.savefig(
        AEQDST / f"{model_id}_absolute_metrics_full.{SAVE_FMT['format']}", **SAVE_FMT
    )

## Bias

In [None]:
bias = Bias()
bias_df = bias.get_disparity_predefined_groups(
    xtab,
    original_df=aequitas_cohort.loc[
        :,
        [
            "score",
            "label_value",
            "Sex",
            "Race",
            "model_id",
        ],
    ],
    ref_groups_dict={
        "Sex": "male",
        "Race": "White",
    },
    alpha=0.05,
    check_significance=True,
    mask_significance=True,
)

## Fairness

In [None]:
fairness = Fairness()
fairness_df = fairness.get_group_value_fairness(bias_df)

In [None]:
for model_id in ["Mayo ATTR-CM Score", "EchoNet-LVH", "EchoGo Amyloidosis"]:
    temp_fig = aequitas_plot.plot_fairness_group_all(
        fairness_df[fairness_df.model_id == model_id], metrics="all", ncols=5
    )
    temp_fig.tight_layout()
    temp_fig.savefig(AEQDST / f"{model_id}_fairness.{SAVE_FMT['format']}", **SAVE_FMT)

### Demographic Parity

In [None]:
for model_id in ["Mayo ATTR-CM Score", "EchoNet-LVH", "EchoGo Amyloidosis"]:
    fairness_pprev = Fairness(
        fair_eval=lambda tau: lambda x: (
            np.nan if np.isnan(x) else (True if 0.8 <= x <= np.inf else False)
        )
    )
    fairness_df_pprev = fairness_pprev.get_group_value_fairness(bias_df)
    temp_fig = aequitas_plot.plot_fairness_disparity_all(
        fairness_df_pprev[fairness_df_pprev.model_id == model_id],
        metrics=["pprev"],
        show_figure=False,
    )
    for text in temp_fig.axes[1].texts:
        text.set_text(text.get_text().replace("**", "*"))
        text.set_fontsize(22)
    temp_fig.savefig(
        AEQDST / f"{model_id}_pprev_disparity.{SAVE_FMT['format']}", **SAVE_FMT
    )

### Predicted Parity

In [None]:
for model_id in ["Mayo ATTR-CM Score", "EchoNet-LVH", "EchoGo Amyloidosis"]:
    fairness_precision = Fairness(
        fair_eval=lambda tau: lambda x: (
            np.nan if np.isnan(x) else (True if 0.8 <= x <= np.inf else False)
        )
    )
    fairness_df_precision = fairness_precision.get_group_value_fairness(bias_df)
    temp_fig = aequitas_plot.plot_fairness_disparity_all(
        fairness_df_precision[fairness_df_precision.model_id == model_id],
        metrics=["precision"],
        show_figure=False,
    )
    for text in temp_fig.axes[1].texts:
        text.set_text(text.get_text().replace("**", "*"))
        text.set_fontsize(22)
    temp_fig.savefig(
        AEQDST / f"{model_id}_precision_disparity.{SAVE_FMT['format']}", **SAVE_FMT
    )

### Equal Opportunity

In [None]:
for model_id in ["Mayo ATTR-CM Score", "EchoNet-LVH", "EchoGo Amyloidosis"]:
    fairness__fnr = Fairness(
        fair_eval=lambda tau: lambda x: (
            np.nan if np.isnan(x) else (True if 0 <= x <= 1.2 else False)
        )
    )
    fairness_df_fnr = fairness__fnr.get_group_value_fairness(bias_df)
    temp_fig = aequitas_plot.plot_fairness_disparity_all(
        fairness_df_fnr[fairness_df_fnr.model_id == model_id],
        metrics=["fnr"],
        show_figure=False,
    )
    for text in temp_fig.axes[1].texts:
        text.set_text(text.get_text().replace("**", "*"))
        text.set_fontsize(22)
    temp_fig.savefig(
        AEQDST / f"{model_id}_fnr_disparity.{SAVE_FMT['format']}", **SAVE_FMT
    )

### Tree Plot

In [None]:
from create_treeplots import TreePlotConfig, tree_plot

config = TreePlotConfig
image = tree_plot(
    src=Path("/home/nea914/projects/aha_risk/figures_out/temp_trash/aequitas"),
    dst="/home/nea914/projects/aha_risk/figures_out/temp_trash/disparity_treemap2.tiff",
    config=TreePlotConfig,
)

In [None]:
plt.imshow(image)

In [None]:
# Demographic parity refers to positive prediction rate
# predictive parity refers to positive predictive value (precision)
# equalized opportunity refers to false negative rate

common_cols = [
    "model_id",
    "attribute_value",
    "group_size",
    "prev",
    "tpr",
    "tnr",
]

table6 = pd.concat(
    [
        fairness_df_pprev.loc[
            fairness_df_precision.attribute_name == "Race",
            common_cols + ["pprev_disparity"],
        ],
        fairness_df_precision.loc[
            fairness_df_precision.attribute_name == "Race",
            ["precision_disparity"],
        ],
        fairness_df_fnr.loc[
            fairness_df_precision.attribute_name == "Race", ["fnr_disparity"]
        ],
    ],
    axis=1,
)

In [None]:
race_map = {k: v for v, k in enumerate(["White", "Black", "Hispanic", "Other"])}
for model_id in ["Mayo ATTR-CM Score", "EchoNet-LVH", "EchoGo Amyloidosis"]:
    print(
        table6[table6.model_id == model_id]
        .sort_values("attribute_value", key=lambda x: x.map(race_map))
        .round(2)
    )