In [None]:
import pandas as pd
import numpy as np
import os
import math
import random
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import time
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold, StratifiedKFold
import torch
import torch.nn as nn
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import shap
from lifelines.utils import concordance_index
from sksurv.metrics import cumulative_dynamic_auc
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from functools import partial
from scipy.stats import chi2_contingency, ttest_ind

pd.set_option ('display.max_columns', None)
pd.set_option ('display.max_rows', None)
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch Version: {torch.__version__}")

In [None]:
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"随机数种子已设置为: {seed_value}")

def preprocess_dataframe(df):
    df_copy = df.copy().reset_index(drop=True)
    if 'VALUE_NUMERIC' in df_copy.columns:
        df_copy['VALUE_NUMERIC'] = df_copy['VALUE_NUMERIC'].fillna(0.0)
    if 'VALUE_CATEGORICAL' in df_copy.columns:
        df_copy['VALUE_CATEGORICAL'] = df_copy['VALUE_CATEGORICAL'].fillna('Missing')
    return df_copy
    
def encode_categorical_features_leakproof(train_df, val_df, categorical_cols):
    vocab_mappings = {}
    train_df_encoded = train_df.copy()
    val_df_encoded = val_df.copy()
    for col in categorical_cols:
        train_df_encoded[col] = train_df_encoded[col].astype(str)
        unique_vals = train_df_encoded[col].unique()
        vocab = {val: i + 1 for i, val in enumerate(unique_vals)}
        vocab['<PAD>'] = 0
        vocab['<UNK>'] = len(vocab)
        train_df_encoded[col + '_encoded'] = train_df_encoded[col].map(vocab)
        val_df_encoded[col + '_encoded'] = val_df_encoded[col].astype(str).map(vocab)
        val_df_encoded[col + '_encoded'].fillna(vocab['<UNK>'], inplace=True)
        vocab_mappings[col] = {'vocab': vocab, 'vocab_size': len(vocab)}
    return train_df_encoded, val_df_encoded, vocab_mappings

def normalize_numerical_features_leakproof(train_df, val_df, numerical_cols):
    scaler = StandardScaler()
    train_df_normalized = train_df.copy()
    val_df_normalized = val_df.copy()
    if numerical_cols:
        train_df_normalized.loc[:, numerical_cols] = scaler.fit_transform(train_df[numerical_cols])
        val_df_normalized.loc[:, numerical_cols] = scaler.transform(val_df[numerical_cols])
    return train_df_normalized, val_df_normalized, scaler

class PatientSequenceDataset(Dataset):
    def __init__(self, df, numerical_cols, categorical_cols_encoded):
        self.sample_groups = {}
        self.sample_ids = []
        for sid, group in df.groupby('SAMPLE_ID'):
            self.sample_ids.append(sid)
            x_numerical = torch.tensor(group[numerical_cols].values, dtype=torch.float32)
            x_categorical = {
                col.replace('_encoded', ''): torch.tensor(group[col].values, dtype=torch.long)
                for col in categorical_cols_encoded
            }
            label_time = torch.tensor(group['time'].iloc[0], dtype=torch.float32)
            label_dead = torch.tensor(group['dead'].iloc[0], dtype=torch.float32)
            self.sample_groups[sid] = (x_numerical, x_categorical, (label_time, label_dead))

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        return self.sample_groups[self.sample_ids[idx]]

def collate_fn_pad(batch):
    (numericals, categoricals_list, labels) = zip(*batch)
    padded_numericals = pad_sequence(numericals, batch_first=True, padding_value=0.0)
    if categoricals_list and categoricals_list[0]:
        categorical_keys = categoricals_list[0].keys()
        categoricals_padded = {}
        for key in categorical_keys:
            sequences = [cat[key] for cat in categoricals_list]
            categoricals_padded[key] = pad_sequence(sequences, batch_first=True, padding_value=0)
    else:
        categoricals_padded = {}
    label_times, label_deads = zip(*labels)
    stacked_labels = (torch.stack(label_times), torch.stack(label_deads))
    return padded_numericals, categoricals_padded, stacked_labels

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:x.size(1)].transpose(0, 1)
        return self.dropout(x)

class SurvivalTransformer(nn.Module):
    def __init__(self, vocab_sizes, embedding_dims, num_numerical_features, d_model=128, nhead=8, num_encoder_layers=4, dim_feedforward=256, dropout_prob=0.5):
        super().__init__()
        self.categorical_keys = sorted(vocab_sizes.keys())
        self.embedding_layers = nn.ModuleDict({
            key: nn.Embedding(vocab_sizes[key], embedding_dims[key], padding_idx=0)
            for key in self.categorical_keys
        })
        total_embedding_dim = sum(embedding_dims.values())
        input_dim = num_numerical_features + total_embedding_dim
        self.input_projection = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout_prob)
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, 
                dropout=dropout_prob, batch_first=True, activation='gelu'
            ) for _ in range(num_encoder_layers)
        ])
        self.fc = nn.Linear(d_model, 1)
        self.d_model = d_model

    def forward(self, x_numerical, x_categorical=None, pre_embedded_categorical=None, return_attention=False):
        if x_categorical is not None and self.categorical_keys:
            src_key_padding_mask = (x_categorical[self.categorical_keys[0]] == 0)
        elif pre_embedded_categorical is not None and pre_embedded_categorical:
            src_key_padding_mask = (pre_embedded_categorical[0].sum(dim=-1) == 0)
        else: 
            src_key_padding_mask = torch.zeros(x_numerical.shape[0], x_numerical.shape[1], dtype=torch.bool, device=x_numerical.device)

        if pre_embedded_categorical is None:
            embeds = [self.embedding_layers[key](x_categorical[key]) for key in self.categorical_keys] if self.categorical_keys else []
        else:
            embeds = pre_embedded_categorical

        combined_features = torch.cat([x_numerical] + embeds, dim=2)
        projected_features = self.input_projection(combined_features) * math.sqrt(self.d_model)
        transformer_input = self.pos_encoder(projected_features)
        
        attention_weights = None
        output = transformer_input
        for i, layer in enumerate(self.encoder_layers):
            is_last_layer = (i == len(self.encoder_layers) - 1)
            if is_last_layer and return_attention:
                attn_output, attention_weights = layer.self_attn(
                    output, output, output, 
                    key_padding_mask=src_key_padding_mask,
                    need_weights=True,
                    average_attn_weights=False
                )
                output = output + layer.dropout1(attn_output)
                output = layer.norm1(output)
                ff_output = layer.linear2(layer.dropout(layer.activation(layer.linear1(output))))
                output = output + layer.dropout2(ff_output)
                output = layer.norm2(output)
            else:
                output = layer(output, src_key_padding_mask=src_key_padding_mask)
        
        transformer_output = output
        non_padding_mask = ~src_key_padding_mask
        seq_lengths = non_padding_mask.sum(dim=1, keepdim=True)
        masked_output = transformer_output * non_padding_mask.unsqueeze(-1)
        summed_output = masked_output.sum(dim=1)
        mean_output = summed_output / seq_lengths.clamp(min=1)
        risk_score = self.fc(mean_output)
        
        if return_attention:
            return risk_score, attention_weights
        else:
            return risk_score

def cox_loss(risk_scores, times, events):
    risk_scores = risk_scores.squeeze(-1)
    sorted_indices = torch.argsort(times, descending=True)
    risk_scores_sorted = risk_scores[sorted_indices]
    events_sorted = events[sorted_indices]
    log_risk_set_sum = torch.log(torch.cumsum(torch.exp(risk_scores_sorted), dim=0))
    loss = -torch.sum(risk_scores_sorted[events_sorted.bool()] - log_risk_set_sum[events_sorted.bool()])
    num_events = torch.sum(events)
    if num_events > 0:
        loss = loss / num_events
    return loss

def train_one_epoch(model, dataloader, optimizer, loss_fn, scaler, device):
    model.train()
    for batch_numerical_cpu, batch_categorical_cpu, (times_cpu, events_cpu) in dataloader:
        batch_numerical = batch_numerical_cpu.to(device)
        batch_categorical = {k: v.to(device) for k, v in batch_categorical_cpu.items()}
        times, events = times_cpu.to(device), events_cpu.to(device)
        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type=str(device).split(":")[0]):
            risk_scores = model(batch_numerical, batch_categorical)
            loss = loss_fn(risk_scores, times, events)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

