In [None]:
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import pandas as pd
from sksurv.svm import FastKernelSurvivalSVM
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sksurv.metrics import concordance_index_censored, integrated_brier_score, brier_score
from lifelines.utils.sklearn_adapter import sklearn_adapter
from lifelines.utils import concordance_index
from scipy.stats import gumbel_r, norm, logistic
import lifelines.datasets as dset
import math
from sklearn.calibration import calibration_curve, CalibrationDisplay
from matplotlib.gridspec import GridSpec
import warnings
from scipy.special import erf
from sklearn.preprocessing import StandardScaler
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchtuples as tt

from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox.models import DeepHit
from pycox.evaluation import EvalSurv
import pandas as pd
import numpy as np
import os,sys
sys.path.append("../")
from dirac_phi import DiracPhi
from survival import MixExpPhiStochastic,InnerGenerator,InnerGenerator2,sample,HACSurv_3D_Sym_shared
import torch.optim as optim
import torch
from tqdm import tqdm


# 固定随机数种子
np.random.seed(42)

df = pd.read_csv('./mimic_time_data_withoutETT.csv')

# 数据分割
df_test = df.sample(frac=0.2, random_state=42)
df_train = df.drop(df_test.index)
df_val = df_train.sample(frac=0.2, random_state=42)
df_train = df_train.drop(df_val.index)
df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)

# 提取特征
x_train = df_train.drop(columns=['time', 'death_reason', 'label']).values.astype('float32')
x_val = df_val.drop(columns=['time', 'death_reason', 'label']).values.astype('float32')
x_test = df_test.drop(columns=['time', 'death_reason', 'label']).values.astype('float32')

# 提取时间和事件标签
get_target = lambda df: (df['time'].values, df['death_reason'].values)
time_train, event_train = get_target(df_train)
time_val, event_val = get_target(df_val)
time_test, event_test = get_target(df_test)

# 将数据转换为张量，并移动到指定设备上（GPU或CPU）
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

covariate_tensor_train = torch.tensor(x_train, dtype=torch.float64).to(device)
covariate_tensor_val = torch.tensor(x_val, dtype=torch.float64).to(device)
covariate_tensor_test = torch.tensor(x_test, dtype=torch.float64).to(device)

times_tensor_train = torch.tensor(time_train, dtype=torch.float64).to(device)
event_indicator_tensor_train = torch.tensor(event_train, dtype=torch.float64).to(device)

times_tensor_val = torch.tensor(time_val, dtype=torch.float64).to(device)
event_indicator_tensor_val = torch.tensor(event_val, dtype=torch.float64).to(device)

times_tensor_test = torch.tensor(time_test, dtype=torch.float64).to(device)
event_indicator_tensor_test = torch.tensor(event_test, dtype=torch.float64).to(device)

# print(covariate_tensor_val)
torch.set_num_threads(16)
torch.set_default_tensor_type(torch.DoubleTensor)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

num_epochs = 100000
batch_size = 300
early_stop_epochs = 1500

selected_indicators = [1,3,5]

model_dir = './MIMIC-III/checkpoint'
figures_dir = './MIMIC-III/figure'

phi = MixExpPhiStochastic(device)
model =HACSurv_3D_Sym_shared(phi, device = device, num_features=x_train.shape[1], tol=1e-14, hidden_size = 100).to(device)
optimizer = optim.AdamW([{"params": model.sumo_e1.parameters(), "lr": 1e-4},
                         {"params": model.sumo_e2.parameters(), "lr": 1e-4},
                        {"params": model.sumo_c.parameters(), "lr": 1e-4},
                        # {"params": model.phi.parameters(), "lr": 3e-4}
                        {"params": model.phi.parameters(), "lr": 2e-4}
                    ])

