# Visualized Results


**可视化流程**

- 加载模型
- 加载数据
- 使用模型推理
- 可视化推理结果

In [1]:
import pandas as pd
import os
import torch
import shutil
import argparse
import numpy as np
from tqdm import tqdm
import nibabel as nib
from matplotlib import pyplot as plt
from tabulate import tabulate
from itertools import product
import shutil
from tabulate import tabulate
import time
from torch.nn import functional as F
from torch.optim import RMSprop, AdamW
from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader

from datasets.BraTS21 import BraTS21_3D
from datasets.transforms import Compose, FrontGroundNormalize, RandomCrop3D, ToTensor
from lossFunc import DiceLoss, CELoss
from metrics import *
from utils.logger_tools import custom_logger, get_current_date, get_current_time
from utils.ckpt_tools import load_checkpoint
from nnArchitecture.baselines.UNet3d import UNet3D
from nnArchitecture.baselines.AttentionUNet import AttentionUNet3D

from nnArchitecture.optimization_nets.DasppResAtteUNet import DasppResAtteUNet
from nnArchitecture.optimization_nets.ScgaResAtteUNet import ScgaResAtteUNet
from nnArchitecture.optimization_nets.AA_UNet import AAUNet

from nnArchitecture.ref_homo_nets.unetr import UNETR
from nnArchitecture.ref_homo_nets.unetrpp import UNETR_PP
from nnArchitecture.ref_homo_nets.segFormer3d import SegFormer3D

from nnArchitecture.ref_hetero_nets.Mamba3d import Mamba3d
from nnArchitecture.ref_hetero_nets.MogaNet import MogaNet


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

Importing from timm.models.layers is deprecated, please import via timm.layers


## 加载模型

In [2]:

def load_model(model_name):
    """加载模型"""
    if model_name == 'UNet3D':
        model = UNet3D(in_channels=4, out_channels=4)
    elif model_name == 'AttentionUNet3D':
        model = AttentionUNet3D(in_channels=4, out_channels=4)
    elif model_name == 'unetr':
        model = UNETR(in_channels=4, out_channels=4)
    elif model_name == 'UNETR':
        model = UNETR_PP(in_channels=4, out_channels=4)
    elif model_name == 'SegFormer3D':
        model = SegFormer3D(in_channels=4, out_channels=4)
    elif model_name == 'Mamba3d':
        model = Mamba3d(in_channels=4, out_channels=4)
    elif model_name == 'MogaNet':
        model = MogaNet(in_channels=4, out_channels=4)
    elif model_name == 'DasppResAtteUNet':
        model = DasppResAtteUNet(in_channels=4, out_channels=4)
    elif model_name == 'ScgaResAtteUNet':
        model = ScgaResAtteUNet(in_channels=4, out_channels=4)
    elif model_name == 'AAUNet':
        model = AAUNet(in_channels=4, out_channels=4)
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    model = model.to(DEVICE)
    
    return model

In [3]:

# 获取nii文件地址
def get_nii_path(out_dir, model_name, patient_idx, modal):
    return f'{out_dir}/{model_name}/outputs/BraTS2021_{patient_idx}/BraTS2021_{patient_idx}_{modal}.nii.gz'

## 加载数据

In [4]:
def load_data(test_csv, local_train=True, test_length=10, batch_size=1, num_workers=4):
    """加载数据集"""
    TransMethods_test = Compose([
        ToTensor(),
        RandomCrop3D(size=(155, 240, 240)),
        FrontGroundNormalize(),
    ])

    test_dataset = BraTS21_3D(
        data_file=test_csv,
        transform=TransMethods_test,
        local_train=local_train,
        length=test_length,
    )
    
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True  # 减少 worker 初始化时间
    )
    
    print(f"已加载测试数据: {len(test_loader)}")
    return test_loader

## 滑窗推理

In [5]:
from optimized_inference import inference, load_model, load_data


csv_file = '/root/workspace/VoxelMedix/data/raw/brats21_original/test.csv'
out_dir = '/mnt/d/results'


if not os.path.exists(out_dir):
    out_dir = '../../output'
    
