# 多模态医学影像Embedding融合

基于MedFuse的LSTM-based fusion方法，融合SAM-Med2D、MedCLIP和RadFM的image embedding。

## 功能
- 融合三种模型的embedding（SAM-Med2D, MedCLIP, RadFM）
- CT和PET分别进行fusion
- 分别计算进展和死亡的ROC AUC


In [90]:
# 导入必要的库
import argparse
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

print("库导入完成！")


库导入完成！


## 1. 定义MedFuse LSTM融合模型


In [91]:
class MedFuseLSTM(nn.Module):
    """
    MedFuse LSTM-based fusion module
    可以处理单模态或多模态输入
    """
    def __init__(
        self,
        input_dims: List[int],
        hidden_dim: int = 128,
        num_layers: int = 2,
        dropout: float = 0.3,
        output_dim: int = 1,
    ):
        super(MedFuseLSTM, self).__init__()
        self.num_modalities = len(input_dims)
        self.hidden_dim = hidden_dim
        
        # 为每个模态创建投影层
        self.projection_layers = nn.ModuleList([
            nn.Linear(dim, hidden_dim) for dim in input_dims
        ])
        
        # LSTM层
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=False,
        )
        
        # 输出层
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, output_dim),
        )
        
    def forward(self, modality_embeddings: List[torch.Tensor]) -> torch.Tensor:
        """
        Args:
            modality_embeddings: List of tensors, each of shape [batch_size, embedding_dim]
                                可以是1个或多个模态
        Returns:
            output: [batch_size, output_dim]
        """
        batch_size = modality_embeddings[0].shape[0]
        
        # 投影每个模态的embedding
        projected = []
        for i, emb in enumerate(modality_embeddings):
            if emb is not None:
                proj = self.projection_layers[i](emb)  # [batch_size, hidden_dim]
                projected.append(proj)
        
        if not projected:
            # 如果没有可用模态，返回零向量
            return torch.zeros(batch_size, 1, device=modality_embeddings[0].device)
        
        # 堆叠为序列 [batch_size, num_modalities, hidden_dim]
        sequence = torch.stack(projected, dim=1)
        
        # LSTM处理
        lstm_out, (h_n, c_n) = self.lstm(sequence)
        
        # 使用最后一个时间步的输出
        last_output = lstm_out[:, -1, :]  # [batch_size, hidden_dim]
        
        # 全连接层
        output = self.fc(last_output)
        
        return output

print("MedFuseLSTM模型定义完成！")


MedFuseLSTM模型定义完成！


## 2. 定义数据集类


In [92]:
class MultiModalDataset(Dataset):
    """多模态数据集"""
    def __init__(
        self,
        patient_data: Dict[str, Dict],
        labels: Dict[str, int],
        modality_keys: List[str],
        embedding_dims: Dict[str, int],
    ):
        self.patient_data = patient_data
        self.labels = labels
        self.modality_keys = modality_keys
        self.embedding_dims = embedding_dims
        self.patient_ids = list(patient_data.keys())
        
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        data = self.patient_data[patient_id]
        label = self.labels.get(patient_id, 0)
        
        # 提取各模态的embedding
        embeddings = []
        for key in self.modality_keys:
            if key in data and data[key] is not None:
                emb_array = data[key]
                # 确保是numpy数组
                if not isinstance(emb_array, np.ndarray):
                    emb_array = np.array(emb_array)
                
                # 处理不同维度的数组
                if emb_array.ndim == 0:
                    # 标量，转换为1维数组
                    emb_array = np.array([emb_array])
                elif emb_array.ndim > 1:
                    # 多维数组，展平
                    emb_array = emb_array.flatten()
                
                # 确保维度正确
                expected_dim = self.embedding_dims.get(key, emb_array.shape[0] if emb_array.ndim > 0 else 1)
                if emb_array.shape[0] != expected_dim:
                    # 如果维度不匹配，调整
                    if emb_array.shape[0] > expected_dim:
                        emb_array = emb_array[:expected_dim]
                    else:
                        # 用零填充
                        padding = np.zeros(expected_dim - emb_array.shape[0], dtype=emb_array.dtype)
                        emb_array = np.concatenate([emb_array, padding])
                
                emb = torch.tensor(emb_array, dtype=torch.float32)
            else:
                # 如果缺失，使用零向量
                dim = self.embedding_dims.get(key, 1)
                emb = torch.zeros(dim, dtype=torch.float32)
            embeddings.append(emb)
        
        return {
            'embeddings': embeddings,
            'label': torch.tensor(label, dtype=torch.float32),
            'patient_id': patient_id,
        }

print("MultiModalDataset定义完成！")


MultiModalDataset定义完成！


## 3. 加载Embedding数据


In [93]:
def normalize_patient_name(name: str) -> str:
    """标准化患者名称，用于匹配"""
    if pd.isna(name):
        return ""
    name = str(name).strip()
    # 去除多余空格
    name = " ".join(name.split())
    # 统一转换为小写（可选，根据实际情况调整）
    # name = name.lower()
    return name

