-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
68 lines (60 loc) · 3.34 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from utils import *
import tensorflow as tf
def SE_UResNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
Attention UNet,
'''
# network structure
FILTER_NUM = 32 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters
inputs = layers.Input(input_shape, dtype=tf.float32)
# Downsampling layers
# DownRes 1, convolution + pooling
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, 0.2, 1, batch_norm)
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
# DownRes 2
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, 0.2, 2, batch_norm)
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
# DownRes 3
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM,0.2, 3, batch_norm)
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
# DownRes 4
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM,0.2,4, batch_norm)
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
# DownRes 5, convolution only
conv_8 = conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, 0.2,5, batch_norm)
# W-net layers
attw_16 = se_block(conv_16, 8*FILTER_NUM)
upw_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
upw_16 = layers.concatenate([upw_16, attw_16], axis=3)
up_convw_16 = resb(upw_16, FILTER_SIZE, 8*FILTER_NUM, 0.2,6, batch_norm)
poolw_8 = layers.MaxPooling2D(pool_size=(2,2))(up_convw_16)
convw_16 = conv_block(poolw_8, FILTER_SIZE, 16*FILTER_NUM, 0.2,7, batch_norm)
# UpRes 6, attention gated concatenation + upsampling + double residual convolution
att_16 = se_block(up_convw_16, 8*FILTER_NUM)
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(convw_16)
up_16 = layers.concatenate([up_16, att_16], axis=3)
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, 0.2,8, batch_norm)
# UpRes 7
att_32 = se_block(conv_32, 4*FILTER_NUM)
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, att_32], axis=3)
up_conv_32 =conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate,9, batch_norm)
# UpRes 8
att_64 = se_block(conv_64, 2*FILTER_NUM)
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, att_64], axis=3)
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, 0.2,10, batch_norm)
# UpRes 9
att_128 = se_block(conv_128, FILTER_NUM)
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
up_128 = layers.concatenate([up_128, att_128], axis=3)
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM, 0.2,11, batch_norm)
# 1*1 convolutional layers
conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128)
conv_final = layers.BatchNormalization(axis=3)(conv_final)
conv_final = layers.Activation('sigmoid')(conv_final) #Change to softmax for multichannel
# Model integration
model = models.Model(inputs, conv_final, name="Attention_UWNet")
return model