model_names = ['UNet3D', 'AttentionUNet3D']
test_df = pd.read_csv(csv_file)
test_loader = load_data(csv_file)

unet_model = load_model('UNet3D') 
attention_unet_model = load_model('AttentionUNet3D')
attention_unet_daspp_model = load_model('DasppResAtteUNet')
attention_unet_scga_model = load_model('ScgaResAtteUNet')
attention_unet_aa_model = load_model('AAUNet')

#TODO: 修改（添加）模型路径
ckpt_path = {
    'UNet3D': f'{out_dir}/1_UNet3D/checkpoints/best@e117_UNet3D__diceloss0.1605_dice0.8397_2025-02-16_18-07-54_19.pth',
    'AttentionUNet3D': f'{out_dir}/2_AttentionUNet3D/checkpoints/best@e187_AttentionUNet3D__diceloss0.1476_dice0.8526_2025-02-15_22-19-03_21.pth',
    'AttnUNet3D_DenseASPP': f'{out_dir}/2_AttentionUNet3D_DenseASPP/checkpoints/best@e50_DasppResAtteUNet__diceloss0.1436_dice0.8567_2025-02-25_18-24-52_11.pth',
    'AttnUNet3D_SCGA': f'{out_dir}/2_AttentionUNet3D_SCGA/checkpoints/best@e53_ScgaResAtteUNet__diceloss0.1556_dice0.8447_2025-02-25_18-29-04_14.pth',
    
    'AttnUNet3D_AA': f'{out_dir}/3_AA_UNet/checkpoints/best@e136_AAUNet__diceloss0.1404_dice0.8599_2025-02-25_19-55-50_9.pth'
}


def cfg_generator(model, 
                  ckpt_path=None,
                  test_df=test_df, 
                  test_loader=test_loader, 
                  out_dir=out_dir, 
                  scaler=GradScaler(), 
                  metricer=EvaluationMetrics()):
    return {
        'test_df': test_df,
        'test_loader': test_loader,
        'output_root': out_dir,
        'model': model,
        'metricer': metricer,
        'scaler': scaler,
        'optimizer': AdamW(model.parameters(), lr=0.0001, betas=(0.9, 0.99), weight_decay=0.00001),
        'ckpt_path': ckpt_path,
    }

已加载测试数据: 10


AssertionError: factor too big, channels // self.group > 4

In [None]:
# 初始化配置
# unet_cfg = cfg_generator(unet_model, ckpt_path=ckpt_path['UNet3D'])
# attention_unet_cfg = cfg_generator(attention_unet_model, ckpt_path=ckpt_path['AttentionUNet3D'])
# attention_unet_denseaspp_cfg = cfg_generator(attention_unet_daspp_model, ckpt_path=ckpt_path['AttnUNet3D_DenseASPP'])
# attention_unet_scga_cfg = cfg_generator(attention_unet_scga_model, ckpt_path=ckpt_path['AttnUNet3D_SCGA'])
# attention_unet_aa_cfg = cfg_generator(attention_unet_aa_model, ckpt_path=ckpt_path['AttnUNet3D_AA'])

# 执行推理
# # unet_results = inference(**unet_cfg)
# attention_unet_results = inference(**attention_unet_cfg)
# attention_unet_denseaspp_results = inference(**attention_unet_denseaspp_cfg)
# attention_unet_scga_results = inference(**attention_unet_scga_cfg)
# attention_unet_aa_results = inference(**attention_unet_aa_cfg)