def evaluate_model(model, dataloader, loss_fn, train_df_outcomes, device):
    model.eval()
    total_loss = 0
    all_risk_scores, all_times, all_events = [], [], []
    with torch.no_grad():
        for batch_numerical, batch_categorical, (times, events) in dataloader:
            batch_numerical = batch_numerical.to(device)
            batch_categorical = {k: v.to(device) for k, v in batch_categorical.items()}
            with autocast(device_type=str(device).split(":")[0]):
                risk_scores = model(batch_numerical, batch_categorical)
                loss = loss_fn(risk_scores, times.to(device), events.to(device))
            total_loss += loss.item()
            all_risk_scores.append(risk_scores.cpu())
            all_times.append(times.cpu())
            all_events.append(events.cpu())

    all_risk_scores_np = torch.cat(all_risk_scores).numpy()
    all_times_np = torch.cat(all_times).numpy()
    all_events_np = torch.cat(all_events).numpy()
    
    c_index = concordance_index(all_times_np, -all_risk_scores_np.squeeze(), all_events_np)
    
    train_outcomes_struct = np.array(list(zip(train_df_outcomes['dead'].astype(bool), train_df_outcomes['time'])), dtype=[('event', bool), ('time', float)])
    val_outcomes_struct = np.array(list(zip(all_events_np.astype(bool), all_times_np)), dtype=[('event', bool), ('time', float)])
    
    event_times = train_df_outcomes[train_df_outcomes['dead'] == 1]['time']
    if len(event_times) > 10:
        min_time, max_time = np.quantile(event_times, [0.1, 0.9])
        times_for_auc = np.linspace(min_time, max_time, 100) if max_time > min_time else [min_time]
    else:
        times_for_auc = np.quantile(train_df_outcomes['time'], [0.25, 0.5, 0.75])

    try:
        auc, mean_auc = cumulative_dynamic_auc(train_outcomes_struct, val_outcomes_struct, all_risk_scores_np.squeeze(), times_for_auc)
        td_auc_df = pd.DataFrame({'time': times_for_auc, 'auc': auc})
    except Exception as e:
        print(f"  - 警告: 计算 TD-AUC 失败。原因: {e}")
        mean_auc, td_auc_df = np.nan, None
        
    return total_loss / len(dataloader), c_index, mean_auc, td_auc_df

