In [5]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn_pandas import DataFrameMapper
from sklearn.metrics import mean_squared_error, roc_auc_score
from scipy.integrate import trapz
import torch
import torchtuples as tt
from pycox.models import CoxCC
from pycox.evaluation import EvalSurv
from survival_evaluation import d_calibration, l1
import statistics
from pycox.models.cox_time import MLPVanillaCoxTime
from pycox.models import LogisticHazard, PMF, DeepHitSingle, CoxPH, MTLR, CoxTime
from sksurv.metrics import cumulative_dynamic_auc

In [6]:
AUC_scores = []

# Read and process the data
path = './datacsv/battery_signature_features.csv'
D = pd.read_csv(path)

# Configure feature columns as the last 15 columns
x_cols = D.columns[-15:].tolist()  # The last 15 columns are features
battery_battery_name = ['sheet_cycle']
event_col = ['event']  # Survival event column
time_col = ['time']    # Survival time column

# Data cleaning and column selection
D = D[x_cols + event_col + time_col]

# Copy DataFrame to avoid modifying the original data
d = D.copy()

# Set columns to be standardized
cols_standardize = x_cols.copy()
n_exp = 5  # Number of experiments

# Model evaluation metrics
CI = []
IBS = []
L1_hinge = []
RMSE = []
AUC = []

In [7]:
# 数据分割函数
def train_val_test_stratified_split(df, stratify_colname='y', frac_train=0.6, frac_val=0.15, frac_test=0.25, random_state=None):
    if frac_train + frac_val + frac_test != 1.0:
        raise ValueError('fractions %f, %f, %f do not add up to 1.0' % (frac_train, frac_val, frac_test))
    if stratify_colname not in df.columns:
        raise ValueError('%s is not a column in the dataframe' % (stratify_colname))
    X = df
    y = df[[stratify_colname]]
    df_train, df_temp, y_train, y_temp = train_test_split(X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state)
    relative_frac_test = frac_test / (frac_val + frac_test)
    if relative_frac_test == 1.0:
        df_val, df_test, y_val, y_test = [], df_temp, [], y_temp
    else:
        df_val, df_test, y_val, y_test = train_test_split(df_temp, y_temp, stratify=y_temp, test_size=relative_frac_test, random_state=random_state)
    assert len(df) == len(df_train) + len(df_val) + len(df_test)
    return df_train, df_val, df_test

In [8]:
for i in range(n_exp):
    df_train, df_val, df_test = train_val_test_stratified_split(d, 'event', frac_train=0.8, frac_val=0.05, frac_test=0.15, random_state=10)
    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = []
    x_mapper = DataFrameMapper(standardize + leave)
    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
    in_features = x_train.shape[1]

    num_durations = 10
    get_target = lambda df: (df['time'].values, df['event'].values)
    y_train = get_target(df_train)
    y_val = get_target(df_val)
    val = (x_val, y_val)
    durations_test, events_test = get_target(df_test)

    out_features = 1
    num_nodes = [32, 32]
    batch_norm = True
    dropout = 0.1
    output_bias = False

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, output_bias=output_bias)

    model = CoxCC(net, tt.optim.Adam)

    batch_size = 256
    lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)
    model.optimizer.set_lr(0.01)

    epochs = 512
    callbacks = [tt.cb.EarlyStopping()]
    model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, verbose=0, val_batch_size=batch_size)

    _ = model.compute_baseline_hazards()
    surv = model.predict_surv_df(x_test)

    surv_df = pd.DataFrame(surv)
    surv_df.index.name = 'time'
    surv_df.columns.name = 'survival_function'
    surv_df.to_csv("./datacsv/test-data/discharge_Cox.csv", index=True)

    survival_predictions = pd.Series(trapz(surv.values.T, surv.index), index=df_test.index)
    l1_hinge = l1(df_test.time, df_test.event, survival_predictions, l1_type = 'hinge')

    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index = ev.concordance_td('antolini')
    time_grid = np.linspace(durations_test.min(), durations_test.max(), 10)
    brier = ev.integrated_brier_score(time_grid)

    quantiles = np.sort(df_test['time'].unique())

    labels_train = np.array([(e, t) for e, t in zip(df_train['event'], df_train['time'])], dtype=[('event', 'bool'), ('time', 'float')])
    labels_test = np.array([(e, t) for e, t in zip(df_test['event'], df_test['time'])], dtype=[('event', 'bool'), ('time', 'float')])
    time_grid_train = np.unique(df_train['time'])

    auc_scores = []
    for eval_time in quantiles:
        try:
            interp_time_index = np.argmin(np.abs(eval_time - surv.index))
            surv_values_at_eval_time = surv.iloc[interp_time_index].values
            estimated_risks = 1 - surv_values_at_eval_time
            if np.min(estimated_risks) == np.max(estimated_risks):
                continue
            auc = cumulative_dynamic_auc(labels_train, labels_test, estimated_risks, times=[eval_time])[0][0]
            if not np.isnan(auc) and not np.isinf(auc):
                auc_scores.append(auc)
        except Exception as e:
            print(f"AUC calculation failed: {e}, eval_time={eval_time}")

    AUC_scores.append(np.mean(auc_scores) if auc_scores else 0.5)
    CI.append(c_index)
    IBS.append(brier)
    L1_hinge.append(l1_hinge)

def safe_stat(data):
    return round(statistics.mean(data), 3), round(statistics.stdev(data), 3) if len(data) > 1 else (0.0, 0.0)
auc_mean, auc_std = safe_stat(AUC_scores)

print('CI:', round(statistics.mean(CI), 3), '±', round(statistics.stdev(CI), 3))
print('IBS:', round(statistics.mean(IBS), 3), '±', round(statistics.stdev(IBS), 3))
print('L1_hinge:', round(statistics.mean(L1_hinge), 3), '±', round(statistics.stdev(L1_hinge), 3))
print(f'AUC: {auc_mean} ± {auc_std}')

d_calib = d_calibration(df_test['event'], surv.iloc[6])
print('d_calibration_p_value:', round(d_calib['p_value'], 3))
print('D-Calibration (bin proportions):', round(sum(d_calib['bin_proportions']), 3))
print('D-Calibration (censored contributions):', round(sum(d_calib['censored_contributions']), 3))
print('D-Calibration (uncensored contributions):', round(sum(d_calib['uncensored_contributions']), 3))

AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
CI: 0.965 ± 0.004
IBS: 0.01 ± 0.002
L1_hinge: 2111.432 ± 1.101
AUC: 0.999 ± 0.001
d_calibration_p_value: 0.0
D-Calibration (bin proportions): 1.0
D-Calibration (censored contributions): 0.198
D-Calibration (uncensored contributions): 0.802
