# Conformalized Survival Analysis for General Right-Censored Data

### Imports

In [55]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
from tqdm import tqdm
import torch
import torchtuples as tt
import seaborn as sns
import matplotlib.patches as mpatches

import random
import warnings
warnings.filterwarnings("ignore")

from datasets import generate_data, load_dataset
from plotting import plot_lpbs_ablation, plot_lpbs, plot_coverage_and_lpb_comperison, plot_c_ind_ibs_and_dcal_comperison, plot_censorship_rate_lpb_coverage, plot_lpb_coverage_n_samples
from training import train_models_get_data, get_target, train_early_event_models
from calibration import adaptive_conformal_cov, csd_cov, base_model_cov, csd_metrics_estimation, adaptive_conformal_metrics_estimation, base_model_metrics_estimation

### Define the experiment functions


In [58]:
def run_coverage_comparison_same_model(args, n_runs, target_alphas, alphas, frac_early_cens=0.1, threshold_early_cens=0.15, frac_early_surv=0.1, threshold_early_surv=0.12, retrain=True, settings=range(1, 7), include_baselines=True):
    index = pd.MultiIndex.from_product(
        [['Focused', 'Fused', 'Naive', 'CSD', 'Uncalibrated'], 
         target_alphas, 
         settings, 
         range(n_runs)], 
        names=['Method', 'Target Alpha', 'Setting', 'Run']
    )
    
    df = pd.DataFrame(index=index, columns=['Coverage', 'A_hat', 'LPB', 'Include Proportion'])
    for setting in settings:
        print(f'\n Setting {setting}')
        if not retrain:
            surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv)
            early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
        for i in tqdm(range(n_runs)):
            if retrain:
                surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv)
                early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
            else:
                # Repartition the calibration and test sets
                df_cal, df_test, x_cal, x_test = repartition(df_cal, df_test, x_cal, x_test)
            durations, events = get_target(df_cal)
            if setting in range(1,7):
                if include_baselines:
                    coverage_base, lengths_base = base_model_cov(target_alphas, setting, surv_model, df_test, x_test)
                    coverage_csd, lengths_csd = csd_cov(target_alphas, setting, surv_model, df_train, df_cal, df_test, x_cal, x_test, durations)
                coverages_naive, a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                coverages_focused, a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                coverages_fused, a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)
            else:
                if include_baselines:  
                    lengths_base = base_model_cov(target_alphas, setting, surv_model, df_test, x_test)
                    lengths_csd = csd_cov(target_alphas, setting, surv_model, df_train, df_cal, df_test, x_cal, x_test, durations)
                a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)

            for j, alpha in enumerate(target_alphas):
                if include_baselines:
                    df.loc[('Uncalibrated', alpha, setting, i), 'LPB'] = lengths_base[j]

                    df.loc[('CSD', alpha, setting, i), 'LPB'] = lengths_csd[j]
                
                df.loc[('Naive', alpha, setting, i), 'A_hat'] = a_hats_naive[j]
                df.loc[('Naive', alpha, setting, i), 'LPB'] = lengths_naive[j]

                df.loc[('Focused', alpha, setting, i), 'A_hat'] = a_hats_focused[j]
                df.loc[('Focused', alpha, setting, i), 'LPB'] = lengths_focused[j]

                df.loc[('Fused', alpha, setting, i), 'A_hat'] = a_hats_fused[j]
                df.loc[('Fused', alpha, setting, i), 'LPB'] = lengths_fused[j]
                df.loc[('Fused', alpha, setting, i), 'Include Proportion'] = include_proportion[j]

                if setting in range(1, 7):
                    if include_baselines:
                        df.loc[('Uncalibrated', alpha, setting, i), 'Coverage'] = coverage_base[j]
                        df.loc[('CSD', alpha, setting, i), 'Coverage'] = coverage_csd[j]
                    df.loc[('Naive', alpha, setting, i), 'Coverage'] = coverages_naive[j]
                    df.loc[('Focused', alpha, setting, i), 'Coverage'] = coverages_focused[j]
                    df.loc[('Fused', alpha, setting, i), 'Coverage'] = coverages_fused[j]
    return df

