## Main Cohort Metrics and Plots

In [None]:
# !pip install seaborn==0.12.2 lifelines scikit-learn==1.1.3 python-dotenv torchtuples pymongo==3.12.0 scikit-survival pycox seaborn

In [None]:
import os
import random
import sys

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from dotenv import load_dotenv
from joblib import load
from sklearn import metrics
import scipy.stats as st
from lifelines.utils import concordance_index

module_path = os.path.abspath(os.path.join("../scripts"))
if module_path not in sys.path:
    sys.path.append(module_path)


load_dotenv()

torch.manual_seed(int(os.getenv("RANDOM_SEED")))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed(int(os.getenv("RANDOM_SEED")))
random.seed(int(os.getenv("RANDOM_SEED")))
np.random.seed(int(os.getenv("RANDOM_SEED")))

from run_models import CauseSpecificNet, DeepHit, get_preprocessed_datasets
from utils import VTEDataLoader, get_logger, get_parent_dir, plot_roc, plot_calibration, plot_grouped_risks, bootstrap_ci
from vte_deephit import get_datasets, get_best_params


In [None]:
from matplotlib import cycler

color_list = [
    "#E64B35FF",
    "#4DBBD5FF",
    "#00A087FF",
    "#3C5488FF",
    "#F39B7FFF",
    "#8491B4FF",
    "#91D1C2FF",
    "#DC0000FF",
    "#7E6148FF",
    "#B09C85FF",
]
matplotlib.rcParams["font.family"] = "Arial"
matplotlib.rcParams["axes.prop_cycle"] = cycler(color=color_list)
plt.rcParams["font.size"] = 18
plt.rcParams["axes.linewidth"] = 2

In [None]:
logger = get_logger("insights-notebook")

In [None]:
# Get Datasets
dl = VTEDataLoader()
data = dl.raw_data
datasets = get_datasets();

In [None]:
data.OBS_TIME.median()

In [None]:
data.shape

In [None]:
data[(data.OBS_TIME<=180) & ((data.EVENT==1) | (data.EVENT==3))].shape

In [None]:
print((data.EVENT_6 == 1).sum())
(data.EVENT_6 == 1).sum()/len(data)

In [None]:
data.SEX.value_counts()

In [None]:
# create all datasets

x_train = datasets.get("x_train")
x_test = datasets.get("x_test")
y_train = datasets.get("y_train")
y_train_6 = datasets.get("y_train_6")
y_test = datasets.get("y_test")
x_train_ks = datasets.get("x_train_ks")
x_test_ks = datasets.get("x_test_ks")
y_test_ks = datasets.get("y_test_ks")
labtrans = datasets.get("labtrans")
labtrans_6 = datasets.get("labtrans_6")

In [None]:
x_train.shape

In [None]:
x_test.shape

In [None]:
x_test_ks.KS.value_counts()

In [None]:
# Missing Data
covars = ["AGE", "SEX", "CANCER_TYPE_FINAL", "SAMPLE_TYPE", "DX_delta", "PROC_delta", "SODIUM", "POTASSIUM", "CHLORIDE", "CALCIUM", "CO2", "GLUCOSE", "UREA", "CREATININE", "TPROTEIN",
          "AST", "ALT", "TBILI", "ALKPHOS", "ALBUMIN", "HB", ]
100*data[covars].isna().sum()[data[covars].isna().sum()>0]/len(data)

In [None]:
x_train.SAMPLE_TYPE.value_counts()

In [None]:
x_test.SAMPLE_TYPE.value_counts()

In [None]:
data.SAMPLE_TYPE.value_counts()

In [None]:
x_test_ks.KS.value_counts()

In [None]:
ks_metrics = x_test_ks[x_test_ks.KS.notna()].copy()
print(ks_metrics.shape)
ks_metrics["OBS_TIME"] = y_test_ks[0]
ks_metrics["EVENT"] = y_test_ks[1]

In [None]:
lower_cidx, upper_cidx, mean_cidx, cidxs = bootstrap_ci(ks_metrics, concordance_index, "EVENT", "KS", "OBS_TIME")

In [None]:
print(f"{mean_cidx} ({lower_cidx}, {upper_cidx})")

In [None]:
len(cidxs)

sns.displot(cidxs)

In [None]:
sns.histplot(dl.raw_data[(dl.raw_data.OBS_TIME>0)&(dl.raw_data.EVENT==0)]["OBS_TIME"])

