# Survival Cross Validation Evaluation

This notebook evaluates the results of the cross validation of the best survival models.

In [11]:
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
import utilities.latex_figures as latex_figs

In [3]:
cv_df = pd.read_csv("./data/cross_validation/surv_cross_validation.csv")
cv_df["params.n_dev"] = cv_df["params.n_dev"].astype(str)
cv_df.sort_values(by="params.n_dev", inplace=True)

In [4]:
metric_name_to_metric = {
    "metrics.c_index_ipcw": "C-Index IPCW",
    "metrics.ibs": "IBS",
}
model_class_to_model = {
    "kaplan_meier_estimator": "Kaplan-Meier",
    "CoxPHFitter": "Cox-PH",
    "RandomSurvivalForest": "Random Survival Forest",
}
model_class_to_abbr = {
    "kaplan_meier_estimator": "km",
    "CoxPHFitter": "cph",
    "RandomSurvivalForest": "rsf",
}

In [5]:
cv_df.shape

(180, 22)

In [6]:
print(cv_df["params.by_metric"].unique())

['metrics.c_index_ipcw' 'metrics.ibs']


In [7]:
def calculate_and_print_improvements(group: pd.DataFrame, by_metric: str) -> None:
    for n_dev, sub_group in group.groupby(by="params.n_dev"):
        best_with = sub_group[sub_group["Augmentation"] == "Ja"].sort_values(by=by_metric, ascending=by_metric=="metrics.ibs").iloc[0]
        best_with_value = best_with[by_metric]
        best_with_out = sub_group[sub_group["Augmentation"] == "Nein"].sort_values(by=by_metric, ascending=by_metric=="metrics.ibs").iloc[0]
        best_with_out_value = best_with_out[by_metric]
        improvement = best_with_value - best_with_out_value
        improvement_percent = (improvement / best_with_out_value) * 100
        with_str = f"{n_dev} & {best_with['params.n_aug']} & {best_with_value:.3f} ({'+' if improvement_percent > 0 else ''}{improvement_percent:.1f} \\%) \\\\".replace(".", ",")
        with_out_str = f"{n_dev} & {best_with_out['params.n_aug']} & {best_with_out_value:.3f} \\\\".replace(".", ",")
        print(with_str)
        print("\\hline")
        print(with_out_str)
        print("\\hline")

In [8]:
from utilities import latex_tables


for by_metric, metric_group in cv_df.groupby(by="params.by_metric"):
    print(by_metric)
    print(metric_group.shape)
    subfig_grid = latex_figs.LatexSubfigureGrid(
        caption=f"Vergleich des \\gls{{cv}} {metric_name_to_metric[by_metric]}s der besten Modelle.",
        label=f"comp_surv_cv_{by_metric.split('.')[1]}",
    )

    for model_class, model_group in metric_group.groupby(by="params.model_class"):
        model_group["Augmentation"] = model_group["params.n_aug"] != 0
        model_group["Augmentation"] = model_group["Augmentation"].replace(
            {True: "Ja", False: "Nein"}
        )
        model_group.sort_values(by=["Augmentation", "params.n_dev"], inplace=True)
        print(model_class)
        calculate_and_print_improvements(model_group, by_metric)


        fig = px.box(
            model_group,
            x="params.n_dev",
            y=by_metric,
            color="Augmentation",
            points="all",
            title=f"Vergleich des Cross-Validation {metric_name_to_metric[by_metric]}s der besten {model_class_to_model[model_class]}-Modelle mit und ohne Augmentation.",
            width=1000,
            height=600,
            range_y=[0, 1] if by_metric == "metrics.c_index_ipcw" else [0, 0.5],
        )

        fig.update_layout(
            xaxis_title="Anzahl der Trainingsgeräte",
            yaxis_title=metric_name_to_metric[by_metric],
        )
        fig.show()
        subfig = latex_figs.LatexSubfigure(
            f"resources/figures/survival_cv/{by_metric.split('.')[1]}_{model_class_to_abbr[model_class]}_cv.png",
            fig,
            caption=f"Vergleich des \\gls{{cv}} \\gls{{{by_metric.split('.')[1]}}}s der besten \\gls{{{model_class_to_abbr[model_class]}}}-Modelle.",
            label=f"comp_surv_cv_{by_metric.split('.')[1]}_{model_class_to_abbr[model_class]}",
        )
        subfig_grid.add_subfigure(subfig)
        subfig_grid.add_newline()
    subfig_grid.save_figure()
    subfig_grid.write_latex_code_to_file(f"comp_surv_cv_{by_metric.split('.')[1]}.tex")