def plot_fold_history(history, fold_num, output_dir="cv_plots"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 7))
    
    ax1.plot(epochs, history['train_loss'], 'bo-', label='Train Loss', markersize=3)
    ax1.plot(epochs, history['val_loss'], 'ro-', label='Validation Loss', markersize=3)
    ax1.set_title(f'Fold {fold_num+1}: Loss Curve'); ax1.set_xlabel('Epoch'); ax1.set_ylabel('Cox Loss'); ax1.legend(); ax1.grid(True)
    
    ax2.plot(epochs, history['train_c_index'], 'bo-', label='Train C-Index', markersize=3)
    ax2.plot(epochs, history['val_c_index'], 'ro-', label='Validation C-Index', markersize=3)
    ax2.set_title(f'Fold {fold_num+1}: C-Index Curve'); ax2.set_xlabel('Epoch'); ax2.set_ylabel('C-Index'); ax2.legend(); ax2.grid(True)
    
    ax3.plot(epochs, history['train_iauc'], 'bo-', label='Train iAUC', markersize=3)
    ax3.plot(epochs, history['val_iauc'], 'ro-', label='Validation iAUC', markersize=3)
    ax3.set_title(f'Fold {fold_num+1}: iAUC Curve'); ax3.set_xlabel('Epoch'); ax3.set_ylabel('iAUC'); ax3.legend(); ax3.grid(True)
    
    plt.suptitle(f'Fold {fold_num+1} Training & Validation History', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    save_path = os.path.join(output_dir, f"fold_{fold_num+1}_training_history.png")
    plt.savefig(save_path, dpi=200)
    plt.show()
    plt.close()
    print(f"图表已保存至: {save_path}")

def plot_average_risk_difference_histogram(all_results_dfs, set_name):
    if not all_results_dfs:
        print("警告: 没有可供汇总的风险差异结果。")
        return

    # 1. 将所有折叠的DataFrame合并成一个
    combined_results_df = pd.concat(all_results_dfs, ignore_index=True)
    
    print(f"\n--- 正在生成所有 {len(all_results_dfs)} 折的汇总风险差异直方图 ---")
    print(f"  - 总共分析了 {len(combined_results_df)} 个患者-进展事件。")

    # 2. 绘图 (逻辑与单个直方图类似，但作用于汇总数据)
    plt.figure(figsize=(10, 6))
    
    sns.histplot(data=combined_results_df, x='risk_difference', kde=False, bins=30, edgecolor="black")
    
    median_risk_diff = combined_results_df['risk_difference'].median()
    plt.axvline(median_risk_diff, color='red', linestyle='--', label=f'Overall Median Difference: {median_risk_diff:.4f}')
    
    plt.annotate(
        'High-Impact (90th percentile)', 
        xy=(combined_results_df['risk_difference'].quantile(0.9), 5),
        xytext=(combined_results_df['risk_difference'].quantile(0.9) + 1, 50),
        arrowprops=dict(facecolor='black', shrink=0.05),
        fontsize=12
    )

    #plt.title(f"Overall Distribution of Risk Differences ({set_name.upper()}, {len(all_results_dfs)}-Fold CV Average)", fontsize=14)
    plt.xlabel("Risk Difference (Factual Risk - Counterfactual Risk)", fontsize=12)
    plt.ylabel("Number of Patients", fontsize=12)
    plt.legend()
    plt.tight_layout()
    
    # 3. 保存图像
    output_dir = f"./{set_name}/"
    save_path = os.path.join(output_dir, f"average_progression_risk_difference_histogram.png")
    plt.savefig(save_path)
    plt.close()
    print(f"\n汇总的风险差异分布图已保存至: {save_path}")

def visualize_attention_for_patient(
    model, 
    patient_df, 
    numerical_cols, 
    categorical_cols, 
    sample_id, 
    fold_num, 
    death_status, 
    device,
):
    print(f"  - 正在为样本 {sample_id} 生成注意力图...")
    model.to(device).eval()
    categorical_cols_encoded = [c + '_encoded' for c in categorical_cols]
    patient_dataset = PatientSequenceDataset(patient_df, numerical_cols, categorical_cols_encoded)
    patient_loader = DataLoader(patient_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_pad)
    batch_numerical, batch_categorical, _ = next(iter(patient_loader))
    batch_numerical = batch_numerical.to(device)
    batch_categorical = {k: v.to(device) for k, v in batch_categorical.items()}
    with torch.no_grad():
        _, attention_map = model(batch_numerical, batch_categorical, return_attention=True)
    attention_map = attention_map.squeeze(0).cpu().mean(dim=0).numpy()
    event_labels = [f"T{i}: {st}" for i, st in enumerate(patient_df['EVENT_SUBTYPE'])]
    num_events = len(event_labels)
    fig, ax = plt.subplots(figsize=(max(8, num_events/1.5), max(6, num_events/2)))
    cax = ax.imshow(attention_map, cmap='viridis')
    fig.colorbar(cax, label='Attention Weight')
    ax.set_xticks(np.arange(num_events)); ax.set_yticks(np.arange(num_events))
    ax.set_xticklabels(event_labels, rotation=90, ha="right", fontsize=9)
    ax.set_yticklabels(event_labels, fontsize=9)
    ax.set_xlabel("Key (Attended To)", fontsize=12); ax.set_ylabel("Query (Attending From)", fontsize=12)
    plt.title(f"Attention Map for Sample {sample_id} (Status: {death_status}, Fold {fold_num})", fontsize=14, pad=20)
    plot_dir = f"./{SET}/shap_plots_fold_{fold_num}"
    os.makedirs(plot_dir, exist_ok=True)
    save_path = os.path.join(plot_dir, f"attention_map_{sample_id}.png")
    plt.savefig(save_path, bbox_inches='tight', dpi=200)
    plt.close(fig)
    print(f"    - 注意力图已保存至: {save_path}")

def plot_shap_over_time(shap_values_tensor, input_tensor, sample_df, sample_id,
                        feature_names, sample_ids_list, true_seq_lengths,
                        fold_num, death_status, top_n_features=5):
    print(f"  - 正在为样本 {sample_id} 生成SHAP时序图...")
    try:
        sample_idx = sample_ids_list.index(sample_id)
    except (ValueError, IndexError):
        print(f"警告: 在SHAP分析的样本列表中未找到ID {sample_id}。"); return
    patient_shap = shap_values_tensor[sample_idx, :, :]
    true_length = true_seq_lengths.get(sample_id, patient_shap.shape[0])
    patient_shap = patient_shap[:true_length, :]
    event_labels = [f"T{i}: {st}" for i, st in enumerate(sample_df[sample_df['SAMPLE_ID'] == sample_id]['EVENT_SUBTYPE'])]
    total_importance = np.sum(np.abs(patient_shap), axis=0)
    top_feature_indices = np.argsort(total_importance)[-top_n_features:]
    plt.figure(figsize=(max(12, len(event_labels) * 0.8), 7))
    colors = cm.get_cmap('tab10', len(top_feature_indices))
    for i, f_idx in enumerate(top_feature_indices):
        plt.plot(patient_shap[:, f_idx], marker='o', linestyle='-', label=feature_names[f_idx], color=colors(i))
    plt.axhline(0, color='grey', linestyle='--', linewidth=0.8)
    plt.xticks(ticks=np.arange(len(event_labels)), labels=event_labels, rotation=45, ha="right")
    plt.xlabel("Event Sequence"); plt.ylabel("SHAP Value (Impact on Risk)")
    plt.title(f"Temporal SHAP Values for Sample {sample_id} (Status: {death_status}, Fold {fold_num})")
    plt.legend(title="Top Features"); plt.grid(axis='x', linestyle=':', alpha=0.6); plt.tight_layout()
    plot_dir = f"./{SET}/shap_plots_fold_{fold_num}"
    save_path = os.path.join(plot_dir, f"temporal_shap_{sample_id}.png")
    plt.savefig(save_path); plt.close()
    print(f"    - SHAP时序图已保存至: {save_path}")

def plot_waterfall(shap_values, feature_names, sample_id, title_suffix, save_dir, base_value, top_n=7):
    explanation = shap.Explanation(
        values=shap_values,
        base_values=base_value,
        feature_names=feature_names
    )
    
    plt.figure()
    
    shap.plots.waterfall(explanation, max_display=top_n, show=False)
    
    ax = plt.gca()
    ax.set_xlabel(f"Base Value = {base_value:.3f}", fontsize=12)

    #plt.title(f"Aggregated SHAP Waterfall for Sample {sample_id}\n({title_suffix})", fontsize=14)
    plt.tight_layout()
    
    safe_suffix = title_suffix.replace(' ', '_').replace('(', '').replace(')', '')
    fname = os.path.join(save_dir, f"waterfall_{sample_id}_{safe_suffix}.png")
    plt.savefig(fname, bbox_inches='tight')
    plt.close()
    print(f"    - 瀑布图已保存: {fname}")

def run_shap_analysis(model, train_df, val_df, numerical_cols, categorical_cols,
                      fold_num, device='cuda', background_size=50, test_size=30):

    print(f"\n--- [Fold {fold_num}] 开始SHAP可解释性分析 (device={device}) ---")

    model.to(device).eval()

    print("  - 正在对验证集进行风险预测以筛选典型样本...")
    categorical_cols_encoded = [c + '_encoded' for c in categorical_cols]
    val_dataset_full = PatientSequenceDataset(val_df, numerical_cols, categorical_cols_encoded)
    if not val_dataset_full:
        print("警告: 验证集为空，跳过SHAP分析。")
        return

    val_loader_full = DataLoader(val_dataset_full, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_pad)

    all_val_risks = []
    with torch.no_grad():
        for num, cat, _ in val_loader_full:
            num = num.to(device)
            cat_device = {k: v.to(device) for k, v in cat.items()}
            risks = model(num, cat_device)
            all_val_risks.extend(risks.squeeze().detach().cpu().numpy().flatten())

    risk_df = pd.DataFrame({'SAMPLE_ID': val_dataset_full.sample_ids, 'predicted_risk': all_val_risks})
    val_outcomes_df = val_df[['SAMPLE_ID', 'dead']].drop_duplicates()
    risk_df = pd.merge(risk_df, val_outcomes_df, on='SAMPLE_ID')
    average_risk = risk_df['predicted_risk'].mean()
    print(f"  - 验证集平均预测风险: {average_risk:.4f}")

    background_samples = list(train_df['SAMPLE_ID'].unique()[:background_size])
    background_df = train_df[train_df['SAMPLE_ID'].isin(background_samples)].copy()
    test_samples_for_shap = list(risk_df['SAMPLE_ID'].unique()[:test_size])
    test_df_shap = val_df[val_df['SAMPLE_ID'].isin(test_samples_for_shap)].copy()

    if not background_samples or not test_samples_for_shap:
        print("警告: 背景或测试样本为空，跳过SHAP分析。")
        return

    combined_df = pd.concat([background_df, test_df_shap])
    combined_dataset = PatientSequenceDataset(combined_df, numerical_cols, categorical_cols_encoded)
    if not combined_dataset:
        print("警告: 组合数据集为空，跳过SHAP。")
        return

    combined_loader = DataLoader(combined_dataset, batch_size=len(combined_dataset), shuffle=False, collate_fn=collate_fn_pad)
    all_numericals_padded, all_categoricals_padded, _ = next(iter(combined_loader))

    all_numericals_padded = all_numericals_padded.to(device)
    all_categoricals_padded = {k: v.to(device) for k, v in all_categoricals_padded.items()}
    combined_sample_order = combined_dataset.sample_ids

    background_idx = [combined_sample_order.index(sid) for sid in background_samples if sid in combined_sample_order]
    test_idx = [combined_sample_order.index(sid) for sid in test_samples_for_shap if sid in combined_sample_order]
    if not background_idx or not test_idx:
        print("警告: 批次中无背景或测试样本，跳过SHAP。")
        return

    background_numericals_t = all_numericals_padded[background_idx]
    test_numericals_t = all_numericals_padded[test_idx]
    background_categoricals_t = {k: v[background_idx] for k, v in all_categoricals_padded.items()}
    test_categoricals_t = {k: v[test_idx] for k, v in all_categoricals_padded.items()}

    categorical_keys_sorted = model.categorical_keys
    with torch.no_grad():
        background_embeddings = [model.embedding_layers[key](background_categoricals_t[key]) for key in categorical_keys_sorted]
        test_embeddings = [model.embedding_layers[key](test_categoricals_t[key]) for key in categorical_keys_sorted]

    class SHAPWrapper(nn.Module):
        def __init__(self, m): 
            super().__init__(); self.m = m
        def forward(self, x_num, *pre_emb):
            return self.m(x_num, pre_embedded_categorical=list(pre_emb))   

    wrapper = SHAPWrapper(model)
    background_inputs = [background_numericals_t.float()] + [be.float() for be in background_embeddings]
    test_inputs = [test_numericals_t.float()] + [te.float() for te in test_embeddings]

    explainer = shap.DeepExplainer(wrapper, background_inputs)

    shap_values_list = explainer.shap_values(test_inputs, check_additivity=False)

    def to_numpy(x):
        if isinstance(x, list) or isinstance(x, tuple):
            return [to_numpy(xx) for xx in x]
        if isinstance(x, np.ndarray):
            return x
        try:
            if hasattr(x, 'detach'):
                return x.detach().cpu().numpy()
        except Exception:
            pass
        return np.array(x)

    shap_values_list = to_numpy(shap_values_list)

    shap_numerical = shap_values_list[0]
    if shap_numerical is None:
        print("警告: 数值特征的 SHAP 值为空，跳过后续处理。")
        return

    if shap_numerical.ndim == 4 and shap_numerical.shape[-1] == 1:
        shap_numerical = np.squeeze(shap_numerical, axis=-1)  # 形状 -> (batch, seq, num_numerical_features)

    shap_categorical_summed = []
    for s in shap_values_list[1:]:
        if s is None:
            continue
        if s.ndim == 4 and s.shape[-1] == 1:
            s = np.squeeze(s, axis=-1)  # (batch, seq, embedding_dim)
        summed_s = np.sum(s, axis=-1, keepdims=True)
        shap_categorical_summed.append(summed_s)

    all_shap_tensors = [shap_numerical] + shap_categorical_summed
    try:
        shap_values_temporal = np.concatenate(all_shap_tensors, axis=2)
    except Exception as e:
        print("错误: 在连接 SHAP 张量时发生异常：", e)
        print([arr.shape for arr in all_shap_tensors])
        return

    print(f"  - SHAP值张量已成功构建，形状为: {shap_values_temporal.shape}")

    plot_dir = f"./{SET}/shap_plots_fold_{fold_num}"; os.makedirs(plot_dir, exist_ok=True)
    feature_names = numerical_cols + categorical_keys_sorted
    
    shap_values_2d = shap_values_temporal.reshape(-1, shap_values_temporal.shape[-1])
    input_df_for_shap = pd.DataFrame(shap_values_2d, columns=feature_names) 

    plt.figure()
    shap.summary_plot(shap_values_2d, features=input_df_for_shap, feature_names=feature_names, show=False, max_display=15)
    #plt.title(f"Global Feature Importance (Fold {fold_num}) - Top 15")
    summary_path = os.path.join(plot_dir, "summary_plot.png"); plt.savefig(summary_path, bbox_inches='tight'); plt.close()
    print(f"  - 全局SHAP摘要图已保存: {summary_path}")

    sample_ids_in_shap_calc = [combined_sample_order[i] for i in test_idx]
    risk_df_subset = risk_df[risk_df['SAMPLE_ID'].isin(sample_ids_in_shap_calc)]

    high_risk_dead = risk_df_subset[(risk_df_subset['dead'] == 1) & (risk_df_subset['predicted_risk'] > average_risk)]['SAMPLE_ID'].tolist()
    low_risk_alive = risk_df_subset[(risk_df_subset['dead'] == 0) & (risk_df_subset['predicted_risk'] < average_risk)]['SAMPLE_ID'].tolist()

    print(f"\n  - 找到 {len(high_risk_dead)} 个典型死亡样本 (死亡且风险 > 平均值)。")
    print(f"  - 找到 {len(low_risk_alive)} 个典型存活样本 (存活且风险 < 平均值)。")

    per_sample_feature_sum = np.sum(shap_values_temporal, axis=1)

    ev = explainer.expected_value
    try:
        if hasattr(ev, 'item'):
            base_value = ev.item()
        elif isinstance(ev, (list, tuple, np.ndarray)):
            base_value = float(ev[0])
        else:
            base_value = float(ev)
    except Exception:
        base_value = float(np.zeros(1))

    if high_risk_dead:
        sid = high_risk_dead[0]
        risk_value = risk_df.loc[risk_df['SAMPLE_ID'] == sid, 'predicted_risk'].item()
        print(f"\n  分析典型死亡样本: {sid} (预测风险: {risk_value:.4f})")
        sample_pos = sample_ids_in_shap_calc.index(sid)
        plot_waterfall(per_sample_feature_sum[sample_pos], feature_names, sid, "Dead (High Risk)", plot_dir, base_value)

    if low_risk_alive:
        sid = low_risk_alive[0]
        risk_value = risk_df.loc[risk_df['SAMPLE_ID'] == sid, 'predicted_risk'].item()
        print(f"\n  分析典型存活样本: {sid} (预测风险: {risk_value:.4f})")
        sample_pos = sample_ids_in_shap_calc.index(sid)
        plot_waterfall(per_sample_feature_sum[sample_pos], feature_names, sid, "Alive (Low Risk)", plot_dir, base_value)

    print(f"\n--- [Fold {fold_num}] SHAP可解释性分析完成 ---\n")

def explain_single_df(model, df_to_explain, background_df, scaler, vocabs, numerical_cols, categorical_cols, device):
    """
    【修正版 v2】为单个任意DataFrame计算SHAP值，用于反事实解释。
    修复了因特征列不匹配导致的scaler.transform错误。
    """
    if df_to_explain.empty:
        return None, None
    
    expected_numerical_features = list(scaler.feature_names_in_)

    df_to_explain_norm = df_to_explain.copy()
    for col in expected_numerical_features:
        if col not in df_to_explain_norm.columns:
            df_to_explain_norm[col] = 0.0
    df_to_explain_norm.loc[:, expected_numerical_features] = scaler.transform(df_to_explain_norm[expected_numerical_features])

    background_df_norm = background_df.copy()
    for col in expected_numerical_features:
        if col not in background_df_norm.columns:
            background_df_norm[col] = 0.0
    background_df_norm.loc[:, expected_numerical_features] = scaler.transform(background_df_norm[expected_numerical_features])
    
    categorical_cols_encoded = []
    for col in categorical_cols:
        if col in vocabs:
            encoded_col_name = col + '_encoded'
            categorical_cols_encoded.append(encoded_col_name)
            df_to_explain_norm[encoded_col_name] = df_to_explain_norm[col].astype(str).map(vocabs[col]['vocab']).fillna(vocabs[col]['vocab']['<UNK>'])
            background_df_norm[encoded_col_name] = background_df_norm[col].astype(str).map(vocabs[col]['vocab']).fillna(vocabs[col]['vocab']['<UNK>'])

    combined_df = pd.concat([background_df_norm, df_to_explain_norm])
    dataset = PatientSequenceDataset(combined_df, numerical_cols, categorical_cols_encoded)
    if not dataset: return None, None
    
    loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, collate_fn=collate_fn_pad)
    try:
        numericals, categoricals, _ = next(iter(loader))
    except StopIteration:
        return None, None

    model_cpu = model.to('cpu').eval()
    bg_tensors = [numericals[:-1].cpu().float()] + [model_cpu.embedding_layers[k](categoricals[k][:-1].cpu()) for k in model_cpu.categorical_keys]
    test_tensors = [numericals[-1:].cpu().float()] + [model_cpu.embedding_layers[k](categoricals[k][-1:].cpu()) for k in model_cpu.categorical_keys]
    
    class SHAPWrapper(nn.Module):
        def __init__(self, m): super().__init__(); self.m = m
        def forward(self, x_num, *pre_emb): return self.m(x_num, pre_embedded_categorical=list(pre_emb))

    explainer = shap.DeepExplainer(SHAPWrapper(model_cpu), bg_tensors)
    shap_values_list = explainer.shap_values(test_tensors, check_additivity=False)

    shap_numerical = shap_values_list[0]
    if shap_numerical.ndim == 4 and shap_numerical.shape[-1] == 1:
        shap_numerical = np.squeeze(shap_numerical, axis=-1)
    
    shap_categorical_summed = []
    for s in shap_values_list[1:]:
        if s.ndim == 4 and s.shape[-1] == 1: s = np.squeeze(s, axis=-1)
        summed_s = np.sum(s, axis=-1, keepdims=True)
        shap_categorical_summed.append(summed_s)
        
    shap_temporal = np.concatenate([shap_numerical] + shap_categorical_summed, axis=2)
    
    aggregated_shap_values = np.sum(shap_temporal, axis=1).squeeze(0)
    base_value = explainer.expected_value[0].item() if hasattr(explainer.expected_value, 'item') else explainer.expected_value[0]
    
    return aggregated_shap_values, base_value