In [None]:
dl.raw_data.EVENT_6.value_counts()/len(dl.raw_data.EVENT_6)

In [None]:
dx_to_cohort_entry = (data.REPORT_DTE - data.TM_DX_DTE).dt.days

In [None]:
dx_to_cohort_entry.describe()

In [None]:
bins = [
    0,
    91,
    181,
    366,
    731,
    3 * 365 + 1,
    4 * 365 + 1,
    5 * 365 + 1,
    6 * 365 + 1,
    7 * 365 + 1,
    8 * 365 + 1,
    9 * 365 + 1,
    10 * 365,
]

In [None]:
fig = plt.figure(figsize=(10, 10))
for i, item in enumerate(bins):
    if i == len(bins) - 1:
        height = len(dx_to_cohort_entry[(dx_to_cohort_entry > bins[i])])
        label = f">{bins[i] // 365} year"
    else:
        height = len(
            dx_to_cohort_entry[
                (dx_to_cohort_entry > bins[i]) & (dx_to_cohort_entry < bins[i + 1])
            ]
        )
        if bins[i] > 365:
            label = f"{bins[i]//365}-{bins[i+1]//365} year"
        else:
            label = f"{bins[i]}-{bins[i+1]-1}"
    plt.bar(label, height)
plt.xticks(rotation=90)
plt.xlabel("Days to Cohort Entry from Diagnosis")
plt.ylabel("Patient Count")
plt.tight_layout()
plt.savefig(
    get_parent_dir() / "visualizations/dx-to-cohort-entry.svg", dpi=300, format="svg"
)

### Persist test and train data with Audit SEQ

In [None]:
train_seq = pd.DataFrame(
    {
        "train": x_train.AUDIT_SEQ.values,
        "EVENT": y_train[1],
        "EVENT_6": y_train_6[1],
        "OBS_TIME": y_train[0],
        "OBS_TIME_6": y_train_6[0],
    }
)

In [None]:
test_seq = pd.DataFrame(
    {
        "test": x_test.AUDIT_SEQ.values,
        "EVENT_6": y_test[1],
        "OBS_TIME_6": y_test[0],
    }
)

In [None]:
# print(train_seq.head())
# print(test_seq.head())

# train_seq.to_csv(get_parent_dir() / "assets/data_asset/train_seq.csv", index=None)
# test_seq.to_csv(get_parent_dir() / "assets/data_asset/test_seq.csv", index=None)

## 1. List of genes in the models

In [None]:
list(x_train.filter(regex="_alt$").columns)[:5]

## 2. Kaplan Meier for the whole data

In [None]:
from lifelines import KaplanMeierFitter, AalenJohansenFitter
from lifelines.plotting import add_at_risk_counts

In [None]:
def plot_compare_cif(data, event_col, duration_col):
    # Create a DataFrame copy for competing risks analysis
    data_competing_risks = data.copy()

    # Fit Kaplan-Meier estimator
    kmf = KaplanMeierFitter()
    data_competing_risks[duration_col] = data_competing_risks[duration_col]
    kmf.fit_right_censoring(data_competing_risks[duration_col], 
                            data_competing_risks[event_col]==1, label='Kaplan Meier')

    # Fit Cumulative Incidence Function (CIF) with competing risks using Aalen-Johansen estimator
    ajf = AalenJohansenFitter()
    ajf.fit_right_censoring(data_competing_risks[duration_col].values, 
                            data_competing_risks[event_col].values, 1, label='Competing Risk')

    # Plot the results
    fig, ax = plt.subplots(figsize=(16, 12))

    kmf.plot_cumulative_density(ax=ax, ci_show=True)
    add_at_risk_counts(kmf, labels=['VTE'], ax=ax)
    ajf.plot(ax=ax, ci_show=True)

    
    ax.set_title('Risk Comparison (with and without Competing Risk)')
    ax.set_xlabel('Time (in Years)')
    ax.set_ylabel('Cumulative Incidence')
    ax.legend(loc="upper left")

    plt.savefig(get_parent_dir() / "visualizations/comapre_km_cr.svg", dpi=300, format="svg", bbox_inches='tight')
    return (plt, kmf, ajf) 


In [None]:
compare_cr_plot, kmfit, ajffit = plot_compare_cif(data, "EVENT", "OBS_TIME")

In [None]:
km_6 = kmfit.cumulative_density_at_times(180).values[0]

In [None]:
km_end = kmfit.cumulative_density_at_times(1064).values[0]

