In [None]:
pip install lifelines

In [None]:
from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
from lifelines import CoxPHFitter
coxph=CoxPHFitter()
import numpy as np
import time
import math
import pandas as pd

df_record={}

def c(a, b):
    return np.concatenate((a, b))

def rep(val, n):
    return np.ones(n,)*val

def sum_v_2d(x):
    # mat[i,j] = sum(x[i:j+1])
    cumsum = np.cumsum(x)
    mat_1=cumsum[None, :] - cumsum[:-1, None]
    mat = np.concatenate((cumsum[None, :], mat_1), axis=0)
    return np.triu(mat)

def mono_haz(df1_y1, df1_cov1, df1_delta, df1_truncation, adjust, beta_hat):
    z_k = np.unique(np.sort(c(df1_y1, df1_truncation)))
    z1 = np.unique(np.sort(df1_y1[df1_delta==1]))
    N = z_k.shape[0]
    death_sum = np.zeros(N,)
    risk_set = np.zeros(N,)
    exp_term = np.zeros(N,)
    lamda = np.zeros(N,)
    for h in range(N):
        index = np.logical_and(df1_y1>=z_k[h], df1_truncation<=z_k[h])
        risk_set[h] = np.sum(index)
        death_sum[h] = np.sum(z1==z_k[h])
        exp_term[h] = np.sum(np.exp(df1_cov1[index] * beta_hat))

    difference = np.diff(z_k)
    nom_2d = sum_v_2d(death_sum[:-1])
    if adjust:
        denom_2d = sum_v_2d(exp_term[:-1] * difference)
    else:
        denom_2d = sum_v_2d(risk_set[:-1] * difference)
    ratio_2d = nom_2d / (denom_2d + 1e-4*(denom_2d==0))

    for j in range(N-1):
        lamda[j] = np.max(np.min(ratio_2d[0:j+1, j:N-1], axis=1), axis=0)
    lamda[N-1] = lamda[N-2]
    return lamda, z_k

def baseline_survival_sim(t1, df1_y1, df1_cov1, df1_delta, df1_truncation, hazard_est):
    z_k = np.unique(np.sort(c(df1_y1, df1_truncation)))
    distinct_time = np.concatenate((np.zeros(1), z_k))
    I = np.digitize(np.array([t1]), bins=z_k, right=True)-1
    if t1 <= z_k[0]:
        cumhaz1 = 0
        rest=t1 * hazard_est[0]
    elif t1 >= np.max(z_k):
        cumhaz1 = np.cumsum(np.diff(distinct_time) * hazard_est)[I]
        rest = 0
    else:
        cumhaz1 = np.cumsum(np.diff(distinct_time) * hazard_est)[I]
        rest = (t1-z_k[I]) * hazard_est[I+1]
    intergation = cumhaz1 + rest
    return np.exp(-intergation)

def rweibull(n, shape, scale):
    return scale * np.random.weibull(a=shape, size=(n, ))

def med_weibull(shape,scale):
    return((-np.log(0.5))^(1/shape)*scale)