def _get_prediction_for_processed_df(df_processed, model, numerical_cols, categorical_cols_encoded, device):
    if df_processed.empty:
        return np.nan
        
    temp_dataset = PatientSequenceDataset(df_processed, numerical_cols, categorical_cols_encoded)
    if len(temp_dataset) == 0: return np.nan
    
    temp_loader = DataLoader(temp_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_pad)
    
    try:
        batch_numerical, batch_categorical, _ = next(iter(temp_loader))
    except StopIteration:
        return np.nan

    model.to(device)
    model.eval()

    batch_numerical = batch_numerical.to(device)
    batch_categorical = {k: v.to(device) for k, v in batch_categorical.items()}

    with torch.no_grad():
        risk_score = model(batch_numerical, batch_categorical)
        
    return risk_score.squeeze().cpu().item()

def _create_predictor_from_raw(model, numerical_cols, categorical_cols, scaler, vocabs, device):
    def predictor(raw_df):
        if raw_df.empty:
            return np.nan
        
        df_processed = raw_df.copy()
        
        expected_numerical_features = list(scaler.feature_names_in_)

        for col in expected_numerical_features:
            if col not in df_processed.columns:
                df_processed[col] = 0.0

        df_to_scale = df_processed[expected_numerical_features]

        scaled_data = scaler.transform(df_to_scale)
        df_processed.loc[:, expected_numerical_features] = scaled_data

        categorical_cols_encoded = []
        for col in categorical_cols:
            if col in vocabs:
                encoded_col_name = col + '_encoded'
                categorical_cols_encoded.append(encoded_col_name)
                df_processed[encoded_col_name] = df_processed[col].astype(str).map(vocabs[col]['vocab'])
                df_processed[encoded_col_name].fillna(vocabs[col]['vocab']['<UNK>'], inplace=True)

        return _get_prediction_for_processed_df(
            df_processed, model, numerical_cols, categorical_cols_encoded, device
        )
    return predictor