metrics.c_index_ipcw
(72, 22)
CoxPHFitter
10 & 3 & 0,874 (+2,8 \%) \\
\hline
10 & 0 & 0,850 \\
\hline
20 & 10 & 0,908 (+12,3 \%) \\
\hline
20 & 0 & 0,808 \\
\hline
40 & 10 & 0,894 (+2,3 \%) \\
\hline
40 & 0 & 0,874 \\
\hline
63 & 3 & 0,940 (0,0 \%) \\
\hline
63 & 0 & 0,940 \\
\hline


RandomSurvivalForest
10 & 3 & 0,887 (+4,9 \%) \\
\hline
10 & 0 & 0,845 \\
\hline
20 & 3 & 0,879 (+1,2 \%) \\
\hline
20 & 0 & 0,869 \\
\hline
40 & 10 & 0,906 (+2,0 \%) \\
\hline
40 & 0 & 0,889 \\
\hline
63 & 3 & 0,920 (+5,2 \%) \\
\hline
63 & 0 & 0,874 \\
\hline


Saving subfigure to "/home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/resources/figures/survival_cv/c_index_ipcw_cph_cv.png"...
Done!
Saving subfigure to "/home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/resources/figures/survival_cv/c_index_ipcw_rsf_cv.png"...
Done!
Writing latex code to /home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/comp_surv_cv_c_index_ipcw.tex
Done!
metrics.ibs
(108, 22)
CoxPHFitter
10 & 3 & 0,073 (-40,0 \%) \\
\hline
10 & 0 & 0,122 \\
\hline
20 & 3 & 0,080 (-24,0 \%) \\
\hline
20 & 0 & 0,106 \\
\hline
40 & 3 & 0,123 (+16,4 \%) \\
\hline
40 & 0 & 0,106 \\
\hline
63 & 3 & 0,068 (-41,7 \%) \\
\hline
63 & 0 & 0,116 \\
\hline


RandomSurvivalForest
10 & 3 & 0,102 (-28,9 \%) \\
\hline
10 & 0 & 0,144 \\
\hline
20 & 1 & 0,094 (-29,0 \%) \\
\hline
20 & 0 & 0,132 \\
\hline
40 & 3 & 0,070 (-30,1 \%) \\
\hline
40 & 0 & 0,100 \\
\hline
63 & 1 & 0,075 (-26,0 \%) \\
\hline
63 & 0 & 0,101 \\
\hline


kaplan_meier_estimator
10 & 3 & 0,167 (-31,8 \%) \\
\hline
10 & 0 & 0,245 \\
\hline
20 & 3 & 0,166 (-24,9 \%) \\
\hline
20 & 0 & 0,221 \\
\hline
40 & 3 & 0,189 (-19,9 \%) \\
\hline
40 & 0 & 0,236 \\
\hline
63 & 3 & 0,174 (-24,3 \%) \\
\hline
63 & 0 & 0,230 \\
\hline


Saving subfigure to "/home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/resources/figures/survival_cv/ibs_cph_cv.png"...
Done!
Saving subfigure to "/home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/resources/figures/survival_cv/ibs_rsf_cv.png"...
Done!
Saving subfigure to "/home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/resources/figures/survival_cv/ibs_km_cv.png"...
Done!
Writing latex code to /home/nkuechen/Documents/Thesis/latex/Bachelor Thesis/comp_surv_cv_ibs.tex
Done!


In [15]:
fig = go.Figure(
    data=[
        go.Bar(name="Augmentiert", x=["10", "20", "40", "63"], y=[100, 200, 500, 673]),
        go.Bar(name="Unaugmentiert", x=["10", "20", "40", "63"], y=[56, 123, 982, 213]),
    ]
)

