# Survival Experiments

This notebook contains all our code for survival modeling.

The experiments test multimodal fusion of survival models and varying dimensionality reductions for high-dimensional embeddings.

Specifically, we experiment with 5 modalities:
* Patient demographics (sex, age - binned, race, ethnicity)
* Cancer type (we use the TCGA project ID as a proxy for cancer type)
* RNA-seq gene expression (`BulkRNABert` embeddings)
* Whole slide histology images (`UNI2` embeddings)
* Pathology reports (`BioMistral` embeddings)

We additionally experiment with various alternate embeddings, including:
* `BioMistral` embeddings of pathology report summaries generated by `Llama-3.1-8B-Instruct`
* `Mistral-7B-Instruct-v0.1` embeddings of pathology reports
* `Mistral-7B-Instruct-v0.1` embeddings of pathology report summaries generated by `Llama-3.1-8B-Instruct`
* `UCE` embeddings of RNA-seq gene expression

To use these alternate embeddings, modify the variables for input/output files in the first code cell of this notebook.

Run experiments by executing all cells of this notebook. Results are saved in the `results` subdirectory at the root of the repo. Analysis and visualization is done using tools also in the `results` folder.

In [None]:
expr_file = "../embed/expr.h5" # BulkRNABert
hist_file = "../embed/hist.h5" # UNI2
text_file = "../embed/summ.h5" # BioMistral - Summarized

In [None]:
from itertools import chain, combinations
from collections import defaultdict
import h5py
import pandas as pd
import numpy as np
from tqdm import tqdm
from pqdm.processes import pqdm
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, integrated_brier_score, cumulative_dynamic_auc

In [None]:
df = pd.read_csv("../data/clinical.csv")
clin_case_ids = set(df["case_id"])

with h5py.File(expr_file, "r") as expr_h5:
    expr_case_ids = set(expr_h5.keys())

with h5py.File(hist_file, "r") as hist_h5:
    hist_case_ids = set(hist_h5.keys())

with h5py.File(text_file, "r") as text_h5:
    text_case_ids = set(text_h5.keys())

In [None]:
(df["days_to_death"].where(df["days_to_death"].notna(), df["days_to_last_follow_up"]) <= 365).sum() / len(df)

In [None]:
(df["days_to_death"].where(df["days_to_death"].notna(), df["days_to_last_follow_up"]) <= 5*365).sum() / len(df)

In [None]:
case_ids = sorted(list(clin_case_ids & expr_case_ids & hist_case_ids & text_case_ids))

df = df[df["case_id"].isin(case_ids)]
df = df.sort_values("case_id").reset_index(drop=True)
assert df["case_id"].is_unique

In [None]:
df["age_binned"] = pd.cut(
    df["age"],
    bins=[0, 20, 40, 60, 80, 100],
    labels=["(0, 20]", "(20, 40]", "(40, 60]", "(60, 80]", "(80, 100]"],
)

In [None]:
dead = df["vital_status"] == "Dead"
days_to_event = np.where(dead, df["days_to_death"], df["days_to_last_follow_up"])
assert not np.isnan(days_to_event).any()

In [None]:
y = np.array(list(zip(dead, days_to_event)), dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

In [None]:
demo_ohe = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32)
canc_ohe = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32)

In [None]:
demo_X = demo_ohe.fit_transform(df[["sex", "age_binned", "race", "ethnicity"]])
canc_X = canc_ohe.fit_transform(df[["project"]])

In [None]:
demo_X.shape

In [None]:
canc_X.shape

In [None]:
demo_ohe.categories_

In [None]:
canc_ohe.categories_

In [None]:
def extract_case_emb_from_h5(case_ids: list[str], h5: h5py.File):
    X = []
    for case_id in tqdm(case_ids):
        case_group = h5[case_id]
        embs = np.stack([v[:] for v in case_group.values()], axis=0)
        emb = np.mean(embs, axis=0)
        X.append(emb)
    return np.stack(X, axis=0)

