## Validation: External Cohort 1 (MSK)

This notebook has the validation set for external cohort from Jerry

In [None]:
# !pip install seaborn==0.12.0 lifelines scikit-learn==1.1.3 scikit-survival pymongo==3.12.0 python-dotenv pycox numpy==1.20 

In [None]:
# https://stats.stackexchange.com/questions/518773/statistical-test-for-comparing-performance-metrics-of-two-regression-models-on-a

In [None]:
import os

import numpy as np
import pandas as pd
import random
import torch
from joblib import load, dump
from lifelines import KaplanMeierFitter
from sklearn.model_selection import train_test_split

from dotenv import load_dotenv

load_dotenv()

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cycler

color_list = [
    "#E64B35FF",
    "#4DBBD5FF",
    "#00A087FF",
    "#3C5488FF",
    "#F39B7FFF",
    "#8491B4FF",
    "#91D1C2FF",
    "#DC0000FF",
    "#7E6148FF",
    "#B09C85FF",
]
matplotlib.rcParams["font.family"] = "Arial"
matplotlib.rcParams["axes.prop_cycle"] = cycler(color=color_list)
plt.rcParams["font.size"] = 18
plt.rcParams["axes.linewidth"] = 2

In [None]:
seed = int(os.getenv("RANDOM_SEED"))
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

In [None]:
import os
import sys
from pathlib import Path
import pandas as pd
from torch import nn, optim
import torchtuples as tt
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
from sklearn.utils import resample
from sklearn.metrics import roc_auc_score