from scipy.stats import ttest_ind

def analyze_progression_impact_heterogeneity(results_df, val_df, static_features_list, fold_num):
    print(f"\n--- [Fold {fold_num}] 开始生成进展影响异质性总结表 ---")
    
    if results_df is None or len(results_df) < 10:
        print("  - 结果太少，无法进行有意义的异质性分析。跳过。")
        return

    quantiles = results_df['risk_difference'].quantile([0.25, 0.75])
    low_impact_threshold, high_impact_threshold = quantiles[0.25], quantiles[0.75]
    results_df['impact_group'] = pd.cut(
        results_df['risk_difference'],
        bins=[-np.inf, low_impact_threshold, high_impact_threshold, np.inf],
        labels=['Low-Impact', 'Mid-Impact', 'High-Impact']
    )
    analysis_df = results_df[results_df['impact_group'].isin(['Low-Impact', 'High-Impact'])].copy()
    if analysis_df.empty or analysis_df['impact_group'].nunique() < 2:
        print("  - 未能成功划分出高/低影响组。跳过。"); return

    static_features_df = val_df.sort_values('START_DATE').drop_duplicates(subset='PATIENT_ID', keep='first')
    analysis_df = pd.merge(analysis_df, static_features_df, on='PATIENT_ID', how='left')

    table_data = []
    group_low = analysis_df[analysis_df['impact_group'] == 'Low-Impact']
    group_high = analysis_df[analysis_df['impact_group'] == 'High-Impact']
    n_low, n_high = len(group_low), len(group_high)
    print(f"  - 低影响组 N={n_low}, 高影响组 N={n_high}")

    features_to_compare = [f for f in static_features_list if f in analysis_df.columns]
    
    for feature in features_to_compare:
        if analysis_df[feature].nunique() < 2: continue

        stat_row = {'Baseline Characteristic': feature}
        
        is_continuous = pd.api.types.is_float_dtype(analysis_df[feature]) or \
                        (pd.api.types.is_integer_dtype(analysis_df[feature]) and analysis_df[feature].nunique() > 2)

        if is_continuous:
            median_low, (q1_low, q3_low) = group_low[feature].median(), group_low[feature].quantile([0.25, 0.75])
            stat_row[f'Low-Impact Group (N={n_low})'] = f"{median_low:.0f} [{q1_low:.0f}-{q3_low:.0f}]"
            median_high, (q1_high, q3_high) = group_high[feature].median(), group_high[feature].quantile([0.25, 0.75])
            stat_row[f'High-Impact Group (N={n_high})'] = f"{median_high:.0f} [{q1_high:.0f}-{q3_high:.0f}]"
            _, p_value = ttest_ind(group_low[feature].dropna(), group_high[feature].dropna(), equal_var=False)
        else:
            count_low, perc_low = group_low[feature].sum(), (group_low[feature].sum() / n_low * 100) if n_low > 0 else 0
            stat_row[f'Low-Impact Group (N={n_low})'] = f"{count_low} ({perc_low:.1f}%)"
            count_high, perc_high = group_high[feature].sum(), (group_high[feature].sum() / n_high * 100) if n_high > 0 else 0
            stat_row[f'High-Impact Group (N={n_high})'] = f"{count_high} ({perc_high:.1f}%)"
            contingency_table = pd.crosstab(analysis_df['impact_group'], analysis_df[feature])
            _, p_value, _, _ = chi2_contingency(contingency_table)

        stat_row['p-value'] = f"<0.001" if p_value < 0.001 else f"{p_value:.3f}"
        table_data.append(stat_row)

    summary_table = pd.DataFrame(table_data)
    output_dir = f"./{SET}/impact_heterogeneity_tables_fold_{fold_num}"
    os.makedirs(output_dir, exist_ok=True)
    save_path = os.path.join(output_dir, "table2_impact_group_comparison.csv")
    summary_table.to_csv(save_path, index=False)
    
    print("\n--- 异质性分析总结表 (Table 2) ---")
    print(summary_table.to_string(index=False))
    print(f"\n  表格已保存至: {save_path}")

def summarize_heterogeneity_across_folds(analysis_data_list, static_features_list, set_name):
    print(f"\n--- 正在汇总所有 {len(analysis_data_list)} 折的进展影响异质性分析 ---")
    
    if not analysis_data_list:
        print("  - 没有可供汇总的异质性分析数据。")
        return

    all_results = []
    for data in analysis_data_list:
        results_df = data['results_df']
        val_df = data['val_df']
        
        static_features_df = val_df.sort_values('START_DATE').drop_duplicates(subset='PATIENT_ID', keep='first')
        merged_df = pd.merge(results_df, static_features_df, on='PATIENT_ID', how='left')
        all_results.append(merged_df)

    combined_analysis_df = pd.concat(all_results, ignore_index=True)

    quantiles = combined_analysis_df['risk_difference'].quantile([0.25, 0.75])
    low_impact_threshold = quantiles[0.25]
    high_impact_threshold = quantiles[0.75]

    combined_analysis_df['impact_group'] = pd.cut(
        combined_analysis_df['risk_difference'],
        bins=[-np.inf, low_impact_threshold, high_impact_threshold, np.inf],
        labels=['Low-Impact', 'Mid-Impact', 'High-Impact']
    )
    
    analysis_df = combined_analysis_df[combined_analysis_df['impact_group'].isin(['Low-Impact', 'High-Impact'])].copy()
    
    if analysis_df.empty or analysis_df['impact_group'].nunique() < 2:
        print("  - 未能从汇总数据中成功划分出高/低影响组。跳过。")
        return
    
    table_data = []
    group_low = analysis_df[analysis_df['impact_group'] == 'Low-Impact']
    group_high = analysis_df[analysis_df['impact_group'] == 'High-Impact']
    n_low, n_high = len(group_low), len(group_high)

    print(f"  - 汇总分析: 低影响组 N={n_low}, 高影响组 N={n_high}")

    features_to_compare = [f for f in static_features_list if f in analysis_df.columns]
    
    for feature in features_to_compare:
        if analysis_df[feature].nunique() < 2: continue

        stat_row = {'Baseline Characteristic': feature}
        
        is_continuous = pd.api.types.is_float_dtype(analysis_df[feature]) or \
                        (pd.api.types.is_integer_dtype(analysis_df[feature]) and analysis_df[feature].nunique() > 2)

        if is_continuous:
            median_low, (q1_low, q3_low) = group_low[feature].median(), group_low[feature].quantile([0.25, 0.75])
            stat_row[f'Low-Impact Group (N={n_low})'] = f"{median_low:.0f} [{q1_low:.0f}-{q3_low:.0f}]"
            median_high, (q1_high, q3_high) = group_high[feature].median(), group_high[feature].quantile([0.25, 0.75])
            stat_row[f'High-Impact Group (N={n_high})'] = f"{median_high:.0f} [{q1_high:.0f}-{q3_high:.0f}]"
            _, p_value = ttest_ind(group_low[feature].dropna(), group_high[feature].dropna(), equal_var=False)
        else:
            count_low, perc_low = group_low[feature].sum(), (group_low[feature].sum() / n_low * 100) if n_low > 0 else 0
            stat_row[f'Low-Impact Group (N={n_low})'] = f"{int(count_low)} ({perc_low:.1f}%)"
            count_high, perc_high = group_high[feature].sum(), (group_high[feature].sum() / n_high * 100) if n_high > 0 else 0
            stat_row[f'High-Impact Group (N={n_high})'] = f"{int(count_high)} ({perc_high:.1f}%)"
            contingency_table = pd.crosstab(analysis_df['impact_group'], analysis_df[feature])
            _, p_value, _, _ = chi2_contingency(contingency_table)

        stat_row['p-value'] = f"<0.001" if p_value < 0.001 else f"{p_value:.3f}"
        table_data.append(stat_row)

    summary_table = pd.DataFrame(table_data)
    
    output_dir = f"./{set_name}/"
    save_path = os.path.join(output_dir, f"summary_table_impact_group_comparison_{set_name}.csv")
    
    summary_table.to_csv(save_path, index=False)
    
    print("\n--- 汇总的异质性分析总结表 (Table 2) ---")
    print(summary_table.to_string(index=False))
    print(f"\n  汇总表格已保存至: {save_path}")

