# Analyze Survival Model Results

The primary endpoints of our analysis are concordance index (C-index) and risk stratification. This notebook analyzes several facets of our results. Each subheading in this notebook should be self contained (that is, it does not depend on or influence other sections of the notebook). The only exception is the **Setup** section which must be run prior to any given section.

## Setup

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sksurv.metrics import concordance_index_censored
from sksurv.nonparametric import kaplan_meier_estimator

In [None]:
df = pd.read_csv("split_cases.csv")
assert df["case_id"].is_unique
df = df.set_index("case_id")

In [None]:
def make_outcome_array(df):
    y = np.array(
        list(zip(df["dead"], df["days_to_death_or_censor"])),
        dtype=[("Status", "?"), ("Survival_in_days", "<f8")],
    )
    return y

In [None]:
def fit_survival_curve(y, max_time=3650, interpolate=True):
    time, survival_prob = kaplan_meier_estimator(y["Status"], y["Survival_in_days"])
    time[0] = 0
    if not interpolate:
        return time, survival_prob
    interp_time = np.linspace(0, max_time, max_time+1)
    interp_prob = np.interp(interp_time, time, survival_prob)
    return interp_time, interp_prob

In [None]:
def per_split_lo_hi_risk_curves(risk_scores, split_df, interpolate=True):
    pivot = np.median(risk_scores)
    lo_risk_idxs = np.argwhere(risk_scores < pivot).squeeze(-1)
    hi_risk_idxs = np.argwhere(risk_scores >= pivot).squeeze(-1)
    lo_risk_cases = split_df.loc[lo_risk_idxs, "case_id"]
    hi_risk_cases = split_df.loc[hi_risk_idxs, "case_id"]
    split_df = split_df.set_index("case_id")
    lo_risk_df = split_df.loc[lo_risk_cases]
    hi_risk_df = split_df.loc[hi_risk_cases]
    lo_risk_y = make_outcome_array(lo_risk_df)
    hi_risk_y = make_outcome_array(hi_risk_df)
    lo_risk_time, lo_risk_prob = fit_survival_curve(lo_risk_y, interpolate=interpolate)
    hi_risk_time, hi_risk_prob = fit_survival_curve(hi_risk_y, interpolate=interpolate)
    return {
        "lo_risk": {
            "time": lo_risk_time,
            "prob": lo_risk_prob,
        },
        "hi_risk": {
            "time": hi_risk_time,
            "prob": hi_risk_prob,
        },
    }

In [None]:
def cross_val_lo_hi_risk_curves(preds, df):
    lo_risk_times = []
    lo_risk_probs = []
    hi_risk_times = []
    hi_risk_probs = []
    c_idxs = []
    for i in range(5):
        split_df = df[df["split"] == i].sort_values("split_order").reset_index()
        c_idx = preds[i]["c_index"]
        c_idxs.append(c_idx)

        risk_scores = preds[i]["y_test_pred"]
        curves = per_split_lo_hi_risk_curves(risk_scores, split_df)
        lo_risk_times.append(curves["lo_risk"]["time"])
        lo_risk_probs.append(curves["lo_risk"]["prob"])
        hi_risk_times.append(curves["hi_risk"]["time"])
        hi_risk_probs.append(curves["hi_risk"]["prob"])

    lo_risk_time_mean = np.mean(lo_risk_times, axis=0)
    lo_risk_prob_mean = np.mean(lo_risk_probs, axis=0)
    lo_risk_prob_std = np.std(lo_risk_probs, axis=0)

    hi_risk_time_mean = np.mean(hi_risk_times, axis=0)
    hi_risk_prob_mean = np.mean(hi_risk_probs, axis=0)
    hi_risk_prob_std = np.std(hi_risk_probs, axis=0)

    c_idx_mean = np.mean(c_idxs)
    c_idx_std = np.std(c_idxs)

    return {
        "lo_risk": {
            "time": lo_risk_time_mean,
            "prob": {
                "mean": lo_risk_prob_mean,
                "std": lo_risk_prob_std,
            }
        },
        "hi_risk": {
            "time": hi_risk_time_mean,
            "prob": {
                "mean": hi_risk_prob_mean,
                "std": hi_risk_prob_std,
            }
        },
        "c_index": {
            "mean": c_idx_mean,
            "std": c_idx_std,
        }
    }