# attention_unet_config = {
#     'test_df': test_df,
#     'test_loader': test_loader,
#     'output_root': out_dir,
#     'model': attention_unet_model,
#     'metricer': EvaluationMetrics(),
#     'scaler': GradScaler(),
#     'optimizer': AdamW(attention_unet_model.parameters(), lr=0.0001, betas=(0.9, 0.99), weight_decay=0.00001),
#     'ckpt_path': f'{out_dir}/2_AttentionUNet3D/checkpoints/best@e187_AttentionUNet3D__diceloss0.1476_dice0.8526_2025-02-15_22-19-03_21.pth'
# }
# attention_unet_denseaspp_config = {
#     'test_df': test_df,
#     'test_loader': test_loader,
#     'output_root': out_dir,
#     'model': attention_unet_denseaspp_model,
#     'metricer': EvaluationMetrics(),
#     'scaler': GradScaler(),
#     'optimizer': AdamW(attention_unet_denseaspp_model.parameters(), lr=0.0001, betas=(0.9, 0.99), weight_decay=0.00001),
#     'ckpt_path': f'{out_dir}/2_AttentionUNet3D_DenseASPP/checkpoints/best@e50_DasppResAtteUNet__diceloss0.1436_dice0.8567_2025-02-25_18-24-52_11.pth'
# }
# attention_unet_scga_config = {
#     'test_df': test_df,
#     'test_loader': test_loader,
#     'output_root': out_dir,
#     'model': attention_unet_scga_model,
#     'metricer': EvaluationMetrics(),
#     'scaler': GradScaler(),
#     'optimizer': AdamW(attention_unet_scga_model.parameters(), lr=0.0001, betas=(0.9, 0.99), weight_decay=0.00001),
#     'ckpt_path': f'{out_dir}/2_AttentionUNet3D_SCGA/checkpoints/best@e53_ScgaResAtteUNet__diceloss0.1556_dice0.8447_2025-02-25_18-29-04_14.pth'
# }

# attention_aa_unet_cfg = {
#     'test_df': test_df,
#     'test_loader': test_loader,
#     'output_root': out_dir,
#     'model': attention_unet_aa_model,
#     'metricer': EvaluationMetrics(),
#     'scaler': GradScaler(),
#     'optimizer': AdamW(attention_unet_aa_model.parameters(), lr=0.0001, betas=(0.9, 0.99), weight_decay=0.00001),
#     'ckpt_path': f'{out_dir}/3_AA_UNet/checkpoints/best@e136_AAUNet__diceloss0.1404_dice0.8599_2025-02-25_19-55-50_9.pth'
#     }

# 执行推理
# unet_results = inference(**unet_config)
# attention_unet_results = inference(**attention_unet_config)
# attention_unet_denseaspp_results = inference(**attention_unet_denseaspp_config)
# attention_unet_scga_results = inference(**attention_unet_scga_config)
# attention_unet_aa_results = inference(**attention_aa_unet_cfg)

In [None]:
ckpt_path = '/root/workspace/BraTS_Solution/results/best@e53_ScgaResAtteUNet__diceloss0.1556_dice0.8447_2025-02-25_18-29-04_14.pth'

checkpoint = torch.load(ckpt_path, map_location='cpu')
saved_keys = checkpoint['model_state_dict'].keys()
current_model = ScgaResAtteUNet()
current_keys = current_model.state_dict().keys()

missing = [k for k in saved_keys if k not in current_keys]
unexpected = [k for k in current_keys if k not in saved_keys]

print("Missing in current model:", missing)
print("Unexpected in current model:", unexpected)
    

## 结果可视化

In [None]:
import nibabel as nib

import numpy as np

from matplotlib import pyplot as plt

patient_idx = '00150'
output_dir = '/mnt/d/results'

nii_path = f'{output_dir}/1_UNet3D/outputs/BraTS2021_{patient_idx}/BraTS2021_{patient_idx}_flair.nii.gz'

gt_path = f'{output_dir}/1_UNet3D/outputs/BraTS2021_{patient_idx}/BraTS2021_{patient_idx}_seg.nii.gz'


pred_paths = {
    'UNet3D': get_nii_path(output_dir, '1_UNet3D', patient_idx, 'pred'),
    'AttentionUNet3D': get_nii_path(output_dir, '2_AttentionUNet3D', patient_idx, 'pred'),
    'AttentionUNet3D_DenseASPP': get_nii_path(output_dir, '2_AttentionUNet3D_DenseASPP', patient_idx, 'pred'),
    'AttentionUNet3D_SCGA': get_nii_path(output_dir, '2_AttentionUNet3D_SCGA', patient_idx, 'pred'),
    'AA_UNet': get_nii_path(output_dir, '3_AA_UNet', patient_idx, 'pred')
}

