# Survival Experiments

This notebook contains all our code for survival modeling. It does not include the result visualizations.

The experiments test multimodal fusion of survival models and varying dimensionality reductions for high-dimensional embeddings.

Specifically, we experiment with 5 modalities:
* Patient demographics (sex, age - binned, race, ethnicity)
* Cancer type (we use the TCGA project ID as a proxy for cancer type)
* RNA-seq gene expression (`BulkRNABert` embeddings)
* Whole slide histology images (`UNI2` embeddings)
* Pathology reports (`BioMistral` embeddings)

We additionally experiment with using summarizations of the pathology reports. These are summarized by `Llama-3.1-8B-Instruct` and embedding with `BioMistral`. We provide a convenience toggle at the top of this notebook to run experiments with the summarized reports: `use_summarized`.

Run experiments by executing all cells of this notebook. Results are saved in the `results` subdirectory at the root of the repo. Analysis and visualization is done using tools also in the `results` folder.

In [None]:
use_summarized = False

In [None]:
from itertools import chain, combinations
from collections import defaultdict
import h5py
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored

In [None]:
df = pd.read_csv("../data/clinical.csv")
clin_case_ids = set(df["case_id"])

with h5py.File("../embed/expr.h5", "r") as expr_h5:
    expr_case_ids = set(expr_h5.keys())

with h5py.File("../embed/hist.h5", "r") as hist_h5:
    hist_case_ids = set(hist_h5.keys())

if not use_summarized:
    text_file = "../embed/text.h5"
else:
    text_file = "../embed/summ.h5"
with h5py.File(text_file, "r") as text_h5:
    text_case_ids = set(text_h5.keys())

In [None]:
case_ids = sorted(list(clin_case_ids & expr_case_ids & hist_case_ids & text_case_ids))

df = df[df["case_id"].isin(case_ids)]
df = df.sort_values("case_id").reset_index(drop=True)
assert df["case_id"].is_unique

In [None]:
df["age_binned"] = pd.cut(
    df["age"],
    bins=[0, 20, 40, 60, 80, 100],
    labels=["(0, 20]", "(20, 40]", "(40, 60]", "(60, 80]", "(80, 100]"],
)

In [None]:
dead = df["vital_status"] == "Dead"
days_to_event = np.where(dead, df["days_to_death"], df["days_to_last_follow_up"])
assert not np.isnan(days_to_event).any()

In [None]:
y = np.array(list(zip(dead, days_to_event)), dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

In [None]:
demo_ohe = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32)
canc_ohe = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32)

In [None]:
demo_X = demo_ohe.fit_transform(df[["sex", "age_binned", "race", "ethnicity"]])
canc_X = canc_ohe.fit_transform(df[["project"]])

In [None]:
demo_ohe.categories_

In [None]:
canc_ohe.categories_

In [None]:
def extract_case_emb_from_h5(case_ids: list[str], h5: h5py.File):
    X = []
    for case_id in tqdm(case_ids):
        case_group = h5[case_id]
        embs = np.stack([v[:] for v in case_group.values()], axis=0)
        emb = np.mean(embs, axis=0)
        X.append(emb)
    return np.stack(X, axis=0)

In [None]:
with h5py.File("../embed/expr.h5", "r") as expr_h5:
    expr_X = extract_case_emb_from_h5(case_ids, expr_h5)

with h5py.File("../embed/hist.h5", "r") as hist_h5:
    hist_X = extract_case_emb_from_h5(case_ids, hist_h5)

with h5py.File("../embed/text.h5", "r") as text_h5:
    text_X = extract_case_emb_from_h5(case_ids, text_h5)

In [None]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
splitter = (
    df["vital_status"]
    + "_"
    + df["project"]
    + "_"
    + df["sex"]
    + "_"
    + df["age_binned"].astype(str)
    + "_"
    + df["vital_status"]
    + "_"
    + df["race"]
    + "_"
    + df["ethnicity"]
)

n = len(df)
test_splits = [split_idxs for _, split_idxs in skf.split(X=np.zeros(n), y=splitter)]