def run_metrics_comparison_same_model(args, n_runs, target_alphas, alphas, frac_early_cens=0.1, threshold_early_cens=0.15, frac_early_surv=0.1, threshold_early_surv=0.12, retrain=True, settings=range(1, 7)):
    index = pd.MultiIndex.from_product(
        [['Focused', 'Fused', 'Naive', 'CSD', 'Uncalibrated'], 
         settings, 
         range(n_runs)], 
        names=['Method', 'Setting', 'Run']
    )
    
    df = pd.DataFrame(index=index, columns=['C-Index', 'IBS', 'D-Cal'])
    for setting in settings:
        print(f'\n Setting {setting}')
        if not retrain:
            surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv)
            # early_event_models = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
            early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
        for i in tqdm(range(n_runs)):
            if retrain:
                surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv)
                # early_event_models = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
                early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
            else:
                # Repartition the calibration and test sets
                df_cal, df_test, x_cal, x_test = repartition(df_cal, df_test, x_cal, x_test)
            durations, events = get_target(df_cal)
            c_ind_base, ibs_base, dcal_base = base_model_metrics_estimation(target_alphas, setting, surv_model, df_test, x_test, df_train)
            c_ind_csd, ibs_csd, dcal_csd = csd_metrics_estimation(target_alphas, setting, surv_model, df_train, df_cal, df_test, x_cal, x_test, durations)
            c_ind_naive, ibs_naive, dcal_naive = adaptive_conformal_metrics_estimation(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
            c_ind_focused, ibs_focused, dcal_focused = adaptive_conformal_metrics_estimation(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
            c_ind_fused, ibs_fused, dcal_fused = adaptive_conformal_metrics_estimation(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)

            df.loc[('Uncalibrated', setting, i), 'C-Index'] = c_ind_base
            df.loc[('Uncalibrated', setting, i), 'IBS'] = ibs_base
            df.loc[('Uncalibrated', setting, i), 'D-Cal'] = dcal_base

            df.loc[('CSD', setting, i), 'C-Index'] = c_ind_csd
            df.loc[('CSD', setting, i), 'IBS'] = ibs_csd
            df.loc[('CSD', setting, i), 'D-Cal'] = dcal_csd

            df.loc[('Naive', setting, i), 'C-Index'] = c_ind_naive
            df.loc[('Naive', setting, i), 'IBS'] = ibs_naive
            df.loc[('Naive', setting, i), 'D-Cal'] = dcal_naive

            df.loc[('Focused', setting, i), 'C-Index'] = c_ind_focused
            df.loc[('Focused', setting, i), 'IBS'] = ibs_focused
            df.loc[('Focused', setting, i), 'D-Cal'] = dcal_focused

            df.loc[('Fused', setting, i), 'C-Index'] = c_ind_fused
            df.loc[('Fused', setting, i), 'IBS'] = ibs_fused
            df.loc[('Fused', setting, i), 'D-Cal'] = dcal_fused
    return df

def repartition(df_cal, df_test, x_cal, x_test):
    # merge df_cal and df_test
    df_cal = pd.concat([df_cal, df_test])
    df_cal = df_cal.reset_index(drop=True)
    # sample 50% of the merged df
    df_test = df_cal.sample(frac=0.5)
    df_cal = df_cal.drop(df_test.index)
    # merge x_cal and x_test
    x_cal = np.concatenate([x_cal, x_test])
    x_test = x_cal[df_test.index]
    x_cal = np.delete(x_cal, df_test.index, axis=0)
    return df_cal, df_test, x_cal, x_test

### Run the experiments

#### Synthetic data

In [59]:
args = {
    'n_samples': 2000,
    'num_nodes': [5],
    'batch_norm': 'batch',
    'dropout': 0.1,
    'batch_size': 256,
    'lr': .002,
    'epochs': 50,
    'callbacks': [tt.callbacks.EarlyStopping(patience=5)],
    'verbose': False
}

In [None]:
# set seeds for reproducibility
random.seed(12)
np.random.seed(12)
_ = torch.manual_seed(12)
alphas = np.logspace(-2, -1, 30)
target_alphas = np.round(np.array([0.1]), 5)


df_synthetic = run_coverage_comparison_same_model(args, 20, target_alphas, alphas, retrain=True, settings=range(1,7), frac_early_surv=0.05, frac_early_cens=0.12)

In [None]:
# Call the function with your DataFrame and target alpha
plot_coverage_and_lpb_comperison(df_synthetic, 0.1, fontsize=30)

In [45]:
# alphas = np.linspace(0.02, 0.98, 98)
# target_alphas = np.round(np.linspace(0.1, 0.9, 20), 10)
# df_metrics_synth = run_metrics_comparison_same_model(args, 5, target_alphas, alphas, retrain=False, settings=range(1,7), frac_early_surv=0.05, frac_early_cens=0.12)

In [46]:
# plot_c_ind_ibs_and_dcal_comperison(df_metrics_synth)

#### TCGA-BRCA data

In [47]:
args_real = {
    'num_nodes': [32, 32],
    'batch_norm': 'batch',
    'dropout': 0.1,
    'batch_size': 256,
    'lr': .002,
    'epochs': 1000,
    'callbacks': [tt.callbacks.EarlyStopping(patience=5)],
    'verbose': False
}

In [None]:
# Set seeds for reproducibility
random.seed(12)
np.random.seed(12)
_ = torch.manual_seed(12)

alphas = np.logspace(-2, -1, 30)
target_alphas = np.round(np.array([0.1]), 5)

df_real = run_coverage_comparison_same_model(args_real, 100, target_alphas, alphas, retrain=False, settings=['tcga', 'support', 'metabric', 'churn', 'nacd', 'gbsg'])

In [None]:
# Print include proportion for the Fused method
print(df_real.groupby(['Method', 'Setting'])['Include Proportion'].mean().unstack())

In [None]:
plot_lpbs(df_real, 0.1)


In [16]:
# alphas = np.linspace(0.02, 0.98, 98)
# target_alphas = np.round(np.linspace(0.1, 0.9, 10), 10)
# df_metrics_real = run_metrics_comparison_same_model(args_real, 5, target_alphas, alphas, retrain=False, settings=['support', 'metabric', 'churn', 'tcga', 'nacd', 'gbsg'])
# plot_c_ind_ibs_and_dcal_comperison(df_metrics_real)

### Hyperparams ablation

In [17]:
def run_coverage_comparison_ablation(args, n_runs, target_alphas, alphas, frac_early_cens=0.1, threshold_early_cens=0.15, frac_early_surv=0.1, threshold_early_surv=0.12, retrain=True, settings=range(1, 7)):
    dqs = {'Shallow q': [5], 'Deep q': [5, 5, 5]}
    dws = {'Shallow w': 2, 'Deep w': 6}
    index = pd.MultiIndex.from_product(
        [['Focused', 'Fused', 'Naive'], 
         target_alphas, 
         settings, 
         range(n_runs),
         dqs.keys(),
         dws.keys()],
        names=['Method', 'Target Alpha', 'Setting', 'Run', 'Depth Quantiles', 'Depth Weights']
    )
    df = pd.DataFrame(index=index, columns=['Coverage', 'A_hat', 'LPB', 'Include Proportion'])
    for setting in settings:
        print(f'\n Setting {setting}')
        for dq in dqs.keys():
            for dw in dws.keys():
                args["num_nodes"] = dqs[dq]
                if not retrain:
                    surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv, max_depth_w=dws[dw])
                    early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
                for i in tqdm(range(n_runs)):
                    if retrain:
                        surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv, max_depth_w=dws[dw])
                        early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
                    else:
                        # Repartition the calibration and test sets
                        df_cal, df_test, x_cal, x_test = repartition(df_cal, df_test, x_cal, x_test)
                    durations, events = get_target(df_cal)
                    if setting in range(1,7):
                        coverages_naive, a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                        coverages_focused, a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                        coverages_fused, a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)
                    else:
                        a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                        a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                        a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)

                    for j, alpha in enumerate(target_alphas):
                        df.loc[('Naive', alpha, setting, i, dq, dw), 'A_hat'] = a_hats_naive[j]
                        df.loc[('Naive', alpha, setting, i, dq, dw), 'LPB'] = lengths_naive[j]

                        df.loc[('Focused', alpha, setting, i, dq, dw), 'A_hat'] = a_hats_focused[j]
                        df.loc[('Focused', alpha, setting, i, dq, dw), 'LPB'] = lengths_focused[j]

                        df.loc[('Fused', alpha, setting, i, dq, dw), 'A_hat'] = a_hats_fused[j]
                        df.loc[('Fused', alpha, setting, i, dq, dw), 'LPB'] = lengths_fused[j]
                        df.loc[('Fused', alpha, setting, i, dq, dw), 'Include Proportion'] = include_proportion[j]

                        if setting in range(1, 7):
                            df.loc[('Naive', alpha, setting, i, dq, dw), 'Coverage'] = coverages_naive[j]
                            df.loc[('Focused', alpha, setting, i, dq, dw), 'Coverage'] = coverages_focused[j]
                            df.loc[('Fused', alpha, setting, i, dq, dw), 'Coverage'] = coverages_fused[j]
    return df

