In [2]:
import pytorch_lightning as pl
import torch.cuda
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from wsi_dataset import WSIDataModule
import yaml
import importlib
from models import MILModel,compute_c_index
import os
import random
import numpy as np
import torchmetrics.functional as tf
from models import compute_c_index,calculate_auc  
import glob
import json
import pandas as pd
from sklearn.model_selection import StratifiedKFold, GroupKFold
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.strategies import DeepSpeedStrategy
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import os
from lightning.pytorch.accelerators import find_usable_cuda_devices     
from wsi_dataset import *
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
from torch.utils.data.distributed import DistributedSampler
from copy import deepcopy
from torch.utils.data import Sampler
from sklearn.metrics import accuracy_score, roc_auc_score
from pprint import pprint

In [3]:
def read_config(fname):
    with open(f"/home/huruizhen/mil/configs/{fname}.yaml", mode="r",encoding='utf-8') as file:
        yml = yaml.load(file, Loader=yaml.Loader)
        return yml


def get_obj_from_str(string, reload=False):     # string: "models.report_only.report_only"  
    module, cls = string.rsplit(".", 1)  
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

class CoxSurvLoss(nn.Module):

    def __init__(self, reduction='mean'):
        super(CoxSurvLoss, self).__init__()

    def forward(self, hazards, time, c):
        '''
        # hazards: risk value (log risk) from the model output
        # time: event occurrence or observation time
        # c: Whether the event occurred (1 means the event occurred, 0 means truncated)
        '''

        # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
        # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
        
        hazards = hazards.squeeze()

        current_batch_len = len(time)

        R_mat = torch.zeros(
            [current_batch_len, current_batch_len], 
            dtype=int, 
            device=hazards.device
            )

        for i in range(current_batch_len):
            for j in range(current_batch_len):
                R_mat[i,j] = time[j] >= time[i]
                
        theta = hazards.reshape(-1)
        exp_theta = torch.exp(theta)
        loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * c)

        return loss_cox
    
def calculate_auc(logits, labels):
    # Convert logits to probabilities using softmax
    probabilities = torch.softmax(logits, dim=1)
    # Convert probabilities and labels to CPU tensors
    probabilities = probabilities.cpu().detach().numpy()
    labels = labels.cpu().numpy()
    
    num_classes = probabilities.shape[1]
    
    if num_classes == 2:
        # Binary classification case
        unique_labels = np.unique(labels)
        # print(f'unique_labels: {unique_labels}')
        if len(unique_labels) == 2:
            # Binary classification with two classes (0 and 1)
            binary_probabilities = probabilities[:, 1]
            auc = roc_auc_score(labels, binary_probabilities)  
        else:
            # Binary classification with more than two classes
            print(f'*************There is only one real tag in the binary classification task, and the AUC cannot be calculated.**************')
            auc = 0
        
    else:
        # Multiclass case
        aucs = []
        unique_labels = np.unique(labels)
        if len(unique_labels) == 1 :
            print(f'*************There is only one real tag in a multi-class task, and the AUC cannot be calculated. tag in the multiclass classification task, and the AUC cannot be calculated.**************')
            auc = 0
            return auc
        else:
            for i in unique_labels: 
                i = int(i)
                binary_labels = (labels == i).astype(int)
                class_probabilities = probabilities[:, i]
                auc = roc_auc_score(binary_labels, class_probabilities)
                aucs.append(auc)
            auc = sum(aucs) / len(unique_labels)  
    
    return auc