def explain_counterfactual_pair(model, factual_df, counterfactual_df, background_df,
                                scaler, vocabs, numerical_cols, categorical_cols, device):
    if factual_df.empty or counterfactual_df.empty:
        return None, None, None

    expected_numerical_features = list(scaler.feature_names_in_)

    dataframes_to_process = {
        "background": background_df,
        "factual": factual_df,
        "counterfactual": counterfactual_df
    }
    processed_dfs = {}

    for name, df in dataframes_to_process.items():
        df_norm = df.copy()
        for col in expected_numerical_features:
            if col not in df_norm.columns:
                df_norm[col] = 0.0
        df_norm.loc[:, expected_numerical_features] = scaler.transform(df_norm[expected_numerical_features])
        processed_dfs[name] = df_norm
    
    background_df_norm = processed_dfs["background"]
    factual_df_norm = processed_dfs["factual"]
    cf_df_norm = processed_dfs["counterfactual"]

    categorical_cols_encoded = []
    for col in categorical_cols:
        if col in vocabs:
            encoded_col_name = col + '_encoded'
            categorical_cols_encoded.append(encoded_col_name)
            for df in [background_df_norm, factual_df_norm, cf_df_norm]:
                df[encoded_col_name] = df[col].astype(str).map(vocabs[col]['vocab']).fillna(vocabs[col]['vocab']['<UNK>'])

    combined_df = pd.concat([background_df_norm, factual_df_norm, cf_df_norm])
    dataset = PatientSequenceDataset(combined_df, numerical_cols, categorical_cols_encoded)
    if not dataset: return None, None, None
    
    loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, collate_fn=collate_fn_pad)
    try:
        numericals, categoricals, _ = next(iter(loader))
    except StopIteration:
        return None, None, None

    model_cpu = model.to('cpu').eval()
    bg_tensors = [numericals[:-2].cpu().float()] + [model_cpu.embedding_layers[k](categoricals[k][:-2].cpu()) for k in model_cpu.categorical_keys]
    test_tensors = [numericals[-2:].cpu().float()] + [model_cpu.embedding_layers[k](categoricals[k][-2:].cpu()) for k in model_cpu.categorical_keys]
    
    class SHAPWrapper(nn.Module):
        def __init__(self, m): super().__init__(); self.m = m
        def forward(self, x_num, *pre_emb): return self.m(x_num, pre_embedded_categorical=list(pre_emb))

    explainer = shap.DeepExplainer(SHAPWrapper(model_cpu), bg_tensors)
    shap_values_list = explainer.shap_values(test_tensors, check_additivity=False)

    shap_numerical = shap_values_list[0]
    if shap_numerical.ndim == 4 and shap_numerical.shape[-1] == 1:
        shap_numerical = np.squeeze(shap_numerical, axis=-1)
    
    shap_categorical_summed = []
    for s in shap_values_list[1:]:
        if s.ndim == 4 and s.shape[-1] == 1: s = np.squeeze(s, axis=-1)
        summed_s = np.sum(s, axis=-1, keepdims=True)
        shap_categorical_summed.append(summed_s)
        
    shap_temporal = np.concatenate([shap_numerical] + shap_categorical_summed, axis=2)
    aggregated_shap_values = np.sum(shap_temporal, axis=1)
    
    factual_shap = aggregated_shap_values[0]
    counterfactual_shap = aggregated_shap_values[1]
    base_value = explainer.expected_value[0].item() if hasattr(explainer.expected_value, 'item') else explainer.expected_value[0]
    
    return factual_shap, counterfactual_shap, base_value

def run_progression_counterfactual(model, train_df_norm, val_df, final_first_progression_df, progression_event_def, 
                                 counterfactual_change, scaler, vocabs, numerical_cols, 
                                 categorical_cols, fold_num, device='cpu', n=100):
    print(f"\n--- [Fold {fold_num}] 开始疾病进展反事实模拟与解释 ---")
    model.to(device).eval()
    
    predictor = _create_predictor_from_raw(model, numerical_cols, categorical_cols, scaler, vocabs, device)
    
    pids_in_val = val_df['PATIENT_ID'].unique()
    target_progression_df = final_first_progression_df[final_first_progression_df['PATIENT_ID'].isin(pids_in_val)]
    if target_progression_df.empty:
        print("  - 当前验证集中没有找到指定的进展患者。跳过。"); return None
    
    results, visualized_one = [], False
    for _, row in target_progression_df.iterrows():
        patient_id = row['PATIENT_ID']
        patient_df = val_df[val_df['PATIENT_ID'] == patient_id].copy()
        
        factual_df = patient_df[patient_df['START_DATE'] <= row['START_DATE']]
        history_before = patient_df[patient_df['START_DATE'] < row['START_DATE']]
        prog_event_mask = (patient_df['START_DATE'] == row['START_DATE']) & (patient_df['EVENT_SUBTYPE'] == progression_event_def['EVENT_SUBTYPE'])
        prog_event = patient_df[prog_event_mask]
        if prog_event.empty: continue
        
        cf_event = prog_event.iloc[0:1].copy()
        for col, val in counterfactual_change.items(): cf_event[col] = val
        counterfactual_df = pd.concat([history_before, cf_event], ignore_index=True)
        
        factual_risk = predictor(factual_df); counterfactual_risk = predictor(counterfactual_df)
        if not np.isnan(factual_risk) and not np.isnan(counterfactual_risk):
            results.append({'PATIENT_ID': patient_id, 'factual_risk': factual_risk, 'counterfactual_risk': counterfactual_risk, 'risk_difference': factual_risk - counterfactual_risk})

        if not visualized_one:
            print(f"  - 为患者 {patient_id} 生成反事实SHAP瀑布图...")
            feature_names = numerical_cols + sorted([c for c in categorical_cols if c in vocabs])
            plot_dir = f"./{SET}/counterfactual_plots_fold_{fold_num}"
            os.makedirs(plot_dir, exist_ok=True)
            background_for_cf = train_df_norm.sample(n=n, random_state=SEED)

            factual_shap, cf_shap, base_val = explain_counterfactual_pair(
                model, factual_df, counterfactual_df, background_for_cf,
                scaler, vocabs, numerical_cols, categorical_cols, device
            )

            if factual_shap is not None and cf_shap is not None:
                title_suffix_factual = f"Factual (Risk_ {factual_risk:.3f})"
                plot_waterfall(factual_shap, feature_names, patient_id, title_suffix_factual, plot_dir, base_val)
                
                title_suffix_cf = f"Counterfactual (Risk_ {counterfactual_risk:.3f})"
                plot_waterfall(cf_shap, feature_names, patient_id, title_suffix_cf, plot_dir, base_val)
            
            visualized_one = True
    
    if not results:
        print("  - 未能成功为任何进展患者生成模拟结果。"); return None
        
    results_df = pd.DataFrame(results)
    avg_risk_diff = results_df['risk_difference'].mean()
    print(f"  - 平均风险差异 (真实进展 vs 虚拟稳定): {avg_risk_diff:.4f}")

    # (绘制直方图的代码保持不变)
    plt.figure(figsize=(10, 6)); sns.histplot(data=results_df, x='risk_difference', kde=False, bins=25, edgecolor="black")
    median_risk_diff = results_df['risk_difference'].median()
    plt.axvline(median_risk_diff, color='red', linestyle='--', label=f'Median Difference: {median_risk_diff:.4f}')
    plt.title(f"Distribution of Risk Differences (Fold {fold_num})\nProgression vs. No Progression")
    plt.xlabel("Risk Difference (Factual Risk - Counterfactual Risk)"); plt.ylabel("Number of Patients"); plt.legend()
    save_path = os.path.join(plot_dir, "progression_risk_difference_histogram.png"); plt.savefig(save_path); plt.close()
    print(f"\n风险差异分布图已保存至: {save_path}")
    
    return results_df