def load_embeddings(
    embedding_paths: Dict[str, str],
    patient_key_col: str = 'patient_name',
) -> Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, str]]:
    """
    加载所有embedding文件
    
    Args:
        embedding_paths: {embedding_name: file_path}
        patient_key_col: 患者标识列名
    
    Returns:
        (all_embeddings, original_to_normalized): 
        - all_embeddings: {normalized_patient_id: {embedding_name: embedding_array}}
        - original_to_normalized: {original_name: normalized_name} 映射
    """
    all_embeddings = {}
    original_to_normalized = {}  # 原始名称到标准化名称的映射
    
    for emb_name, emb_path in embedding_paths.items():
        logging.info(f"加载 {emb_name} from {emb_path}")
        
        if emb_path.endswith('.npz'):
            data = np.load(emb_path, allow_pickle=True)
            
            # 尝试不同的键名
            if 'embeddings' in data:
                embeddings = data['embeddings']
            elif 'embedding' in data:
                embeddings = data['embedding']
            else:
                # 取第一个数组
                keys = list(data.keys())
                embeddings = data[keys[0]]
            
            # 获取患者名称
            if 'patient_names' in data:
                patient_names = data['patient_names']
            elif 'patient_name' in data:
                patient_names = data['patient_name']
            elif patient_key_col in data:
                patient_names = data[patient_key_col]
            else:
                # 尝试从CSV读取
                csv_path = emb_path.replace('.npz', '.csv')
                if os.path.exists(csv_path):
                    df = pd.read_csv(csv_path, nrows=1000)  # 只读前1000行来检查
                    if patient_key_col in df.columns:
                        # 重新读取完整文件
                        df = pd.read_csv(csv_path)
                        embeddings = df.drop(columns=[patient_key_col, 'center'] if 'center' in df.columns else [patient_key_col]).values
                        patient_names = df[patient_key_col].values
                    else:
                        logging.warning(f"无法找到患者名称列，跳过 {emb_name}")
                        continue
                else:
                    logging.warning(f"无法找到患者名称，跳过 {emb_name}")
                    continue
            
            # 处理patient_names并标准化
            if isinstance(patient_names, np.ndarray):
                if patient_names.dtype == object:
                    patient_names = [str(p) for p in patient_names]
                else:
                    patient_names = patient_names.astype(str)
            else:
                patient_names = [str(p) for p in patient_names]
            
            # 存储每个患者的embedding（使用标准化名称）
            # 对于RadFM，同一个患者可能有多个embedding（多个token），需要聚合
            patient_embeddings_dict = {}  # {patient_id: [emb1, emb2, ...]}
            
            for i, original_name in enumerate(patient_names):
                normalized_name = normalize_patient_name(original_name)
                if not normalized_name:
                    continue
                
                # 保存映射
                original_to_normalized[original_name] = normalized_name
                
                # 如果是2D数组，取第i行
                if embeddings.ndim == 2:
                    emb = embeddings[i]
                else:
                    emb = embeddings[i] if i < len(embeddings) else None
                
                if emb is not None:
                    if normalized_name not in patient_embeddings_dict:
                        patient_embeddings_dict[normalized_name] = []
                    patient_embeddings_dict[normalized_name].append(emb)
            
            # 聚合同一患者的多个embedding（对于RadFM等）
            for normalized_name, emb_list in patient_embeddings_dict.items():
                if normalized_name not in all_embeddings:
                    all_embeddings[normalized_name] = {}
                
                if len(emb_list) == 1:
                    # 只有一个embedding，直接使用
                    all_embeddings[normalized_name][emb_name] = emb_list[0]
                else:
                    # 多个embedding，进行平均池化
                    emb_array = np.array(emb_list)
                    # 如果是3D数组 [num_tokens, embedding_dim]，平均池化
                    if emb_array.ndim == 2:
                        # 对token维度求平均
                        aggregated_emb = np.mean(emb_array, axis=0)
                    else:
                        # 展平后求平均
                        aggregated_emb = np.mean(emb_array.flatten())
                    all_embeddings[normalized_name][emb_name] = aggregated_emb
        
        elif emb_path.endswith('.xlsx') or emb_path.endswith('.xls'):
            # 处理Excel文件（SAM-Med2D）
            try:
                df = pd.read_excel(emb_path, engine='openpyxl')
                all_columns = df.columns.tolist()
                logging.info(f"  Excel文件列名: {all_columns[:10]}... (共{len(all_columns)}列)")
                
                # 尝试找到患者名称列
                patient_col = None
                for col in all_columns:
                    col_lower = str(col).lower()
                    if 'patient' in col_lower or 'name' in col_lower or '患者' in str(col) or '姓名' in str(col):
                        patient_col = col
                        break
                
                # 如果找不到，使用第一列
                if patient_col is None:
                    patient_col = all_columns[0]
                    logging.info(f"  未找到患者名称列，使用第一列: {patient_col}")
                else:
                    logging.info(f"  使用患者名称列: {patient_col}")
                
                # 提取embedding列（排除患者名称列和center列）
                exclude_cols = [patient_col, 'center', '中心', 'file_key', 'Center']
                emb_cols = [col for col in df.columns if col not in exclude_cols]
                
                # 进一步过滤：只保留数值列
                numeric_cols = []
                for col in emb_cols:
                    col_str = str(col).lower()
                    if col_str.startswith('feature_') or col_str.startswith('embedding_') or col_str.replace('_', '').replace('.', '').isdigit():
                        numeric_cols.append(col)
                
                if not numeric_cols:
                    # 如果没找到特征列，尝试所有非排除列
                    numeric_cols = emb_cols
                    logging.info(f"  使用所有非排除列作为embedding列（共{len(numeric_cols)}列）")
                
                if not numeric_cols:
                    logging.warning(f"  未找到embedding列")
                    continue
                
                # 转换数值列
                emb_data = df[numeric_cols].copy()
                for col in emb_data.columns:
                    emb_data[col] = pd.to_numeric(emb_data[col], errors='coerce')
                emb_data = emb_data.fillna(0)
                
                embeddings_array = emb_data.values.astype(np.float32)
                patient_names = df[patient_col].astype(str).values
                
                # 对于RadFM，同一个患者可能有多个embedding，需要聚合
                patient_embeddings_dict = {}  # {patient_id: [emb1, emb2, ...]}
                
                for i, original_name in enumerate(patient_names):
                    normalized_name = normalize_patient_name(original_name)
                    if not normalized_name:
                        continue
                    
                    # 保存映射
                    original_to_normalized[original_name] = normalized_name
                    
                    emb = embeddings_array[i]
                    if normalized_name not in patient_embeddings_dict:
                        patient_embeddings_dict[normalized_name] = []
                    patient_embeddings_dict[normalized_name].append(emb)
                
                # 聚合同一患者的多个embedding
                for normalized_name, emb_list in patient_embeddings_dict.items():
                    if normalized_name not in all_embeddings:
                        all_embeddings[normalized_name] = {}
                    
                    if len(emb_list) == 1:
                        # 只有一个embedding，直接使用
                        emb = emb_list[0]
                        if not isinstance(emb, np.ndarray):
                            emb = np.array(emb, dtype=np.float32)
                        all_embeddings[normalized_name][emb_name] = emb
                    else:
                        # 多个embedding，进行平均池化
                        try:
                            emb_arrays = []
                            for emb in emb_list:
                                if not isinstance(emb, np.ndarray):
                                    emb = np.array(emb, dtype=np.float32)
                                emb_arrays.append(emb)
                            
                            emb_array = np.array(emb_arrays)
                            if emb_array.dtype == object:
                                emb_array = np.array([np.array(e, dtype=np.float32) for e in emb_arrays])
                            
                            if emb_array.ndim == 2:
                                aggregated_emb = np.mean(emb_array, axis=0).astype(np.float32)
                            elif emb_array.ndim > 2:
                                aggregated_emb = np.mean(emb_array.flatten()).astype(np.float32)
                            else:
                                aggregated_emb = np.mean(emb_array).astype(np.float32)
                            
                            all_embeddings[normalized_name][emb_name] = aggregated_emb
                        except Exception as e:
                            logging.warning(f"  聚合患者 {normalized_name} 的embedding失败: {e}")
                            if emb_list:
                                emb = emb_list[0]
                                if not isinstance(emb, np.ndarray):
                                    emb = np.array(emb, dtype=np.float32)
                                all_embeddings[normalized_name][emb_name] = emb
                
                logging.info(f"  {emb_name}: 加载了 {len(set(original_to_normalized.values()))} 个唯一患者")
                
            except Exception as e:
                logging.error(f"读取Excel文件 {emb_path} 失败: {e}")
                continue
        
        elif emb_path.endswith('.csv'):
            # 先读取第一行来检查列名
            try:
                sample_df = pd.read_csv(emb_path, nrows=5)
                all_columns = sample_df.columns.tolist()
                logging.info(f"  CSV文件列名: {all_columns[:10]}... (共{len(all_columns)}列)")
                
                # 尝试找到患者名称列
                patient_col = None
                for col in all_columns:
                    col_lower = str(col).lower()
                    if 'patient' in col_lower or 'name' in col_lower or '患者' in str(col) or '姓名' in str(col):
                        patient_col = col
                        break
                
                # 如果找不到，使用第一列
                if patient_col is None:
                    patient_col = all_columns[0]
                    logging.info(f"  未找到患者名称列，使用第一列: {patient_col}")
                else:
                    logging.info(f"  使用患者名称列: {patient_col}")
                
            except Exception as e:
                logging.error(f"  读取CSV文件头失败: {e}")
                continue
            
            # 对于大文件，分块读取
            chunk_size = 1000
            patient_names = []
            embeddings_list = []
            
            try:
                for chunk in pd.read_csv(emb_path, chunksize=chunk_size):
                    if patient_col not in chunk.columns:
                        logging.warning(f"CSV文件 {emb_path} 中未找到 {patient_col} 列")
                        break
                    
                    # 提取embedding列（排除患者名称列、center列和其他非数值列）
                    exclude_cols = [patient_col, 'center', '中心', 'file_key', 'Center']
                    emb_cols = [col for col in chunk.columns if col not in exclude_cols]
                    
                    # 进一步过滤：只保留数值列（feature_开头或embedding_开头或纯数字列名）
                    numeric_cols = []
                    for col in emb_cols:
                        col_str = str(col).lower()
                        if col_str.startswith('feature_') or col_str.startswith('embedding_') or col_str.replace('_', '').replace('.', '').isdigit():
                            numeric_cols.append(col)
                    
                    if not numeric_cols:
                        # 如果没找到特征列，尝试所有非排除列
                        numeric_cols = emb_cols
                        logging.info(f"  使用所有非排除列作为embedding列（共{len(numeric_cols)}列）")
                    
                    if not numeric_cols:
                        logging.warning(f"  未找到embedding列")
                        break
                    
                    # 确保只选择数值列并转换类型
                    try:
                        # 先尝试选择数值类型
                        emb_data = chunk[numeric_cols].select_dtypes(include=[np.number])
                        
                        # 如果为空或列数不对，手动转换
                        if emb_data.empty or len(emb_data.columns) != len(numeric_cols):
                            # 逐列转换，将非数值转换为NaN
                            emb_data = chunk[numeric_cols].copy()
                            for col in emb_data.columns:
                                emb_data[col] = pd.to_numeric(emb_data[col], errors='coerce')
                            # 填充NaN为0
                            emb_data = emb_data.fillna(0)
                        
                        # 转换为numpy数组并确保是float32
                        emb_values = emb_data.values.astype(np.float32)
                        embeddings_list.append(emb_values)
                        patient_names.extend(chunk[patient_col].astype(str).values)
                        
                    except Exception as e:
                        logging.warning(f"  处理embedding列时出错: {e}，尝试备用方法")
                        # 备用方法：逐列处理
                        try:
                            emb_list = []
                            for col in numeric_cols:
                                col_data = pd.to_numeric(chunk[col], errors='coerce').fillna(0).values
                                emb_list.append(col_data)
                            
                            if emb_list:
                                emb_array = np.column_stack(emb_list).astype(np.float32)
                                embeddings_list.append(emb_array)
                                patient_names.extend(chunk[patient_col].astype(str).values)
                            else:
                                logging.error(f"  无法提取任何数值列")
                                break
                        except Exception as e2:
                            logging.error(f"  备用方法也失败: {e2}")
                            break
                
                if embeddings_list:
                    embeddings_array = np.vstack(embeddings_list)
                    
                    # 对于RadFM，同一个患者可能有多个embedding，需要聚合
                    patient_embeddings_dict = {}  # {patient_id: [emb1, emb2, ...]}
                    
                    for i, original_name in enumerate(patient_names):
                        normalized_name = normalize_patient_name(original_name)
                        if not normalized_name:
                            continue
                        
                        # 保存映射
                        original_to_normalized[original_name] = normalized_name
                        
                        emb = embeddings_array[i]
                        if normalized_name not in patient_embeddings_dict:
                            patient_embeddings_dict[normalized_name] = []
                        patient_embeddings_dict[normalized_name].append(emb)
                    
                    # 聚合同一患者的多个embedding
                    for normalized_name, emb_list in patient_embeddings_dict.items():
                        if normalized_name not in all_embeddings:
                            all_embeddings[normalized_name] = {}
                        
                        if len(emb_list) == 1:
                            # 只有一个embedding，直接使用
                            emb = emb_list[0]
                            # 确保是numpy数组
                            if not isinstance(emb, np.ndarray):
                                emb = np.array(emb, dtype=np.float32)
                            all_embeddings[normalized_name][emb_name] = emb
                        else:
                            # 多个embedding，进行平均池化
                            try:
                                # 确保所有embedding都是numpy数组
                                emb_arrays = []
                                for emb in emb_list:
                                    if not isinstance(emb, np.ndarray):
                                        emb = np.array(emb, dtype=np.float32)
                                    emb_arrays.append(emb)
                                
                                emb_array = np.array(emb_arrays)
                                # 确保是数值类型
                                if emb_array.dtype == object:
                                    # 如果包含字符串，尝试转换
                                    emb_array = np.array([np.array(e, dtype=np.float32) for e in emb_arrays])
                                
                                if emb_array.ndim == 2:
                                    # 对token维度求平均
                                    aggregated_emb = np.mean(emb_array, axis=0).astype(np.float32)
                                elif emb_array.ndim > 2:
                                    # 多维数组，展平后求平均
                                    aggregated_emb = np.mean(emb_array.flatten()).astype(np.float32)
                                else:
                                    # 一维数组，直接求平均
                                    aggregated_emb = np.mean(emb_array).astype(np.float32)
                                
                                all_embeddings[normalized_name][emb_name] = aggregated_emb
                            except Exception as e:
                                logging.warning(f"  聚合患者 {normalized_name} 的embedding失败: {e}")
                                # 如果聚合失败，使用第一个embedding
                                if emb_list:
                                    emb = emb_list[0]
                                    if not isinstance(emb, np.ndarray):
                                        emb = np.array(emb, dtype=np.float32)
                                    all_embeddings[normalized_name][emb_name] = emb
            except Exception as e:
                logging.warning(f"读取CSV文件 {emb_path} 失败: {e}")
                continue
        
        logging.info(f"  {emb_name}: 加载了 {len(set(original_to_normalized.values()))} 个唯一患者")
    
    return all_embeddings, original_to_normalized

