# 计算测试指标（Dice & HD）
## Diff-nnUNet对20个内部测试集的预测结果
- 修改label（去掉肝脏label）
- 保留最大连通域
- 计算Dice：赛方给的代码
- 计算HD：赛方给的代码

导入库

In [1]:
import os
import SimpleITK as sitk
import numpy as np
from medpy.metric import binary
import glob
import nibabel as nib

保留最大连通域的函数

In [2]:
def connected_domain(itk_mask):
    """
    获取mask中最大连通域
    :param  itk_mask: SimpleITK.Image
    :return:最大连通域  res_itk
    :return:连通域个数  num_connected_label
    :return:每个连通域的体积  each_area
    """

    cc_filter = sitk.ConnectedComponentImageFilter()
    cc_filter.SetFullyConnected(True)
    output_mask = cc_filter.Execute(itk_mask)

    lss_filter = sitk.LabelShapeStatisticsImageFilter()
    lss_filter.Execute(output_mask)

    num_connected_label = cc_filter.GetObjectCount()  # 获取连通域个数

    area_max_label = 0  # 最大的连通域的label
    area_max = 0

    each_area = [lss_filter.GetNumberOfPixels(i) for i in range(1, num_connected_label + 1)]
    # 连通域label从1开始，0表示背景
    for i in range(1, num_connected_label + 1):
        area = lss_filter.GetNumberOfPixels(i)

          # 根据label获取连通域面积
        if area > area_max:
            area_max_label = i
            area_max = area
    
    np_output_mask = sitk.GetArrayFromImage(output_mask)
  
    res_mask = np.zeros_like(np_output_mask)
    res_mask[np_output_mask == area_max_label] = 1

    res_itk = sitk.GetImageFromArray(res_mask)
    res_itk.SetOrigin(itk_mask.GetOrigin())
    res_itk.SetSpacing(itk_mask.GetSpacing())
    res_itk.SetDirection(itk_mask.GetDirection())

    return  num_connected_label, res_itk

计算dice的函数

In [3]:
def cal_dice(seg, gt, classes=2, background_id=0):
    channel_dice = []
    for i in range(classes):
        if i == background_id:
            continue
        cond = i ** 2
        # 计算相交部分
        inter = len(np.where(seg * gt == cond)[0])
        total_pix = len(np.where(seg == i)[0]) + len(np.where(gt == i)[0])
        if total_pix == 0:
            dice = 0
        else:
            dice = (2 * inter) / total_pix
        channel_dice.append(dice)
    return np.array(channel_dice)

计算测试集DSC和HD的函数

In [4]:
def cal_DSC_HD(pre_path,GT_path):
    
    name_list = os.listdir(pre_path)
    Dice={}
    HD={}
    sum_result_dc = np.zeros((1,1)) 
    sum_result_HD=0
    
    for name in name_list:
        print(name)

        prediction_path = os.path.join(pre_path , name)
        gt_path = os.path.join(GT_path , name)
        _, save_biggest_img= connected_domain(sitk.Cast(sitk.ReadImage(prediction_path), sitk.sitkInt32))
        print("     get biggest connected_domain")

        predict = sitk.GetArrayFromImage(save_biggest_img)
        target = sitk.GetArrayFromImage(sitk.ReadImage(gt_path))

        each_dice = cal_dice(predict, target, classes=2, background_id=0)
        sum_result_dc += each_dice
        Dice[name[0:3]] = [round(i,4) for i in each_dice]
        print("     DSC done")

        try:
            tmp = binary.hd95(target, predict)
            sum_result_HD += tmp 
            HD[name[0:3]] = round(tmp,4) 
        except:
            HD[name[0:3]] = 0
            continue
        print("     HD done")
        print()
    
    avg_result_HD = sum_result_HD / len(name_list)
    avg_result_dice = sum_result_dc / len(name_list)

    var_HD = np.var(list(HD.values()))
    var_DSC = np.var(list(Dice.values()))

    return avg_result_dice, avg_result_HD, Dice, HD, var_HD, var_DSC

创建文件夹