In [None]:
def plot_results(
    *,  # enforce kwargs
    ax: plt.Axes,
    results: dict,
    name: str,
    color: str,
    plot_std: bool = True,
    linestyle: str | None = None,
    include_cidx: bool = True,
):
    c_idx_mean = results["c_index"]["mean"]
    c_idx_std = results["c_index"]["std"]
    label = name
    if include_cidx:
        label += f", C-index = {c_idx_mean:0.2f}"
        if not plot_std and c_idx_std is not None:
            label += f"±{c_idx_std:0.2f}"
    ax.step(
        results["lo_risk"]["time"],
        results["lo_risk"]["prob"]["mean"],
        where="post",
        color=color,
        label=label,
        linestyle=linestyle,
    )
    if plot_std:
        ax.fill_between(
            results["lo_risk"]["time"],
            results["lo_risk"]["prob"]["mean"] - results["lo_risk"]["prob"]["std"],
            results["lo_risk"]["prob"]["mean"] + results["lo_risk"]["prob"]["std"],
            alpha=0.25,
            step="post",
            color=color,
            label=f"±1 std. dev. = {c_idx_std:0.2f}",
        )

    ax.step(
        results["hi_risk"]["time"],
        results["hi_risk"]["prob"]["mean"],
        where="post",
        color=color,
        linestyle=linestyle,
    )
    if plot_std:
        ax.fill_between(
            results["hi_risk"]["time"],
            results["hi_risk"]["prob"]["mean"] - results["hi_risk"]["prob"]["std"],
            results["hi_risk"]["prob"]["mean"] + results["hi_risk"]["prob"]["std"],
            alpha=0.25,
            step="post",
            color=color,
        )

In [None]:
def plot_comparison(
    *,  # enforce kwargs
    modes: list[
        tuple[
            str,  # modality name
            str,  # modality color
            dict,  # modality results
        ],
    ],
    save_path: str | list[str] | None = None,
    plot_std: bool = True,
    include_cidx: bool = True,
    linestyles: list[str] | None = None,
    fig: plt.Figure | None = None,
    ax: plt.Axes | None = None,
):
    if ax is None:
        assert fig is None
        fig, ax = plt.subplots(figsize=(5, 5))

    for i, (name, color, results) in enumerate(modes):
        linestyle = None
        if linestyles is not None:
            linestyle = linestyles[i]
        plot_results(ax=ax, results=results, name=name, color=color, plot_std=plot_std, linestyle=linestyle, include_cidx=include_cidx)

    ax.set_ylim(0, 1.05)
    ax.set_xlim(0, 3650)
    ax.legend(loc="lower left")
    ax.set_ylabel("Survival Probability")
    ax.set_xlabel("Days")
    if fig is not None:
        fig.tight_layout()
        if save_path is not None:
            if isinstance(save_path, str):
                save_path = [save_path]
            for sp in save_path: # accomodates for different file endings
                fig.savefig(sp, dpi=300)

## Project Comparisons

In [None]:
pca_components = 256
modes = {
    "demo": ("demo", "predictions_summarized.npy"),
    # "canc": ("canc", "predictions_summarized.npy"),
    "expr": ("expr", "predictions_summarized.npy"),
    "hist": ("hist", "predictions_summarized.npy"),
    "text": ("text", "predictions_summarized.npy"),
    "orig": ("text", "predictions.npy"),
    "canc-demo-expr-hist-text": ("canc-demo-expr-hist-text", "predictions_summarized.npy"),
}
data = dict()
for mode, (key, pred_file) in modes.items():
    preds = np.load(pred_file, allow_pickle=True).item()[pca_components]
    datum = [preds[i][key] for i in range(5)]
    data[mode] = datum