def main(a,x_time,n_trials,n_samples_old,method_i,modify=False, no_est=False,trun_scale=1200):
    surv_median_est = []
    surv_median_tsai = []
    km_med=np.zeros(n_trials,)
    cox_med=np.zeros(n_trials,)
    elapsed = 0
    save_data={"y1":[], "truncation":[], "cov1":[], "delta":[],
                "tsai_weibull":[], "wu_weibull":[],
                "surv_res_est":[], "surv_res_tsai":[],
               "tsai_zk":[], "wu_zk":[], "death_sum":[],"exp_term":[]
               }
    min_t = None
    min_x = None

    dataframe_list=[]
    sample_size=[]

    for i in range(n_trials):
        tt1 = time.time()
        np.random.seed(i)
        n_samples = n_samples_old
        man1=rweibull(n_samples//2, shape=a, scale = 1000)
        cen1=rweibull(n_samples//2, shape=a, scale = 1000)
        woman2=rweibull(n_samples//2, shape=a, scale = 2000)
        cen2=rweibull(n_samples//2, shape=a, scale = 2000)
        censor=c(cen1, cen2)
        failure = c(man1, woman2)
        cov1 = c(rep(0, n_samples//2), rep(1, n_samples//2))
        delta = failure <= censor
        y1 = delta * failure + (1-delta) * censor

        if method_i == 1:
            # original method
            trun=y1*np.random.uniform(0,1,n_samples)
            truncation = np.where(trun>y1, y1*0.9, trun)
        elif method_i == 2:
            # method 2. sample till all truncation<=y1
            truncation = rweibull(n_samples, shape=a, scale=300)
            indices=np.where(truncation>y1)[0]
            n_tries=0
            while len(indices)>0:
                BIG_N=n_samples*2
                new_truncation = rweibull(len(indices)*BIG_N, shape=a, scale=1000).reshape(len(indices),BIG_N)
                new_truncation = np.min(new_truncation, axis=1)
                truncation[indices] = new_truncation
                indices=np.where(truncation>y1)[0]
                n_tries+=1
                #print(n_tries, len(indices))
        elif method_i == 3:
            # method 3. swap y1,truncation to ensure truncation<=y1
            truncation = rweibull(n_samples, shape=a, scale=500)
            indices = truncation>y1
            y1[indices], truncation[indices] = truncation[indices], y1[indices]
        elif method_i == 4:
            # method 4. remove the violated y1 & truncation
            truncation = rweibull(n_samples, shape=a, scale=trun_scale)
            indices = np.where(truncation<=y1)[0]
            y1=y1[indices]
            cov1 = cov1[indices]
            delta = delta[indices]
            truncation = truncation[indices]
            n_samples = len(indices)

        ######################
        if modify:
            minimum_y = np.min(y1)
            keep_indices = np.where(np.logical_or(truncation> minimum_y, y1 ==minimum_y))[0]
            #min_idx = np.where(y1==minimum_y)[0]
            # delta[min_idx]=1
            # cov1[min_idx]=0
            y1 = y1[keep_indices]
            truncation = truncation[keep_indices]
            delta = delta[keep_indices]
            cov1 = cov1[keep_indices]
            n_samples = len(keep_indices)
        ######################

        d = {'y1': y1, 'truncation': truncation,'delta': delta,'cov1': cov1, "trial_i":i+cov1*0}
        df2 = pd.DataFrame(data=d)

        df_record[i] = df2.copy()
        dataframe_list.append(df_record[i])
        if no_est:
            print("Trial", i)
            continue
        kmf.fit(durations=y1[cov1==0], event_observed=delta[cov1==0], entry=truncation[cov1==0])
        km_med[i]=np.mean(kmf.survival_function_at_times(x_time))
        bashaz_index=np.digitize(x_time, bins=np.sort(y1), right=True)-1
        coxph.fit(df=df2,duration_col="y1",event_col="delta",entry_col='truncation',formula="cov1")
        death_sum=np.zeros(n_samples,)
        exp_term=np.zeros(n_samples,)
        z_k = np.unique(np.sort(y1))
        z1 = np.unique(np.sort(y1[delta==1]))

        for l in range(n_samples):
            index = np.logical_and(y1>=z_k[l], truncation<=z_k[l])
            death_sum[l] = np.sum(z1==z_k[l])
            exp_term[l] = np.sum(np.exp(cov1[index] * coxph.params_[0]))
        cox_med[i]=np.where(bashaz_index==-1, 1, np.exp(-np.cumsum(death_sum/exp_term))[bashaz_index])
        tsai_weibull, tsai_zk = mono_haz(y1[cov1==0], cov1[cov1==0], delta[cov1==0], truncation[cov1==0],
                              adjust=False, beta_hat=0)
        wu_weibull, wu_zk = mono_haz(y1, cov1, delta, truncation, adjust=True, beta_hat=coxph.params_[0])
        if min_t is None:
            min_t = np.min([np.max(tsai_zk), np.max(wu_zk)])
        else:
            min_t = np.min([min_t, np.max(tsai_zk), np.max(wu_zk)])
        if min_x is None:
            min_x = np.min(y1)
        else:
            min_x = np.min([min_x, np.min(y1)])
        surv_res_est=baseline_survival_sim(x_time, y1, cov1, delta,truncation,hazard_est=wu_weibull)
        surv_median_est.append(surv_res_est)
        surv_res_tsai=baseline_survival_sim(x_time, y1[cov1==0], cov1[cov1==0], delta[cov1==0], truncation[cov1==0],hazard_est=tsai_weibull)
        surv_median_tsai.append(surv_res_tsai)
        dt = time.time() - tt1
        elapsed += dt
        eta = (n_trials - i - 1) * (elapsed / (i+1))
        #print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs"%(i, n_trials, dt, elapsed, eta))
        save_data["y1"].append(y1)
        save_data["tsai_zk"].append(tsai_zk)
        save_data["wu_zk"].append(wu_zk)
        save_data["cov1"].append(cov1)
        save_data["truncation"].append(truncation)
        save_data["delta"].append(delta)
        save_data["tsai_weibull"].append(tsai_weibull)
        save_data["wu_weibull"].append(wu_weibull)
        save_data["surv_res_est"].append(surv_res_est)
        save_data["surv_res_tsai"].append(surv_res_tsai)
        save_data["death_sum"].append(death_sum)
        save_data["exp_term"].append(exp_term)
        #km:%.4f km_med[i],
        print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))
        sample_size.append(n_samples)
    if no_est==False:
        #print(np.mean(surv_median_est)-0.5,np.mean(surv_median_tsai)-0.5,np.mean(cox_med)-0.5,
        #    math.sqrt(np.square(np.subtract(surv_median_est,np.mean(surv_median_est))).mean()),
        #    math.sqrt(np.square(np.subtract(surv_median_tsai,np.mean(surv_median_tsai))).mean()),
            #math.sqrt(np.square(np.subtract(km_med,np.mean(km_med))).mean()),
        #    math.sqrt(np.square(np.subtract(cox_med,np.mean(cox_med))).mean()),
        #    np.mean(n_samples),min_t)
        print("#"*40)
        print("Summary:   bias: %.4f| %.4f| %.4f| %.4f   variance: %.4f| %.4f| %.4f| %.4f avg_n:%4d"%(
            np.mean(surv_median_est)-0.5,np.mean(surv_median_tsai)-0.5,
            np.mean(cox_med)-0.5,np.mean(km_med)-0.5,
            np.std(surv_median_est,ddof=1),np.std(surv_median_tsai,ddof=1),
            np.std(cox_med,ddof=1),np.std(km_med,ddof=1),np.mean(sample_size)
        ))
        #df = pd.DataFrame(data=save_data)
        #return np.mean(surv_median_est), np.mean(surv_median_tsai), np.mean(km_med),np.mean(cox_med)

    # save data
    from google.colab import drive
    drive.mount('/content/drive')
    if no_est==False:
        np.savez("/content/drive/My Drive/data_monotone_hazard.npz", data=save_data)
    pd.concat(dataframe_list).to_csv('/content/drive/My Drive/df_concat.csv')

    #df.to_csv('/content/drive/My Drive/df.csv')
if __name__ == "__main__":
    #S1
    #main(a=1, x_time=693.1472,n_trials=1000,n_samples_old=1000,method_i = 4,trun_scale=250)
    #S2
    #main(a=2, x_time=832.5546,n_trials=1000,n_samples_old=1000,method_i = 4,trun_scale=450)
    #S3
    #main(a=4, x_time=912.4443,n_trials=10,n_samples_old=1000,method_i = 4,trun_scale=500)
    #S4
    #main(a=6, x_time=940.7428,n_trials=1000,n_samples_old=1000,method_i = 4,trun_scale=500)
    #S5
    #main(a=4, x_time=912.4443,n_trials=1000,n_samples_old=2000,method_i = 4, modify=True, no_est=False,trun_scale=1200)

  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0000/0010  dt:3.016s  elapsed:3.016s  ETA:27.147s  | OUTPUT=wu:0.5105 tsai:0.5159  cox:0.5055 km:0.5112| min_x:377.5206 min_t:1381.5167  valid_n: 935


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0001/0010  dt:1.994s  elapsed:5.011s  ETA:20.043s  | OUTPUT=wu:0.5205 tsai:0.5265  cox:0.5115 km:0.5165| min_x:319.4844 min_t:1334.0673  valid_n: 949


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0002/0010  dt:1.705s  elapsed:6.716s  ETA:15.671s  | OUTPUT=wu:0.4965 tsai:0.4850  cox:0.4855 km:0.4814| min_x:312.1600 min_t:1334.0673  valid_n: 946


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0003/0010  dt:2.373s  elapsed:9.089s  ETA:13.634s  | OUTPUT=wu:0.5019 tsai:0.4990  cox:0.4906 km:0.4858| min_x:312.1600 min_t:1334.0673  valid_n: 950


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0004/0010  dt:1.811s  elapsed:10.900s  ETA:10.900s  | OUTPUT=wu:0.5225 tsai:0.5196  cox:0.5113 km:0.5104| min_x:282.2222 min_t:1274.4383  valid_n: 937


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0005/0010  dt:1.634s  elapsed:12.534s  ETA:8.356s  | OUTPUT=wu:0.4992 tsai:0.4969  cox:0.4942 km:0.4911| min_x:282.2222 min_t:1274.4383  valid_n: 940


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0006/0010  dt:1.042s  elapsed:13.577s  ETA:5.819s  | OUTPUT=wu:0.5259 tsai:0.5259  cox:0.4953 km:0.5113| min_x:240.6683 min_t:1274.4383  valid_n: 937


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0007/0010  dt:1.307s  elapsed:14.883s  ETA:3.721s  | OUTPUT=wu:0.5283 tsai:0.5194  cox:0.5297 km:0.5225| min_x:240.6683 min_t:1245.2418  valid_n: 946


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0008/0010  dt:1.434s  elapsed:16.318s  ETA:1.813s  | OUTPUT=wu:0.5015 tsai:0.5050  cox:0.5002 km:0.5004| min_x:240.6683 min_t:1245.2418  valid_n: 936


  print("Trials:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs  | OUTPUT=wu:%.4f tsai:%.4f  cox:%.4f km:%.4f| min_x:%.4f min_t:%.4f  valid_n:%4d"%(i, n_trials, dt, elapsed, eta, surv_median_est[i], surv_median_tsai[i],cox_med[i],km_med[i],min_x, min_t, n_samples))


Trials:0009/0010  dt:1.443s  elapsed:17.760s  ETA:0.000s  | OUTPUT=wu:0.4931 tsai:0.5041  cox:0.4854 km:0.4929| min_x:240.6683 min_t:1245.2418  valid_n: 952
########################################
Summary:   bias: 0.0100| 0.0097| 0.0009| 0.0024   variance: 0.0132| 0.0138| 0.0139| 0.0140 avg_n: 942
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# compute probability
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')
# load data
loaded_data = np.load("/content/drive/My Drive/data_monotone_hazard.npz", allow_pickle=True)
print(loaded_data['data'].item().keys())
loaded_d=loaded_data['data'].item()
a=1 #shape
n_trials = len(loaded_d["truncation"])
# times = np.linspace(0, 1196, 100)
times = np.linspace(0, 1432, 100)
wu_item_list=[]
tsai_item_list=[]
km_item_list=[]
cox_item_list=[]
true_prob=np.exp(-(times/1000)**a)
elapsed = 0

for i in range(len(times)):
    tt1 = time.time()
    wu_item_list.append([])
    tsai_item_list.append([])
    km_item_list.append([])
    cox_item_list.append([])
    for trial_i in range(n_trials):
        n_samples = len(loaded_d["truncation"][trial_i])
        y1 = loaded_d["y1"][trial_i]
        cov1 = loaded_d["cov1"][trial_i]
        delta = loaded_d["delta"][trial_i]
        truncation = loaded_d["truncation"][trial_i]
        wu_weibull = loaded_d["wu_weibull"][trial_i]
        tsai_weibull = loaded_d["tsai_weibull"][trial_i]
        death_sum = loaded_d["death_sum"][trial_i]
        exp_term  = loaded_d["exp_term"][trial_i]

        wu_item=baseline_survival_sim(times[i], y1, cov1, delta, truncation, hazard_est=wu_weibull)
        tsai_item=baseline_survival_sim(times[i], y1[cov1==0], cov1[cov1==0], delta[cov1==0], truncation[cov1==0], hazard_est=tsai_weibull)
        wu_item = wu_item.item()
        tsai_item = tsai_item.item()
        wu_item_list[-1].append(wu_item)
        tsai_item_list[-1].append(tsai_item)
        kmf.fit(durations=y1[cov1==0], event_observed=delta[cov1==0], entry=truncation[cov1==0])
        km_item=np.mean(kmf.survival_function_at_times(times[i]))
        km_item_list[-1].append(km_item)
        bashaz_index=np.digitize(times[i], bins=np.sort(y1), right=True)-1
        cox_item=np.where(bashaz_index==-1, 1, np.exp(-np.cumsum(death_sum/exp_term))[bashaz_index])
        cox_item_list[-1].append(cox_item)
    dt = time.time() - tt1
    elapsed += dt
    eta = (len(times) - i - 1) * (elapsed / (i+1))
    print("Times:%04d/%04d  dt:%.3fs  elapsed:%.3fs  ETA:%.3fs"%(i, len(times), dt, elapsed, eta))



In [None]:
# visualization
#times = np.linspace(110, 1190, 100)
# times = np.linspace(0, 1196, 100)
times = np.linspace(0, 2052, 100)
wu_avg = np.mean(np.array(wu_item_list), axis=-1)
tsai_avg = np.mean(np.array(tsai_item_list), axis=-1)
km_avg = np.mean(np.array(km_item_list), axis=-1)
cox_avg = np.mean(np.array(cox_item_list), axis=-1)
linewidth=1
plt.plot(times, wu_avg, label="wu", color="red", linestyle="--",linewidth=linewidth)
plt.plot(times, tsai_avg, label="tsai", color="orange", linestyle="-", linewidth=linewidth)
plt.plot(times, km_avg, label="km", color="green", linestyle="-.", linewidth=linewidth)
plt.plot(times, cox_avg, label="cox", color="blue", linestyle=":", linewidth=linewidth)
plt.plot(times, true_prob, label="true", color="black", linestyle='--', linewidth=linewidth)
#plt.title('shape=4; trun=1200; bias: 0.0332| 0.0340| -0.1458   MCSE: 0.0702| 0.0720| 0.1831; N=1000', fontsize=10)

plt.legend()
plt.xlabel("Time")
plt.ylabel("Survival")
plt.show()

result_df = pd.DataFrame(data={'wu_avg': wu_avg, 'tsai_avg': tsai_avg,'cox_avg': cox_avg,'true_prob': true_prob})
# save data
from google.colab import drive
drive.mount('/content/drive')
result_df.to_csv('/content/drive/My Drive/result_df.csv')