-
Notifications
You must be signed in to change notification settings - Fork 81
/
model.py
209 lines (176 loc) · 8.48 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import tensorflow as tf
from six.moves import cPickle
# Loading net skeleton with parameters name and shapes.
with open("./util/net_skeleton.ckpt", "rb") as f:
net_skeleton = cPickle.load(f)
# The DeepLab-LargeFOV model can be represented as follows:
## input -> [conv-relu](dilation=1, channels=64) x 2 -> [max_pool](stride=2)
## -> [conv-relu](dilation=1, channels=128) x 2 -> [max_pool](stride=2)
## -> [conv-relu](dilation=1, channels=256) x 3 -> [max_pool](stride=2)
## -> [conv-relu](dilation=1, channels=512) x 3 -> [max_pool](stride=1)
## -> [conv-relu](dilation=2, channels=512) x 3 -> [max_pool](stride=1) -> [avg_pool](stride=1)
## -> [conv-relu](dilation=12, channels=1024) -> [dropout]
## -> [conv-relu](dilation=1, channels=1024) -> [dropout]
## -> [conv-relu](dilation=1, channels=21) -> [pixel-wise softmax loss].
num_layers = [2, 2, 3, 3, 3, 1, 1, 1]
dilations = [[1, 1],
[1, 1],
[1, 1, 1],
[1, 1, 1],
[2, 2, 2],
[12],
[1],
[1]]
n_classes = 21
# All convolutional and pooling operations are applied using kernels of size 3x3;
# padding is added so that the output of the same size as the input.
ks = 3
def create_variable(name, shape):
"""Create a convolution filter variable of the given name and shape,
and initialise it using Xavier initialisation
(http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf).
"""
initialiser = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
variable = tf.Variable(initialiser(shape=shape), name=name)
return variable
def create_bias_variable(name, shape):
"""Create a bias variable of the given name and shape,
and initialise it to zero.
"""
initialiser = tf.constant_initializer(value=0.0, dtype=tf.float32)
variable = tf.Variable(initialiser(shape=shape), name=name)
return variable
class DeepLabLFOVModel(object):
"""DeepLab-LargeFOV model with atrous convolution and bilinear upsampling.
This class implements a multi-layer convolutional neural network for semantic image segmentation task.
This is the same as the model described in this paper: https://arxiv.org/abs/1412.7062 - please look
there for details.
"""
def __init__(self, weights_path=None):
"""Create the model.
Args:
weights_path: the path to the cpkt file with dictionary of weights from .caffemodel.
"""
self.variables = self._create_variables(weights_path)
def _create_variables(self, weights_path):
"""Create all variables used by the network.
This allows to share them between multiple calls
to the loss function.
Args:
weights_path: the path to the ckpt file with dictionary of weights from .caffemodel.
If none, initialise all variables randomly.
Returns:
A dictionary with all variables.
"""
var = list()
index = 0
if weights_path is not None:
with open(weights_path, "rb") as f:
weights = cPickle.load(f) # Load pre-trained weights.
for name, shape in net_skeleton:
var.append(tf.Variable(weights[name],
name=name))
del weights
else:
# Initialise all weights randomly with the Xavier scheme,
# and
# all biases to 0's.
for name, shape in net_skeleton:
if "/w" in name: # Weight filter.
w = create_variable(name, list(shape))
var.append(w)
else:
b = create_bias_variable(name, list(shape))
var.append(b)
return var
def _create_network(self, input_batch, keep_prob):
"""Construct DeepLab-LargeFOV network.
Args:
input_batch: batch of pre-processed images.
keep_prob: probability of keeping neurons intact.
Returns:
A downsampled segmentation mask.
"""
current = input_batch
v_idx = 0 # Index variable.
# Last block is the classification layer.
for b_idx in xrange(len(dilations) - 1):
for l_idx, dilation in enumerate(dilations[b_idx]):
w = self.variables[v_idx * 2]
b = self.variables[v_idx * 2 + 1]
if dilation == 1:
conv = tf.nn.conv2d(current, w, strides=[1, 1, 1, 1], padding='SAME')
else:
conv = tf.nn.atrous_conv2d(current, w, dilation, padding='SAME')
current = tf.nn.relu(tf.nn.bias_add(conv, b))
v_idx += 1
# Optional pooling and dropout after each block.
if b_idx < 3:
current = tf.nn.max_pool(current,
ksize=[1, ks, ks, 1],
strides=[1, 2, 2, 1],
padding='SAME')
elif b_idx == 3:
current = tf.nn.max_pool(current,
ksize=[1, ks, ks, 1],
strides=[1, 1, 1, 1],
padding='SAME')
elif b_idx == 4:
current = tf.nn.max_pool(current,
ksize=[1, ks, ks, 1],
strides=[1, 1, 1, 1],
padding='SAME')
current = tf.nn.avg_pool(current,
ksize=[1, ks, ks, 1],
strides=[1, 1, 1, 1],
padding='SAME')
elif b_idx <= 6:
current = tf.nn.dropout(current, keep_prob=keep_prob)
# Classification layer; no ReLU.
w = self.variables[v_idx * 2]
b = self.variables[v_idx * 2 + 1]
conv = tf.nn.conv2d(current, w, strides=[1, 1, 1, 1], padding='SAME')
current = tf.nn.bias_add(conv, b)
return current
def prepare_label(self, input_batch, new_size):
"""Resize masks and perform one-hot encoding.
Args:
input_batch: input tensor of shape [batch_size H W 1].
new_size: a tensor with new height and width.
Returns:
Outputs a tensor of shape [batch_size h w 21]
with last dimension comprised of 0's and 1's only.
"""
with tf.name_scope('label_encode'):
input_batch = tf.image.resize_nearest_neighbor(input_batch, new_size) # As labels are integer numbers, need to use NN interp.
input_batch = tf.squeeze(input_batch, squeeze_dims=[3]) # Reducing the channel dimension.
input_batch = tf.one_hot(input_batch, depth=21)
return input_batch
def preds(self, input_batch):
"""Create the network and run inference on the input batch.
Args:
input_batch: batch of pre-processed images.
Returns:
Argmax over the predictions of the network of the same shape as the input.
"""
raw_output = self._create_network(tf.cast(input_batch, tf.float32), keep_prob=tf.constant(1.0))
raw_output = tf.image.resize_bilinear(raw_output, tf.shape(input_batch)[1:3,])
raw_output = tf.argmax(raw_output, dimension=3)
raw_output = tf.expand_dims(raw_output, dim=3) # Create 4D-tensor.
return tf.cast(raw_output, tf.uint8)
def loss(self, img_batch, label_batch):
"""Create the network, run inference on the input batch and compute loss.
Args:
input_batch: batch of pre-processed images.
Returns:
Pixel-wise softmax loss.
"""
raw_output = self._create_network(tf.cast(img_batch, tf.float32), keep_prob=tf.constant(0.5))
prediction = tf.reshape(raw_output, [-1, n_classes])
# Need to resize labels and convert using one-hot encoding.
label_batch = self.prepare_label(label_batch, tf.pack(raw_output.get_shape()[1:3]))
gt = tf.reshape(label_batch, [-1, n_classes])
# Pixel-wise softmax loss.
loss = tf.nn.softmax_cross_entropy_with_logits(prediction, gt)
reduced_loss = tf.reduce_mean(loss)
return reduced_loss