In [None]:
meta = pd.read_csv("../data/clinical.csv")
assert not meta["case_id"].duplicated().any()
meta = meta.set_index("case_id")
meta = pd.merge(df, meta, left_index=True, right_index=True)
df["project"] = meta.loc[df.index, "project"]

In [None]:
data_by_project = defaultdict(lambda: defaultdict(list))
for proj_rank, project in enumerate(df["project"].value_counts().index):
    print(proj_rank, project)
# for project in df["project"].value_counts().head(8).index:
    proj_df = df[df["project"] == project]
    for i in range(5):
        proj_split_df = proj_df[proj_df["split"] == i].sort_values("split_order")
        proj_split_idxs = proj_split_df["split_order"].to_numpy()
        proj_split_y_test = make_outcome_array(proj_split_df)

        died_mask = proj_split_y_test["Status"]
        died_survival = proj_split_y_test[died_mask]["Survival_in_days"]
        no_c_index = (
            died_mask.sum() == 0 # all censored
            or (
                # no comparable pairs
                died_mask.sum() < 2
                and died_survival[0] == proj_split_y_test["Survival_in_days"].max()
            )
        )
        skip = False
        if no_c_index:
            # given 5-fold cross validation with splits stratified by death
            # we can pretty much guarantee that projects with < 10 deaths
            # will have "bad" survival data
            # raise ValueError(f"Bad Survival Data {project} Split {i}")
            if died_mask.sum() == 0:
                reason = "all censored"
            else:
                reason = "no comparable pairs"
            print(f"Bad Survival Data {project} Split {i} because {reason}")
            skip = True

        for mode, datum in data.items():
            proj_split_y_test_pred = datum[i]["y_test_pred"][proj_split_idxs]
            if not skip:
                proj_split_c_index = concordance_index_censored(
                    event_indicator=proj_split_y_test["Status"],
                    event_time=proj_split_y_test["Survival_in_days"],
                    estimate=proj_split_y_test_pred,
                )[0]
            else:
                proj_split_c_index = -1
            data_by_project[project][mode].append(
                {
                    "c_index": proj_split_c_index,
                    "y_test_pred": proj_split_y_test_pred,
                }
            )

In [None]:
# fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(20, 10))
results_by_project = dict()
for i, project in enumerate(df["project"].value_counts().index):
# for i, project in enumerate(df["project"].value_counts().head(8).index):
    # ax = axs[i // 4, i % 4]
    data = data_by_project[project]
    if any([x["c_index"] == -1 for x in data["demo"]]):
        print("Skipping", i, project)
        continue
    proj_df = df[df["project"] == project]

    demo_results = cross_val_lo_hi_risk_curves(data["demo"], proj_df)
    # canc_results = cross_val_lo_hi_risk_curves(data["canc"], proj_df)
    expr_results = cross_val_lo_hi_risk_curves(data["expr"], proj_df)
    hist_results = cross_val_lo_hi_risk_curves(data["hist"], proj_df)
    text_results = cross_val_lo_hi_risk_curves(data["text"], proj_df)
    orig_results = cross_val_lo_hi_risk_curves(data["orig"], proj_df)
    mult_results = cross_val_lo_hi_risk_curves(data["canc-demo-expr-hist-text"], proj_df)

    # plot_comparison(
    #     modes=[
    #         ("Demographics", "tab:blue", demo_results),
    #         # ("Cancer Type", "tab:blue", canc_results),
    #         ("RNA-seq", "tab:orange", expr_results),
    #         ("Histology", "tab:green", hist_results),
    #         ("Text", "tab:purple", text_results),
    #         ("Multimodal", "tab:red", mult_results),
    #     ],
    #     plot_std=False,
    #     ax=ax,
    # )

    results_by_project[project.replace("TCGA-", "")] = {
        "Mean Split Size": f'{proj_df.groupby("split").size().mean():0.1f}',
        "Mean Split Mortality": f'{proj_df.groupby("split")["dead"].sum().mean():0.1f}',
        "Demographics": demo_results["c_index"]["mean"],
        # "Cancer type": canc_results["c_index"]["mean"],
        "RNA-seq": expr_results["c_index"]["mean"],
        "Histology": hist_results["c_index"]["mean"],
        "Text": text_results["c_index"]["mean"],
        "Orig": orig_results["c_index"]["mean"],
        "Multimodal": mult_results["c_index"]["mean"],
    }