best_val_loglikelihood = float('-inf')
epochs_no_improve = 0
# for epoch in tqdm(range(num_epochs)):
for epoch in range(num_epochs):
    optimizer.zero_grad()
    model.phi.resample_M(200)
    logloss = model(covariate_tensor_train, times_tensor_train, event_indicator_tensor_train, max_iter = 10000,selected_indicators=[1,3,5])
    # scaleloss = torch.square(torch.mean(model.phi.M)-1)
    # reg_loss = logloss+scaleloss
    (-logloss).backward() 
    optimizer.step()
    if epoch % 100 == 0 and epoch > 0:
        # Validation and logging
        print("Epoch", epoch, "Train loglikelihood:", logloss.item())
        model.phi.resample_M(200)
        val_loglikelihood = model(covariate_tensor_val, times_tensor_val, event_indicator_tensor_val, max_iter=10000,selected_indicators=[1,3,5])
        print("Validation likelihood:", val_loglikelihood.item())

        # Model checkpointing
        if val_loglikelihood > best_val_loglikelihood:
            best_val_loglikelihood = val_loglikelihood
            epochs_no_improve = 0
            indicators_str = ''.join(map(str, selected_indicators))
            checkpoint_path = os.path.join(model_dir, f'MIMIC_e135_sharedHACSurv.pth')
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': best_val_loglikelihood}, checkpoint_path)

            # Generate and save plots
            print('Scatter sampling')
            samples = sample(model, 2, 2000, device=device)
            plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), s=15)
            plot_path = os.path.join(figures_dir, f'MIMIC_e135_sharedHACSurv.png')
            plt.savefig(plot_path)
            plt.clf()

        else:
            epochs_no_improve += 100

        # Early stopping condition
        if epochs_no_improve >= early_stop_epochs:
            print('Early stopping triggered at epoch:', epoch)
            break

Epoch 100 Train loglikelihood: -40505.15495092473
Validation likelihood: -9898.372913866388
Scatter sampling
Sampling from dim: 1
Epoch 200 Train loglikelihood: -33856.54557412451
Validation likelihood: -8503.786815441183
Scatter sampling
Sampling from dim: 1
Epoch 300 Train loglikelihood: -29735.021735904418
Validation likelihood: -7484.613893502482
Scatter sampling
Sampling from dim: 1
Epoch 400 Train loglikelihood: -23328.427001659595
Validation likelihood: -5833.037466623371
Scatter sampling
Sampling from dim: 1
Epoch 500 Train loglikelihood: -16115.06751842512
Validation likelihood: -4056.2389932412666
Scatter sampling
Sampling from dim: 1
Epoch 600 Train loglikelihood: -13236.004046969005
Validation likelihood: -3345.1118275094623
Scatter sampling
Sampling from dim: 1
Epoch 700 Train loglikelihood: -11726.61714453782
Validation likelihood: -2949.408806550214
Scatter sampling
Sampling from dim: 1
Epoch 800 Train loglikelihood: -11379.137098943957
Validation likelihood: -2862.58069

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [None]:
##训练inner
## 应该学习两个inner， 135和24
from torch.utils.data import TensorDataset, DataLoader
from survival import sample,Copula,HACSurv_2D_shared,HACSurv_3D_Sym_shared
def sampleInner(phi, ndims, n, M0=None):
    device = phi.mu.device  # 获取phi.mu的设备，确保所有操作都在同一设备上进行
    
    if M0 is None:
        M0 = phi.psi.sample_M(n).to(device)  # 确保M0在正确的设备上
    
    njumps = torch.poisson(torch.exp(phi.beta) * M0).int()
    
    # 确保c也在同一个设备
    c = torch.tensor([phi.sample_M(njumps[i].item()).sum() for i in range(n)], device=device)
    
    lso = torch.exp(phi.mu) * M0 + c  # 现在phi.mu, M0和c都在同一个设备上
    
    M = lso[:, None].expand(-1, ndims)
    e = torch.distributions.exponential.Exponential(torch.ones((n, ndims), device=device))  # 确保指数分布也在同一设备上
    E = e.sample()
    
    return phi.forward(E / M)
# 固定随机数种子
torch.set_default_tensor_type(torch.DoubleTensor)
phi = MixExpPhiStochastic(device)
# selected_indicators = [0,1]
##学习24
# model =HACSurv_2D_shared(phi, device = device, num_features=x_train.shape[1], tol=1e-10, hidden_size = 100).to(device)
#学习135
model =HACSurv_3D_Sym_shared(phi, device = device, num_features=x_train.shape[1], tol=1e-14, hidden_size = 100).to(device)

print(device)
#要学习的inner的copula参数
checkpoint = torch.load('./MIMIC-III/checkpoint/MIMIC_e135_sharedHACSurv.pth')
model.load_state_dict(checkpoint['model_state_dict'])

model.phi.resample_M(200)
samples = sample(model, 2, 5000, device =  device)
# print(samples.shape)
num_train_samples = 4000
num_test_samples = 1000  # 或者通过 samples.shape[0] - num_train_samples 计算得到


plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), s=15)
# os.makedirs('/home/liuxin/GenACSurvival-main/sample_figs/'+copula_form+'/'+str(theta_true), exist_ok=True)
plt.savefig('./MIMIC-III/figure/HACMIMIC_135.png' )


# 分割样本
train_data = samples[:num_train_samples]
test_data = samples[num_train_samples:num_train_samples + num_test_samples]
psi = MixExpPhiStochastic(device)
#选取outer copula
ckpt_path_out = './MIMIC-III/checkpoint/MIMIC_e05.pth'
    
ckpt_out = torch.load(ckpt_path_out)
        # # # print(ckpt_out)
        
phi_out_keys = {k.replace('phi.', ''): v for k, v in ckpt_out['model_state_dict'].items() if 'phi' in k and 'phi_inv' not in k}
psi.load_state_dict(phi_out_keys)  



psi_cop = Copula(psi,device)

phi = InnerGenerator(psi,device)
net = Copula(phi,device)
optim_args = \
    {
        'lr': 1e-5, # it is 1e-3 since torch.sum was used instead of torch.mean for loglikelihood
        'momentum': 0.9
    }

optimizer = optim.SGD(net.parameters(), optim_args['lr'], optim_args['momentum'])
num_epochs = 10000
batch_size = 4000
chkpt_freq = 500
patience = 1000
best_val_loss = np.inf
epochs_no_improve = 0
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000000, shuffle=True)
    
train_loss_per_epoch = []
# print('1')
for epoch in range(num_epochs):
        # print(epoch)
        loss_per_minibatch = []
        for i, data in enumerate(train_loader, 0):
            optimizer.zero_grad()

            if hasattr(net.phi.psi, 'resample_M'):
                net.phi.psi.resample_M(100)
            net.phi.resample_M(100)

            d = data.detach().clone().to(device)
            p = net(d, mode='pdf')
            # print('2')
            scaleloss = torch.square(torch.mean(net.phi.M)-1)
            logloss = -torch.sum(torch.log(p))
            reg_loss = logloss+scaleloss
            reg_loss.backward()
            # loss_per_minibatch.append((logloss / p.numel()).detach().cpu().numpy())
            # loss_per_minibatch.append((logloss/p.numel()).detach().numpy())
            optimizer.step()

        # train_loss_per_epoch.append(np.mean(loss_per_minibatch))

        # 验证集性能评估
        if epoch % 100 == 0:
            val_logloss = 0
            for i, data in enumerate(test_loader, 0):
                if hasattr(net.phi.psi, 'resample_M'):
                    net.phi.psi.resample_M(1000)
                net.phi.resample_M(1000)
                d = data.detach().clone().to(device)
                p = net(d, mode='pdf')
                val_logloss += -torch.mean(torch.log(p)).item()
            val_logloss /= len(test_loader)

            # 更新最佳模型
            if val_logloss < best_val_loss:
                best_val_loss = val_logloss
                epochs_no_improve = 0
                # 保存模型状态
                checkpoint_path = os.path.join('./MIMIC-III/checkpoint', 'Model_inner_e1e3e5_step2_05outer.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': val_logloss,
                }, checkpoint_path)

                # 生成并保存样本图像
                samples = sampleInner(net.phi, 2, 5000).detach().cpu()
                plt.scatter(samples[:, 0], samples[:, 1],s=15)
                # plt.axis("square")
                plot_path = os.path.join('./MIMIC-III/figure', 'Model_inner_e1e3e5_step2_05outer.png')
                plt.savefig(plot_path)
                plt.clf()

                print('Epoch {}: Train {}, Val {}'.format(epoch, reg_loss.item(), val_logloss))
            else:
                epochs_no_improve += 100

            # 早停检查
            if epochs_no_improve >= patience:
                print('Early stopping triggered at epoch:', epoch)
                break

cuda:1
Sampling from dim: 1
Epoch 0: Train -423.20985593340436, Val -0.08711694476418459
Epoch 100: Train -520.6144569841939, Val -0.12144183278371501
Epoch 200: Train -1580.2365629492901, Val -0.42484742569839906
Epoch 300: Train -2221.446200764324, Val -0.5684877581758566
Epoch 400: Train -2542.0710500220607, Val -0.6133586408909045
Epoch 500: Train -2558.506074953206, Val -0.6294615030640216
Epoch 600: Train -2575.3783541036228, Val -0.6534670908730695


KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>