print("load_embeddings函数定义完成！")


load_embeddings函数定义完成！


## 4. 加载标签数据（进展和死亡）


In [94]:
def load_labels_from_excel(excel_path: str) -> Tuple[Dict[str, int], Dict[str, int], Dict[str, str]]:
    """
    从Excel文件加载标签（进展/死亡）
    返回两个字典：进展标签和死亡标签，以及原始名称到标准化名称的映射
    """
    try:
        df = pd.read_excel(excel_path, engine='openpyxl')
        logging.info(f"从 {excel_path} 加载标签，列名: {df.columns.tolist()}")
        
        # 尝试找到患者名称列和标签列
        patient_col = None
        progress_col = None
        death_col = None
        
        for col in df.columns:
            col_lower = str(col).lower()
            col_str = str(col)
            if '患者' in col_str or 'patient' in col_str or '姓名' in col_str or 'name' in col_str:
                patient_col = col
            if '进展' in col_str:
                progress_col = col
            if '死亡' in col_str:
                death_col = col
        
        if patient_col is None:
            logging.warning(f"未找到患者列，使用第一列作为患者名")
            patient_col = df.columns[0]
        
        progress_labels = {}
        death_labels = {}
        original_to_normalized = {}
        
        for _, row in df.iterrows():
            original_patient_id = str(row[patient_col])
            patient_id = normalize_patient_name(original_patient_id)
            
            if not patient_id:
                continue
            
            # 保存映射
            original_to_normalized[original_patient_id] = patient_id
            
            # 处理进展标签
            if progress_col and progress_col in row:
                progress_val = row[progress_col]
                if not pd.isna(progress_val):
                    if isinstance(progress_val, (int, float)):
                        progress_labels[patient_id] = int(progress_val)
                    elif isinstance(progress_val, str):
                        if '进展' in progress_val or '1' in progress_val or 'yes' in progress_val.lower():
                            progress_labels[patient_id] = 1
                        else:
                            progress_labels[patient_id] = 0
                    else:
                        progress_labels[patient_id] = int(progress_val)
            
            # 处理死亡标签
            if death_col and death_col in row:
                death_val = row[death_col]
                if not pd.isna(death_val):
                    if isinstance(death_val, (int, float)):
                        death_labels[patient_id] = int(death_val)
                    elif isinstance(death_val, str):
                        if '死亡' in death_val or '1' in death_val or 'yes' in death_val.lower():
                            death_labels[patient_id] = 1
                        else:
                            death_labels[patient_id] = 0
                    else:
                        death_labels[patient_id] = int(death_val)
        
        logging.info(f"进展标签: {len(progress_labels)} 个，正样本: {sum(progress_labels.values())}")
        logging.info(f"死亡标签: {len(death_labels)} 个，正样本: {sum(death_labels.values())}")
        return progress_labels, death_labels, original_to_normalized
    
    except Exception as e:
        logging.error(f"加载标签失败: {e}")
        return {}, {}, {}

