# 单模型多模态可视化

In [301]:
import os
import copy

from pathlib import Path # 路径操作
import pandas as pd
import nibabel as nib
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, IntSlider
from IPython.display import display, clear_output
from matplotlib.colors import ListedColormap, BoundaryNorm
import tracemalloc

# 配置Jupyter友好显示

%matplotlib inline
plt.ioff()  # 关闭交互模式避免重复绘图

<contextlib.ExitStack at 0x748f2c810880>

In [302]:
def convert_path(path: str) -> str:
    """
    将Windows路径转换为WSL/Linux路径
    支持所有盘符转换 (D: -> /mnt/d, E: -> /mnt/e)
    自动标准化路径格式
    """
    # 将路径转换为Path对象自动处理不同OS的分隔符
    p = str(Path(path)) 
    
    # 处理盘符转换 (如 D: -> /mnt/d)
    if ':' in p:  # 仅处理包含盘符的路径
        drive, rest = p.split(':')
        # print(rest)
        rest = str(rest).replace('\\', '/').split('/')  # 替换反斜杠并分割路径
        # print(rest)
        linux_path = f"/mnt/{drive.lower()}/{'/'.join(rest)}"
        p = Path(linux_path)
        
    return str(p) # 返回绝对路径

## 创建颜色映射

In [303]:
COLOR_MAP = ListedColormap([(0,0,0), (1,0,0), (0,1,0), (0,0,1)], name='custom_discrete', N=4)
BOUNDARIES = [0, 1, 2, 3, 4]
NORM = BoundaryNorm(BOUNDARIES, COLOR_MAP.N)

## 数据集分析

In [None]:

import os
import csv
from pathlib import Path
from tqdm import tqdm

