In [2]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn_pandas import DataFrameMapper
from scipy.integrate import trapz
import torch
import torchtuples as tt
from pycox.datasets import metabric
from pycox.models import LogisticHazard, PMF, DeepHitSingle, CoxPH, MTLR
from pycox.evaluation import EvalSurv
from survival_evaluation import d_calibration, l1, one_calibration
import random
import statistics
from pdb import set_trace
from sklearn.metrics import mean_squared_error, roc_auc_score
from sksurv.metrics import cumulative_dynamic_auc

In [3]:
path = './datacsv/battery_signature_features.csv'
D = pd.read_csv(path)
x_cols = D.columns[-15:].tolist()
event_col = ['event']
time_col = ['time']
D = D[x_cols + event_col + time_col]
d = D.copy()

cols_standardize = d.columns.values.tolist()
cols_standardize.pop(1)
cols_standardize.pop(4)
if len(cols_standardize) > 4:
    cols_standardize.pop(1)
    cols_standardize.pop(4)

n_exp = 5
CI = []
IBS = []
L1_hinge = []
L1_margin = []
AUC_scores = []

In [4]:
def train_val_test_stratified_split(df, stratify_colname='y', frac_train=0.6, frac_val=0.15, frac_test=0.25, random_state=None):
    if frac_train + frac_val + frac_test != 1.0:
        raise ValueError('fractions %f, %f, %f do not add up to 1.0' % (frac_train, frac_val, frac_test))
    if stratify_colname not in df.columns:
        raise ValueError('%s is not a column in the dataframe' % (stratify_colname))
    X = df
    y = df[[stratify_colname]]
    df_train, df_temp, y_train, y_temp = train_test_split(X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state)
    relative_frac_test = frac_test / (frac_val + frac_test)
    if relative_frac_test == 1.0:
        df_val, df_test, y_val, y_test = [], df_temp, [], y_temp
    else:
        df_val, df_test, y_val, y_test = train_test_split(df_temp, y_temp, stratify=y_temp, test_size=relative_frac_test, random_state=random_state)
    assert len(df) == len(df_train) + len(df_val) + len(df_test)
    return df_train, df_val, df_test

In [5]:
for i in range(n_exp):
    df_train, df_val, df_test = train_val_test_stratified_split(d, 'event', frac_train=0.80, frac_val=0.05, frac_test=0.15, random_state=10)

    standardize = [([col], StandardScaler()) for col in cols_standardize]
    x_mapper = DataFrameMapper(standardize)
    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
    in_features = x_train.shape[1]

    num_durations = 10
    labtrans = DeepHitSingle.label_transform(num_durations)
    get_target = lambda df: (df['time'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))
    val = tt.tuplefy(x_val, y_val)
    durations_test, events_test = get_target(df_test)

    out_features = labtrans.out_features
    num_nodes = [32, 32]
    batch_norm = True
    dropout = 0.1
    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)
    model = DeepHitSingle(net, tt.optim.Adam, duration_index=labtrans.cuts)

    batch_size = 256
    lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=6)
    model.optimizer.set_lr(0.01)

    epochs = 30
    callbacks = [tt.cb.EarlyStopping()]
    model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, verbose=0)

    surv = model.interpolate(10).predict_surv_df(x_test)
    surv_df = pd.DataFrame(surv)
    surv_df.index.name = 'time'
    surv_df.columns.name = 'survival_function'
    surv_df.to_csv("./datacsv/test-data/discharge_deephit.csv", index=True)

    survival_predictions = pd.Series(trapz(surv.values.T, surv.index), index=df_test.index)
    l1_hinge_value = l1(df_test.time, df_test.event, survival_predictions, l1_type='hinge')
    l1_margin_value = l1(df_test.time, df_test.event, survival_predictions, df_train.time, df_train.event, l1_type='margin')

    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index = ev.concordance_td('antolini')
    time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
    brier = ev.integrated_brier_score(time_grid)

    quantiles = np.sort(df_test['time'].unique())
    labels_train = np.array([(e, t) for e, t in zip(df_train['event'], df_train['time'])], dtype=[('event', 'bool'), ('time', 'float')])
    labels_test = np.array([(e, t) for e, t in zip(df_test['event'], df_test['time'])], dtype=[('event', 'bool'), ('time', 'float')])

    auc_scores = []
    for eval_time in quantiles:
        try:
            # 修复：使用 surv.index 替代 time_grid_train
            interp_time_index = np.argmin(np.abs(eval_time - surv.index.values))
            surv_values_at_eval_time = surv.iloc[interp_time_index].values
            estimated_risks = 1 - surv_values_at_eval_time

            if np.min(estimated_risks) == np.max(estimated_risks):
                continue

            auc = cumulative_dynamic_auc(labels_train, labels_test, estimated_risks, times=[eval_time])[0][0]

            if not np.isnan(auc) and not np.isinf(auc):
                auc_scores.append(auc)
        except Exception as e:
            print(f"AUC calculation failed: {e}, eval_time={eval_time}")

    AUC_scores.append(np.mean(auc_scores) if auc_scores else 0.5)
    CI.append(c_index)
    IBS.append(brier)
    L1_hinge.append(l1_hinge_value)
    L1_margin.append(l1_margin_value)

def safe_stat(data):
    return round(statistics.mean(data), 3), round(statistics.stdev(data), 3) if len(data) > 1 else (0.0, 0.0)

auc_mean, auc_std = safe_stat(AUC_scores)

print('CI:', round(statistics.mean(CI), 3), round(statistics.stdev(CI), 3))
print('IBS:', round(statistics.mean(IBS), 3), round(statistics.stdev(IBS), 3))
print('L1_hinge:', round(statistics.mean(L1_hinge), 3), round(statistics.stdev(L1_hinge), 3))
print('L1_margin:', round(statistics.mean(L1_margin), 3), round(statistics.stdev(L1_margin), 3))
print(f'AUC: {auc_mean} ± {auc_std}')

print('d_calibration_p_value:', round(d_calibration(df_test.event, surv.iloc[6])['p_value'], 3))
print('D-Calibration:', round(sum(d_calibration(df_test.event, surv.iloc[6])['bin_proportions']), 3))
print('d_calibration_bin_proportions:')
for i in d_calibration(df_test.event, surv.iloc[6])['bin_proportions']:
    print(i)
print('D-Calibration_censored:', round(sum(d_calibration(df_test.event, surv.iloc[6])['censored_contributions']), 3))
print('d_calibration_censored_contributions:')
for i in d_calibration(df_test.event, surv.iloc[6])['censored_contributions']:
    print(i)
print('D-Calibration_uncensored:', round(sum(d_calibration(df_test.event, surv.iloc[6])['uncensored_contributions']), 3))
print('d_calibration_uncensored_contributions:')
for i in d_calibration(df_test.event, surv.iloc[6])['uncensored_contributions']:
    print(i)


  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094


  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094


  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094


  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094


  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
CI: 0.93 0.02
IBS: 0.045 0.002
L1_hinge: 60.889 9.043
L1_margin: 75.932 11.274
AUC: 0.975 ± 0.005
d_calibration_p_value: 0.0
D-Calibration: 1.0
d_calibration_bin_proportions:
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.8217217609286308
D-Calibration_censored: 0.198
d_calibration_censored_contributions:
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.01980869099497795
0.019638427595297492
D-Calibration_uncensored: 0.802
d_calibration_uncensored_contributions:
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.8020833333333334