print("load_labels_from_excel函数定义完成！")


load_labels_from_excel函数定义完成！


## 5. 配置参数和加载数据


In [95]:
# ========== 配置参数 ==========
MODALITY = "ct"  # "ct" 或 "pet"
LABEL_TYPE = "progress"  # "progress" 或 "death"
EMBEDDING_DIR = "."
LABEL_FILE = "名单.xlsx"
OUTPUT_DIR = "fusion_results"

# 训练参数
HIDDEN_DIM = 128
NUM_LAYERS = 2
DROPOUT = 0.3
BATCH_SIZE = 32
NUM_EPOCHS = 50
LR = 0.001
TEST_SIZE = 0.2
VAL_SIZE = 0.2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"配置完成！")
print(f"  模态: {MODALITY.upper()}")
print(f"  标签类型: {LABEL_TYPE}")
print(f"  设备: {DEVICE}")


配置完成！
  模态: CT
  标签类型: progress
  设备: cuda


In [96]:
# ========== 定义embedding文件路径 ==========
# SAM-Med2D使用xlsx，其他使用CSV
if MODALITY.lower() == "ct":
    embedding_paths = {
        "sam_med2d": os.path.join(EMBEDDING_DIR, "embeddings_output_CT.xlsx"),  # 注意是复数形式embeddings
        "medclip": os.path.join(EMBEDDING_DIR, "medclip_embeddings", "medclip_embeddings_ct_vit.csv"),
        "radfm": os.path.join(EMBEDDING_DIR, "ct_embeddings_tokens.csv"),
    }
else:  # PET
    embedding_paths = {
        "sam_med2d": os.path.join(EMBEDDING_DIR, "embeddings_output.xlsx"),  # 注意是复数形式embeddings
        "medclip": os.path.join(EMBEDDING_DIR, "medclip_embeddings", "medclip_embeddings_vit_dropcol.csv"),
        "radfm": os.path.join(EMBEDDING_DIR, "pet_embedding_tokens_with_labels.csv"),
    }

print("定义的embedding文件路径:")
for emb_name, path in embedding_paths.items():
    print(f"  {emb_name}: {path}")

# 检查文件是否存在
existing_paths = {}
for emb_name, path in embedding_paths.items():
    if os.path.exists(path):
        existing_paths[emb_name] = path
    else:
        print(f"警告: 未找到 {emb_name} 文件: {path}")

if not existing_paths:
    print("错误: 未找到任何embedding文件")
    print("尝试查找的文件:")
    for emb_name, path in embedding_paths.items():
        print(f"  {emb_name}: {path}")
else:
    print(f"找到 {len(existing_paths)} 个embedding文件:")
    for emb_name, path in existing_paths.items():
        print(f"  {emb_name}: {path}")


定义的embedding文件路径:
  sam_med2d: .\embeddings_output_CT.xlsx
  medclip: .\medclip_embeddings\medclip_embeddings_ct_vit.csv
  radfm: .\ct_embeddings_tokens.csv
找到 3 个embedding文件:
  sam_med2d: .\embeddings_output_CT.xlsx
  medclip: .\medclip_embeddings\medclip_embeddings_ct_vit.csv
  radfm: .\ct_embeddings_tokens.csv


In [97]:
# ========== 加载embeddings ==========
all_embeddings, emb_name_mapping = load_embeddings(existing_paths)
print(f"\n总共加载了 {len(all_embeddings)} 个患者的embedding数据")

# 统计每个模态的患者数量
for emb_name in existing_paths.keys():
    count = sum(1 for pid, emb_dict in all_embeddings.items() if emb_name in emb_dict)
    print(f"  {emb_name}: {count} 个患者")
    
    # 显示一些患者名称示例
    sample_patients = [pid for pid, emb_dict in list(all_embeddings.items())[:3] if emb_name in emb_dict]
    if sample_patients:
        print(f"    患者名称示例: {sample_patients}")
    
    # 对于RadFM，显示一些示例信息
    if 'radfm' in emb_name.lower() and sample_patients:
        sample_pid = sample_patients[0]
        sample_emb = all_embeddings[sample_pid][emb_name]
        if hasattr(sample_emb, 'shape'):
            print(f"    embedding形状: {sample_emb.shape}, dtype: {sample_emb.dtype}")
        else:
            print(f"    embedding类型: {type(sample_emb)}")