def compute_P_value(hazard_scores: torch.Tensor, labels:torch.Tensor, status:torch.Tensor) -> np.ndarray:  
    from lifelines.statistics import logrank_test  

    all_hazard_scores = hazard_scores.cpu().numpy()
    all_hazard_ratios = np.exp(all_hazard_scores)
    all_labels = labels.cpu().numpy()
    all_status = status.cpu().numpy()
    assert len(all_hazard_scores) == len(all_labels) == len(all_status)

    # Divide high and low risk groups by median all_hazard_ratios
    median_risk = np.median(all_hazard_scores)
    high_risk_mask = all_hazard_scores > median_risk
    low_risk_mask = ~high_risk_mask  # Reverse, indicating low risk

    # Obtain HR in high and low risk groups
    high_risk_HR = all_hazard_scores[high_risk_mask]
    low_risk_HR = all_hazard_scores[low_risk_mask]

    # Get indexes for high-risk groups
    high_risk_indices = np.where(high_risk_mask)[0]
    low_risk_indices = np.where(low_risk_mask)[0]

    # Get the corresponding data by index
    high_risk_labels = all_labels[high_risk_indices]
    high_risk_status = all_status[high_risk_indices]

    low_risk_labels = all_labels[low_risk_indices]
    low_risk_status = all_status[low_risk_indices]


    # 4.Calculate p_value
    results = logrank_test(
        durations_A=high_risk_labels,
        durations_B=low_risk_labels,
        event_observed_A=high_risk_status,
        event_observed_B=low_risk_status,
    )
    p_value = results.p_value
    return p_value


# seed everything
def fix_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)   # Setting python hash seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.set_deterministic(True)  # torch < 1.8
    torch.use_deterministic_algorithms(True, warn_only=True)  # torch >= 1.8

In [4]:
# Use bootstrap to calculate AUC and C-index 95% Confidence Interval
def bootstrap_auc(y_true: np.ndarray | torch.Tensor | list, 
                  y_prob: np.ndarray | torch.Tensor | list, 
                  n_bootstrap: int = 3000, 
                  alpha: float = 0.05) -> tuple[float, float, float, float]:
    """
    Compute the 95% CI, auc_mean and auc_std using the Bootstrap method.

    Parameters:
        y_true (array-like): True labels. Can be a numpy array, torch tensor, or list.
        y_prob (array-like): Predicted probabilities. [n_samples, n_classes] Can be a numpy array, torch tensor, or list.
        n_bootstrap (int, optional): Number of bootstrap resampling iterations. Default is 3000.
        alpha (float, optional): Significance level for confidence interval. Default is 0.05 (corresponding to 95% CI).

    Returns:
        tuple: A tuple containing:
            - auc_mean (float): Mean AUC for binary, Mean Macro-AUC for Multiclasses from bootstrap samples.
            - ci_lower (float): Lower bound of the 95% confidence interval.
            - ci_upper (float): Upper bound of the 95% confidence interval.
            - auc_std (float): Standard deviation of AUC values from bootstrap samples.
    """
    # Convert inputs to numpy arrays if they are not already
    def to_numpy(x):
        if isinstance(x, torch.Tensor):
            return x.cpu().numpy()  # Move tensor to CPU and convert to np.ndarray
        return np.array(x) if isinstance(x, list) else x

    y_true = to_numpy(y_true)
    if len(np.unique(y_true)) <= 1:
        raise ValueError("y_true contains only one unique class. At least two classes are required for AUC computation.")
    y_true_numbers = np.unique(y_true)
    y_prob = to_numpy(y_prob)
    n_classes = y_prob.shape[1]
    n_samples = len(y_true)

    auc_values = []
    if n_classes == 2:
        # Binary classification
        assert len(np.unique(y_true)) == 2, "y_true should contain only two unique classes for binary classification."
        for _ in range(n_bootstrap):
            indices = np.random.choice(n_samples, size=n_samples, replace=True)
            if len(np.unique(y_true[indices])) == 2:
                auc = roc_auc_score(y_true[indices], y_prob[indices][:, 1])
                auc_values.append(auc)

    elif n_classes > 2:  # Multiclass classification
        if len(np.unique(y_true)) < n_classes:
            # Handle the case where len(np.unique(y_true)) < n_classes
            # Calculate AUC for each existing class and average them
            existing_classes = np.unique(y_true)
            for _ in range(n_bootstrap):
                indices = np.random.choice(n_samples, size=n_samples, replace=True)
                class_aucs = []
                for cls in existing_classes:
                    cls = int(cls)
                    binary_y_true = (y_true[indices] == cls).astype(int)  # Convert to binary classification for the current class
                    if len(np.unique(binary_y_true)) == 2:  # Ensure both classes are present
                        auc = roc_auc_score(binary_y_true, y_prob[indices][:, cls])
                        class_aucs.append(auc)
                if len(class_aucs) > 0:
                    auc_values.append(np.mean(class_aucs))

        elif len(np.unique(y_true)) == n_classes:
            # Handle the case where len(np.unique(y_true)) == n_classes. Use 'Macro-AUC' to calculate AUC for each class and average them
            for _ in range(n_bootstrap):
                indices = np.random.choice(n_samples, size=n_samples, replace=True)
                if len(np.unique(y_true[indices])) == n_classes:
                    auc = roc_auc_score(y_true[indices], y_prob[indices], average='macro', multi_class='ovr')
                    auc_values.append(auc)
    
    if len(auc_values) == 0:
        raise ValueError("Bootstrap did not generate valid samples. This may occur if the data contains only one class or is highly imbalanced.")
    
    auc_values = np.array(auc_values)
    
    # Calculate 95% CI and mean, standard deviation
    ci_lower = np.percentile(auc_values, 100 * alpha / 2)  # Lower bound of CI
    ci_upper = np.percentile(auc_values, 100 * (1 - alpha / 2))  # Upper bound of CI
    auc_mean = np.mean(auc_values)  # Mean AUC
    auc_std = np.std(auc_values)  # Standard deviation of AUC values
    
    return auc_mean, ci_lower, ci_upper, auc_std, auc_values


