-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_util.py
131 lines (115 loc) · 4.97 KB
/
test_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import h5py
import math
import nibabel as nib
import numpy as np
from medpy import metric
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.util import converToSlice
def test_all_case(net, image_list, num_classes, patch_size=(192, 192, 96), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None):
total_metric = 0.0
for image_path in tqdm(image_list):
print("ssssss",image_path)
id = image_path.split('/')[-1]
h5f = h5py.File(image_path, 'r')
image = h5f['image'][:]
label = h5f['label'][:]
if preproc_fn is not None:
image = preproc_fn(image)
print("preprocess")
prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
if np.sum(prediction)==0:
single_metric = (0,0,0,0)
else:
single_metric = calculate_metric_percase(prediction, label[:])
total_metric += np.asarray(single_metric)
if save_result:
nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + id + "_pred.nii.gz")
nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + id + "_img.nii.gz")
nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + id + "_gt.nii.gz")
avg_metric = total_metric / len(image_list)
print('average metric is {}'.format(avg_metric))
return avg_metric
def test(image,xs,ys,zs,patch_size,net,num_classes):
test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]]
test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(np.float32)
test_patch = torch.from_numpy(test_patch).cuda()
# hybrid
output3d = net(test_patch)
result = F.softmax(output3d, dim=1)
return result
def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1):
w, h, d = image.shape
print(w,h,d)
# if the size of image is less than patch_size, then padding it
add_pad = False
if w < patch_size[0]:
w_pad = patch_size[0]-w
add_pad = True
else:
w_pad = 0
if h < patch_size[1]:
h_pad = patch_size[1]-h
add_pad = True
else:
h_pad = 0
if d < patch_size[2]:
d_pad = patch_size[2]-d
add_pad = True
else:
d_pad = 0
wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
if add_pad:
image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
ww,hh,dd = image.shape
print(ww,hh,dd)
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
print("{}, {}, {}".format(sx, sy, sz))
score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
cnt = np.zeros(image.shape).astype(np.float32)
for x in range(0, sx):
xs = min(stride_xy*x, ww-patch_size[0])
for y in range(0, sy):
ys = min(stride_xy * y,hh-patch_size[1])
for z in range(0, sz):
zs = min(stride_z * z, dd-patch_size[2])
result = test(image,xs,ys,zs,patch_size,net,num_classes)
result = result.cpu().data.numpy()
score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
= score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + result[0, :, :, :, :]
cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
= cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
if num_classes==1:
score_map = score_map[0,:,:,:] / cnt
label_map = score_map.copy()
label_map[label_map>0.5] = 1
label_map[label_map<=0.5] = 0
else:
score_map = score_map/np.expand_dims(cnt,axis=0)
label_map = np.argmax(score_map, axis = 0)
if add_pad:
label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
return label_map, score_map
def cal_dice(prediction, label, num=2):
total_dice = np.zeros(num-1)
for i in range(1, num):
prediction_tmp = (prediction==i)
label_tmp = (label==i)
prediction_tmp = prediction_tmp.astype(np.float)
label_tmp = label_tmp.astype(np.float)
dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp))
total_dice[i - 1] += dice
return total_dice
def calculate_metric_percase(pred, gt):
dice = metric.binary.dc(pred, gt)
jc = metric.binary.jc(pred, gt)
hd = metric.binary.hd95(pred, gt)
asd = metric.binary.asd(pred, gt)
print("dice: {}".format(dice))
return dice, jc, hd, asd