2025-11-29 16:09:52,351 - INFO - 加载 sam_med2d from .\embeddings_output_CT.xlsx
2025-11-29 16:10:34,371 - INFO -   Excel文件列名: ['医院', '患者名称', 'file_name', 'embedding_dim', 'embedding_0', 'embedding_1', 'embedding_2', 'embedding_3', 'embedding_4', 'embedding_5']... (共16383列)
2025-11-29 16:10:34,373 - INFO -   使用患者名称列: 患者名称
2025-11-29 16:10:38,099 - INFO -   sam_med2d: 加载了 163 个唯一患者
2025-11-29 16:10:38,100 - INFO -   sam_med2d: 加载了 163 个唯一患者
2025-11-29 16:10:38,101 - INFO - 加载 medclip from .\medclip_embeddings\medclip_embeddings_ct_vit.csv
2025-11-29 16:10:38,111 - INFO -   CSV文件列名: ['patient_name', 'center', 'embedding_000', 'embedding_001', 'embedding_002', 'embedding_003', 'embedding_004', 'embedding_005', 'embedding_006', 'embedding_007']... (共514列)
2025-11-29 16:10:38,112 - INFO -   使用患者名称列: patient_name
2025-11-29 16:10:40,827 - INFO -   medclip: 加载了 163 个唯一患者
2025-11-29 16:10:40,828 - INFO - 加载 radfm from .\ct_embeddings_tokens.csv
2025-11-29 16:10:40,886 - INFO -   CSV文件列名: ['中心', 


总共加载了 163 个患者的embedding数据
  sam_med2d: 163 个患者
    患者名称示例: ['20211227zhoulinmei', '20160926fuzhongyou', '20170215ruanjieliu']
  medclip: 109 个患者
    患者名称示例: ['20211227zhoulinmei', '20160926fuzhongyou', '20170215ruanjieliu']
  radfm: 158 个患者
    患者名称示例: ['20160926fuzhongyou', '20170215ruanjieliu']
    embedding形状: (5120,), dtype: float32


In [98]:
# ========== 加载标签 ==========
progress_labels, death_labels, label_name_mapping = load_labels_from_excel(LABEL_FILE)

# 选择要使用的标签
if LABEL_TYPE == "progress":
    labels = progress_labels
    label_name = "进展"
else:
    labels = death_labels
    label_name = "死亡"

print(f"\n使用标签: {label_name}")
print(f"标签数量: {len(labels)}, 正样本: {sum(labels.values())}")


2025-11-29 16:10:45,173 - INFO - 从 名单.xlsx 加载标签，列名: ['医院', '姓名', '进展', '死亡']


2025-11-29 16:10:45,182 - INFO - 进展标签: 161 个，正样本: 87
2025-11-29 16:10:45,183 - INFO - 死亡标签: 161 个，正样本: 25



使用标签: 进展
标签数量: 161, 正样本: 87


In [99]:
# ========== 准备数据 ==========
# 使用所有患者，即使不匹配也保留
patient_data = {}
patient_labels = {}

# 获取所有患者ID（embedding和标签的并集）
all_patient_ids = set(all_embeddings.keys()) | set(labels.keys())

# 优先使用有embedding的患者，即使没有标签也保留（用默认标签0）
# 对于只有标签的患者，用零向量填充embedding
for patient_id in all_patient_ids:
    # 只要有embedding就保留
    if patient_id in all_embeddings:
        patient_data[patient_id] = all_embeddings[patient_id]
        # 如果有标签就用标签，没有就用0
        patient_labels[patient_id] = labels.get(patient_id, 0)
    elif patient_id in labels:
        # 只有标签没有embedding，创建空的embedding字典（后续会用零向量填充）
        patient_data[patient_id] = {}
        patient_labels[patient_id] = labels[patient_id]

# 诊断信息
matched_with_both = sum(1 for pid in patient_data.keys() if pid in all_embeddings and pid in labels)
only_embeddings = sum(1 for pid in patient_data.keys() if pid in all_embeddings and pid not in labels)
only_labels = sum(1 for pid in patient_data.keys() if pid not in all_embeddings and pid in labels)

print(f"\n数据匹配情况:")
print(f"  Embedding中的患者数: {len(all_embeddings)}")
print(f"  标签中的患者数: {len(labels)}")
print(f"  总患者数（并集）: {len(all_patient_ids)}")
print(f"  准备使用的患者数: {len(patient_data)}")
print(f"  同时有embedding和标签: {matched_with_both}")
print(f"  只有embedding（无标签，使用默认标签0）: {only_embeddings}")
print(f"  只有标签（无embedding，用零向量填充）: {only_labels}")

# 显示一些不匹配的患者示例
if only_labels > 0:
    only_label_patients = [pid for pid in patient_data.keys() if pid not in all_embeddings and pid in labels][:5]
    print(f"\n  只有标签的患者示例（前5个）: {only_label_patients}")

if only_embeddings > 0:
    only_emb_patients = [pid for pid in patient_data.keys() if pid in all_embeddings and pid not in labels][:5]
    print(f"  只有embedding的患者示例（前5个）: {only_emb_patients}")

if not patient_data:
    print("错误: 没有找到匹配的患者数据")
else:
    print(f"\n准备训练数据: {len(patient_data)} 个患者 ({label_name})")
    print(f"  有标签的患者: {matched_with_both + only_labels}")
    print(f"  有embedding的患者: {matched_with_both + only_embeddings}")



数据匹配情况:
  Embedding中的患者数: 163
  标签中的患者数: 161
  总患者数（并集）: 195
  准备使用的患者数: 195
  同时有embedding和标签: 129
  只有embedding（无标签，使用默认标签0）: 34
  只有标签（无embedding，用零向量填充）: 32

  只有标签的患者示例（前5个）: ['z507468', 'z560929', 'z824491', 'z632091', 'z681396']
  只有embedding的患者示例（前5个）: ['20180625yangshiyi', '20180828hexiaoxian', '20180808limengjie', '20170215ruanjieliu', '20191125zhangyehong']

准备训练数据: 195 个患者 (进展)
  有标签的患者: 161
  有embedding的患者: 163


In [100]:
# ========== 获取embedding维度 ==========
modality_keys = list(existing_paths.keys())
embedding_dims = {}
input_dims = []

# 从所有患者中找到每个模态的维度
for key in modality_keys:
    dim = None
    # 优先从有embedding的患者中找
    for patient_id, emb_dict in patient_data.items():
        if key in emb_dict and emb_dict[key] is not None:
            emb_array = emb_dict[key]
            if not isinstance(emb_array, np.ndarray):
                emb_array = np.array(emb_array)
            
            # 处理维度
            if emb_array.ndim == 0:
                dim = 1  # 标量
            elif emb_array.ndim == 1:
                dim = emb_array.shape[0]  # 一维数组
            else:
                dim = emb_array.size  # 多维数组，展平后的维度
            
            if dim is not None and dim > 0:
                break
    
    # 如果还是找不到，尝试从all_embeddings中找
    if dim is None:
        for patient_id, emb_dict in all_embeddings.items():
            if key in emb_dict and emb_dict[key] is not None:
                emb_array = emb_dict[key]
                if not isinstance(emb_array, np.ndarray):
                    emb_array = np.array(emb_array)
                
                if emb_array.ndim == 0:
                    dim = 1
                elif emb_array.ndim == 1:
                    dim = emb_array.shape[0]
                else:
                    dim = emb_array.size
                
                if dim is not None and dim > 0:
                    break
    
    if dim is None or dim == 0:
        dim = 1  # 默认维度
        print(f"警告: 无法确定 {key} 的维度，使用默认值 1")
    else:
        print(f"  {key}: {dim} 维")
    
    embedding_dims[key] = dim
    input_dims.append(dim)

print(f"\nEmbedding维度: {dict(zip(modality_keys, input_dims))}")


  sam_med2d: 16380 维
  medclip: 512 维
  radfm: 5120 维

Embedding维度: {'sam_med2d': 16380, 'medclip': 512, 'radfm': 5120}


## 6. 划分数据集


In [101]:
# 划分数据集
patient_ids = list(patient_data.keys())
train_ids, temp_ids = train_test_split(patient_ids, test_size=TEST_SIZE + VAL_SIZE, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=TEST_SIZE / (TEST_SIZE + VAL_SIZE), random_state=42)

train_data = {pid: patient_data[pid] for pid in train_ids}
val_data = {pid: patient_data[pid] for pid in val_ids}
test_data = {pid: patient_data[pid] for pid in test_ids}

train_labels = {pid: patient_labels[pid] for pid in train_ids}
val_labels = {pid: patient_labels[pid] for pid in val_ids}
test_labels = {pid: patient_labels[pid] for pid in test_ids}

print(f"训练集: {len(train_ids)} 个患者")
print(f"验证集: {len(val_ids)} 个患者")
print(f"测试集: {len(test_ids)} 个患者")


训练集: 117 个患者
验证集: 39 个患者
测试集: 39 个患者


In [102]:
# 创建数据集
train_dataset = MultiModalDataset(train_data, train_labels, modality_keys, embedding_dims)
val_dataset = MultiModalDataset(val_data, val_labels, modality_keys, embedding_dims)
test_dataset = MultiModalDataset(test_data, test_labels, modality_keys, embedding_dims)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("数据集创建完成！")


数据集创建完成！


## 7. 创建和训练模型


In [103]:
# 创建模型
device = torch.device(DEVICE)
model = MedFuseLSTM(
    input_dims=input_dims,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
).to(device)

print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
print(f"模型已移动到设备: {device}")


模型参数数量: 3090433
模型已移动到设备: cuda


In [104]:
# 定义训练函数
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    num_epochs: int = 50,
    lr: float = 0.001,
) -> Tuple[nn.Module, List[float], List[float]]:
    """训练模型"""
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    best_model_state = None
    
    for epoch in range(num_epochs):
        # 训练
        model.train()
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            embeddings = [e.to(device) for e in batch['embeddings']]
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            outputs = model(embeddings)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        # 验证
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                embeddings = [e.to(device) for e in batch['embeddings']]
                labels = batch['label'].to(device)
                
                outputs = model(embeddings)
                loss = criterion(outputs.squeeze(), labels)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        scheduler.step(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, train_losses, val_losses

print("训练函数定义完成！")


训练函数定义完成！


In [105]:
# 开始训练
model, train_losses, val_losses = train_model(
    model, train_loader, val_loader, device, NUM_EPOCHS, LR
)


Epoch 1/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 60.83it/s]
Epoch 1/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 234.48it/s]