fig.show()

In [115]:
# Filtern und Gruppieren der Daten
grouped = cv_df.groupby(['params.n_dev', 'params.n_aug', 'params.model_class'])


for metric in ["metrics.ibs", "metrics.c_index_ipcw"]:
    # Finden der besten Metrikwerte
    if metric == "metrics.ibs":
        best_metric = grouped[metric].min().reset_index()
    else:
        best_metric = grouped[metric].max().reset_index()
        best_metric = best_metric[
            best_metric["params.model_class"] != "kaplan_meier_estimator"
        ]

    # Bar chart erstellen
    fig = go.Figure()

    pattern_styles = {
        "kaplan_meier_estimator": "",
        "CoxPHFitter": "x",
        "RandomSurvivalForest": ".",
    }

    # Iteration über die model_class
    for model_class in pattern_styles.keys():
        # Separieren der Daten in die beiden Bedingungen
        min_metrics_aug_0 = best_metric[
            (best_metric["params.n_aug"] == 0)
            & (best_metric["params.model_class"] == model_class)
        ]
        min_metrics_aug_not_0 = best_metric[
            (best_metric["params.n_aug"] != 0)
            & (best_metric["params.model_class"] == model_class)
        ]

        if metric == "metrics.ibs":
            min_metrics_aug_not_0 = (
                min_metrics_aug_not_0.groupby(["params.n_dev", "params.model_class"])
                .min()
                .reset_index()
            )
        else:
            min_metrics_aug_not_0 = (
                min_metrics_aug_not_0.groupby(["params.n_dev", "params.model_class"])
                .max()
                .reset_index()
            )

        # Bars für params.n_aug == 0
        fig.add_trace(
            go.Bar(
                x=min_metrics_aug_0["params.n_dev"],
                y=min_metrics_aug_0[metric],
                name=f"{model_class} params.n_aug == 0",
                marker_color="red",
                marker_pattern_shape=pattern_styles[model_class],
                showlegend=False,  # Legende deaktivieren
            )
        )

        # Bars für params.n_aug != 0
        fig.add_trace(
            go.Bar(
                x=min_metrics_aug_not_0["params.n_dev"],
                y=min_metrics_aug_not_0[metric],
                name=f"{model_class} params.n_aug != 0",
                marker_color="blue",
                marker_pattern_shape=pattern_styles[model_class],
                showlegend=False,  # Legende deaktivieren
            )
        )

    # Hinzufügen von Dummy-Traces für die Legende
    # Farbenlegende
    fig.add_trace(go.Bar(x=[None], y=[None], name="Unaugmentiert", marker_color="red"))

    fig.add_trace(go.Bar(x=[None], y=[None], name="Augmentiert", marker_color="blue"))

    # Musterlegende
    for model_class, pattern in pattern_styles.items():
        fig.add_trace(
            go.Bar(
                x=[None],
                y=[None],
                name=f"Modelltyp {model_class}",
                marker_color="white",
                marker_pattern_shape=pattern,
            )
        )

    # Layout anpassen
    fig.update_layout(
        title=dict(
            text=f"Vergleich der besten {'IBS' if 'ibs' in metric else 'C-Index IPCW'} Werte mit der Menge an Trainingsgeräten und dem Modelltyp",
            font=dict(size=24),  # Adjust the size as needed
        ),
        xaxis_title="params.n_dev",
        yaxis_title=metric,
        barmode="group",
        legend_title=dict(text="Legende", font=dict(size=18)),
        legend=dict(font=dict(size=16)),  # Adjust the size as needed
        width=1200,
        height=600,
        bargap=0,
        bargroupgap=0.1,
    )
    fig.update_xaxes(
        title=dict(text="Anzahl an Trainingsgeräten", font=dict(size=20)),
    )
    if metric == "metrics.ibs":
        fig.update_yaxes(title=dict(text="IBS", font=dict(size=20)))
    else:
        fig.update_yaxes(title=dict(text="C-Index IPCW", font=dict(size=20)))

    # Plot anzeigen
    fig.show()