In [None]:
with h5py.File(expr_file, "r") as expr_h5:
    expr_X = extract_case_emb_from_h5(case_ids, expr_h5)

with h5py.File(hist_file, "r") as hist_h5:
    hist_X = extract_case_emb_from_h5(case_ids, hist_h5)

with h5py.File(text_file, "r") as text_h5:
    text_X = extract_case_emb_from_h5(case_ids, text_h5)

In [None]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
splitter = (
    df["vital_status"]
    + "_"
    + df["project"]
    + "_"
    + df["sex"]
    + "_"
    + df["age_binned"].astype(str)
    + "_"
    + df["race"]
    + "_"
    + df["ethnicity"]
)

n = len(df)
test_splits = [split_idxs for _, split_idxs in skf.split(X=np.zeros(n), y=splitter)]

In [None]:
meta_df = df[["case_id"]].copy()
meta_df["split"] = -1
meta_df["split_order"] = -1
for i, test_idxs in enumerate(test_splits):
    meta_df.loc[test_idxs, "split"] = i
    meta_df.loc[test_idxs, "split_order"] = list(range(len(test_idxs)))
meta_df["dead"] = y["Status"]
meta_df["days_to_death_or_censor"] = y["Survival_in_days"]
# meta_df.to_csv("../results/split_cases.csv", index=False)

In [None]:
meta_df["split"].value_counts()

