# Calculate the metrics for validate set

In [2]:
import os

niigz_files = os.listdir()
niigz_files.sort()
# niigz_files

groundtruth_files = []
predict_files = []
for niigz_file in niigz_files:
    if "groundtruth" in niigz_file:
        groundtruth_files.append(niigz_file)
    if "predict" in niigz_file:
        predict_files.append(niigz_file)

groundtruth_files.sort()
predict_files.sort()

case_names = []
for file in groundtruth_files:
    case_name = file.split("-groundtruth")[0]
    case_names.append(case_name)

case_names.sort()
print(case_names)

test_dict = dict()
for case_name in case_names:
    test_dict[case_name] = {}
    groundtruth_niigz = case_name + "-groundtruth.nii.gz"
    predict_niigz = case_name + "-predict.nii.gz"
    test_dict[case_name]["groundtruth"] = groundtruth_niigz
    test_dict[case_name]["predict"] = predict_niigz

print(len(test_dict))

['ATM_029_0000', 'ATM_054_0000', 'ATM_055_0000', 'ATM_057_0000', 'ATM_091_0000', 'ATM_174_0000', 'ATM_215_0000', 'ATM_505_0000', 'ATM_688_0000']
9


In [3]:
import SimpleITK as sitk
import numpy as np

def load_CT_scan_3D_image(niigz_file_name):
    itkimage = sitk.ReadImage(niigz_file_name)
    numpyImages = sitk.GetArrayFromImage(itkimage)
    numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
    numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
    return numpyImages, numpyOrigin, numpySpacing

def false_positive_rate_calculation(pred, label, smooth=1e-5):
    pred = pred.flatten()
    label = label.flatten()
    fp = np.sum(pred - pred * label) + smooth
    fpr = round(fp * 100 / (np.sum((1.0 - label)) + smooth), 3)
    return fpr

def false_negative_rate_calculation(pred, label, smooth=1e-5):
    pred = pred.flatten()
    label = label.flatten()
    fn = np.sum(label - pred * label) + smooth
    fnr = round(fn * 100 / (np.sum(label) + smooth), 3)
    return fnr

def sensitivity_calculation(pred, label):   #  identical to True-Positive-Rate
    sensitivity = round(100 - false_negative_rate_calculation(pred, label), 3)
    return sensitivity

def dice_coefficient_score_calculation(pred, label, smooth=1e-5):
    pred = pred.flatten()
    label = label.flatten()
    intersection = np.sum(pred * label)
    dice_coefficient_score = round(((2.0 * intersection + smooth) / (np.sum(pred) + np.sum(label) + smooth)) * 100, 2)
    return dice_coefficient_score

def precision_calculation(pred, label, smooth=1e-5):
    pred = pred.flatten()
    label = label.flatten()
    tp = np.sum(pred * label) + smooth
    precision = round(tp * 100 / (np.sum(pred) + smooth), 3)
    return precision


In [15]:
FPR_list = []
FNR_list = []
Sensitivity_list = []
Precision_list = []
DSC_list = []

for item in test_dict.items():
    case_name = item[0]
    if (case_name == "ATM_174_0000") or (case_name == "ATM_505_0000"):
        continue
    gt_npy, origin, spacing = load_CT_scan_3D_image(item[1]["groundtruth"])
    pred_npy, _, _ = load_CT_scan_3D_image(item[1]["predict"])
    
    FPR = false_positive_rate_calculation(pred_npy, gt_npy)
    FNR = false_negative_rate_calculation(pred_npy, gt_npy)
    Sensitivity = sensitivity_calculation(pred_npy, gt_npy)
    Precision = precision_calculation(pred_npy, gt_npy)
    DSC = dice_coefficient_score_calculation(pred_npy, gt_npy)
    
    print("{0} & {1} & {2} & {3} & {4} & {5}"
          .format(case_name, FPR, FNR, Sensitivity, Precision, DSC))
    
    FPR_list.append(FPR)
    FNR_list.append(FNR)
    Sensitivity_list.append(Sensitivity)
    Precision_list.append(Precision)
    DSC_list.append(DSC)

BD_list = [76.92,
           76.16,
           81.17,
           78.35,
           85.46,
           72.25,
           93.94]

TLD_list = [85.85,
            83.93,
            89.09,
            87.98,
            91.99,
            84.32,
            96.31]

print("mean & {0} & {1} & {2} & {3} & {4}"
      .format(np.mean(FPR_list), 
              np.mean(FNR_list), 
              np.mean(Sensitivity_list),
              np.mean(Precision_list),
              np.mean(DSC_list)))
    
print("{0} & {1}".format(np.mean(BD_list), np.mean(TLD_list)))

ATM_029_0000 & 0.022 & 9.266 & 90.734 & 91.012 & 90.87
ATM_054_0000 & 0.037 & 3.499 & 96.501 & 92.319 & 94.36
ATM_055_0000 & 0.039 & 3.453 & 96.547 & 89.855 & 93.08
ATM_057_0000 & 0.039 & 4.683 & 95.317 & 90.426 & 92.81
ATM_091_0000 & 0.041 & 3.57 & 96.43 & 89.583 & 92.88
ATM_215_0000 & 0.021 & 10.595 & 89.405 & 92.453 & 90.9
ATM_688_0000 & 0.028 & 2.109 & 97.891 & 92.182 & 94.95
mean & 0.03242857142857143 & 5.310714285714286 & 94.68928571428572 & 91.11857142857143 & 92.83571428571429
80.60714285714286 & 88.49571428571429