In [None]:
def load_nii(nii_path, type='mask', trans_position=(0, 1, 2)):
    
    if type == 'mask':
        return nib.load(nii_path).get_fdata().transpose(trans_position).astype(np.uint8)
    elif type =='image':
        return nib.load(nii_path).get_fdata().transpose(trans_position)
    else:
        raise ValueError('type must be mask or image')

In [None]:
nii_data = load_nii(nii_path, type='image', trans_position=(2, 0, 1))
gt_data = load_nii(gt_path, type='mask', trans_position=(2, 0, 1))

pred_dict = {}
for key, path in pred_paths.items():
    pred_data = load_nii(path, type='mask', trans_position=(2, 0, 1))
    pred_dict[key] = pred_data
    print(f'{key} pred shape: {pred_data.shape}')

print(nii_data.shape)
print(gt_data.shape)


In [None]:
# print(np.unique(nii_data))
print(np.unique(gt_data))
print(np.unique(pred_dict['UNet3D']))
print(np.unique(pred_dict['AttentionUNet3D']))

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import numpy as np

# 创建离散Colormap（含归一化）
color_list = [(0,0,0), (1,0,0), (0,1,0), (0,0,1)]
cmap = ListedColormap(color_list, name='custom_discrete', N=4)
bounds = [0, 1, 2, 3, 4]  # 类别边界
norm = BoundaryNorm(bounds, cmap.N)

def get_slices(nii_data, gt_data, pred_dict, slice_idx=80, axis=0):
    pred_slices = {}
    if axis == 0:
        nii_slice = nii_data[slice_idx]
        gt_slice = gt_data[slice_idx]
        for k, v in pred_dict.items():
            pred_slices[k] = v[slice_idx]
    elif axis == 1:
        nii_slice = nii_data[:, slice_idx]
        gt_slice = gt_data[:, slice_idx]
        for k, v in pred_dict.items():
            pred_slices[k] = v[:, slice_idx]
    elif axis == 2:
        nii_slice = nii_data[:, :, slice_idx]
        gt_slice = gt_data[:, :, slice_idx]
        
    return nii_slice, gt_slice, pred_slices

def plot_slice(nii_slice, gt_slice, pred_slices, title: str, overlay_cmap: str = cmap, alpha: float = 1, axis=0):
    
    if axis == 0:
        figsize = (18, 10)
    elif axis == 1 or axis == 2:
        figsize = (10, 18)
        
    fig, axes = plt.subplots(1, len(pred_slices)+1, figsize=figsize)
    
    # 显示原始图像和GT
    axes[0].imshow(nii_slice, cmap='gray')  # 底层图像用灰度
    axes[0].imshow(np.ma.masked_where(gt_slice == 0, gt_slice), 
                   cmap=overlay_cmap, norm=norm, alpha=alpha)  # GT用自定义Colormap
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')  # 关闭坐标轴
    # 显示预测结果
    for i, (model_name, overlay) in enumerate(pred_slices.items()):
        axes[i+1].imshow(nii_slice, cmap='gray')        # 底层图像用灰度
        axes[i+1].imshow(np.ma.masked_where(overlay == 0, overlay),  # 预测结果用自定义Colormap
                         cmap=overlay_cmap, norm=norm, alpha=alpha)
        axes[i+1].set_title(f"{model_name} Prediction") 
        axes[i+1].axis('off')  # 关闭坐标轴
    fig.tight_layout()
    axes[i+1].axis('off')  # 关闭坐标轴
    fig.show() 


def plot_volume(nii_data, gt_data, pred_dict, title: str, overlay_cmap: str = cmap, alpha: float = 1, axis=0):

    for slice in range(nii_data.shape[axis]):
        nii_slice, gt_slice, pred_slices = get_slices(nii_data, gt_data, pred_dict, slice_idx=slice, axis=axis)
        plot_slice(nii_slice, gt_slice, pred_slices, title=f'Slice {slice} {title}', overlay_cmap=overlay_cmap, alpha=alpha, axis=axis)
        