In [None]:
def run_split(
    *,  # enforce kwargs
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    pca_components: int | None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    if verbose:
        print(f"Running {name}")

    # z-score input features
    if standardize:
        if verbose:
            print("--standardized")
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
    else:
        X_train_scaled = X_train
        X_test_scaled = X_test

    # dimensionality reduction
    if pca_components is not None:
        if verbose:
            print("--reduced")
        pca = PCA(n_components=pca_components, random_state=42)
        X_train_red = pca.fit_transform(X_train_scaled)
        X_test_red = pca.transform(X_test_scaled)
    else:
        X_train_red = X_train_scaled
        X_test_red = X_test_scaled

    # fit survival model
    cox = CoxPHSurvivalAnalysis(alpha=0.1).fit(X_train_red, y_train)

    # generate predictions
    y_train_pred = cox.predict(X_train_red)
    y_test_pred = cox.predict(X_test_red)

    # evaluate predictions
    c_index = concordance_index_censored(
        event_indicator=y_test["Status"],
        event_time=y_test["Survival_in_days"],
        estimate=y_test_pred,
    )[0]

    return {
        "c_index": c_index,
        "y_test_pred": y_test_pred,
        "y_train_pred": y_train_pred,
    }

def run_unimodal_split(
    *,  # enforce kwargs
    X: np.ndarray,
    y: np.ndarray,
    test_idxs: np.ndarray,
    train_idxs: np.ndarray,
    pca_components: int | None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    # split matrices
    X_train, X_test = X[train_idxs], X[test_idxs]
    y_train, y_test = y[train_idxs], y[test_idxs]

    return run_split(
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        pca_components=pca_components,
        standardize=standardize,
        name=name,
        verbose=verbose,
    )

def powerset(s):
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [None]:
def run_experiment(pca_components: int) -> dict:
    results = []
    for test_idxs in tqdm(test_splits, desc="Cross Validation Splits"):
        split_results = dict()

        temp = set(test_idxs)
        train_idxs = [i for i in range(n) if i not in temp]

        split_results["demo"] = run_unimodal_split(X=demo_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=None, standardize=False)
        split_results["canc"] = run_unimodal_split(X=canc_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=None, standardize=False)
        split_results["expr"] = run_unimodal_split(X=expr_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_components, standardize=True)
        split_results["hist"] = run_unimodal_split(X=hist_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_components, standardize=True)
        split_results["text"] = run_unimodal_split(X=text_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_components, standardize=True)

        y_train, y_test = y[train_idxs], y[test_idxs]

        combos = [sorted(x) for x in powerset(["demo", "canc", "expr", "hist", "text"]) if len(x) > 1]
        for combo in combos:
            mult_X_train = []
            mult_X_test = []
            for modality in combo:
                x_train = split_results[modality]["y_train_pred"][:, np.newaxis]
                x_test = split_results[modality]["y_test_pred"][:, np.newaxis]
                if modality not in ["demo", "canc"]:
                    scaler = StandardScaler()
                    x_train = scaler.fit_transform(x_train)
                    x_test = scaler.transform(x_test)
                mult_X_train.append(x_train)
                mult_X_test.append(x_test)

            mult_X_train = np.concat(mult_X_train, axis=1)
            mult_X_test = np.concat(mult_X_test, axis=1)

            split_results["-".join(combo)] = run_split(X_train=mult_X_train, y_train=y_train, X_test=mult_X_test, y_test=y_test, pca_components=None, standardize=False)

        results.append(split_results)
    return results

In [None]:
results = dict()
for pca_components in tqdm([4, 8, 16, 32, 64, 128, 256]):
    results[pca_components] = run_experiment(pca_components=pca_components)

In [None]:
if not use_summarized:
    np.save("../results/predictions.npy", results)
else:
    np.save("../results/predictions_summarized.npy", results)

In [None]:
combos = ["-".join(sorted(x)) for x in powerset(["demo", "canc", "expr", "hist", "text"]) if len(x) > 0]
df = defaultdict(dict)
for pca_components in tqdm([4, 8, 16, 32, 64, 128, 256]):
    for combo in combos:
        c_idxs = []
        for i in range(5):
            c_idx = results[pca_components][i][combo]["c_index"]
            c_idxs.append(c_idx)
        c_idx = np.mean(c_idxs)
        if combo in ["demo", "canc"]:
            combo += "*"
            if pca_components != 4:
                continue
        df[combo][pca_components] = c_idx
df = pd.DataFrame.from_dict(df, orient="index")
sorted_keys = sorted(sorted(combos), key=lambda x: len(x))
sorted_keys = ["canc*", "demo*"] + sorted_keys[2:]
df = df.loc[sorted_keys]
df.columns.name = "pca components"
if not use_summarized:
    df.to_csv("../results/results.csv")
else:
    df.index = df.index.str.replace("text", "summ")
    df.to_csv("../results/results_summarized.csv")
df