### 计算测试指标（Dice & HD）
#### 原始nnUNet对20个内部测试集的预测结果
- 保留最大连通域
- 计算Dice：赛方给的代码
- 计算HD：赛方给的代码

导入库

In [1]:
import os
import SimpleITK as sitk
import numpy as np
from medpy.metric import binary
import glob
import nibabel as nib

保留最大连通域的函数

In [2]:
def connected_domain(itk_mask):
    """
    获取mask中最大连通域
    :param itk_mask: SimpleITK.Image
    :return:最大连通域  res_itk
    :return:连通域个数  num_connected_label
    :return:每个连通域的体积  each_area
    """

    cc_filter = sitk.ConnectedComponentImageFilter()
    cc_filter.SetFullyConnected(True)
    output_mask = cc_filter.Execute(itk_mask)

    lss_filter = sitk.LabelShapeStatisticsImageFilter()
    lss_filter.Execute(output_mask)

    num_connected_label = cc_filter.GetObjectCount()  # 获取连通域个数

    area_max_label = 0  # 最大的连通域的label
    area_max = 0

    each_area = [lss_filter.GetNumberOfPixels(i) for i in range(1, num_connected_label + 1)]
    # 连通域label从1开始，0表示背景
    for i in range(1, num_connected_label + 1):
        area = lss_filter.GetNumberOfPixels(i)

          # 根据label获取连通域面积
        if area > area_max:
            area_max_label = i
            area_max = area

    np_output_mask = sitk.GetArrayFromImage(output_mask)
  
    res_mask = np.zeros_like(np_output_mask)
    res_mask[np_output_mask == area_max_label] = 1

    res_itk = sitk.GetImageFromArray(res_mask)
    res_itk.SetOrigin(itk_mask.GetOrigin())
    res_itk.SetSpacing(itk_mask.GetSpacing())
    res_itk.SetDirection(itk_mask.GetDirection())

    return  num_connected_label,res_itk

计算dice的函数

In [3]:
def cal_dice(seg, gt, classes=2, background_id=0):
    channel_dice = []
    for i in range(classes):
        if i == background_id:
            continue
        cond = i ** 2
        # 计算相交部分
        inter = len(np.where(seg * gt == cond)[0])
        total_pix = len(np.where(seg == i)[0]) + len(np.where(gt == i)[0])
        if total_pix == 0:
            dice = 0
        else:
            dice = (2 * inter) / total_pix
        channel_dice.append(dice)

    return np.array(channel_dice)

计算测试集DSC和HD的函数

In [4]:
def cal_DSC_HD(pre_path,GT_path):
        
    name_list = os.listdir(pre_path)
    Dice={}
    HD={}
    sum_result_dc = np.zeros((1,1)) 
    sum_result_HD=0
    
    for name in name_list:
        print(name)

        prediction_path = os.path.join(pre_path , name)
        gt_path = os.path.join(GT_path , name)
        _, save_biggest_img= connected_domain(sitk.Cast(sitk.ReadImage(prediction_path), sitk.sitkInt32))
        print("     get biggest connected_domain")

        predict = sitk.GetArrayFromImage(save_biggest_img)
        target = sitk.GetArrayFromImage(sitk.ReadImage(gt_path))

        each_dice = cal_dice(predict, target, classes=2, background_id=0)
        sum_result_dc += each_dice
        Dice[name[0:3]] = [round(i,4) for i in each_dice]
        print("     DSC done")

        try:
            tmp = binary.hd95(target, predict)
            sum_result_HD += tmp 
            HD[name[0:3]] = round(tmp,4) 
        except:
            HD[name[0:3]] = 0
            continue
        print("     HD done")
        print()

    avg_result_HD = sum_result_HD / len(name_list)
    avg_result_dice = sum_result_dc / len(name_list)

    var_HD = np.var(list(HD.values()))
    var_DSC = np.var(list(Dice.values()))

    return avg_result_dice, avg_result_HD, Dice, HD, var_HD, var_DSC

计算指标

In [5]:
pre_path = "/root/wjz_diff/test_case/raw_pre20"
GT_path = "/root/wjz_diff/test_case/GT_raw"
avg_result_dice, avg_result_HD, Dice, HD, var_HD, var_DSC = cal_DSC_HD(pre_path, GT_path)

103.nii.gz
     get biggest connected_domain
     DSC done
     HD done

030.nii.gz
     get biggest connected_domain
     DSC done
     HD done

176.nii.gz
     get biggest connected_domain
     DSC done
     HD done

238.nii.gz
     get biggest connected_domain
     DSC done
     HD done

072.nii.gz
     get biggest connected_domain
     DSC done
     HD done

029.nii.gz
     get biggest connected_domain
     DSC done
     HD done

055.nii.gz
     get biggest connected_domain
     DSC done
     HD done

265.nii.gz
     get biggest connected_domain
     DSC done
     HD done

212.nii.gz
     get biggest connected_domain
     DSC done
     HD done

039.nii.gz
     get biggest connected_domain
     DSC done
     HD done

063.nii.gz
     get biggest connected_domain
     DSC done
     HD done

042.nii.gz
     get biggest connected_domain
     DSC done
     HD done

014.nii.gz
     get biggest connected_domain
     DSC done
     HD done

207.nii.gz
     get biggest connected_domain
     D

In [6]:
print("指标: ", "DSC为", avg_result_dice, "±", round(var_DSC,4), "HD为", round(avg_result_HD), "±", round(var_HD,4))
print("每个测试集的DSC为", Dice)
print("每个测试集的HD为", HD)

指标:  DSC为 [[0.7309063]] ± 0.0967 HD为 49 ± 14071.7262
每个测试集的DSC为 {'103': [0.8503], '030': [0.7176], '176': [0.0042], '238': [0.9633], '072': [0.905], '029': [0.8952], '055': [0.8752], '265': [0.9587], '212': [0.9468], '039': [0.5085], '063': [0.8846], '042': [0.0004], '014': [0.9265], '207': [0.9279], '125': [0.1937], '028': [0.4652], '215': [0.9508], '203': [0.8924], '234': [0.8584], '177': [0.8932]}
每个测试集的HD为 {'103': 10.8628, '030': 16.4012, '176': 355.1535, '238': 3.3166, '072': 1.4142, '029': 1.0, '055': 2.0, '265': 4.0, '212': 6.4807, '039': 75.012, '063': 5.9161, '042': 443.2878, '014': 2.0, '207': 3.0, '125': 39.2301, '028': 4.1231, '215': 4.2426, '203': 2.2361, '234': 2.0, '177': 8.1854}