Epoch 1: Train Loss=0.6997, Val Loss=0.6922


Epoch 2/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 100.16it/s]
Epoch 2/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 170.20it/s]


Epoch 2: Train Loss=0.6947, Val Loss=0.6962


Epoch 3/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 83.87it/s]
Epoch 3/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 198.56it/s]


Epoch 3: Train Loss=0.6914, Val Loss=0.6989


Epoch 4/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 72.21it/s]
Epoch 4/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 190.08it/s]


Epoch 4: Train Loss=0.6872, Val Loss=0.7105


Epoch 5/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 77.12it/s]
Epoch 5/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 186.83it/s]


Epoch 5: Train Loss=0.6867, Val Loss=0.7150


Epoch 6/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 72.59it/s]
Epoch 6/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.23it/s]


Epoch 6: Train Loss=0.6862, Val Loss=0.7189


Epoch 7/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 76.76it/s]
Epoch 7/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.63it/s]


Epoch 7: Train Loss=0.6848, Val Loss=0.7195


Epoch 8/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 72.72it/s]
Epoch 8/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 162.39it/s]


Epoch 8: Train Loss=0.6840, Val Loss=0.7177


Epoch 9/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 78.45it/s]
Epoch 9/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 166.13it/s]


Epoch 9: Train Loss=0.6789, Val Loss=0.7173


Epoch 10/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 78.67it/s]
Epoch 10/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.70it/s]


Epoch 10: Train Loss=0.6850, Val Loss=0.7154


Epoch 11/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.39it/s]
Epoch 11/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.59it/s]


Epoch 11: Train Loss=0.6717, Val Loss=0.7105


Epoch 12/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.91it/s]
Epoch 12/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 168.43it/s]


Epoch 12: Train Loss=0.6805, Val Loss=0.7163


Epoch 13/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.59it/s]
Epoch 13/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.49it/s]


Epoch 13: Train Loss=0.6773, Val Loss=0.7212


Epoch 14/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 74.21it/s]
Epoch 14/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 177.13it/s]


Epoch 14: Train Loss=0.6763, Val Loss=0.7169


Epoch 15/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.47it/s]
Epoch 15/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.54it/s]


Epoch 15: Train Loss=0.6679, Val Loss=0.7074


Epoch 16/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 79.03it/s]
Epoch 16/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 190.41it/s]


Epoch 16: Train Loss=0.6708, Val Loss=0.7057


Epoch 17/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 77.12it/s]
Epoch 17/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 164.02it/s]


Epoch 17: Train Loss=0.6702, Val Loss=0.7032


Epoch 18/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 76.94it/s]
Epoch 18/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 181.20it/s]


Epoch 18: Train Loss=0.6700, Val Loss=0.6988


Epoch 19/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 77.13it/s]
Epoch 19/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 181.18it/s]


Epoch 19: Train Loss=0.6724, Val Loss=0.6911


Epoch 20/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 80.26it/s]
Epoch 20/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 190.15it/s]


Epoch 20: Train Loss=0.6728, Val Loss=0.6881


Epoch 21/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 80.00it/s]
Epoch 21/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 190.26it/s]


Epoch 21: Train Loss=0.6598, Val Loss=0.6805


Epoch 22/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.63it/s]
Epoch 22/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 179.93it/s]


Epoch 22: Train Loss=0.6599, Val Loss=0.6751


Epoch 23/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 80.77it/s]
Epoch 23/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 172.17it/s]


Epoch 23: Train Loss=0.6545, Val Loss=0.6649


Epoch 24/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.33it/s]
Epoch 24/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 196.50it/s]


Epoch 24: Train Loss=0.6476, Val Loss=0.6513


Epoch 25/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 72.67it/s]
Epoch 25/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 186.57it/s]


Epoch 25: Train Loss=0.6517, Val Loss=0.6466


Epoch 26/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 79.85it/s]
Epoch 26/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 171.80it/s]


Epoch 26: Train Loss=0.6431, Val Loss=0.6523


Epoch 27/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 81.12it/s]
Epoch 27/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.50it/s]


Epoch 27: Train Loss=0.6373, Val Loss=0.6385


Epoch 28/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 83.83it/s]
Epoch 28/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 154.47it/s]


Epoch 28: Train Loss=0.6316, Val Loss=0.6212


Epoch 29/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 74.38it/s]
Epoch 29/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 177.69it/s]


Epoch 29: Train Loss=0.6264, Val Loss=0.6153


Epoch 30/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 82.58it/s]
Epoch 30/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 143.92it/s]


Epoch 30: Train Loss=0.6193, Val Loss=0.6104


Epoch 31/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 75.03it/s]
Epoch 31/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 158.34it/s]


Epoch 31: Train Loss=0.6092, Val Loss=0.5928


Epoch 32/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 56.19it/s]
Epoch 32/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 105.02it/s]


Epoch 32: Train Loss=0.6179, Val Loss=0.5778


Epoch 33/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 58.15it/s]
Epoch 33/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 141.87it/s]


Epoch 33: Train Loss=0.6005, Val Loss=0.5617


Epoch 34/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 72.69it/s]
Epoch 34/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.53it/s]


Epoch 34: Train Loss=0.5992, Val Loss=0.5634


Epoch 35/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 79.92it/s]
Epoch 35/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 189.79it/s]


Epoch 35: Train Loss=0.5969, Val Loss=0.5795


Epoch 36/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 80.56it/s]
Epoch 36/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.54it/s]


Epoch 36: Train Loss=0.5957, Val Loss=0.5500


Epoch 37/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 73.31it/s]
Epoch 37/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.28it/s]


Epoch 37: Train Loss=0.6024, Val Loss=0.5317


Epoch 38/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 71.09it/s]
Epoch 38/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 159.70it/s]


Epoch 38: Train Loss=0.5706, Val Loss=0.5453


Epoch 39/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 72.41it/s]
Epoch 39/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 179.15it/s]


Epoch 39: Train Loss=0.5850, Val Loss=0.5287


Epoch 40/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 63.31it/s]
Epoch 40/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.54it/s]


Epoch 40: Train Loss=0.5888, Val Loss=0.5308


Epoch 41/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 67.83it/s]
Epoch 41/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 159.66it/s]


Epoch 41: Train Loss=0.5668, Val Loss=0.5309


Epoch 42/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 69.42it/s]
Epoch 42/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 159.67it/s]


Epoch 42: Train Loss=0.5494, Val Loss=0.5306


Epoch 43/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 69.16it/s]
Epoch 43/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.68it/s]


Epoch 43: Train Loss=0.5639, Val Loss=0.5309


Epoch 44/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 66.92it/s]
Epoch 44/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 153.45it/s]


Epoch 44: Train Loss=0.5526, Val Loss=0.5412


Epoch 45/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 70.66it/s]
Epoch 45/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 155.85it/s]


Epoch 45: Train Loss=0.5753, Val Loss=0.5481


Epoch 46/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 71.96it/s]
Epoch 46/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 165.03it/s]


Epoch 46: Train Loss=0.5396, Val Loss=0.5354


Epoch 47/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 75.03it/s]
Epoch 47/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 186.24it/s]


Epoch 47: Train Loss=0.5543, Val Loss=0.5314


Epoch 48/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 77.16it/s]
Epoch 48/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 173.38it/s]


Epoch 48: Train Loss=0.5378, Val Loss=0.5331


Epoch 49/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 79.33it/s]
Epoch 49/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 189.95it/s]


Epoch 49: Train Loss=0.5393, Val Loss=0.5316


Epoch 50/50 [Train]: 100%|██████████| 4/4 [00:00<00:00, 69.52it/s]
Epoch 50/50 [Val]: 100%|██████████| 2/2 [00:00<00:00, 189.92it/s]