# =============================================================================
# Main Execution Block (MODIFIED FOR TRAIN/ANALYZE MODES)
# =============================================================================
SEED = 42
set_seed(SEED)
DEVICE = "cuda"

MODE = 'analyze' # 'train' 或 'analyze'

SET = 'brca'
data_filename = f'df_{SET}_landmarks.csv'
print(f"\n--- 当前模式: {MODE.upper()} | 数据集: {SET.upper()} ---")

set_features = {
    'brca': ['AGE', 'STAGE 1', 'STAGE 3', 'STAGE 4', 'PTEN', 'ERBB2', 'TP53'],
    'crc': ['AGE', 'BLACK', 'STAGE 1', 'STAGE 2', 'STAGE 4', 'KRAS', 'BRAF'],
    'nsclc': ['AGE', 'MALE', 'STAFE 1', 'STAGE 3', 'STAGE 4', 'PTEN', 'EGFR', 'TP53'],
    'panc': ['AGE', 'MALE', 'STAGE 1', 'STAGE 4', 'KRAS', 'TP53'],
    'prostate': ['AGE', 'BLACK', 'STAGE 4', 'PTEN', 'TP53']
}
CORE_FEATURES = ['START_DATE', 'VALUE_NUMERIC', 'EVENT_DURATION', 'EVENT_TYPE', 'EVENT_SUBTYPE', 'VALUE_CATEGORICAL']
ALL_FEATURES = CORE_FEATURES + set_features[SET]
STANDARDIZE_COLS = ['START_DATE', 'EVENT_DURATION', 'VALUE_NUMERIC', 'AGE']
CATEGORICAL_COLS = ['EVENT_TYPE', 'EVENT_SUBTYPE', 'VALUE_CATEGORICAL']

model_params = {'d_model': 64, 'nhead': 8, 'num_encoder_layers': 4, 'dim_feedforward': 1024, 'dropout_prob': 0.45}
optimizer_params = {'lr': 8.1e-06, 'weight_decay': 0.00037}
N_SPLITS = 5
NUM_EPOCHS = 50
BATCH_SIZE = 64
data_filename = f'df_{SET}_landmarks.csv'

if not os.path.exists(data_filename):
    print(f"错误: 未找到数据文件 '{data_filename}'。")