pd.set_option("max_colwidth", None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

module_path = str(Path("../scripts").resolve())
if module_path not in sys.path:
    sys.path.append(module_path)

from vte_deephit import get_target, get_best_params, LabTransform, c_stat
from utils import get_parent_dir
from utils import calc_ci, plot_roc , plot_calibration, bootstrap_ci, plot_grouped_risks

In [None]:
get_best_params("ext_cohort_1")

In [None]:
# since the external cohort's variable set do not match
# any varible set in our list of variable set - we trained an the feature set
# availble in the external cohort
feature = "ext_cohort_1"

test_cohort_data = pd.read_csv(
    get_parent_dir() / os.getenv("DATA_DIR") / os.getenv("EXT_COHORT_1")
)

In [None]:
test_cohort_data.KS.value_counts()

In [None]:
# get observation time 
new_obs_time = pd.read_csv(get_parent_dir() / "assets/data_asset/jerry_cohort_modified_2023_03_20.csv", usecols=["AUDIT_SEQ", "OBS_TIME_MOD"])

In [None]:
test_cohort_data = test_cohort_data.merge(new_obs_time, how="left")

In [None]:
test_cohort_data["OBS_TIME"] = np.where((test_cohort_data.EVENT==0) & (test_cohort_data.OBS_TIME_MOD.notna()), test_cohort_data.OBS_TIME_MOD, test_cohort_data.OBS_TIME)

In [None]:
test_cohort_data.shape

In [None]:
test_cohort_data["OBS_TIME"].isna().sum()

In [None]:
## CENSORING DEATHS

test_cohort_data["OBS_TIME_MOD"] = np.where(test_cohort_data["OBS_TIME_MOD"].isna(), test_cohort_data.OBS_TIME, test_cohort_data["OBS_TIME_MOD"])

test_cohort_data.shape

## plot KM 

In [None]:
kmf_1 = KaplanMeierFitter()
kmf_1.fit_right_censoring(durations=test_cohort_data["OBS_TIME"], event_observed=test_cohort_data["EVENT"]==1)
kmf_1.plot_survival_function()

In [None]:
kmf_1 = KaplanMeierFitter()
kmf_1.fit_right_censoring(durations=test_cohort_data["OBS_TIME_MOD"], event_observed=test_cohort_data["EVENT"]==1)
kmf_1.plot_survival_function()

In [None]:
plot_roc(test_cohort_data, "KS", "EVENT", "External Cohort A - KS AUC", "external_cohort_A_KS_AUC")

In [None]:
from lifelines.utils import concordance_index

In [None]:
lower_cidx, upper_cidx, mean_cidx, cidxs = bootstrap_ci(test_cohort_data, concordance_index, "EVENT", "KS", "OBS_TIME")

In [None]:
print(f"{mean_cidx} ({lower_cidx}, {upper_cidx})")

In [None]:
hyper_params = get_best_params(feature)

In [None]:
chemo_cols = list(test_cohort_data.filter(regex="CHEMO_").columns)
# chemo_cols = [a for a in chemo_cols if "_DATE" not in a]
test_cohort_data.loc[:, chemo_cols] = test_cohort_data.loc[:, chemo_cols].clip(upper=28)
assert test_cohort_data.CHEMO_alkylating.max() == 28
assert test_cohort_data.CHEMO_antibiotic.max() == 28
assert test_cohort_data.CHEMO_cdki.max() == 28
assert test_cohort_data.CHEMO_vegfi.max() == 28
assert test_cohort_data.CHEMO_serm.max() == 28
assert test_cohort_data.CHEMO_alkylating.min() == 0
assert test_cohort_data.CHEMO_antibiotic.min() == 0
assert test_cohort_data.CHEMO_cdki.min() == 8
assert test_cohort_data.CHEMO_vegfi.min() == 0
assert test_cohort_data.CHEMO_serm.min() == 0

In [None]:
chemo_cols

In [None]:
test_cohort_data[test_cohort_data.CANCER_TYPE_FINAL=="low_grade_glioma"].values

In [None]:
# since original data has only one row for `low_grade_glioma` - duplicate that row
test_cohort_data.loc[len(test_cohort_data.index)] = test_cohort_data[test_cohort_data.CANCER_TYPE_FINAL=="low_grade_glioma"].values[0]

In [None]:
test_cohort_data.CANCER_TYPE_FINAL.value_counts()

In [None]:
test_cohort_data["EVENT"].value_counts()

In [None]:
352/(5402+496+352)

In [None]:
496/(5402+496+352)

In [None]:
sns.histplot(test_cohort_data.OBS_TIME)

In [None]:
x_train, x_test = train_test_split(
    test_cohort_data, test_size=0.2, stratify=test_cohort_data["EVENT"], random_state=int(os.getenv("RANDOM_SEED"))+1
)

In [None]:
100*x_test.KS.value_counts()/len(x_test)

In [None]:
# c-index for whole cohort
lower_cidx_test, upper_cidx_test, mean_cidx_test, stat_ks_test_cidx_test = bootstrap_ci(test_cohort_data, concordance_index, "EVENT", "KS", "OBS_TIME")

print(f"{mean_cidx_test} ({lower_cidx_test}, {upper_cidx_test})")

In [None]:
# c-index for test set only
lower_cidx_test, upper_cidx_test, mean_cidx_test, stat_ks_test_cidx_test = bootstrap_ci(x_test, concordance_index, "EVENT", "KS", "OBS_TIME")

print(f"{mean_cidx_test} ({lower_cidx_test}, {upper_cidx_test})")

In [None]:
x_train.head()

In [None]:
print(x_train.shape)
x_test.shape

In [None]:
num_durations = int(max(x_train["OBS_TIME"])) + 1  # for cut-points
labtrans = LabTransform(num_durations)
labtrans_6 = LabTransform(181)

In [None]:
labtrans.cuts

In [None]:
y_train = pd.DataFrame({"event": x_train.EVENT, "times": x_train.OBS_TIME})
event_type = int
y_train = np.array(
    [tuple(a) for a in y_train.values],
    dtype=list(zip(y_train.dtypes.index, [event_type, int])),
)

y_test = pd.DataFrame({"event": x_test.EVENT, "times": x_test.OBS_TIME})
event_type = int
y_test = np.array(
    [tuple(a) for a in y_test.values],
    dtype=list(zip(y_test.dtypes.index, [event_type, int])),
)

y_train = labtrans.fit_transform(*get_target(y_train))
y_test = labtrans.transform(*get_target(y_test))

In [None]:
y_test[1].max()

In [None]:
x_train.drop(["OBS_TIME_MOD","OBS_TIME", "EVENT"], axis=1, inplace=True)
x_test.drop(["OBS_TIME_MOD", "OBS_TIME", "EVENT"], axis=1, inplace=True)

In [None]:
transformation_pipeline = load(
    get_parent_dir() / f"models/{feature}/preprocessing_fit.joblib"
)
# refit original model on new data
features_train = transformation_pipeline.transform(x_train).astype("float32");
features_test = transformation_pipeline.transform(x_test).astype("float32");

In [None]:
transformation_pipeline.get_feature_names_out()

In [None]:
print(features_test.shape)

In [None]:
# full_data = np.vstack([features_train, features_val, features_test])
full_data = np.vstack([features_train, features_test])

In [None]:
full_data.shape

In [None]:
# full_target = (np.concatenate((y_train[0], y_val[0], y_test[0])), np.concatenate((y_train[1], y_val[1], y_test[1])))
full_target = (np.concatenate((y_train[0], y_test[0])), np.concatenate((y_train[1], y_test[1])))

In [None]:
hyper_params

In [None]:
full_target[0].max()

In [None]:
import torch
from pycox.models import DeepHit
from vte_deephit import CauseSpecificNet

if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

print(device)
n = 30

params = load(get_parent_dir() / f"models/{feature}/params.pkl")
models = []
for i in range(n):
    net = CauseSpecificNet(**params)
    optimizer = tt.optim.AdamWR(
        lr=.1*hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                            cycle_eta_multiplier=hyper_params["eta_par"]
    )
    m = DeepHit(net, 
                optimizer=optimizer, 
                alpha=hyper_params["alpha_par"], 
                sigma=hyper_params["sigma_par"], 
                device=device,
                duration_index=labtrans.cuts)
    m.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
    models.append(m)

cifs_full = []

for sm in models:
    cifs_full.append(sm.predict_cif(full_data))

cif_full = np.mean(cifs_full, dtype=np.float64, axis=0)


In [None]:
models[0].net

In [None]:
c_stat_full = c_stat(
    cif_full[:, :181, :], full_target[0], full_target[1], labtrans.cuts, suffix="ext_cohort_1_ks",
)
c_stat_full

In [None]:
cifs_test = []

for sm in models:
    cifs_test.append(sm.predict_cif(features_test))

cif_test = np.mean(cifs_test, dtype=np.float64, axis=0)
c_stat_test = c_stat(
    cif_test[:, :181, :],
    y_test[0],
    y_test[1],
    labtrans.cuts,
    suffix="ext_cohort_1_test",
)
c_stat_test

## Transfer learning - only last layer

In [None]:
print(models[0].net)

In [None]:
features_train.shape

In [None]:
feature

In [None]:
sum(p.numel() for p in models[0].net.parameters() if p.requires_grad)

In [None]:
labtrans.cuts

In [None]:
tl_models = []
n = 30
if (get_parent_dir() / f"models_finetuned/{feature}").is_dir():
    params = load(get_parent_dir() / f"models_finetuned/{feature}/params.pkl")
    for i in range(n):
        net = CauseSpecificNet(**params)
        optimizer = tt.optim.AdamWR(
            lr=.1*hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                                cycle_eta_multiplier=hyper_params["eta_par"]
        )
        tm = DeepHit(net,
                     optimizer=optimizer,
                     alpha=hyper_params["alpha_par"],
                     sigma=hyper_params["sigma_par"],
                     device=device,
                     duration_index=labtrans.cuts)

        tm.load_model_weights(get_parent_dir() / f"models_finetuned/{feature}/model_{i}.pt")
        tl_models.append(tm)
else:
    print("No finetuned model found")
    params = load(get_parent_dir() / f"models/{feature}/params.pkl")

    for i in range(n):
        net = CauseSpecificNet(**params)
        optimizer = tt.optim.AdamWR(
            lr=.1*hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                                cycle_eta_multiplier=hyper_params["eta_par"]
        )
        tm = DeepHit(net,
                     optimizer=optimizer,
                     alpha=hyper_params["alpha_par"],
                     sigma=hyper_params["sigma_par"],
                     device=device,
                     duration_index=labtrans.cuts)

        tm.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
        tl_models.append(tm)

    layers_to_tune = ["risk_nets.0.net.2.weight", 
                      "risk_nets.0.net.2.bias",
                      "risk_nets.1.net.2.weight",
                      "risk_nets.1.net.2.bias",
                     ]

    for m in tl_models:
        total_trainable_params = 0
        total_non_trainable_params = 0
        for name, prms in m.net.named_parameters():
            # Freeze the parameters except the last layer
            if name in layers_to_tune:
                prms.requires_grad = True
            else:
                prms.requires_grad = False

        for param in m.net.parameters():
            if not param.requires_grad:
                total_non_trainable_params += param.numel()
            else:
                total_trainable_params += param.numel()
        print(f"Total Trainable params = {total_trainable_params}\nTotal non-trainable params = {total_non_trainable_params}")
        log = m.fit(features_train, (y_train[0], y_train[1]), features_train.shape[1], 30, verbose=True)
    
    print(f"Saving models for {feature} to models_finetuned/{feature}")
    Path((get_parent_dir() / f"models_finetuned/{feature}")).mkdir(parents=True, exist_ok=True)
    params = {
        "in_features": features_train.shape[1],
        "num_nodes_shared": [int(hyper_params.get("w_shared")), int(hyper_params.get("d_shared"))],
        "num_nodes_indiv": [int(hyper_params.get("w_indiv")), int(hyper_params.get("d_indiv"))],
        "num_risks": int(y_train[1].max()),
        "out_features": 1065,
        "batch_norm": True,
        "dropout": hyper_params.get("dropout"),
    }
    dump(params, get_parent_dir() / f"models_finetuned/{feature}/params.pkl")
    for i, m in enumerate(tl_models):
        m.save_model_weights(get_parent_dir() / f"models_finetuned/{feature}/model_{i}.pt")
    print(f"Saved fine_tuned models for feature {feature}")

In [None]:
(tl_models[0].net.risk_nets[0].net[2].weight - models[0].net.risk_nets[0].net[2].weight).sum()

In [None]:
(tl_models[0].net.risk_nets[0].net[2].bias - models[0].net.risk_nets[0].net[2].bias).sum()

In [None]:
cifs_tl_full = []
for sm in tl_models:
        cifs_tl_full.append(sm.predict_cif(features_test))
        
cif_tl_full = np.mean(cifs_tl_full, dtype=np.float64, axis=0)

In [None]:
c_stat(
    cif_tl_full[:, :181, :],
    y_test[0],
    y_test[1],
    labtrans.cuts,
    suffix="ext_cohort_1_ks",
)


In [None]:
cifs_non_tl_full = []
for sm in models:
        cifs_non_tl_full.append(sm.predict_cif(features_test))

cif_non_tl_full = np.mean(cifs_non_tl_full, dtype=np.float64, axis=0)

In [None]:
cif_full_jerry_non_tl = pd.DataFrame({
    "cif": cif_non_tl_full[0][180, :],
    "event": y_test[1]==1
}
)

In [None]:
y_test[1]

In [None]:
cif_full_jerry = pd.DataFrame({
    "cif": cif_tl_full[0][180, :],
    "event": y_test[1]==1
}
)

In [None]:
plot_roc(cif_full_jerry_non_tl, "cif", "event", "External Cohort A - DeepHit Limited Model",save=True, fname="ext_cohort_a_roc_before_tl")

In [None]:
plot_roc(cif_full_jerry,
         "cif", "event", "External Cohort A - DeepHit Limited Model",
         save=True,
         fname="ext_cohort_a_roc_after_tl")

In [None]:
samples = 200
scores = []
for j in range(samples):
    sub_test, sub_test_y_0, sub_test_y_1 = resample(
            features_test, y_test[0], y_test[1], stratify=y_test[1],random_state=j
        )
    cifs_tl = []
    for sm in tl_models:
        cifs_tl.append(sm.predict_cif(sub_test))

    cif_tl = np.mean(cifs_tl, dtype=np.float64, axis=0)
    c_stat_tl = c_stat(
            cif_tl[:, :181, :],
            sub_test_y_0,
            sub_test_y_1,
            tl_models[0].duration_index,
            suffix="test_jerry_tl_ks"
        )

    cifs_non_tl = []
    for sm in models:
        cifs_non_tl.append(sm.predict_cif(sub_test))

    cif_non_tl = np.mean(cifs_non_tl, dtype=np.float64, axis=0)
    c_stat_tl.update(c_stat(
            cif_non_tl[:, :181, :],
            sub_test_y_0,
            sub_test_y_1,
            models[0].duration_index,
            suffix="test_jerry_wo_tl_ks"
        ))

    scores.append(c_stat_tl)

assert len(scores) == samples

In [None]:
res = pd.concat([pd.DataFrame(df) for df in scores])
res["feature"] = feature

In [None]:
res.groupby("feature").agg(["mean", calc_ci])

In [None]:
# KS On test set: 0.65 (.58,  0.71)

In [None]:
res.groupby("feature").agg(["mean", calc_ci]).to_csv("jerry_cohort_tl.csv", index=None)

In [None]:
res.td_c_idx_vte_test_jerry_tl_ks.mean()

In [None]:
res.td_c_idx_vte_test_jerry_wo_tl_ks.mean()

In [None]:
sns.histplot(cif_non_tl[0, 180, :])

In [None]:
sns.histplot(cif_tl[0, 180, :])

In [None]:
y_test[1][:5]

In [None]:
obs_time = estimate = np.where(y_test[1]==2, 180, y_test[0])

In [None]:
full_cifs_tl = []

for sm in tl_models:
    full_cifs_tl.append(sm.predict_cif(features_test))
    
full_cif_tl = np.mean(full_cifs_tl, dtype=np.float64, axis=0)

In [None]:
cif_at_180_test = full_cif_tl[0][180, :]

In [None]:
plt.plot(cif_at_180_test)

In [None]:
KS = x_test.KS

In [None]:
df_compare_khorana = pd.DataFrame({"cif": cif_at_180_test, "ks": KS, "event": y_test[1], "obs_time": y_test[0]})

In [None]:
df_compare_khorana.shape

In [None]:
df_compare_khorana.head()

In [None]:
sns.displot(df_compare_khorana.cif)

In [None]:
from utils import get_estimated_cif

In [None]:
def get_pair_counts_and_vte(df, ks_condition, cif_condition):
    filtered_df = df[ks_condition & cif_condition]
    pair_count = len(filtered_df)
    vte_estimate = get_estimated_cif(filtered_df["obs_time"], filtered_df["event"])
    return pair_count, vte_estimate

int_risk_ppv=.09
high_risk_condition = (df_compare_khorana.ks >= 2)
low_risk_condition = (df_compare_khorana.ks < 2)
high_cif_condition = (df_compare_khorana.cif >= int_risk_ppv)
low_cif_condition = (df_compare_khorana.cif < int_risk_ppv)

concordant_pairs_high_risk, concordant_high_risk_vte = get_pair_counts_and_vte(df_compare_khorana, high_risk_condition, high_cif_condition)
discordant_pairs_high_risk, discordant_high_risk_vte = get_pair_counts_and_vte(df_compare_khorana, low_risk_condition, high_cif_condition)

concordant_pairs_low_risk, concordant_low_risk_vte = get_pair_counts_and_vte(df_compare_khorana, low_risk_condition, low_cif_condition)
discordant_pairs_low_risk, discordant_low_risk_vte = get_pair_counts_and_vte(df_compare_khorana, high_risk_condition, low_cif_condition)


In [None]:
data = {
    "Concordant Pairs": {
        "KS": ["High Risk", "Low Risk"],
        "DeepVTE": ["High Risk", "Low Risk"],
        "No": [
           concordant_pairs_high_risk,
           concordant_pairs_low_risk,
        ],
        "Incidence VTE": [concordant_high_risk_vte, concordant_low_risk_vte]
    },
    "Discordant Pairs": {
        "KS": ["Low Risk", "High Risk"],
        "DeepVTE": ["High Risk", "Low Risk"],
        "No": [
            discordant_pairs_high_risk,
            discordant_pairs_low_risk,
        ],
        "Incidence VTE": [discordant_high_risk_vte, discordant_low_risk_vte]
    },
}

In [None]:
pd.concat([pd.DataFrame.from_dict(data["Concordant Pairs"]), pd.DataFrame.from_dict(data["Discordant Pairs"])])

In [None]:
pd.concat([pd.DataFrame.from_dict(data["Concordant Pairs"]), pd.DataFrame.from_dict(data["Discordant Pairs"])]).to_csv("external_cohort_a_reclassification.csv")

In [None]:
full_cifs_non_tl = []

for sm in models:
    full_cifs_non_tl.append(sm.predict_cif(features_test))
    
full_cif_non_tl = np.mean(full_cifs_non_tl, dtype=np.float64, axis=0)

In [None]:
plot_grouped_risks(full_cif_non_tl, time_of_interest=181, save=True, name="grouped_risks_ext_cohort_a_validation_non_tl")

In [None]:
plot_calibration(full_cif_non_tl[0][180, :],
                 durations=y_test[0],
                 events=y_test[1],
                 feature="External Cohort A",
                 name="Before Transfer Learning",
                 save=True)

In [None]:
plot_grouped_risks(full_cif_tl, time_of_interest=181, save=True, name="grouped_risks_ext_cohort_a_validation_tl")

In [None]:
plot_calibration(full_cif_tl[0][180, :],
                 durations=y_test[0],
                 events=y_test[1],
                 feature="External Cohort A",
                 name="After Transfer Learning",
                 save=True)

### TL Full retrain

In [None]:
tl_models_full_retrain = []
n = 30
if (get_parent_dir() / f"models_finetuned/{feature}/full_retrain").is_dir():
    params = load(get_parent_dir() / f"models_finetuned/{feature}/full_retrain/params.pkl")
    for i in range(n):
        net = CauseSpecificNet(**params)
        optimizer = tt.optim.AdamWR(
            lr=hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                                cycle_eta_multiplier=hyper_params["eta_par"]
        )
        tm = DeepHit(net,
                     optimizer=optimizer,
                     alpha=hyper_params["alpha_par"],
                     sigma=hyper_params["sigma_par"],
                     device=device,
                     duration_index=labtrans.cuts)

        tm.load_model_weights(get_parent_dir() / f"models_finetuned/{feature}/full_retrain/model_{i}.pt")
        tl_models_full_retrain.append(tm)
else:
    print("No finetuned model found")
    params = load(get_parent_dir() / f"models/{feature}/params.pkl")
    tl_models_full_retrain = []

    for i in range(n):
        net = CauseSpecificNet(**params)
        optimizer = tt.optim.AdamWR(
            lr=hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                                cycle_eta_multiplier=hyper_params["eta_par"]
        )
        tm = DeepHit(net,
                     optimizer=optimizer,
                     alpha=hyper_params["alpha_par"],
                     sigma=hyper_params["sigma_par"],
                     device=device,
                     duration_index=labtrans.cuts)

        tm.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
        tl_models_full_retrain.append(tm)

    # layers_to_tune = ["risk_nets.0.net.2.weight", 
    #                   "risk_nets.0.net.2.bias",
    #                   "risk_nets.1.net.2.weight",
    #                   "risk_nets.1.net.2.bias",
    #                  ]

    for m in tl_models_full_retrain:
        total_trainable_params = 0
        total_non_trainable_params = 0
        for param in m.net.parameters():
            if not param.requires_grad:
                total_non_trainable_params += param.numel()
            else:
                total_trainable_params += param.numel()
        print(f"Total Trainable params = {total_trainable_params}\nTotal non-trainable params = {total_non_trainable_params}")
        log = m.fit(features_train, (y_train[0], y_train[1]), features_train.shape[1], 30, verbose=True)

    print(f"Saving models for {feature} to models_finetuned/{feature}/full_retrain")
    Path((get_parent_dir() / f"models_finetuned/{feature}/full_retrain")).mkdir(parents=True, exist_ok=True)
    params = {
        "in_features": features_train.shape[1],
        "num_nodes_shared": [int(hyper_params.get("w_shared")), int(hyper_params.get("d_shared"))],
        "num_nodes_indiv": [int(hyper_params.get("w_indiv")), int(hyper_params.get("d_indiv"))],
        "num_risks": int(y_train[1].max()),
        "out_features": 1065,
        "batch_norm": True,
        "dropout": hyper_params.get("dropout"),
    }
    dump(params, get_parent_dir() / f"models_finetuned/{feature}/full_retrain/params.pkl")
    for i, m in enumerate(tl_models_full_retrain):
        m.save_model_weights(get_parent_dir() / f"models_finetuned/{feature}/full_retrain/model_{i}.pt")
    print(f"Saved fine_tuned models for feature {feature}")

In [None]:
cifs_full_retrain = []
for sm in tl_models_full_retrain:
        cifs_full_retrain.append(sm.predict_cif(features_test))
        
cif_full_retrain = np.mean(cifs_full_retrain, dtype=np.float64, axis=0)

In [None]:
samples = 200
scores_full_retrain = []
for j in range(samples):
    sub_test, sub_test_y_0, sub_test_y_1 = resample(
            features_test, y_test[0], y_test[1], stratify=y_test[1],random_state=j
        )
    cifs_tl_full_retrain = []
    for sm in tl_models_full_retrain:
        cifs_tl_full_retrain.append(sm.predict_cif(sub_test))

    cif_tl_full_retrain = np.mean(cifs_tl_full_retrain, dtype=np.float64, axis=0)
    c_stat_tl_full_retrain = c_stat(
            cif_tl_full_retrain[:, :181, :],
            sub_test_y_0,
            sub_test_y_1,
            tl_models_full_retrain[0].duration_index,
            suffix="test_jerry_tl_full_retrain"
        )
    
    scores_full_retrain.append(c_stat_tl_full_retrain)

assert len(scores_full_retrain) == samples

In [None]:
res_full_retrain = pd.concat([pd.DataFrame(df) for df in scores_full_retrain])
res_full_retrain["feature"] = feature

In [None]:
res_full_retrain.groupby("feature").agg(["mean", calc_ci])

In [None]:
plot_calibration(cif_full_retrain[0][180, :],
                 durations=y_test[0],
                 events=y_test[1],
                 feature="External Cohort A",
                 name="After Retraining",
                 save=True)

## Transfer Learning Bias Only

In [None]:
tl_models_bias_only = []
n = 30
if (get_parent_dir() / f"models_finetuned/{feature}/bias_only").is_dir():
    params = load(get_parent_dir() / f"models_finetuned/{feature}/bias_only/params.pkl")
    for i in range(n):
        net = CauseSpecificNet(**params)
        optimizer = tt.optim.AdamWR(
            lr=.1*hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                                cycle_eta_multiplier=hyper_params["eta_par"]
        )
        tm = DeepHit(net,
                     optimizer=optimizer,
                     alpha=hyper_params["alpha_par"],
                     sigma=hyper_params["sigma_par"],
                     device=device,
                     duration_index=labtrans.cuts)

        tm.load_model_weights(get_parent_dir() / f"models_finetuned/{feature}/bias_only/model_{i}.pt")
        tl_models_bias_only.append(tm)
else:
    print("No finetuned model found")
    params = load(get_parent_dir() / f"models/{feature}/params.pkl")

    for i in range(n):
        net = CauseSpecificNet(**params)
        optimizer = tt.optim.AdamWR(
            lr=hyper_params["lr"], decoupled_weight_decay=hyper_params["L2_par"],
                                cycle_eta_multiplier=hyper_params["eta_par"]
        )
        tm = DeepHit(net,
                     optimizer=optimizer,
                     alpha=hyper_params["alpha_par"],
                     sigma=hyper_params["sigma_par"],
                     device=device,
                     duration_index=labtrans.cuts)

        tm.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
        tl_models_bias_only.append(tm)

    layers_to_tune = ["risk_nets.0.net.2.bias",
                      "risk_nets.1.net.2.bias"
                     ]

    for m in tl_models_bias_only:
        total_trainable_params = 0
        total_non_trainable_params = 0
        for name, prms in m.net.named_parameters():
            # Freeze the parameters except the last layer
            if name in layers_to_tune:
                prms.requires_grad = True
            else:
                prms.requires_grad = False
                
        for param in m.net.parameters():
            if not param.requires_grad:
                total_non_trainable_params += param.numel()
            else:
                total_trainable_params += param.numel()
        print(f"Total Trainable params = {total_trainable_params}\nTotal non-trainable params = {total_non_trainable_params}")
        log = m.fit(features_train, (y_train[0], y_train[1]), features_train.shape[1], 30, verbose=True)

    print(f"Saving models for {feature} to models_finetuned/{feature}/bias_only")
    Path((get_parent_dir() / f"models_finetuned/{feature}/bias_only")).mkdir(parents=True, exist_ok=True)
    params = {
        "in_features": features_train.shape[1],
        "num_nodes_shared": [int(hyper_params.get("w_shared")), int(hyper_params.get("d_shared"))],
        "num_nodes_indiv": [int(hyper_params.get("w_indiv")), int(hyper_params.get("d_indiv"))],
        "num_risks": int(y_train[1].max()),
        "out_features": 1065,
        "batch_norm": True,
        "dropout": hyper_params.get("dropout"),
    }
    dump(params, get_parent_dir() / f"models_finetuned/{feature}/bias_only/params.pkl")
    for i, m in enumerate(tl_models_bias_only):
        m.save_model_weights(get_parent_dir() / f"models_finetuned/{feature}/bias_only/model_{i}.pt")
    print(f"Saved fine_tuned bias only models for feature {feature}")

In [None]:
cifs_full_bias = []
for sm in tl_models_bias_only:
        cifs_full_bias.append(sm.predict_cif(features_test))
        
cif_full_bias = np.mean(cifs_full_bias, dtype=np.float64, axis=0)

In [None]:
plot_calibration(cif_full_bias[0][180, :],
                 durations=y_test[0],
                 events=y_test[1],
                 feature="External Cohort A",
                 name="After Bias Tuning",
                 save=True)