In [None]:
# set seeds for reproducibility
random.seed(1)
np.random.seed(1)
_ = torch.manual_seed(1)
alphas = np.linspace(0.02, 0.15, 60)
target_alphas = np.round(np.array([0.1]), 5)

df_ablation = run_coverage_comparison_ablation(args, 10, target_alphas, alphas, retrain=True, settings=range(3,4))

In [None]:
# Call the function with the DataFrame and target alpha value
plot_lpbs_ablation(df_ablation, 0.1)


### Censorship proportion experiment

In [None]:
def run_comparison_censorship(args, n_runs, target_alphas, alphas, frac_early_cens=0.1, threshold_early_cens=0.15, frac_early_surv=0.1, threshold_early_surv=0.12, retrain=True, settings=range(1, 7)):
    end_of_trial = np.linspace(0.5, 10, 10)
    index = pd.MultiIndex.from_product(
        [['Focused', 'Fused', 'Naive'], 
         target_alphas, 
         settings, 
         range(n_runs),
         end_of_trial],
        names=['Method', 'Target Alpha', 'Setting', 'Run', 'End of trial time']
    )
    df = pd.DataFrame(index=index, columns=['Coverage', 'A_hat', 'LPB', 'Include Proportion', 'Censorship Rate'])
    for setting in settings:
        print(f'\n Setting {setting}')
        for eot in end_of_trial:
            if not retrain:
                surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv, eot=eot)
                early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
            for i in tqdm(range(n_runs)):
                if retrain:
                    surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv, eot=eot)
                    early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
                else:
                    # Repartition the calibration and test sets
                    df_cal, df_test, x_cal, x_test = repartition(df_cal, df_test, x_cal, x_test)
                durations, events = get_target(df_cal)
                if setting in range(1,7):
                    coverages_naive, a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    coverages_focused, a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    coverages_fused, a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)
                else:
                    a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)

                for j, alpha in enumerate(target_alphas):
                    df.loc[('Naive', alpha, setting, i, eot), 'A_hat'] = a_hats_naive[j]
                    df.loc[('Naive', alpha, setting, i, eot), 'LPB'] = lengths_naive[j]
                    df.loc[('Naive', alpha, setting, i, eot), 'Censorship Rate'] = 1 - df_train['event'].mean()

                    df.loc[('Focused', alpha, setting, i, eot), 'A_hat'] = a_hats_focused[j]
                    df.loc[('Focused', alpha, setting, i, eot), 'LPB'] = lengths_focused[j]
                    df.loc[('Focused', alpha, setting, i, eot), 'Censorship Rate'] = 1 - df_train['event'].mean()

                    df.loc[('Fused', alpha, setting, i, eot), 'A_hat'] = a_hats_fused[j]
                    df.loc[('Fused', alpha, setting, i, eot), 'LPB'] = lengths_fused[j]
                    df.loc[('Fused', alpha, setting, i, eot), 'Include Proportion'] = include_proportion[j]
                    df.loc[('Fused', alpha, setting, i, eot), 'Censorship Rate'] = 1 - df_train['event'].mean()

                    if setting in range(1, 7):
                        df.loc[('Naive', alpha, setting, i, eot), 'Coverage'] = coverages_naive[j]
                        df.loc[('Focused', alpha, setting, i, eot), 'Coverage'] = coverages_focused[j]
                        df.loc[('Fused', alpha, setting, i, eot), 'Coverage'] = coverages_fused[j]
    return df