def generate_csv(dirInput, csvPath, rowNameList):
    """生成CSV数据
    Args:
        dirInput: 输入路径
        csvPath: 输出CSV文件路径
        rowNameList: 列名列表
    Returns:
        data: 包含目录结构的列表，每个元素是一个子列表，包含子目录名、绝对路径、文件名和文件绝对路径
    """
    dirInput = Path(dirInput)
    dataList = []
    
    # 获取所有子目录
    subdirs = [d for d in dirInput.iterdir() if d.is_dir()]
    
    # 主进度条
    for subdir in tqdm(subdirs, desc="Processing directories", unit="dir"):
        subdir_name = subdir.name
        subdir_abs_path = str(subdir.resolve())
        
        # 获取当前子目录下的所有文件
        files = [f for f in subdir.iterdir() if f.is_file()]
        
        # 子目录文件处理进度条
        for file_item in tqdm(files, desc=f"Files in {subdir_name}", unit="file", leave=False):
            file_name = file_item.name
            file_abs_path = str(file_item.resolve())
            
            # 添加到数据列表
            dataList.append([
                subdir_name,
                subdir_abs_path,
                file_name,
                file_abs_path
            ])
            
        # 写入CSV文件
    with open(csvPath, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        # 写入表头
        writer.writerow(rowNameList)
        # 写入数据
        writer.writerows(dataList)
    
    print(f"成功生成CSV文件：{csvPath}")

# 添加modal列
def add_modal_type(filename):
    return filename.split('.')[0].split('_')[-1]
def add_patient_id(dirName):
    return dirName.split('_')[-1]  # 假设ID在第二个位置

def process_pd(df: pd.DataFrame):
    # 添加modal列和patient_id列
    ## 添加PatientID列
    df["PatientID"] = df["subDirName"].apply(add_patient_id) 
    ## 添加ModalType列
    df["ModalType"] = df["fileName"].apply(add_modal_type)   
    return df

# # 生成整个数据集的CSV文件
# generate_csv(
#     dirInput=convert_path(r"/root/workspace/DCLA-UNet/data/brats21/BraTS2021_Training_Data"),
#     csvPath='./datasets_dir_anasys.csv',
#     rowNameList=['subDirName', 'absPathOfSubDir', 'fileName', 'absPathOfFile']
# )

# # 生成模型预测的CSV文件
# generate_csv(
#     dirInput=convert_path(r"D:\results\2_消融实验结果\【final】DCLA_UNet\outputs\DCLA_UNet_final_20250601110011"),
#     csvPath='./DCLA_UNet_final_pred_dir_anasys.csv',
#     rowNameList=['subDirName', 'absPathOfSubDir', 'fileName', 'absPathOfFile']
# )

model_config = {
    # "UNet3D":               convert_path(r"D:\results\1_对比网络结果\output\UNet3D_20250528111955"),
    # "AttUNet3D":            convert_path(r"D:\results\1_对比网络结果\output\AttUNet3D_20250528112832"),
    # "DCLA_UNet_final":      convert_path(r"D:\results\2_消融实验结果\【final】DCLA_UNet\outputs\DCLA_UNet_final_20250601110011"),
    "BaseLine_S_SLK_final": convert_path(r"D:\results\2_消融实验结果\【final】DCLA_UNet\outputs\BaseLine_S_SLK_final_20250601111332"),
    "BaseLine_S_DCLA_SLK_final":    convert_path(r"D:\results\2_消融实验结果\【final】DCLA_UNet\outputs\BaseLine_S_DCLA_SLK_final_20250601112220"),
    "BaseLine_S_MSF_final":         convert_path(r"D:\results\2_消融实验结果\【final】DCLA_UNet\outputs\BaseLine_S_MSF_final_20250601113110"),
    "BaseLine_S_DCLA_MSF_final":    convert_path(r"D:\results\2_消融实验结果\【final】DCLA_UNet\outputs\BaseLine_S_DCLA_MSF_final_20250601114225")
}

def multi_generate_csv(model_config: dict):
    """
    批量生成模型预测的CSV文件
    Args:
        model_config: 包含模型名称和对应路径的字典
    """
    for model_name, dirInput in model_config.items():
        csvPath = f'./csv/{model_name}_pred_dir_anasys.csv'
        Path(csvPath).parent.mkdir(parents=True, exist_ok=True)  # 确保目录存在
        print(f"正在生成 {model_name} 的CSV文件...")
        generate_csv(
            dirInput=dirInput,
            csvPath=csvPath,
            rowNameList=['subDirName', 'absPathOfSubDir', 'fileName', 'absPathOfFile']
        )
        
multi_generate_csv(model_config)


# # 读取生成的CSV文件
# datasets_df = pd.read_csv('./datasets_dir_anasys.csv') # 获取整个数据集的CSV
# pred_df = pd.read_csv('./DCLA_UNet_final_pred_dir_anasys.csv')         # 获取测试数据集（预测）的CSV

# # 添加modal列和patient_id列
# ## 添加PatientID列
# datasets_df["PatientID"] = datasets_df["subDirName"].apply(add_patient_id) 
# pred_df["PatientID"] = pred_df["subDirName"].apply(add_patient_id)

# ## 添加ModalType列
# datasets_df["ModalType"] = datasets_df["fileName"].apply(add_modal_type)     
# pred_df["ModalType"] = pred_df["fileName"].apply(add_modal_type)


# # 筛选出原始数据集中对应病例的原始数据
# pred_mask_pd = pred_df[pred_df["ModalType"] == "pred"]
# gt_mask_pd = pred_df[pred_df["ModalType"] == "mask"]

# mask_ids = set(pred_mask_pd["PatientID"].unique())

# ## 从原始数据集中筛选出对应pred的t1数据
# conditions = datasets_df["PatientID"].isin(mask_ids) & (datasets_df["ModalType"] == "t1")

# ## 在原始数据集中添加setOfDatasets列，并打上标签test
# datasets_df['setOfDatasets'] = None
# datasets_df.loc[conditions, "setOfDatasets"] = "test"

# test_orginal_pd = datasets_df[datasets_df['setOfDatasets']== "test"]

# # 检测是否有重复的PatientID
# set(test_orginal_pd["PatientID"].unique()) == set(pred_mask_pd["PatientID"].unique())


# # 测试获取某个病例的原始数据路径
# ids = "01128"
# original_path = test_orginal_pd.loc[test_orginal_pd["PatientID"]==ids, "absPathOfFile"].item()

# pred_path = pred_mask_pd.loc[pred_mask_pd["PatientID"]==ids, "absPathOfFile"].item()

# gt_path = gt_mask_pd.loc[gt_mask_pd["PatientID"]==ids, "absPathOfFile"].item()


# print(f"Original Path: {original_path}")
# print(f"Predicted Path: {pred_path}")
# print(f"Ground Truth Path: {gt_path}")

# path

Processing directories:   0%|          | 0/1251 [00:00<?, ?dir/s]

Processing directories: 100%|██████████| 1251/1251 [00:06<00:00, 181.52dir/s]


成功生成CSV文件：./datasets_dir_anasys.csv


Processing directories:   4%|▍         | 5/126 [00:01<00:23,  5.15dir/s]

In [None]:
datasets_df

Unnamed: 0,subDirName,absPathOfSubDir,fileName,absPathOfFile,PatientID,ModalType,setOfDatasets
0,BraTS2021_01128,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_01128_t2.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,01128,t2,
1,BraTS2021_01128,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_01128_seg.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,01128,seg,
2,BraTS2021_01128,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_01128_flair.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,01128,flair,
3,BraTS2021_01128,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_01128_t1.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,01128,t1,test
4,BraTS2021_01128,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_01128_t1ce.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,01128,t1ce,
...,...,...,...,...,...,...,...
6639,BraTS2021_00831,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_00831_t2.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,00831,t2,
6640,BraTS2021_00831,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_00831_flair.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,00831,flair,
6641,BraTS2021_00831,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_00831_t1.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,00831,t1,
6642,BraTS2021_00831,/root/data/BraTS21_original_kaggle/BraTS2021_T...,BraTS2021_00831_seg.nii.gz,/root/data/BraTS21_original_kaggle/BraTS2021_T...,00831,seg,


In [None]:
def plotter(original_path, pred_path, gt_path, slice_index=80, axis=0):
    orginal_img = nib.load(original_path).get_fdata().transpose(2, 0, 1)  # 获取原始图像数据
    pred_img = nib.load(pred_path).get_fdata().transpose(2, 0, 1)  # 获取预测图像数据
    gt_img = nib.load(gt_path).get_fdata().transpose(2, 0, 1)  # 获取Ground Truth图像数据
    
    # 获取原始图像和预测图像的形状
    if orginal_img.shape != pred_img.shape or orginal_img.shape != gt_img.shape:
        raise ValueError("Original and predicted images must have the same shape.")
    if axis not in [0, 1, 2]:
        raise ValueError("Axis must be 0, 1, or 2.")
    if slice_index < 0 or slice_index >= orginal_img.shape[axis]:
        raise ValueError(f"Slice index must be in range [0, {orginal_img.shape[axis] - 1}].")
    
    # 定义切片函数
    slicer = {
        0: lambda img, idx: img[idx, :, :],  # 沿着第一个轴切片
        1: lambda img, idx: img[:, idx, :],  # 沿着第二个轴切片
        2: lambda img, idx: img[:, :, idx]   # 沿着第三个轴切片
    }
    
    original_slice = slicer[axis](orginal_img, slice_index)  # 获取指定切片的原始图像
    pred_slice = slicer[axis](gt_img, slice_index)  # 获取指定切片的预测图像
    gt_slice = slicer[axis](gt_img, slice_index)  # 获取指定切片的Ground Truth图像
    
    et_data = np.zeros_like(original_slice, dtype=np.int8)  # 创建一个与data相同形状的全零数组
    tc_data = np.zeros_like(original_slice, dtype=np.int8)  # 创建一个与data相同形状的全零数组
    wt_data = np.zeros_like(original_slice, dtype=np.int8)  # 创建一个与data相同形状的全零数组
    
    
    pred_et_slice = np.where(pred_slice == 3, 1, et_data) # 获取预测的增强肿瘤切片
    pred_tc_slice = np.where((pred_slice == 3) | (pred_slice == 1), 1, tc_data) # 获取预测的肿瘤核心切片
    pred_wt_slice = np.where((pred_slice == 3) | (pred_slice == 2) | (pred_slice == 1), 1, wt_data) # 获取预测的全肿瘤切片
    
    gt_et_slice = np.where(gt_slice == 3, 1, et_data) # 获取Ground Truth的增强肿瘤切片
    gt_tc_slice = np.where((gt_slice == 3) | (gt_slice == 1), 1, tc_data) # 获取Ground Truth的肿瘤核心切片
    gt_wt_slice = np.where((gt_slice == 3) | (gt_slice == 2) | (gt_slice == 1), 1, wt_data) # 获取Ground Truth的全肿瘤切片
    
    # 绘制映射分割图    
    fig, axes = plt.subplots(1, 5, figsize=(20, 8))
    axes[0].imshow(original_slice, cmap='gray')
    original_gt = np.ma.masked_where(gt_slice == 0, gt_slice)  # 将Ground Truth中为0的部分遮罩掉
    axes[0].imshow(original_gt, cmap=COLOR_MAP, norm=NORM, alpha=0.5)  # 显示Ground Truth
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')
    
    axes[1].imshow(original_slice, cmap='gray')
    original_pred = np.ma.masked_where(pred_slice == 0, pred_slice)  # 将pred_slice中为0的部分遮罩掉
    axes[1].imshow(original_pred, cmap=COLOR_MAP, norm=NORM, alpha=0.5)  # 显示预测结果
    axes[1].set_title('Prediction')
    axes[1].axis('off')
    plt.tight_layout()
    
    # axes[2].imshow(original_slice, cmap='gray')
    # original_et = np.ma.masked_where(pred_et_slice == 0, pred_et_slice)  # 将Ground Truth的增强肿瘤部分遮罩掉
    axes[2].imshow(pred_et_slice, cmap='gray', alpha=0.9)  # 显示预测的增强肿瘤
    axes[2].set_title('Enhanced Tumor Prediction')
    axes[2].axis('off')
    plt.tight_layout()
    
    # axes[3].imshow(original_slice, cmap='gray')
    # original_tc = np.ma.masked_where(pred_tc_slice == 0, pred_tc_slice)
    axes[3].imshow(pred_tc_slice, cmap='gray', alpha=0.9)
    axes[3].set_title('Tumor Core Prediction')
    axes[3].axis('off')
    plt.tight_layout()
     
    # axes[4].imshow(original_slice, cmap='gray')
    # original_wt = np.ma.masked_where(pred_wt_slice == 0, pred_wt_slice)
    axes[4].imshow(pred_wt_slice, cmap='gray', alpha=0.9)
    axes[4].set_title('Whole Tumor Prediction')
    axes[4].axis('off')
    plt.tight_layout()
    plt.show() 
    

            
    
    
    
def mul_plotter(original_path_dict, pred_path_dict, gt_path_dict, slice_index=80, axis=0):


    # 交互式封装
def interactive_plot(original_path, pred_path, gt_path):
    def update(slice_index=80, axis=0):
        plotter(original_path, pred_path, gt_path, slice_index=slice_index, axis=axis)
    
    # 动态获取最大切片数
    img_shape = nib.load(original_path).get_fdata().transpose(2, 0, 1).shape
    max_slices = {0: img_shape[0]-1, 1: img_shape[1]-1, 2: img_shape[2]-1}
    
    interact(update, 
             slice_index=IntSlider(min=0, max=max_slices[0], value=80),
             axis=IntSlider(min=0, max=2, value=0))


interactive_plot(original_path, pred_path, gt_path)


interactive(children=(IntSlider(value=80, description='slice_index', max=154), IntSlider(value=0, description=…

In [None]:
import os
import pandas as pd
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, Dropdown
from matplotlib.colors import ListedColormap, BoundaryNorm
from pathlib import Path

# 配置可视化参数
COLOR_MAP = ListedColormap([(0,0,0), (1,0,0), (0,1,0), (0,0,1)], name='custom_discrete', N=4)
BOUNDARIES = [0, 1, 2, 3, 4]
NORM = BoundaryNorm(BOUNDARIES, COLOR_MAP.N)
%matplotlib inline

class MultiModelVisualizer:
    def __init__(self, model_config):
        """
        初始化多模型可视化器
        Args:
            model_config: 字典 {模型名: 模型预测路径}
        """
        self.models = model_config
        self.current_model = next(iter(model_config))
        self.current_patient = None
        
    def load_data(self, patient_id):
        """加载指定患者的多模态数据"""
        self.current_patient = patient_id
        # 原始数据
        original_path = self.test_orginal_pd.loc[
            self.test_orginal_pd["PatientID"]==patient_id, 
            "absPathOfFile"].item()
        self.original_img = nib.load(original_path).get_fdata().transpose(2, 0, 1)
        
        # 各模型预测数据
        self.model_preds = {}
        for model_name, _ in self.models.items():
            pred_path = self.pred_dfs[model_name].loc[
                self.pred_dfs[model_name]["PatientID"]==patient_id & 
                (self.pred_dfs[model_name]["ModalType"] == "pred"), 
                "absPathOfFile"].item()
            self.model_preds[model_name] = nib.load(pred_path).get_fdata().transpose(2, 0, 1)
            
        # Ground Truth
        gt_path = self.gt_mask_pd.loc[
            self.gt_mask_pd["PatientID"]==patient_id, 
            "absPathOfFile"].item()
        self.gt_img = nib.load(gt_path).get_fdata().transpose(2, 0, 1)
        
        return self.original_img.shape

    def plot_comparison(self, slice_index=80, axis=0, model_name=None):
        """绘制多模型比较视图"""
        model_name = model_name or self.current_model
        pred_img = self.model_preds[model_name]
        
        # 获取切片
        slicer = {
            0: lambda img, idx: img[idx, :, :],
            1: lambda img, idx: img[:, idx, :],
            2: lambda img, idx: img[:, :, idx]
        }
        
        original_slice = slicer[axis](self.original_img, slice_index)
        pred_slice = slicer[axis](pred_img, slice_index)
        gt_slice = slicer[axis](self.gt_img, slice_index)
        
        # 创建子图
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # 原始图像+GT
        axes[0,0].imshow(original_slice, cmap='gray')
        axes[0,0].imshow(np.ma.masked_where(gt_slice==0, gt_slice), 
                        cmap=COLOR_MAP, norm=NORM, alpha=0.5)
        axes[0,0].set_title('Ground Truth')
        axes[0,0].axis('off')
        
        # 原始图像+预测
        axes[0,1].imshow(original_slice, cmap='gray')
        axes[0,1].imshow(np.ma.masked_where(pred_slice==0, pred_slice),
                        cmap=COLOR_MAP, norm=NORM, alpha=0.5)
        axes[0,1].set_title(f'{model_name} Prediction')
        axes[0,1].axis('off')
        
        # 差异图
        diff = np.where(pred_slice != gt_slice, 1, 0)
        axes[0,2].imshow(original_slice, cmap='gray')
        axes[0,2].imshow(diff, cmap='Reds', alpha=0.5)
        axes[0,2].set_title('Differences (Red=Wrong)')
        axes[0,2].axis('off')
        
        # 肿瘤区域细分
        tumor_types = {
            'Enhanced Tumor': (pred_slice == 3),
            'Tumor Core': (pred_slice == 1) | (pred_slice == 3),
            'Whole Tumor': (pred_slice >= 1)
        }
        
        for i, (title, mask) in enumerate(tumor_types.items()):
            ax = axes[1,i]
            ax.imshow(original_slice, cmap='gray')
            ax.imshow(mask, cmap='autumn', alpha=0.5)
            ax.set_title(title)
            ax.axis('off')
            
        plt.tight_layout()
        plt.show()
        

    def interactive_visualization(self, patient_ids):
        """创建交互式可视化界面"""
        def update(patient_id, model_name, slice_index=80, axis=0):
            if patient_id != self.current_patient:
                self.load_data(patient_id)
            self.plot_comparison(slice_index, axis, model_name)
            
        img_shape = self.load_data(patient_ids[0])
        max_slices = {0: img_shape[0]-1, 1: img_shape[1]-1, 2: img_shape[2]-1}
        
        interact(update,
                patient_id=Dropdown(options=patient_ids, description='Patient ID'),
                model_name=Dropdown(options=list(self.models.keys()), description='Model'),
                slice_index=IntSlider(min=0, max=max_slices[0], value=80),
                axis=IntSlider(min=0, max=2, value=0, description='Axis (0/1/2)'))




# 使用示例
if __name__ == "__main__":
    # 1. 初始化模型配置
    model_config = {
        "UNet3D": "path/to/unet/predictions",
        "AttUNet3D": "path/to/attunet/predictions",
        "DCLA_UNet": "path/to/dcla/predictions"
    }
    
    # 2. 创建可视化器实例
    visualizer = MultiModelVisualizer(model_config)
    
    # 3. 设置数据 (需预先加载CSV数据)
    visualizer.test_orginal_pd = test_orginal_pd  # 原始数据DataFrame
    visualizer.pred_dfs = {
        "UNet3D": process_pd(pd.read_csv('./csv/UNet3D_dir_anasys.csv')),
        "AttUNet3D": process_pd(pd.read_csv('./csv/AttUNet3D_dir_anasys.csv')),
        "DCLA_UNet": process_pd(pd.read_csv('./csv/DCLA_UNet_final_dir_anasys.csv'))
    }
    visualizer.gt_mask_pd = gt_mask_pd


In [None]:
    
# 4. 启动交互式可视化
patient_ids = list(test_orginal_pd["PatientID"].unique())
visualizer.interactive_visualization(patient_ids)

ValueError: can only convert an array of size 1 to a Python scalar