In [5]:
def ensure_folder_exists(folder_path):
    # 检查文件夹路径是否存在
    if not os.path.exists(folder_path):
        # 如果不存在，则创建文件夹
        os.makedirs(folder_path)
        print(f"文件夹{folder_path}已创建。")
    else:
        print(f"文件夹{folder_path}已存在。")

修改 label

In [8]:
def process_and_save_nii_files(source_folder, target_folder):
    nii_files = glob.glob(os.path.join(source_folder, "*.nii.gz"))

    for file_path in nii_files:
        nii_data = nib.load(file_path)
        data = nii_data.get_fdata()

        # 将数值1置零，将数值2置一
        data[data == 1] = 0
        data[data == 2] = 1

        # 创建新的NIfTI图像对象
        new_nii = nib.Nifti1Image(data, affine=nii_data.affine, header=nii_data.header)

        # 构建目标文件路径
        target_file_path = os.path.join(target_folder, os.path.basename(file_path))

        # 保存处理后的文件
        nib.save(new_nii, target_file_path)
        # print(f"Processed file saved to {target_file_path}")

ori_pre_path = "/root/wjz_diff/test_case/predict"
pre_path = "/root/wjz_diff/test_case/predict_processed"
ensure_folder_exists(pre_path)
process_and_save_nii_files(ori_pre_path, pre_path)

ori_GT_path = "/root/wjz_diff/test_case/GT"
GT_path = "/root/wjz_diff/test_case/GT_processed"
ensure_folder_exists(GT_path)
process_and_save_nii_files(ori_GT_path, GT_path)

文件夹/root/wjz_diff/test_case/predict_processed已创建。
文件夹/root/wjz_diff/test_case/GT_processed已创建。


计算指标

In [6]:
pre_path = "/root/wjz_diff/test_case/predict_processed"
GT_path = "/root/wjz_diff/test_case/GT_processed"
avg_result_dice, avg_result_HD, Dice, HD, var_HD, var_DSC = cal_DSC_HD(pre_path, GT_path)

212.nii.gz
     get biggest connected_domain
     DSC done
     HD done

039.nii.gz
     get biggest connected_domain
     DSC done
     HD done

055.nii.gz
     get biggest connected_domain
     DSC done
     HD done

265.nii.gz
     get biggest connected_domain
     DSC done
     HD done

030.nii.gz
     get biggest connected_domain
     DSC done
     HD done

103.nii.gz
     get biggest connected_domain
     DSC done
     HD done

029.nii.gz
     get biggest connected_domain
     DSC done
     HD done

072.nii.gz
     get biggest connected_domain
     DSC done
     HD done

176.nii.gz
     get biggest connected_domain
     DSC done
     HD done

238.nii.gz
     get biggest connected_domain
     DSC done
     HD done

234.nii.gz
     get biggest connected_domain
     DSC done
     HD done

177.nii.gz
     get biggest connected_domain
     DSC done
     HD done

215.nii.gz
     get biggest connected_domain
     DSC done
     HD done

028.nii.gz
     get biggest connected_domain
     D

In [7]:
print("指标: ", "DSC为", avg_result_dice, "±", round(var_DSC,4), "HD为", round(avg_result_HD), "±", round(var_HD,4))
print("每个测试集的DSC为", Dice)
print("每个测试集的HD为", HD)

指标:  DSC为 [[0.78857059]] ± 0.0512 HD为 15 ± 796.0245
每个测试集的DSC为 {'212': [0.9422], '039': [0.7895], '055': [0.8723], '265': [0.9485], '030': [0.85], '103': [0.866], '029': [0.8871], '072': [0.9113], '176': [0.7692], '238': [0.9518], '234': [0.4657], '177': [0.8294], '215': [0.9446], '028': [0.0], '203': [0.6688], '042': [0.4888], '014': [0.9092], '063': [0.8915], '125': [0.8506], '207': [0.935]}
每个测试集的HD为 {'212': 7.0, '039': 6.1644, '055': 2.2361, '265': 4.1231, '030': 3.4641, '103': 8.8318, '029': 1.0, '072': 1.7321, '176': 7.6158, '238': 4.0, '234': 96.1439, '177': 9.8995, '215': 5.3852, '028': 101.4776, '203': 10.247, '042': 6.7601, '014': 2.0, '063': 4.0, '125': 7.0711, '207': 2.4495}