results_by_project = pd.DataFrame(results_by_project)
results_by_project.index.name = "Modality"
results_by_project = results_by_project.reset_index()

In [None]:
print(results_by_project.loc[[0, 1, 2, 3, 4, 5, 7]].iloc[:, [0] + list(range(1, 8))].to_markdown(index=False, floatfmt="0.3f"))
print(results_by_project.loc[[0, 1, 2, 3, 4, 5, 7]].iloc[:, [0] + list(range(8, 15))].to_markdown(index=False, floatfmt="0.3f"))
print(results_by_project.loc[[0, 1, 2, 3, 4, 5, 7]].iloc[:, [0] + list(range(15, 22))].to_markdown(index=False, floatfmt="0.3f"))
print(results_by_project.loc[[0, 1, 2, 3, 4, 5, 7]].iloc[:, [0] + list(range(22, 28))].to_markdown(index=False, floatfmt="0.3f"))


## Multimodal Comparisons

In [None]:
trials = pd.read_csv("results_summarized.csv")
trials = trials.rename(columns={"Unnamed: 0": "combo"})

In [None]:
# rebuttal to reviewer oaBB
temp = trials.loc[trials["combo"].str.contains("canc"), ["combo", "256"]]
temp.loc[[0, 5], "256"] = trials.loc[[0, 5], "4"]
temp = temp.rename(columns={"256": "C-index"})
print(temp.to_markdown(index=False, floatfmt="0.3f"))

## Qualitative Analysis of Hallucination Corrections

In [None]:
import pandas as pd
import difflib

# copy paste from SO: https://stackoverflow.com/a/64404008
# red = lambda text: f"\033[38;2;255;0;0m{text}\033[38;2;255;255;255m"
# green = lambda text: f"\033[38;2;0;255;0m{text}\033[38;2;255;255;255m"
# blue = lambda text: f"\033[38;2;0;0;255m{text}\033[38;2;255;255;255m"
# white = lambda text: f"\033[38;2;255;255;255m{text}\033[38;2;255;255;255m"

red = lambda text: f'<span style="color: red;">{text}</span>'
green = lambda text: f'<span style="color: green;">{text}</span>'
blue = lambda text: f'<span style="color: blue;">{text}</span>'
white = lambda text: text

def get_edits_string(old, new):
    result = ""
    codes = difflib.SequenceMatcher(a=old, b=new).get_opcodes()
    for code in codes:
        if code[0] == "equal": 
            result += white(old[code[1]:code[2]])
        elif code[0] == "delete":
            result += red(old[code[1]:code[2]])
        elif code[0] == "insert":
            result += green(new[code[3]:code[4]])
        elif code[0] == "replace":
            result += (red(old[code[1]:code[2]]) + green(new[code[3]:code[4]]))
    return result

In [None]:
df = pd.read_csv("../data/sampled_corrected.csv")

In [None]:
# copy the output into an HTML file, easier copy/pasting from the browser rendered text than in notebook
for i in df[df["summ"] != df["corrected"]].index:
    print(f"<h3>{df.loc[i, 'case_id']}</h3>")
    print(get_edits_string(df.loc[i, "summ"], df.loc[i, "corrected"]))
    print()