In [None]:
ajffit.cumulative_density_.loc[180.0].values[0]

In [None]:
ajf_6 = ajffit.cumulative_density_.loc[180.0].values[0]

In [None]:
ajf_end = ajffit.cumulative_density_.loc[1064.0].values[0]

In [None]:
estimates = pd.DataFrame({"6 Months": [km_6, ajf_6], "End of Study": [km_end, ajf_end]}, index=["KM", "CR"])

In [None]:
# get KM and AJ estimates
(estimates*100).round(1)

In [None]:
estimates.to_csv(get_parent_dir() / "results/estimates_main.csv")

In [None]:
feature = "no_genes"

In [None]:
n = 30
(
    feature_train,
    feature_test,
    feature_train_ks,
    feature_test_ks,
) = get_preprocessed_datasets(feature, x_train, x_test, x_train_ks, x_test_ks)

logger.info(f"Running for feature: {feature}")
params = load(get_parent_dir() / f"models/{feature}/params.pkl")
models = []
for i in range(n):
    net = CauseSpecificNet(**params)
    m = DeepHit(net)
    m.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
    models.append(m)

In [None]:
models[0].net

## 3. CIF for the whole cohort using DH model

In [None]:
full_feature = np.vstack([feature_train, feature_test])

In [None]:
full_feature.shape

In [None]:
cifs = []
for sm in models:
    cifs.append(sm.predict_cif(full_feature))

cif = np.mean(cifs, axis=0, dtype=np.float32) * 100

In [None]:
cif.shape

In [None]:
cif1 = pd.DataFrame(cif[0], models[0].duration_index)

In [None]:
vte_cif = cif[0]

In [None]:
vte_cif.shape

In [None]:
m_test = (100 - vte_cif[180, :]).mean()
std_test = (100 - vte_cif[180, :]).std()

In [None]:
print(m_test)
print(std_test)

In [None]:
m_test

In [None]:
std_test

## 4. RoC plot for 6 months on test set

In [None]:
cifs = []
for sm in models:
    cifs.append(sm.predict_cif(feature_test))

cif = np.mean(cifs, dtype=np.float32, axis=0)

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.hist(cif[0][180, :], density=True, bins=100)
plt.title("Distribution of CIF for VTE")
plt.savefig(get_parent_dir() / "visualizations/cif_density.svg", dpi=300, format="svg")

In [None]:
from sklearn.metrics import roc_auc_score

fpr, tpr, _ = metrics.roc_curve(y_test[1] == 1, cif[0][180, :])
auc = roc_auc_score(y_test[1] == 1, cif[0][180, :])

y_test_df = pd.DataFrame({"OBS_TIME": y_test[0], "EVENT": y_test[1], "cif": cif[0][180, :]})
low, high, mean_auc, idxs = bootstrap_ci(y_test_df, roc_auc_score, "EVENT", "cif")
# create ROC curve
fig = plt.figure(figsize=(10, 10))
plt.plot(
    fpr, tpr, linestyle="--", lw=2, label="ROC curve", clip_on=False,
)
plt.plot([0, 1], [0, 1], linestyle="--")
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
print(auc)
plt.title("%s, AUC = %.2f (%.2f, %.2f)" % ("RoC Curve DH Model at 180 days", mean_auc, low, high))
plt.savefig(get_parent_dir() / "visualizations/deephit_auc.svg", dpi=300, format="svg")

In [None]:
cifs_ks = []
for sm in models:
    cifs_ks.append(sm.predict_cif(feature_test_ks))

cif_ks = np.mean(cifs_ks, dtype=np.float32, axis=0)

# 5. DH results for patients with diagnosis time under a year

In [None]:
feature_test_dx_365 = feature_test[x_test["DX_delta"] <= 365]

In [None]:
from vte_deephit import c_stat

cifs_dx_365 = []
for sm in models:
    cifs_dx_365.append(sm.predict_cif(feature_test_dx_365))

cif_test_dx_365 = np.mean(cifs_dx_365, dtype=np.float32, axis=0)

y_test_0 = y_test[0][x_test["DX_delta"] <= 365]
y_test_1 = y_test[1][x_test["DX_delta"] <= 365]

c_stat(
    cif_test_dx_365, y_test_0, y_test_1, models[0].duration_index, suffix="dx_365",
)

## 6. KS RoC

In [None]:
ks_patients = data[data.KS.notna() & (data.OBS_TIME_6_ks > 0)]