else:
    time_start = time.time()
    df_full = pd.read_csv(data_filename)
    if 'time' not in df_full.columns: df_full['time'] = df_full['stop']
    df_processed = preprocess_dataframe(df_full)
    
    existing_features = [f for f in ALL_FEATURES if f in df_processed.columns]
    final_categorical_cols = [c for c in CATEGORICAL_COLS if c in existing_features]
    final_numerical_cols = sorted(list(set(existing_features) - set(final_categorical_cols)))
    final_standardize_cols = [c for c in STANDARDIZE_COLS if c in final_numerical_cols]

    print(f"\n模型将使用以下特征:\n  - 数值: {final_numerical_cols}\n  - 分类: {final_categorical_cols}")

    patient_outcomes = df_processed[['PATIENT_ID', 'dead']].drop_duplicates()
    patient_ids = patient_outcomes['PATIENT_ID'].values
    patient_dead_status = patient_outcomes['dead'].values
    kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    
    # =========================================================================
    #  TRAINING MODE
    # =========================================================================
    if MODE == 'train':
        print("\n--- 开始训练模式 ---")
        for fold, (train_idx, val_idx) in enumerate(kf.split(patient_ids, patient_dead_status)):
            print(f"\n{'='*80}\n===============  开始训练第 {fold+1}/{N_SPLITS} 折  ===============\n{'='*80}")

            train_patient_ids, val_patient_ids = patient_ids[train_idx], patient_ids[val_idx]
            train_df = df_processed[df_processed['PATIENT_ID'].isin(train_patient_ids)].copy()
            val_df = df_processed[df_processed['PATIENT_ID'].isin(val_patient_ids)].copy()
            
            train_df_encoded, val_df_encoded, vocabs = encode_categorical_features_leakproof(train_df, val_df, final_categorical_cols)
            train_df_norm, val_df_norm, feature_scaler = normalize_numerical_features_leakproof(train_df_encoded, val_df_encoded, final_standardize_cols)

            categorical_cols_encoded = [c + '_encoded' for c in final_categorical_cols]
            train_dataset = PatientSequenceDataset(train_df_norm, final_numerical_cols, categorical_cols_encoded)
            val_dataset = PatientSequenceDataset(val_df_norm, final_numerical_cols, categorical_cols_encoded)
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_pad)
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_pad)
            train_outcomes_df = train_df[['PATIENT_ID', 'dead', 'time']].drop_duplicates()
            
            vocab_sizes = {k: v['vocab_size'] for k, v in vocabs.items()}
            embedding_dims = {k: min(50, (v['vocab_size'] // 2) + 1) for k, v in vocabs.items()}
            model = SurvivalTransformer(vocab_sizes, embedding_dims, len(final_numerical_cols), **model_params).to(DEVICE)
            optimizer = torch.optim.Adam(model.parameters(), **optimizer_params, fused=torch.cuda.is_available())
            grad_scaler = GradScaler()
            
            best_val_iauc = -1

            output_dir = f"./{SET}/fold_{fold+1}/"
            os.makedirs(output_dir, exist_ok=True)
            best_model_path = os.path.join(output_dir, f"best_model_fold_{fold+1}.pth")
            
            for epoch in range(NUM_EPOCHS):
                train_one_epoch(model, train_loader, optimizer, cox_loss, grad_scaler, DEVICE)
                _, _, val_iauc, _ = evaluate_model(model, val_loader, cox_loss, train_outcomes_df, DEVICE)
                if val_iauc > best_val_iauc:
                    best_val_iauc = val_iauc
                    torch.save(model.state_dict(), best_model_path)
                if (epoch + 1) % 20 == 0:
                    print(f"  Epoch {epoch+1:03d}/{NUM_EPOCHS} | Val iAUC: {val_iauc if not np.isnan(val_iauc) else 0:.4f} (Best: {best_val_iauc:.4f})")

            print(f"--- 第 {fold+1} 折训练完成。最佳验证iAUC: {best_val_iauc:.4f} ---")
            print(f"  - 模型已保存至: {best_model_path}")

            with open(os.path.join(output_dir, 'scaler.pkl'), 'wb') as f:
                pickle.dump(feature_scaler, f)
            with open(os.path.join(output_dir, 'vocabs.pkl'), 'wb') as f:
                pickle.dump(vocabs, f)
            print(f"  - Scaler 和 Vocabs 已保存至: {output_dir}")

    # =========================================================================
    #  ANALYSIS MODE
    # =========================================================================
    elif MODE == 'analyze':
        print("\n--- 开始分析模式 ---")
        all_folds_td_auc = []
        final_fold_metrics = []
        all_progression_results_dfs = [] 
        heterogeneity_analysis_data = []
        
        for fold in range(N_SPLITS):
            print(f"\n{'='*80}\n===============  开始分析第 {fold+1}/{N_SPLITS} 折  ===============\n{'='*80}")

            input_dir = f"./{SET}/fold_{fold+1}/"
            model_path = os.path.join(input_dir, f"best_model_fold_{fold+1}.pth")
            scaler_path = os.path.join(input_dir, 'scaler.pkl')
            vocabs_path = os.path.join(input_dir, 'vocabs.pkl')
            
            if not all(os.path.exists(p) for p in [model_path, scaler_path, vocabs_path]):
                print(f"错误: 在 '{input_dir}' 中缺少必要的模型或预处理文件。请先运行 'train' 模式。")
                continue

            with open(scaler_path, 'rb') as f:
                feature_scaler = pickle.load(f)
            with open(vocabs_path, 'rb') as f:
                vocabs = pickle.load(f)
            print(f"  - 已从 '{input_dir}' 加载 Scaler 和 Vocabs。")

            train_idx, val_idx = list(kf.split(patient_ids, patient_dead_status))[fold]
            train_patient_ids, val_patient_ids = patient_ids[train_idx], patient_ids[val_idx]
            
            train_df = df_processed[df_processed['PATIENT_ID'].isin(train_patient_ids)].copy()
            val_df = df_processed[df_processed['PATIENT_ID'].isin(val_patient_ids)].copy()

            train_df_encoded, val_df_encoded, _ = encode_categorical_features_leakproof(train_df, val_df, final_categorical_cols)
            train_df_norm, val_df_norm, _ = normalize_numerical_features_leakproof(train_df_encoded, val_df_encoded, final_standardize_cols)
            
            val_loader = DataLoader(
                PatientSequenceDataset(val_df_norm, final_numerical_cols, [c+'_encoded' for c in final_categorical_cols]),
                batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_pad
            )
            train_outcomes_df = train_df[['PATIENT_ID', 'dead', 'time']].drop_duplicates()

            vocab_sizes = {k: v['vocab_size'] for k, v in vocabs.items()}
            embedding_dims = {k: min(50, (v['vocab_size'] // 2) + 1) for k, v in vocabs.items()}
            
            best_model = SurvivalTransformer(vocab_sizes, embedding_dims, len(final_numerical_cols), **model_params)
            best_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
            best_model.to(DEVICE).eval()
            print(f"  - 已从 '{model_path}' 加载最佳模型。")

            final_loss, final_c, final_iauc, final_td_auc_df = evaluate_model(best_model, val_loader, cox_loss, train_outcomes_df, DEVICE)
            final_fold_metrics.append({'fold': fold + 1, 'val_loss': final_loss, 'val_c_index': final_c, 'val_iauc': final_iauc})
            if final_td_auc_df is not None:
                all_folds_td_auc.append(final_td_auc_df)
                print(f"\n  [Fold {fold+1}] 最终验证集TD-AUC详情:\n{final_td_auc_df.to_string(index=False, float_format='%.4f')}")

            time_1 = time.time()
            run_shap_analysis(model=best_model, train_df=train_df_norm, val_df=val_df_norm, numerical_cols=final_numerical_cols, 
                              categorical_cols=final_categorical_cols, fold_num=fold+1, device='cuda', background_size=200, test_size=200)
            time_2 = time.time()
            print(f"SHAP分析完成: {time_2-time_1} s")
            
            prog_def = {'EVENT_SUBTYPE': 'IMAGING_PROGRESSION', 'VALUE_CATEGORICAL': 'Y'}
            cf_mod = {'VALUE_CATEGORICAL': 'N'}
            prog_events = df_full[(df_full['EVENT_SUBTYPE'] == 'IMAGING_PROGRESSION') & (df_full['VALUE_CATEGORICAL'] == 'Y')]
            first_prog_df = prog_events.loc[prog_events.groupby('PATIENT_ID')['START_DATE'].idxmin()]
            
            progression_results_df = run_progression_counterfactual(
                model=best_model, train_df_norm=train_df_norm, val_df=val_df, 
                final_first_progression_df=first_prog_df, progression_event_def=prog_def,
                counterfactual_change=cf_mod, scaler=feature_scaler, vocabs=vocabs,
                numerical_cols=final_numerical_cols, categorical_cols=final_categorical_cols,
                fold_num=fold+1, device=DEVICE, n=200
            )
            
            if progression_results_df is not None:
                all_progression_results_dfs.append(progression_results_df)
                heterogeneity_analysis_data.append({
                    'results_df': progression_results_df,
                    'val_df': val_df 
                })
                model_static_features = list(set(final_numerical_cols + final_categorical_cols) - set(CORE_FEATURES))
                # analyze_progression_impact_heterogeneity( 
                #     results_df=progression_results_df, val_df=val_df,
                #     static_features_list=model_static_features, fold_num=fold+1
                # )

        plot_average_risk_difference_histogram(all_progression_results_dfs, SET)
        summarize_heterogeneity_across_folds(
            analysis_data_list=heterogeneity_analysis_data,
            static_features_list=list(set(final_numerical_cols + final_categorical_cols) - set(CORE_FEATURES)),
            set_name=SET
        )
        
        print(f"\n{'='*80}\n======================  最终交叉验证总结 ({SET.upper()})  ======================\n{'='*80}")
        if final_fold_metrics:
            results_df = pd.DataFrame(final_fold_metrics)
            print("每折的最终验证集性能:"); print(results_df.to_string(index=False))
            print("\n平均性能指标 (± 标准差):")
            print(f"  - 验证集 Loss:   {results_df['val_loss'].mean():.4f} ± {results_df['val_loss'].std():.4f}")
            print(f"  - 验证集 C-Index: {results_df['val_c_index'].mean():.4f} ± {results_df['val_c_index'].std():.4f}")
            print(f"  - 验证集 iAUC:    {results_df['val_iauc'].mean():.4f} ± {results_df['val_iauc'].std():.4f}")

        if all_folds_td_auc:
            combined_auc_df = pd.concat(all_folds_td_auc, ignore_index=True)

            plt.figure(figsize=(10, 7))

            sns.lineplot(data=combined_auc_df, x='time', y='auc', errorbar=('ci', 95), label='平均AUC (95% CI)')

            total_mean_auc = combined_auc_df['auc'].mean()

            plt.axhline(0.5, color='grey', linestyle='--', label='Random Guess (AUC=0.5)')
            plt.title(f'Average Time-Dependent AUC for {SET.upper()} ({N_SPLITS}-Fold CV)\nOverall Mean AUC = {total_mean_auc:.3f}', fontsize=15)
            plt.xlabel('Time (Days)', fontsize=12)
            plt.ylabel('AUC', fontsize=12)
            plt.ylim(0.4, 1.0) 
            plt.grid(True, linestyle=':', alpha=0.6)
            plt.legend()
            plt.tight_layout()

            save_path_fig = f"./{SET}/average_td_auc_{SET}.png"
            plt.savefig(save_path_fig, dpi=300)
            plt.show()
            print(f"\n平均时间依赖性AUC曲线图已保存至: {save_path_fig}")

            print("\n按时间段划分的平均TD-AUC (± 标准差):")

            time_bins = [0, 365, 730, 1095, 1460, 1825]
            time_labels = ['0-1 Year', '1-2 Years', '2-3 Years', '3-4 Years', '4-5 Years']

            summary_df = combined_auc_df.copy()
            summary_df['time_bin'] = pd.cut(summary_df['time'], bins=time_bins, labels=time_labels, right=False)

            auc_summary = summary_df.groupby('time_bin')['auc'].agg(['mean', 'std']).reset_index()
            auc_summary = auc_summary.dropna(subset=['time_bin'])

            auc_summary['mean'] = auc_summary['mean'].map('{:.4f}'.format)
            auc_summary['std'] = auc_summary['std'].map('{:.4f}'.format)
            
            print(auc_summary.to_string(index=False))
    
    print(f"\n脚本总执行时间: {(time.time() - time_start) / 60:.2f} 分钟")