/
res_unet.py
149 lines (110 loc) · 4.9 KB
/
res_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
'''
=======================
res_block added to unet
=======================
'''
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config=tf.ConfigProto()
config.gpu_options.allow_growth=True
session=tf.Session(config=config)
# import packages
from functools import partial
import os
from keras.models import *
from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, Activation, add, concatenate
from keras.optimizers import Adam
from keras import callbacks
from keras import backend as K
import keras.backend.tensorflow_backend as KTF
# import configurations
import configs
K.set_image_data_format('channels_last') # TF dimension ordering in this code
# init configs
image_rows = configs.VOLUME_ROWS
image_cols = configs.VOLUME_COLS
image_depth = configs.VOLUME_DEPS
num_classes = configs.NUM_CLASSES
# patch extraction parameters
patch_size = configs.PATCH_SIZE
BASE = configs.BASE
smooth = configs.SMOOTH
# compute dsc
def dice_coef(y_true, y_pred, smooth=1.):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
# proposed loss function
def dice_coef_loss(y_true, y_pred):
distance = 0
for label_index in range(num_classes):
dice_coef_class = dice_coef(y_true[:,:,:,label_index], y_pred[:, :,:,label_index])
distance = 1 - dice_coef_class + distance
return distance
# dsc per class
def label_wise_dice_coefficient(y_true, y_pred, label_index):
return dice_coef(y_true[:,:,:,label_index], y_pred[:, :,:,label_index])
# get label dsc
def get_label_dice_coefficient_function(label_index):
f = partial(label_wise_dice_coefficient, label_index=label_index)
f.__setattr__('__name__', 'label_{0}_dice_coef'.format(label_index))
return f
def res_block(x, nb_filters, strides):
res_path = BatchNormalization()(x)
res_path = Activation(activation='relu')(res_path)
res_path = Conv2D(filters=nb_filters[0], kernel_size=(3, 3), padding='same', strides=strides[0])(res_path)
res_path = BatchNormalization()(res_path)
res_path = Activation(activation='relu')(res_path)
res_path = Conv2D(filters=nb_filters[1], kernel_size=(3, 3), padding='same', strides=strides[1])(res_path)
shortcut = Conv2D(nb_filters[1], kernel_size=(1, 1), strides=strides[0])(x)
shortcut = BatchNormalization()(shortcut)
res_path = add([shortcut, res_path])
return res_path
def encoder(x):
to_decoder = []
main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(x)
main_path = BatchNormalization()(main_path)
main_path = Activation(activation='relu')(main_path)
main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path)
shortcut = Conv2D(filters=64, kernel_size=(1, 1), strides=(1, 1))(x)
shortcut = BatchNormalization()(shortcut)
main_path = add([shortcut, main_path])
# first branching to decoder
to_decoder.append(main_path)
main_path = res_block(main_path, [128, 128], [(2, 2), (1, 1)])
to_decoder.append(main_path)
main_path = res_block(main_path, [256, 256], [(2, 2), (1, 1)])
to_decoder.append(main_path)
return to_decoder
def decoder(x, from_encoder):
main_path = UpSampling2D(size=(2, 2))(x)
main_path = concatenate([main_path, from_encoder[2]], axis=3)
main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)])
main_path = UpSampling2D(size=(2, 2))(main_path)
main_path = concatenate([main_path, from_encoder[1]], axis=3)
main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)])
main_path = UpSampling2D(size=(2, 2))(main_path)
main_path = concatenate([main_path, from_encoder[0]], axis=3)
main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)])
return main_path
def build_res_unet():
metrics = dice_coef
include_label_wise_dice_coefficients = True;
inputs = Input((patch_size, patch_size, 1))
to_decoder = encoder(inputs)
path = res_block(to_decoder[2], [512, 512], [(2, 2), (1, 1)])
path = decoder(path, from_encoder=to_decoder)
path = Conv2D(filters=num_classes, kernel_size=(1, 1), activation='softmax')(path)
model = Model(inputs=[inputs], outputs=[path])
if not isinstance(metrics, list):
metrics = [metrics]
if include_label_wise_dice_coefficients and num_classes > 1:
label_wise_dice_metrics = [get_label_dice_coefficient_function(index) for index in range(num_classes)]
if metrics:
metrics = metrics + label_wise_dice_metrics
else:
metrics = label_wise_dice_metrics
model.compile(optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=metrics)
return model