In [None]:
def run_split(
    *,  # enforce kwargs
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    pca_components: int | None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    if verbose:
        print(f"Running {name}")

    # z-score input features
    if standardize:
        if verbose:
            print("--standardized")
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
    else:
        X_train_scaled = X_train
        X_test_scaled = X_test

    # dimensionality reduction
    if pca_components is not None:
        if verbose:
            print("--reduced")
        pca = PCA(n_components=pca_components, random_state=42)
        X_train_red = pca.fit_transform(X_train_scaled)
        X_test_red = pca.transform(X_test_scaled)
    else:
        X_train_red = X_train_scaled
        X_test_red = X_test_scaled

    # fit survival model
    cox = CoxPHSurvivalAnalysis(alpha=0.1).fit(X_train_red, y_train)
    if verbose:
        print("--trained")

    # generate predictions
    y_train_pred = cox.predict(X_train_red)
    y_test_pred = cox.predict(X_test_red)

    y_test_survs = cox.predict_survival_function(X_test_red)
    times = np.arange(365, 1826) # 1 year to 5 year
    y_test_probs = np.asarray([fn(times) for fn in y_test_survs])
    if verbose:
        print("--computed survival probabilities")

    # evaluate predictions
    c_index = concordance_index_censored(
        event_indicator=y_test["Status"],
        event_time=y_test["Survival_in_days"],
        estimate=y_test_pred,
    )[0]
    if verbose:
        print("--computed c-index")

    ibs = integrated_brier_score(
        survival_train=y_train,
        survival_test=y_test,
        estimate=y_test_probs,
        times=times,
    )
    if verbose:
        print("--computed IBS")

    _, cd_auc = cumulative_dynamic_auc(
        survival_train=y_train,
        survival_test=y_test,
        estimate=y_test_pred,
        times=times,
    )
    if verbose:
        print("--computed cd-AUC")

    return {
        "c_index": c_index,
        "ibs": ibs,
        "cd_auc": cd_auc,
        "y_test_pred": y_test_pred,
        "y_train_pred": y_train_pred,
        "model": cox,
    }

def run_unimodal_split(
    *,  # enforce kwargs
    X: np.ndarray,
    y: np.ndarray,
    test_idxs: np.ndarray,
    train_idxs: np.ndarray,
    pca_components: int | None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    # split matrices
    X_train, X_test = X[train_idxs], X[test_idxs]
    y_train, y_test = y[train_idxs], y[test_idxs]

    return run_split(
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        pca_components=pca_components,
        standardize=standardize,
        name=name,
        verbose=verbose,
    )

def powerset(s):
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [None]:
def run_experiment(pca_components: int | None | dict[str, int | None]) -> dict:
    if isinstance(pca_components, dict):
        pca_map = pca_components
    else:
        pca_map = {
            "expr": pca_components,
            "hist": pca_components,
            "text": pca_components,
        }
    results = []
    for test_idxs in tqdm(test_splits, desc="Cross Validation Splits"):
        split_results = dict()

        temp = set(test_idxs)
        train_idxs = [i for i in range(n) if i not in temp]

        split_results["demo"] = run_unimodal_split(X=demo_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=None, standardize=False)
        split_results["canc"] = run_unimodal_split(X=canc_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=None, standardize=False)
        split_results["expr"] = run_unimodal_split(X=expr_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_map["expr"], standardize=True)
        split_results["hist"] = run_unimodal_split(X=hist_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_map["hist"], standardize=True)
        split_results["text"] = run_unimodal_split(X=text_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_map["text"], standardize=True)

        y_train, y_test = y[train_idxs], y[test_idxs]

        combos = [sorted(x) for x in powerset(["demo", "canc", "expr", "hist", "text"]) if len(x) > 1]
        # combos = [sorted(["demo", "canc", "expr", "hist", "text"])]
        for combo in combos:
            mult_X_train = []
            mult_X_test = []
            for modality in combo:
                x_train = split_results[modality]["y_train_pred"][:, np.newaxis]
                x_test = split_results[modality]["y_test_pred"][:, np.newaxis]
                # z-score all unimodal risks
                scaler = StandardScaler()
                x_train = scaler.fit_transform(x_train)
                x_test = scaler.transform(x_test)
                mult_X_train.append(x_train)
                mult_X_test.append(x_test)

            mult_X_train = np.concat(mult_X_train, axis=1)
            mult_X_test = np.concat(mult_X_test, axis=1)

            split_results["-".join(combo)] = run_split(X_train=mult_X_train, y_train=y_train, X_test=mult_X_test, y_test=y_test, pca_components=None, standardize=False)

        results.append(split_results)
    return results

### Run mixed raw/reduced experiment

In [None]:
mixed = run_experiment(pca_components={
    "expr": None,
    "hist": None,
    "text": 256,
})
np.save("../results/rebuttal_mixed_raw_embeddings_predictions_summarized.npy", mixed)

In [None]:
combos = ["-".join(sorted(x)) for x in powerset(["demo", "canc", "expr", "hist", "text"]) if len(x) > 0]
df = defaultdict(dict)
for combo in combos:
    temps = []
    for i in range(5):
        temp = mixed[i][combo]["c_index"]
        temps.append(temp)
    avg = np.mean(temps)
    df[combo]["Full Expr/Hist"] = avg
df = pd.DataFrame.from_dict(df, orient="index")
sorted_keys = sorted(sorted(combos), key=lambda x: len(x))
df = df.loc[sorted_keys]

In [None]:
orig = pd.read_csv("../results/results_summarized.csv", index_col=0)
orig = orig.rename(index={"canc*": "canc", "demo*": "demo", "canc-demo*": "canc-demo"})
df.loc[df.index, "PCA=256"] = orig.loc[df.index, "256"]
df.loc[["canc", "demo", "canc-demo"], "PCA=256"] = orig.loc[["canc", "demo", "canc-demo"], "4"]

In [None]:
print(df.loc[["canc", "demo", "expr", "hist", "text", "canc-demo-expr-hist-text"], ["PCA=256", "Full Expr/Hist"]].to_markdown(floatfmt="0.3f"))

In [None]:
# expr vs hist overfitting
train_vs_test = defaultdict(dict)
for mode_X, mode_name in [(expr_X, "expr"), (hist_X, "hist")]:
    train_c_idxs = []
    test_c_idxs = []
    for i, test_idxs in enumerate(test_splits):
        temp = set(test_idxs)
        train_idxs = [i for i in range(n) if i not in temp]

        X_train, X_test = mode_X[train_idxs], mode_X[test_idxs]
        y_train, y_test = y[train_idxs], y[test_idxs]

        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)

        model = mixed[i][mode_name]["model"]
        y_train_pred = model.predict(X_train_scaled)
        y_test_pred = model.predict(X_test_scaled)

        train_c_index = concordance_index_censored(
            event_indicator=y_train["Status"],
            event_time=y_train["Survival_in_days"],
            estimate=y_train_pred,
        )[0]
        train_c_idxs.append(train_c_index)

        test_c_index = concordance_index_censored(
            event_indicator=y_test["Status"],
            event_time=y_test["Survival_in_days"],
            estimate=y_test_pred,
        )[0]
        test_c_idxs.append(test_c_index)

    train_c_idxs = np.asarray(train_c_idxs)
    test_c_idxs = np.asarray(test_c_idxs)

    train_vs_test[mode_name]["train"] = train_c_idxs.mean()
    train_vs_test[mode_name]["test"] = test_c_idxs.mean()

train_vs_test = pd.DataFrame(train_vs_test).T
print(train_vs_test.to_markdown(floatfmt="0.3f"))

### Run the usual experiments with new metrics

In [None]:
results = dict()
for pca_components in tqdm([4, 8, 16, 32, 64, 128, 256]):
    results[pca_components] = run_experiment(pca_components=pca_components)
np.save("../results/rebuttal_predictions_summarized.npy", results)

In [None]:
non_reduced = {
    "demo",
    "canc",
    "canc-demo",
}
combos = ["-".join(sorted(x)) for x in powerset(["demo", "canc", "expr", "hist", "text"]) if len(x) > 0]
metric_dfs = dict()
for metric in ["c_index", "ibs", "cd_auc"]:
    df = defaultdict(dict)
    for pca_components in tqdm([4, 8, 16, 32, 64, 128, 256]):
        for combo in combos:
            temps = []
            for i in range(5):
                temp = results[pca_components][i][combo][metric]
                temps.append(temp)
            avg = np.mean(temps)
            if combo in non_reduced:
                if pca_components != 4:
                    continue
            df[combo][pca_components] = avg
    df = pd.DataFrame.from_dict(df, orient="index")
    sorted_keys = sorted(sorted(combos), key=lambda x: len(x))
    df = df.loc[sorted_keys]
    df.columns.name = "pca components"
    df.to_csv(f"../results/rebuttal_{metric}_results_summarized.csv")
    metric_dfs[metric] = df

In [None]:
print(metric_dfs["c_index"].loc[["canc", "demo", "expr", "hist", "text", "canc-demo-expr-hist-text"]].to_markdown(floatfmt="0.3f"))

### Extra Metrics (IBS and C/D-AUC)

In [None]:
print(metric_dfs["ibs"].loc[["canc", "demo", "expr", "hist", "text", "canc-demo-expr-hist-text"]].to_markdown(floatfmt="0.3f"))

In [None]:
print(metric_dfs["cd_auc"].loc[["canc", "demo", "expr", "hist", "text", "canc-demo-expr-hist-text"]].to_markdown(floatfmt="0.3f"))

### Multimodal Model Hazard Ratios for Each Modality

In [None]:
coefs = []
for i in range(5):
    model = results[256][i]["canc-demo-expr-hist-text"]["model"]
    coefs.append(model.coef_)
coefs = np.asarray(coefs)
coefs = pd.DataFrame(data=coefs, index=pd.Index(np.arange(5), name="Split"), columns=["Canc", "Demo", "Expr", "Hist", "Text"]).T

In [None]:
coefs["Mean"] = coefs.mean(axis=1)

In [None]:
HRs = np.exp(coefs)

In [None]:
print(HRs.to_markdown(floatfmt="0.3f"))