slice_idx_axis0 = 50
slice_idx_axis1 = 170
slice_idx_axis2 = 100

nii_slice, gt_slice, pred_slices = get_slices(nii_data, gt_data, pred_dict, slice_idx=slice_idx_axis0, axis=0)
plot_slice(nii_slice, gt_slice, pred_slices, title=f'Slice {slice_idx_axis0} Axis 横断面', axis=0)

nii_slice, gt_slice, pred_slices = get_slices(nii_data, gt_data, pred_dict, slice_idx=slice_idx_axis1, axis=1)
plot_slice(nii_slice, gt_slice, pred_slices, title=f'Slice {slice_idx_axis1} Axis 矢状面', axis=1)

# nii_slice, gt_slice, pred_slices = get_slices(nii_data, gt_data, pred_dict, slice_idx=slice_idx_axis2, axis=2)
# plot_slice(nii_slice, gt_slice, pred_slices, title=f'Slice {slice_idx_axis2} Axis 冠状面', axis=1)
plot_volume(nii_data, gt_data, pred_dict, title='Volume')


In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import numpy as np

# 创建离散Colormap（含归一化）
color_list = [(0,0,0), (1,0,0), (0,1,0), (0,0,1)]
cmap = ListedColormap(color_list, name='custom_discrete', N=4)
bounds = [0, 1, 2, 3, 4]  # 类别边界
norm = BoundaryNorm(bounds, cmap.N)

def plot_slice(nii_slice, gt_slice, pred_slices, title: str, overlay_cmap: str = 'cool', alpha: float = 1):
    fig, axes = plt.subplots(1, len(pred_slices)+1, figsize=(18, 10))
    
    # 显示原始图像和GT
    axes[0].imshow(nii_slice, cmap='gray')  # 底层图像用灰度
    axes[0].imshow(np.ma.masked_where(gt_slice == 0, gt_slice), 
                   cmap=overlay_cmap, norm=norm, alpha=alpha)  # GT用自定义Colormap
    axes[0].set_title('Ground Truth')
    
    # 显示预测结果
    for i, (model_name, overlay) in enumerate(pred_slices.items()):
        axes[i+1].imshow(nii_slice, cmap='gray')        # 底层图像用灰度
        axes[i+1].imshow(np.ma.masked_where(overlay == 0, overlay),  # 预测结果用自定义Colormap
                         cmap=overlay_cmap, norm=norm, alpha=alpha)
        axes[i+1].set_title(f"{model_name} Prediction") 
    
    plt.tight_layout()
    plt.show()
    
nii_slice, gt_slice, pred_slices = get_slices(nii_data, gt_data, pred_dict, slice_idx=50)
plot_slice(nii_slice, gt_slice, pred_slices, title='Slice 100', overlay_cmap=cmap)

In [None]:
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

def create_slice_slider(nii_data, initial_slice=0):
    """创建切片滑动组件"""
    # 创建交互控件
    slider = widgets.IntSlider(
        value=initial_slice,
        min=0,
        max=nii_data.shape[0]-1,  # 根据数据维度动态设置最大值
        step=1,
        description='Slice Index:',
        continuous_update=False  # 仅在释放滑块时更新（性能优化）
    )
    current_label = widgets.Label(value=f"当前切片：{initial_slice}")
    
    # 控件布局美化
    slider.layout.width = '600px'
    controls = widgets.VBox([slider, current_label])
    
    # 初始化图像显示
    fig, ax = plt.subplots(figsize=(8, 6))
    img = ax.imshow(nii_data[initial_slice], cmap='gray')
    plt.close(fig)  # 避免重复显示初始图像
    
    # 定义更新回调函数
    def update_slice(change):
        img.set_data(nii_data[change['new']])
        current_label.value = f"当前切片：{change['new']}"
        fig.canvas.draw_idle()
    
    # 绑定事件
    slider.observe(update_slice, names='value')
    
    # 组合显示
    return widgets.VBox([controls, fig.canvas])