# Calculate C-index, Concordance Index
from sksurv.metrics import concordance_index_censored

def bootstrap_cindex(risks: np.ndarray | torch.Tensor | list,
                     pfs: np.ndarray | torch.Tensor | list, 
                     status: np.ndarray | torch.Tensor | list, 
                     n_bootstrap: int = 3000,
                     alpha: float = 0.05) -> tuple[float, float, float, float]:
    """
    Compute the 95% CI, Cindex_mean and Cindex_std using the Bootstrap method.

    Parameters:
        risks (array-like): Predicted risks. Can be a numpy array, torch tensor, or list.
        pfs (array-like): True survival times. Can be a numpy array, torch tensor, or list.
        status (array-like): Event indicators (1 if event occurred, 0 otherwise). Can be a numpy array, torch tensor, or list.
        n_bootstrap (int, optional): Number of bootstrap resampling iterations. Default is 3000.

    Returns:
        tuple: A tuple containing:
            - cindex_mean (float): Mean C-index from bootstrap samples.
            - ci_lower (float): Lower bound of the 95% confidence interval.
            - ci_upper (float): Upper bound of the 95% confidence interval.
            - cindex_std (float): Standard deviation of C-index values from bootstrap samples.
    """
    # Convert inputs to numpy arrays if they are not already
    def to_numpy(x):
        if isinstance(x, torch.Tensor):
            return x.cpu().numpy()  # Move tensor to CPU and convert to np.ndarray
        return np.array(x) if isinstance(x, list) else x

    risks = to_numpy(risks)
    pfs = to_numpy(pfs)
    status = to_numpy(status)

    cindex_values = []
    
    # Bootstrap resampling
    for _ in range(n_bootstrap):
        indices = np.random.choice(len(risks), size=len(risks), replace=True)
        if len(np.unique(status[indices])) > 1:
            status = status.astype(bool)
            cindex = concordance_index_censored(status[indices], pfs[indices], np.squeeze(risks[indices]))[0]
            cindex_values.append(cindex)

    if len(cindex_values) == 0:
        raise ValueError("Bootstrap did not generate valid samples. This may occur if the data contains only one class or is highly imbalanced.")
    
    cindex_values = np.array(cindex_values)

    # Calculate 95% CI and mean, standard deviation
    ci_lower = np.percentile(cindex_values, 100 * alpha / 2)  # Lower bound of CI
    ci_upper = np.percentile(cindex_values, 100 * (1 - alpha / 2))  # Upper bound of CI
    cindex_mean = np.mean(cindex_values)  # Mean C-index
    cindex_std = np.std(cindex_values)  # Standard deviation of C-index values
    return cindex_mean, ci_lower, ci_upper, cindex_std, cindex_values


#### Internal Test set

In [None]:
# ! Inetrnal Test Set,NEVA multi-modal
all_top_entries = []
MultiModal_list=['config_cls_hazard_level_MultiModal','config_cls_subtype_MultiModal',
                 'config_cls_shimada_MultiModal','config_cls_mki_MultiModal',
                 'config_cls_alk_MultiModal','config_cls_cmyc_MultiModal',
                 'config_cls_nmyc_MultiModal','config_cls_p36_MultiModal','config_cls_q23_MultiModal',
                 'config_reg_pfs_MultiModal','config_reg_os_MultiModal',]

