/
tf_models.py
135 lines (88 loc) · 5.92 KB
/
tf_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tensorflow as tf
from tensorflow.contrib.keras.python.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, Conv3DTranspose
from tf_layers import *
def PlainCounterpart(input, name):
x = Conv3DWithBN(input, filters=24, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_1x')
x = Conv3DWithBN(x, filters=36, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_2x')
x = Conv3DWithBN(x, filters=48, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_3x')
x = Conv3DWithBN(x, filters=60, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_4x')
x = Conv3DWithBN(x, filters=72, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_5x')
x = Conv3DWithBN(x, filters=84, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_6x')
x = Conv3DWithBN(x, filters=96, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_7x')
out_15rf = x
x = Conv3DWithBN(x, filters=108, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_1x')
x = Conv3DWithBN(x, filters=120, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_2x')
x = Conv3DWithBN(x, filters=132, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_3x')
x = Conv3DWithBN(x, filters=144, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_4x')
x = Conv3DWithBN(x, filters=156, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_5x')
x = Conv3DWithBN(x, filters=168, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_6x')
out_27rf = x
return out_15rf, out_27rf
def BraTS2ScaleDenseNetConcat(input, name):
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1')
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name=name + '_out_15_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock2')
out_27rf = BatchNormalization(center=True, scale=True)(x)
out_27rf = Activation('relu')(out_27rf)
out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name=name + '_out_27_postconv')
return out_15rf, out_27rf
def BraTS2ScaleDenseNetConcat_large(input, name):
x = Conv3D(filters=48, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1')
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=192, ksize=1, strides=1, name=name + '_out_15_postconv')
x = DenseNetUnit3D(x, growth_rate=24, ksize=3, rep=6, name=name+'_denseblock2')
out_27rf = BatchNormalization(center=True, scale=True)(x)
out_27rf = Activation('relu')(out_27rf)
out_27rf = Conv3DWithBN(out_27rf, filters=336, ksize=1, strides=1, name=name + '_out_27_postconv')
return out_15rf, out_27rf
def BraTS2ScaleDenseNet(input, num_labels):
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6)
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name='out_15_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6)
out_27rf = BatchNormalization(center=True, scale=True)(x)
out_27rf = Activation('relu')(out_27rf)
out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name='out_27_postconv')
score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf)
score_27rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_27rf)
score = score_15rf[:, 13:25, 13:25, 13:25, :] + \
score_27rf[:, 13:25, 13:25, 13:25, :]
return score
def BraTS3ScaleDenseNet(input, num_labels):
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
out_13rf = BatchNormalization(center=True, scale=True)(x)
out_13rf = Activation('relu')(out_13rf)
out_13rf = Conv3DWithBN(out_13rf, filters=84, ksize=1, strides=1, name='out_13_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
out_23rf = BatchNormalization(center=True, scale=True)(x)
out_23rf = Activation('relu')(out_23rf)
out_23rf = Conv3DWithBN(out_23rf, filters=144, ksize=1, strides=1, name='out_23_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
out_33rf = BatchNormalization(center=True, scale=True)(x)
out_33rf = Activation('relu')(out_33rf)
out_33rf = Conv3DWithBN(out_33rf, filters=204, ksize=1, strides=1, name='out_33_postconv')
score_13rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_13rf)
score_23rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_23rf)
score_33rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_33rf)
score = score_13rf[:, 16:28, 16:28, 16:28, :] + \
score_23rf[:, 16:28, 16:28, 16:28, :] + \
score_33rf[:, 16:28, 16:28, 16:28, :]
return score
def BraTS1ScaleDenseNet(input, num_labels):
x = Conv3D(filters=36, kernel_size=5, strides=1, padding='same')(input)
x = DenseNetUnit3D(x, growth_rate=18, ksize=3, rep=6)
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv1')
out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv2')
score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf)
score = score_15rf[:, 8:20, 8:20, 8:20, :]
return score