In [None]:
plot_roc(
    ks_patients,
    "KS",
    "EVENT_6_ks",
    f"ROC curve (Khorana Score Patients\n(n={ks_patients.shape[0]})",
    "KS_AUC",
)

In [None]:
# test set event KS
test_audit_seq = pd.read_csv(get_parent_dir() / "assets/data_asset/test_seq.csv")
test_KS = data.merge(test_audit_seq["test"], left_on="AUDIT_SEQ", right_on="test")
test_KS = test_KS[test_KS.KS.notna() & test_KS.OBS_TIME_6_ks > 0]

In [None]:
train_audit_seq = pd.read_csv(get_parent_dir() / "assets/data_asset/train_seq.csv")
train_KS = data.merge(train_audit_seq["train"], left_on="AUDIT_SEQ", right_on="train")
train_KS = train_KS[train_KS.KS.notna() & train_KS.OBS_TIME_6_ks > 0]

In [None]:
plot_roc(
    test_KS,
    "KS",
    "EVENT_6_ks",
    f"ROC curve (Khorana Score Patients)\nTest (n={test_KS.shape[0]})",
    "KS_AUC_TEST",
)

In [None]:
plot_roc(
    train_KS,
    "KS",
    "EVENT_6_ks",
    f"ROC curve (Khorana Score Patients)\nTrain (n={train_KS.shape[0]})",
    "KS_AUC_TRAIN",
)

## 7. Calibration Plot

In [None]:
from joblib import load
from pycox.models import DeepHit
from run_models import get_preprocessed_datasets
from utils import get_logger
from vte_deephit import CauseSpecificNet

logger = get_logger("insights_notebook")

In [None]:
from sklearn.calibration import calibration_curve

feature = "no_genes"
bins = 5

n = 30
(feature_train, feature_test) = get_preprocessed_datasets(feature, x_train, x_test)

logger.info(f"Running for feature: {feature}")
params = load(get_parent_dir() / f"models/{feature}/params.pkl")
models = []
for i in range(n):
    net = CauseSpecificNet(**params)
    m = DeepHit(net)
    m.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
    models.append(m)

cifs = []
for sm in models:
    cifs.append(sm.predict_cif(feature_test))

cif = np.mean(cifs, dtype=np.float32, axis=0)
vte_cif = cif[0][180, :]

In [None]:
a, b = calibration_curve(y_test[1] == 1, vte_cif, pos_label=1, strategy="quantile")

In [None]:
import matplotlib.lines as mlines
import matplotlib.transforms as mtransforms

fig, ax = plt.subplots(figsize=(10, 10))
# only these two lines are calibration curves
plt.plot(a, b, marker="o", linewidth=1, label="DeepHit")

# reference line, legends, and axis labels
line = mlines.Line2D([0, 1], [0, 1], color="black")
transform = ax.transAxes
line.set_transform(transform)
ax.add_line(line)
plt.title("Calibration plot for VTE data")
ax.set_xlabel("Predicted probability")
ax.set_ylabel("True probability in each bin")
plt.legend()
# plt.show()
plt.savefig(
    get_parent_dir() / "visualizations/calibration_vte_probabilities.svg",
    dpi=300,
    format="svg",
)

In [None]:
plot_grouped_risks(cif,
                   y_test[0],
                   y_test[1],
                   name="Main Cohort Validation Set",
                   event_of_interest=1,
                   save=True)

In [None]:
plot_calibration(vte_cif,
                 events=y_test[1], durations=y_test[0],
                 # bins=[0, 2, 4, 6, 8, 100],
                 save=True,
                 name="Validation Set",
                 feature="LIMITED")

In [None]:
from sklearn.calibration import calibration_curve

feature = "ext"
n = 30
(feature_train, feature_test) = get_preprocessed_datasets(feature, x_train, x_test);

logger.info(f"Running for feature: {feature}")
params = load(get_parent_dir() / f"models/{feature}/params.pkl")
models = []
for i in range(n):
    net = CauseSpecificNet(**params)
    m = DeepHit(net)
    m.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
    models.append(m)

cifs = []
for sm in models:
    cifs.append(sm.predict_cif(feature_test))

cif = np.mean(cifs, dtype=np.float32, axis=0)
vte_cif_ext = cif[0][180, :];

In [None]:
plot_calibration(vte_cif_ext,
                 events=y_test[1],
                 durations=y_test[0], 
                 feature="EXTENSIVE", 
                 # bins=[0, 2, 4, 6, 8, 100],
                 save=True,
                 name="Validation Set")