for fname in MultiModal_list:
    config_yaml = read_config(fname)
    original_csv = config_yaml['Data']['dataframe']  
    proj_name = config_yaml['Data']['label_name']   
    if proj_name in ['pfs','os']:
        mode = 'mean'
        print(f'\n{proj_name},mode:{mode}')
    else:
        mode = 'best'
        if proj_name == 'hazard_level':
            print(f'\nRisk Group,mode:{mode}')
        else:
            print(f'\n{proj_name},mode:{mode}')


    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    num_gpus = 1
    dist = False 
    seed_chosen,fold_chosen = config_yaml['seed&fold']

    seed_list = [42,184,762,381,493,526,307, 648, 255, 739]
    for seed in seed_list:
        if seed != seed_chosen:
            continue

        wts_dir = f"/home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/{fname}/{proj_name}/seed{seed}"
        csv_path = f"{wts_dir}/five_fold.csv"   
        all_data_csv_path = f"/home/huruizhen/mil/workspace/splits/{proj_name}.csv"
        print(f"csv_path: {csv_path}")

        assert os.path.exists(csv_path), f"CSV file not found: {csv_path}"
        assert os.path.exists(all_data_csv_path), f"CSV file not found: {all_data_csv_path}"
        df_all_data = pd.read_csv(all_data_csv_path)
        df = pd.read_csv(csv_path)

        metric = []
        for fold in range(5):
            if fold != fold_chosen:
                continue

            test_df = df[df['fold'] == fold].copy()
            test_df_case_ids = test_df['case_id'].tolist()
            train_df = df[df['fold'] != fold].copy()
            train_df_case_ids = train_df['case_id'].tolist()
            os.makedirs(f"/home/huruizhen/mil/workspace/独立测试集/多中心测试/内部测试集/{proj_name}/", exist_ok=True)

            if proj_name not in ['pfs','os']:
                if test_df[str(proj_name)].nunique() == 1:
                    print(f'{proj_name}There is only one category, cannot calculate auc/cindex')
                    continue
             
            patient_files = {}
            for index, row_test in test_df.iterrows():
                case_id = row_test['case_id']  
                if case_id not in patient_files:
                    patient_files[case_id] = []
                patient_files[case_id].append(row_test['filename'])

            training_all_df = pd.DataFrame()  
            for index, row_train in train_df.iterrows():
                case_id = row_train['case_id']  
                for index, row_all in df_all_data.iterrows():
                    if row_all['case_id'] == case_id:
                        training_all_df = pd.concat([training_all_df, row_all.to_frame().T], ignore_index=True)

            SAVE = False
            if SAVE:
                training_all_df.to_csv(f"/home/huruizhen/mil/workspace/独立测试集/多中心测试/内部测试集/{proj_name}/train_seed{seed}_fold{fold}.csv", index=False)
                test_df.to_csv(f"/home/huruizhen/mil/workspace/独立测试集/多中心测试/内部测试集/{proj_name}/test_seed{seed}_fold{fold}.csv", index=False)


            feature_dict = {}
            for file_name in test_df['filename']:
                file_path = '/home/huruizhen/多尺度特征/images_multisacle_2048/' + file_name
                feature_dict[file_name[:-len('.pt')]] = file_path

            patient_files = {str(k): [feature_dict[i[:-len('.pt')]] for i in v] for k, v in patient_files.items()}
            patient_feature_vision = {case_id: [torch.load(file_name,map_location=device).unsqueeze(0).float() for file_name in file_list] for case_id, file_list in patient_files.items()}

            test_df.drop_duplicates(subset=['case_id'], inplace=True)  
            test_df.reset_index(drop=True, inplace=True) 
            test_df['case_id'] = test_df['case_id'].astype(str)  
            if config_yaml['Data']['label_name'] in ['os','pfs']:
                all_labels = torch.tensor(test_df['time'].values)
                all_status = torch.tensor(test_df['status'].values)
                all_case_id_label = {case_id : (torch.tensor(test_df[test_df['case_id']==case_id]['time'].tolist()),torch.tensor(test_df[test_df['case_id']==case_id]['status'].tolist())) for case_id,feature_list in patient_feature_vision.items()}
            else:
                all_labels = torch.tensor(test_df[config_yaml['Data']['label_name']].values)  
                all_case_id_label = {case_id : torch.tensor(test_df[test_df['case_id']==case_id][config_yaml['Data']['label_name']].tolist()) for case_id,feature_list in patient_feature_vision.items()}
            
            patient_feature_report = {str(case_id):torch.load('/home/huruizhen/mil_dataset_1024/reports_1024/'+str(case_id)+'.pt',map_location=device).unsqueeze(dim=0).float() for case_id in test_df['case_id']}
            assert len(patient_feature_vision.keys())==len(patient_feature_report.keys())

            with torch.inference_mode():
                wts_path = f"/home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/{fname}/{proj_name}/seed{seed}/fold_{fold}/fold_{fold}.pth"
                # print(f'seed：{seed} fold：{fold}\n加载模型的路径是：{wts_path}')
                wts = torch.load(wts_path)
                save_path = None
                model = MILModel(config_yaml, save_path=str(save_path)).to(device)
                model.load_state_dict(wts,strict=True)
                criterion = get_obj_from_str(config_yaml["Loss"]["name"])(**config_yaml["Loss"]["params"])
                case_id_logits={}

                for case_id,feature_list_vision in patient_feature_vision.items():  

                    if case_id not in patient_feature_report:
                        print(f"{case_id} Not in the patient's pathological report")
                        print(patient_feature_report)
                    feature_report = patient_feature_report[case_id].to(device)
                    # print(feature_report.shape)
                    logits = []
                    for feature_vision in feature_list_vision:
                        feature_vision = feature_vision.to(device)
                        # print(case_id,feature_vision.shape)
                        with torch.inference_mode():  
                            logit, results_dict = model((feature_vision,feature_report))

                        logits.append(logit)

                    # print(logits)

                    if mode == 'mean':
                        case_id_logits[case_id]=torch.cat(logits,dim=0).mean(dim=0,keepdim=True)  # 同一个人的不同file获得的logit取平均
                    # print(case_id_logits[case_id].shape)
                    if mode == 'best':
                        case_id_logits[case_id] = logits[0]
                        best_logit = logits[0]
                        if config_yaml['Data']['label_name'] in ['os','pfs']:
                            label,status = all_case_id_label[case_id]
                            label = label.long().to(device)
                            status = status.long().to(device)
                            criterion = CoxSurvLoss()
                            best_crossentropy = criterion(logits[0].float(),label,status)
                            for logit in logits[1:]:
                                crossentropy = criterion(logit,label,status)
                                if crossentropy <= best_crossentropy:
                                    case_id_logits[case_id] = logit
                        else:
                            label = all_case_id_label[case_id].long().to(device)

                            criterion = nn.CrossEntropyLoss()
                            best_crossentropy = criterion(logits[0].float(),label)
                            # print('交叉熵损失是：',best_crossentropy)
                            for logit in logits[1:]:
                                crossentropy = criterion(logit,label)
                                # print(crossentropy)
                                if crossentropy <= best_crossentropy:
                                    case_id_logits[case_id] = logit
                    # print(f"case_id:{case_id},logits:{case_id_logits[case_id]},lebels:{all_case_id_label[case_id]}")
                del patient_feature_vision, patient_feature_report
                all_logits = torch.cat([logits for case_id,logits in case_id_logits.items()],dim = 0)

                if config_yaml['Data']['label_name'] in ['os','pfs']:

                    risks_list = all_logits
                    labels_list = all_labels
                    status_list = all_status
                    c_index =float(round(compute_c_index(risks_list, labels_list, status_list), 4))
                    p_value = float(round(compute_P_value(risks_list, labels_list, status_list), 4))
                    print(f'c_index:{c_index},p_value:{p_value}')
                    metric.append(c_index)
                    
                    bootstrap_cindex_mean, ci_lower, ci_upper, cindex_std, cindex_values = bootstrap_cindex(risks_list, labels_list, status_list)
                    print(f'bootstrap_cindex_mean & CI:\n{bootstrap_cindex_mean:.4f} ({ci_lower:.4f}--{ci_upper:.4f}), {p_value}\n')

                else:

                    auc = float(round(calculate_auc(all_logits, all_labels), 4))
                    # print(f'auc:{auc}')
                    metric.append(auc)

                    all_probs = torch.softmax(all_logits, dim=1)
                    print(f"all_labels: {all_labels.shape}, all_probs: {all_probs.shape}")
                    bootstrap_auc_mean, ci_lower, ci_upper, auc_std, auc_values = bootstrap_auc(all_labels, all_probs)
                    print(f'bootstrap_auc_mean & CI:\n{bootstrap_auc_mean:.4f} ({ci_lower:.4f}--{ci_upper:.4f})\n')
                # del patient_feature_vision
                del model
                del all_logits
                del all_labels
                torch.cuda.empty_cache()