Epoch 50: Train Loss=0.5320, Val Loss=0.5421





## 8. 评估模型


In [106]:
# 定义评估函数
def evaluate_model(
    model: nn.Module,
    test_loader: DataLoader,
    device: torch.device,
) -> Tuple[float, np.ndarray, np.ndarray]:
    """评估模型，返回ROC AUC、预测概率和真实标签"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            embeddings = [e.to(device) for e in batch['embeddings']]
            labels = batch['label'].cpu().numpy()
            
            outputs = model(embeddings)
            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            
            all_preds.extend(probs)
            all_labels.extend(labels)
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    if len(np.unique(all_labels)) < 2:
        print("警告: 测试集中只有一个类别，无法计算ROC AUC")
        return 0.0, all_preds, all_labels
    
    auc = roc_auc_score(all_labels, all_preds)
    return auc, all_preds, all_labels

print("评估函数定义完成！")


评估函数定义完成！


In [107]:
# 评估模型（包括训练集、验证集和测试集）
def evaluate_all_sets(model, train_loader, val_loader, test_loader, device):
    """评估所有数据集"""
    train_auc, train_preds, train_labels = evaluate_model(model, train_loader, device)
    val_auc, val_preds, val_labels = evaluate_model(model, val_loader, device)
    test_auc, test_preds, test_labels = evaluate_model(model, test_loader, device)
    
    return {
        'train': (train_auc, train_preds, train_labels),
        'val': (val_auc, val_preds, val_labels),
        'test': (test_auc, test_preds, test_labels),
    }

# 评估所有数据集
results_all = evaluate_all_sets(model, train_loader, val_loader, test_loader, device)

train_auc, train_preds, train_labels = results_all['train']
val_auc, val_preds, val_labels = results_all['val']
test_auc, test_preds, test_labels = results_all['test']

print(f"\n{'='*60}")
print(f"评估结果 ({MODALITY.upper()} - {label_name})")
print(f"{'='*60}")
print(f"训练集 ROC AUC: {train_auc:.4f}")
print(f"验证集 ROC AUC: {val_auc:.4f}")
print(f"测试集 ROC AUC: {test_auc:.4f}")
print(f"{'='*60}")


Evaluating: 100%|██████████| 4/4 [00:00<00:00, 104.29it/s]
Evaluating: 100%|██████████| 2/2 [00:00<00:00, 128.89it/s]
Evaluating: 100%|██████████| 2/2 [00:00<00:00, 128.91it/s]


评估结果 (CT - 进展)
训练集 ROC AUC: 0.8033
验证集 ROC AUC: 0.7315
测试集 ROC AUC: 0.6176





## 9. 保存结果


In [108]:
# 创建输出目录
output_dir = Path(OUTPUT_DIR) / MODALITY.upper() / LABEL_TYPE
output_dir.mkdir(parents=True, exist_ok=True)

# 保存结果
results = {
    "modality": MODALITY.upper(),
    "label_type": LABEL_TYPE,
    "train_auc": float(train_auc),
    "val_auc": float(val_auc),
    "test_auc": float(test_auc),
    "num_train": len(train_ids),
    "num_val": len(val_ids),
    "num_test": len(test_ids),
}

with open(output_dir / f"results_{LABEL_TYPE}.json", "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

# 保存所有数据集的预测结果
# 训练集
train_results_df = pd.DataFrame({
    "patient_id": train_ids,
    "true_label": train_labels,
    "pred_prob": train_preds,
})
train_results_df.to_csv(output_dir / f"train_predictions_{LABEL_TYPE}.csv", index=False)

# 验证集
val_results_df = pd.DataFrame({
    "patient_id": val_ids,
    "true_label": val_labels,
    "pred_prob": val_preds,
})
val_results_df.to_csv(output_dir / f"val_predictions_{LABEL_TYPE}.csv", index=False)

# 测试集
test_results_df = pd.DataFrame({
    "patient_id": test_ids,
    "true_label": test_labels,
    "pred_prob": test_preds,
})
test_results_df.to_csv(output_dir / f"test_predictions_{LABEL_TYPE}.csv", index=False)

print(f"\n结果已保存到 {output_dir}")
print(f"  - results_{LABEL_TYPE}.json")
print(f"  - train_predictions_{LABEL_TYPE}.csv")
print(f"  - val_predictions_{LABEL_TYPE}.csv")
print(f"  - test_predictions_{LABEL_TYPE}.csv")



结果已保存到 fusion_results\CT\progress
  - results_progress.json
  - train_predictions_progress.csv
  - val_predictions_progress.csv
  - test_predictions_progress.csv


## 10. 运行所有任务（CT/PET × 进展/死亡）


In [None]:
# ========== 运行所有任务 ==========
# 定义所有任务
ALL_TASKS = [
    {"modality": "ct", "label_type": "progress", "label_name": "进展"},
    {"modality": "ct", "label_type": "death", "label_name": "死亡"},
    {"modality": "pet", "label_type": "progress", "label_name": "进展"},
    {"modality": "pet", "label_type": "death", "label_name": "死亡"},
]

# 存储所有任务的结果
all_results_summary = []

print("="*60)
print("开始运行所有任务")
print("="*60)

for task_idx, task in enumerate(ALL_TASKS, 1):
    print(f"\n{'='*60}")
    print(f"任务 {task_idx}/4: {task['modality'].upper()} - {task['label_name']}")
    print(f"{'='*60}")
    
    # 设置当前任务参数
    MODALITY = task["modality"]
    LABEL_TYPE = task["label_type"]
    label_name = task["label_name"]
    
    # 定义embedding文件路径（SAM-Med2D使用xlsx，其他使用CSV）
    if MODALITY.lower() == "ct":
        embedding_paths = {
            "sam_med2d": os.path.join(EMBEDDING_DIR, "embeddings_output_CT.xlsx"),  # 注意是复数形式embeddings
            "medclip": os.path.join(EMBEDDING_DIR, "medclip_embeddings", "medclip_embeddings_ct_vit.csv"),
            "radfm": os.path.join(EMBEDDING_DIR, "ct_embeddings_tokens.csv"),
        }
    else:  # PET
        embedding_paths = {
            "sam_med2d": os.path.join(EMBEDDING_DIR, "embeddings_output.xlsx"),  # 注意是复数形式embeddings
            "medclip": os.path.join(EMBEDDING_DIR, "medclip_embeddings", "medclip_embeddings_vit_dropcol.csv"),
            "radfm": os.path.join(EMBEDDING_DIR, "pet_embedding_tokens_with_labels.csv"),
        }
    
    print(f"定义的embedding文件路径 ({MODALITY.upper()}):")
    for emb_name, path in embedding_paths.items():
        print(f"  {emb_name}: {path}")
    
    # 检查文件是否存在
    existing_paths = {}
    for emb_name, path in embedding_paths.items():
        if os.path.exists(path):
            existing_paths[emb_name] = path
        else:
            print(f"警告: 未找到 {emb_name} 文件: {path}")
    
    if not existing_paths:
        print(f"警告: 未找到任何embedding文件，跳过此任务")
        continue
    
    print(f"找到 {len(existing_paths)} 个embedding文件")
    
    # 加载embeddings（注意返回的是元组）
    all_embeddings, emb_name_mapping = load_embeddings(existing_paths)
    print(f"加载了 {len(all_embeddings)} 个患者的embedding数据")
    
    # 选择标签
    if LABEL_TYPE == "progress":
        labels = progress_labels
    else:
        labels = death_labels
    
    # 准备数据
    patient_data = {}
    patient_labels = {}
    
    for patient_id, emb_dict in all_embeddings.items():
        if patient_id in labels:
            patient_data[patient_id] = emb_dict
            patient_labels[patient_id] = labels[patient_id]
    
    if not patient_data:
        print(f"警告: 没有找到匹配的患者数据，跳过此任务")
        continue
    
    print(f"准备训练数据: {len(patient_data)} 个患者 ({label_name})")
    
    # 获取embedding维度
    modality_keys = list(existing_paths.keys())
    embedding_dims = {}
    input_dims = []
    
    for key in modality_keys:
        dim = None
        for patient_id, emb_dict in patient_data.items():
            if key in emb_dict and emb_dict[key] is not None:
                emb_array = emb_dict[key]
                if not isinstance(emb_array, np.ndarray):
                    emb_array = np.array(emb_array)
                if emb_array.ndim > 1:
                    dim = emb_array.size
                else:
                    dim = emb_array.shape[0]
                break
        
        if dim is None:
            dim = 1
        embedding_dims[key] = dim
        input_dims.append(dim)
    
    print(f"Embedding维度: {dict(zip(modality_keys, input_dims))}")
    
    # 划分数据集
    patient_ids = list(patient_data.keys())
    train_ids, temp_ids = train_test_split(patient_ids, test_size=TEST_SIZE + VAL_SIZE, random_state=42)
    val_ids, test_ids = train_test_split(temp_ids, test_size=TEST_SIZE / (TEST_SIZE + VAL_SIZE), random_state=42)
    
    train_data = {pid: patient_data[pid] for pid in train_ids}
    val_data = {pid: patient_data[pid] for pid in val_ids}
    test_data = {pid: patient_data[pid] for pid in test_ids}
    
    train_labels_dict = {pid: patient_labels[pid] for pid in train_ids}
    val_labels_dict = {pid: patient_labels[pid] for pid in val_ids}
    test_labels_dict = {pid: patient_labels[pid] for pid in test_ids}
    
    # 创建数据集
    train_dataset = MultiModalDataset(train_data, train_labels_dict, modality_keys, embedding_dims)
    val_dataset = MultiModalDataset(val_data, val_labels_dict, modality_keys, embedding_dims)
    test_dataset = MultiModalDataset(test_data, test_labels_dict, modality_keys, embedding_dims)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # 创建模型
    device = torch.device(DEVICE)
    model = MedFuseLSTM(
        input_dims=input_dims,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT,
    ).to(device)
    
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
    
    # 训练
    print("开始训练...")
    model, train_losses, val_losses = train_model(
        model, train_loader, val_loader, device, NUM_EPOCHS, LR
    )
    
    # 评估
    print("开始评估...")
    results_all = evaluate_all_sets(model, train_loader, val_loader, test_loader, device)
    
    train_auc, train_preds, train_labels_arr = results_all['train']
    val_auc, val_preds, val_labels_arr = results_all['val']
    test_auc, test_preds, test_labels_arr = results_all['test']
    
    print(f"\n评估结果:")
    print(f"  训练集 ROC AUC: {train_auc:.4f}")
    print(f"  验证集 ROC AUC: {val_auc:.4f}")
    print(f"  测试集 ROC AUC: {test_auc:.4f}")
    
    # 保存结果
    output_dir = Path(OUTPUT_DIR) / MODALITY.upper() / LABEL_TYPE
    output_dir.mkdir(parents=True, exist_ok=True)
    
    results = {
        "modality": MODALITY.upper(),
        "label_type": LABEL_TYPE,
        "train_auc": float(train_auc),
        "val_auc": float(val_auc),
        "test_auc": float(test_auc),
        "num_train": len(train_ids),
        "num_val": len(val_ids),
        "num_test": len(test_ids),
    }
    
    with open(output_dir / f"results_{LABEL_TYPE}.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    # 保存所有数据集的预测结果
    train_results_df = pd.DataFrame({
        "patient_id": train_ids,
        "true_label": train_labels_arr,
        "pred_prob": train_preds,
    })
    train_results_df.to_csv(output_dir / f"train_predictions_{LABEL_TYPE}.csv", index=False)
    
    val_results_df = pd.DataFrame({
        "patient_id": val_ids,
        "true_label": val_labels_arr,
        "pred_prob": val_preds,
    })
    val_results_df.to_csv(output_dir / f"val_predictions_{LABEL_TYPE}.csv", index=False)
    
    test_results_df = pd.DataFrame({
        "patient_id": test_ids,
        "true_label": test_labels_arr,
        "pred_prob": test_preds,
    })
    test_results_df.to_csv(output_dir / f"test_predictions_{LABEL_TYPE}.csv", index=False)
    
    print(f"结果已保存到 {output_dir}")
    
    # 记录到总结
    all_results_summary.append({
        "modality": MODALITY.upper(),
        "label_type": LABEL_TYPE,
        "label_name": label_name,
        "train_auc": float(train_auc),
        "val_auc": float(val_auc),
        "test_auc": float(test_auc),
    })

print(f"\n{'='*60}")
print("所有任务完成！")
print(f"{'='*60}")

# 打印总结
print("\n结果总结:")
print("-" * 60)
for result in all_results_summary:
    print(f"{result['modality']} - {result['label_name']}:")
    print(f"  训练集 AUC: {result['train_auc']:.4f}")
    print(f"  验证集 AUC: {result['val_auc']:.4f}")
    print(f"  测试集 AUC: {result['test_auc']:.4f}")
    print()

# 保存总结
summary_df = pd.DataFrame(all_results_summary)
summary_df.to_csv(Path(OUTPUT_DIR) / "all_results_summary.csv", index=False)
print(f"总结已保存到 {Path(OUTPUT_DIR) / 'all_results_summary.csv'}")


2025-11-29 16:13:30,045 - INFO - 加载 sam_med2d from .\embeddings_output_CT.xlsx


开始运行所有任务

任务 1/4: CT - 进展
定义的embedding文件路径 (CT):
  sam_med2d: .\embeddings_output_CT.xlsx
  medclip: .\medclip_embeddings\medclip_embeddings_ct_vit.csv
  radfm: .\ct_embeddings_tokens.csv
找到 3 个embedding文件


2025-11-29 16:14:12,734 - INFO -   Excel文件列名: ['医院', '患者名称', 'file_name', 'embedding_dim', 'embedding_0', 'embedding_1', 'embedding_2', 'embedding_3', 'embedding_4', 'embedding_5']... (共16383列)
2025-11-29 16:14:12,735 - INFO -   使用患者名称列: 患者名称
2025-11-29 16:14:16,887 - INFO -   sam_med2d: 加载了 163 个唯一患者
2025-11-29 16:14:16,888 - INFO -   sam_med2d: 加载了 163 个唯一患者
2025-11-29 16:14:16,889 - INFO - 加载 medclip from .\medclip_embeddings\medclip_embeddings_ct_vit.csv
2025-11-29 16:14:16,900 - INFO -   CSV文件列名: ['patient_name', 'center', 'embedding_000', 'embedding_001', 'embedding_002', 'embedding_003', 'embedding_004', 'embedding_005', 'embedding_006', 'embedding_007']... (共514列)
2025-11-29 16:14:16,902 - INFO -   使用患者名称列: patient_name
2025-11-29 16:14:19,741 - INFO -   medclip: 加载了 163 个唯一患者
2025-11-29 16:14:19,742 - INFO - 加载 radfm from .\ct_embeddings_tokens.csv
2025-11-29 16:14:19,800 - INFO -   CSV文件列名: ['中心', '患者名', 'file_key', 'feature_0', 'feature_1', 'feature_2', 'feature_3', 'feature

加载了 2 个患者的embedding数据


AttributeError: 'tuple' object has no attribute 'items'