# set seeds for reproducibility
random.seed(1)
np.random.seed(1)
_ = torch.manual_seed(1)
alphas = np.linspace(0.02, 0.15, 60)
target_alphas = np.round(np.array([0.1]), 5)

df_censorship = run_comparison_censorship(args, 10, target_alphas, alphas, retrain=True, settings=range(3,4))

In [None]:
# Call the function with the DataFrame and target alpha value
plot_censorship_rate_lpb_coverage(df_censorship, 0.1, 17)

### Sample num experiment

In [None]:
def run_comparison_samples(args, n_runs, target_alphas, alphas, frac_early_cens=0.1, threshold_early_cens=0.15, frac_early_surv=0.1, threshold_early_surv=0.12, retrain=True, settings=range(1, 7)):
    num_samples = [200, 500, 800, 1100, 1400, 1700, 2000]
    index = pd.MultiIndex.from_product(
        [['Focused', 'Fused', 'Naive'], 
         target_alphas, 
         settings, 
         range(n_runs),
         num_samples],
        names=['Method', 'Target Alpha', 'Setting', 'Run', 'Num Samples']
    )
    df = pd.DataFrame(index=index, columns=['Coverage', 'A_hat', 'LPB', 'Include Proportion', 'Censorship Rate'])
    for setting in settings:
        print(f'\n Setting {setting}')
        for n_samples in num_samples:
            args['n_samples'] = n_samples
            if not retrain:
                surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv)
                early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
            for i in tqdm(range(n_runs)):
                if retrain:
                    surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test = train_models_get_data(setting, args, frac_early_cens=frac_early_cens, threshold_early_cens=threshold_early_cens, frac_early_surv=frac_early_surv, threshold_early_surv=threshold_early_surv)
                    early_event_model = train_early_event_models(surv_model, alphas, x_train, df_train, setting)
                else:
                    # Repartition the calibration and test sets
                    df_cal, df_test, x_cal, x_test = repartition(df_cal, df_test, x_cal, x_test)
                durations, events = get_target(df_cal)
                if setting in range(1,7):
                    coverages_naive, a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    coverages_focused, a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    coverages_fused, a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)
                else:
                    a_hats_naive, lengths_naive = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'naive', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    a_hats_focused, lengths_focused = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'focus', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events)
                    a_hats_fused, lengths_fused, include_proportion = adaptive_conformal_cov(target_alphas, alphas, setting, False, 'fused', surv_model, surv_classifier, df_train, df_cal, df_test, x_train, x_cal, x_test, durations, events, early_event_model) # s)

                for j, alpha in enumerate(target_alphas):
                    df.loc[('Naive', alpha, setting, i, n_samples), 'A_hat'] = a_hats_naive[j]
                    df.loc[('Naive', alpha, setting, i, n_samples), 'LPB'] = lengths_naive[j]
                    df.loc[('Naive', alpha, setting, i, n_samples), 'Censorship Rate'] = 1 - df_train['event'].mean()

                    df.loc[('Focused', alpha, setting, i, n_samples), 'A_hat'] = a_hats_focused[j]
                    df.loc[('Focused', alpha, setting, i, n_samples), 'LPB'] = lengths_focused[j]
                    df.loc[('Focused', alpha, setting, i, n_samples), 'Censorship Rate'] = 1 - df_train['event'].mean()

                    df.loc[('Fused', alpha, setting, i, n_samples), 'A_hat'] = a_hats_fused[j]
                    df.loc[('Fused', alpha, setting, i, n_samples), 'LPB'] = lengths_fused[j]
                    df.loc[('Fused', alpha, setting, i, n_samples), 'Include Proportion'] = include_proportion[j]
                    df.loc[('Fused', alpha, setting, i, n_samples), 'Censorship Rate'] = 1 - df_train['event'].mean()

                    if setting in range(1, 7):
                        df.loc[('Naive', alpha, setting, i, n_samples), 'Coverage'] = coverages_naive[j]
                        df.loc[('Focused', alpha, setting, i, n_samples), 'Coverage'] = coverages_focused[j]
                        df.loc[('Fused', alpha, setting, i, n_samples), 'Coverage'] = coverages_fused[j]
    return df

# set seeds for reproducibility
random.seed(1)
np.random.seed(1)
_ = torch.manual_seed(1)
alphas = np.linspace(0.02, 0.15, 60)
target_alphas = np.round(np.array([0.1]), 5)

df_samples = run_comparison_samples(args, 10, target_alphas, alphas, retrain=True, settings=range(3,4))

In [None]:
# Call the function with the DataFrame and target alpha value
plot_lpb_coverage_n_samples(df_samples, 0.1)