Risk Group,mode:best
csv_path: /home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/config_cls_hazard_level_MultiModal/hazard_level/seed493/five_fold.csv
all_labels: torch.Size([93]), all_probs: torch.Size([93, 3])
bootstrap_auc_mean & CI:
0.7858 (0.7017--0.8647)


subtype,mode:best
csv_path: /home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/config_cls_subtype_MultiModal/subtype/seed648/five_fold.csv
all_labels: torch.Size([170]), all_probs: torch.Size([170, 3])
bootstrap_auc_mean & CI:
0.9698 (0.9499--0.9865)


shimada,mode:best
csv_path: /home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/config_cls_shimada_MultiModal/shimada/seed739/five_fold.csv
all_labels: torch.Size([181]), all_probs: torch.Size([181, 2])
bootstrap_auc_mean & CI:
0.8357 (0.7720--0.8939)


mki,mode:best
csv_path: /home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/config_cls_mki_MultiModal/mki/seed526/five_fold.csv
all_labels: torch.Size([127]), all_probs: torch.Size([127, 3])
bootstrap_auc_mean & CI:
0.8235 (0.7634--0.8775

### Prospective test set, external PUFH, SCH, and GCI test sets

In [None]:
# Prospective test set, external PUFH, SCH, and GCI test sets

all_top_entries = []

csv_dir_list = ['/home/huruizhen/mil/workspace/独立测试集/多中心测试/前瞻实验','/home/huruizhen/mil/workspace/独立测试集/多中心测试/北大附一合并起来',
                '/home/huruizhen/mil/workspace/独立测试集/多中心测试/深圳',
                '/home/huruizhen/mil/workspace/独立测试集/多中心测试/贵阳&cbtn&内蒙古',]

for csv_dir in csv_dir_list:  
    hospital_center = os.path.basename(csv_dir)
    if hospital_center == '前瞻实验':
        hospital_center = 'Prospective'
    elif hospital_center == '北大附一合并起来':
        hospital_center = 'PUFH'
    elif hospital_center == '深圳':
        hospital_center = 'SCH'
    elif hospital_center == '贵阳&cbtn&内蒙古':
        hospital_center = 'GCI'
    print(f'\n\n***************Cohorts：{hospital_center}***************')
    proj_name_list = [proj_name.replace('.csv','') for proj_name in os.listdir(csv_dir) if proj_name.endswith('.csv')]
    MultiModal_list = []
    for proj_name in proj_name_list:
        if proj_name in ['pfs','os']:
            MultiModal_list.append('config_reg_'+proj_name+'_MultiModal')
        else:
            MultiModal_list.append('config_cls_'+proj_name+'_MultiModal')
    print(MultiModal_list)

    for fname in MultiModal_list:

        config_yaml = read_config(fname) 
        proj_name = config_yaml['Data']['label_name']

        chosen_seed,chosen_fold = config_yaml['seed&fold']


        if proj_name in ['pfs','os']:
            mode = 'mean'
            print(f'{proj_name}')
        else:
            mode = 'best'
            if proj_name == 'hazard_level':
                print(f'Risk Group')
            elif proj_name == 'p36':
                print(f'1p36')
            elif proj_name == 'q23':
                print(f'11q23')
            else:
                print(f'{proj_name}')



        output_dir = os.path.join(csv_dir,proj_name)
        # os.makedirs(output_dir,exist_ok=True)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # device = torch.device("cpu")
        num_gpus = 1
        dist = False 
        original_csv = os.path.join(csv_dir,config_yaml['Data']['label_name']+'.csv')   
        print(original_csv)


        test_df = pd.read_csv(original_csv) 
        test_df['case_id'] = test_df['case_id'].astype(str) 
        

        if test_df[str(proj_name)].nunique() == 1:
            print(f"{proj_name} There is only one category, cannot calculate auc/cindex")
            continue

        patient_files = {}
        for index, row in test_df.iterrows():
            case_id = row['case_id'] 
            if case_id not in patient_files:
                patient_files[case_id] = []
            patient_files[case_id].append(row['filename']) 

        feature_dict = {}
        for file_name in test_df['filename']:
            file_path = '/home/huruizhen/多尺度特征/images_multisacle_2048/'+file_name
            feature_dict[file_name[:-len('.pt')]]=file_path

        # print(feature_dict)

        patient_files = {str(k): [feature_dict[i[:-len('.pt')]] for i in v] for k, v in patient_files.items()} 


        patient_feature_vision = {case_id: [torch.load(file_name,map_location=device).unsqueeze(0).float() for file_name in file_list] for case_id, file_list in patient_files.items()}   
        # print(type(patient_feature_vision.keys()))

        test_df.drop_duplicates(subset=['case_id'], inplace=True)  
        test_df.reset_index(drop=True, inplace=True)  
        test_df['case_id'] = test_df['case_id'].astype(str)  

        
        if config_yaml['Data']['label_name'] in ['os','pfs']:
            all_labels = torch.tensor(test_df[config_yaml['Data']['label_name']].values)
            all_status = torch.tensor(test_df['status'].values)
            all_case_id_label = {case_id : (torch.tensor(test_df[test_df['case_id']==case_id][config_yaml['Data']['label_name']].tolist()),torch.tensor(test_df[test_df['case_id']==case_id]['status'].tolist())) for case_id,feature_list in patient_feature_vision.items()}
        else:
            all_labels = torch.tensor(test_df[config_yaml['Data']['label_name']].values) 
            all_case_id_label = {case_id : torch.tensor(test_df[test_df['case_id']==case_id][config_yaml['Data']['label_name']].tolist()) for case_id,feature_list in patient_feature_vision.items()}

        patient_feature_report = {str(case_id):torch.load('/home/huruizhen/mil_dataset_1024/reports_1024/'+str(case_id)+'.pt',map_location=device).unsqueeze(dim=0).float() for case_id in test_df['case_id']}

        assert len(patient_feature_vision.keys())==len(patient_feature_report.keys())
        # if len(patient_feature_report.keys())<=5:
        #     print(f'The number of case_ids in the independent test set of this class is too small:{len(patient_feature_report.keys())}')
        #     continue


        seed_list = [42,184,762,381,493,526,307, 648, 255, 739]
        # 每个病人的所有feature的logits取平均
        for seed in seed_list:
            if seed != chosen_seed:
                continue

            wts_dir = f"/home/huruizhen/mil/workspace/2048维度/多模态/动态权重/模型/{fname}/{proj_name}/seed{seed}/"

            save_path = f"/home/huruizhen/mil/workspace/独立测试集/测试集性能/"
            with torch.inference_mode():
                model = MILModel(config_yaml, save_path=str(save_path)).to(device)  # 定义模型  使用了 pytorch_lighting

                criterion = get_obj_from_str(config_yaml["Loss"]["name"])(**config_yaml["Loss"]["params"])

                case_id_logits={}
                metric = []
                for i in range(5):
                    if i != chosen_fold:
                        continue

                    wts_path = os.path.join(wts_dir, f"fold_{i}/fold_{i}.pth")
                    wts = torch.load(wts_path)


                    model.load_state_dict(wts,strict=True)

                    for case_id,feature_list_vision in patient_feature_vision.items():
                        case_id = str(case_id)
                        if case_id not in patient_feature_report:
                            print(f"{case_id}Not in the patient's pathological report")
                            print(patient_feature_report)
                        feature_report = patient_feature_report[case_id].to(device)
                        # print(feature_report.shape)
                        logits = []
                        for feature_vision in feature_list_vision:
                            feature_vision = feature_vision.to(device)

                            with torch.inference_mode():  
                                logit, results_dict = model((feature_vision,feature_report))

                            logits.append(logit)


                        if mode == 'mean':
                            case_id_logits[case_id]=torch.cat(logits,dim=0).mean(dim=0,keepdim=True)  
                        if mode == 'best':
                            case_id_logits[case_id] = logits[0]
                            best_logit = logits[0]
                            if config_yaml['Data']['label_name'] in ['os','pfs']:
                                label,status = all_case_id_label[case_id]
                                label = label.long().to(device)
                                status = status.long().to(device)
                                criterion = CoxSurvLoss()
                                best_crossentropy = criterion(logits[0].float(),label,status)
                                for logit in logits[1:]:
                                    crossentropy = criterion(logit,label,status)
                                    if crossentropy <= best_crossentropy:
                                        case_id_logits[case_id] = logit
                            else:
                                label = all_case_id_label[case_id].long().to(device)

                                criterion = nn.CrossEntropyLoss()
                                best_crossentropy = criterion(logits[0].float(),label)
                                for logit in logits[1:]:
                                    crossentropy = criterion(logit,label)
                                    if crossentropy <= best_crossentropy:
                                        case_id_logits[case_id] = logit
                        
                    all_logits = torch.cat([logits for case_id,logits in case_id_logits.items()],dim = 0)
                    
                    if config_yaml['Data']['label_name'] in ['os','pfs']:
                        risks_list = all_logits
                        labels_list = all_labels
                        status_list = all_status
                        c_index =float(round(compute_c_index(risks_list, labels_list, status_list),4))
                        p_value = float(round(compute_P_value(risks_list, labels_list, status_list),4))
                        print(f'c_index:{c_index},p_value:{p_value}')
                        metric.append(c_index)
                        bootstrap_cindex_mean, ci_lower, ci_upper, cindex_std, cindex_values = bootstrap_cindex(risks_list, labels_list, status_list)
                        print(f'bootstrap_cindex_mean & CI:\n{bootstrap_cindex_mean:.4f} ({ci_lower:.4f}--{ci_upper:.4f}), {p_value}\n')

                    else:
                        auc = float(round(calculate_auc(all_logits, all_labels), 4))
                        
                        # print(f'auc:{auc}')
                        metric.append(auc)
                        all_probs = torch.softmax(all_logits, dim=1)
                        print(f"all_labels: {all_labels.shape}, all_probs: {all_probs.shape}")
                        bootstrap_auc_mean, ci_lower, ci_upper, auc_std, auc_values = bootstrap_auc(all_labels, all_probs)
                        print(f'bootstrap_auc_mean & CI:\n{bootstrap_auc_mean:.4f} ({ci_lower:.4f}--{ci_upper:.4f})\n')





***************Cohorts：Prospective***************
['config_cls_cmyc_MultiModal', 'config_cls_subtype_MultiModal', 'config_cls_hazard_level_MultiModal', 'config_cls_p36_MultiModal', 'config_cls_shimada_MultiModal', 'config_cls_nmyc_MultiModal', 'config_cls_mki_MultiModal', 'config_cls_alk_MultiModal', 'config_cls_q23_MultiModal']
cmyc
/home/huruizhen/mil/workspace/独立测试集/多中心测试/前瞻实验/cmyc.csv
all_labels: torch.Size([59]), all_probs: torch.Size([59, 2])
bootstrap_auc_mean & CI:
0.8756 (0.7673--0.9578)

subtype
/home/huruizhen/mil/workspace/独立测试集/多中心测试/前瞻实验/subtype.csv
all_labels: torch.Size([67]), all_probs: torch.Size([67, 3])
bootstrap_auc_mean & CI:
0.9484 (0.9048--0.9835)

Risk Group
/home/huruizhen/mil/workspace/独立测试集/多中心测试/前瞻实验/hazard_level.csv
all_labels: torch.Size([23]), all_probs: torch.Size([23, 3])
bootstrap_auc_mean & CI:
0.8814 (0.7263--0.9898)

1p36
/home/huruizhen/mil/workspace/独立测试集/多中心测试/前瞻实验/p36.csv
all_labels: torch.Size([33]), all_probs: torch.Size([33, 2])
bootstrap_