Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to TF v2 #7

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions nets/classification_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ class ClassificationNet(keras.Model):
def __init__(self, num_class, **kwargs):
super().__init__(self, **kwargs)
# classification net
self.conv1 = DeformableConvLayer(32, [5, 5], num_deformable_group=1, activation='relu') # out 24
self.conv1 = DeformableConvLayer(32, [5, 5],
num_deformable_group=1,
activation='relu') # out 24
# self.conv1 = Conv2D(32, [5, 5], activation='relu')
self.conv2 = Conv2D(32, [5, 5], activation='relu') # out 20
self.max_pool1 = MaxPool2D(2, [2, 2]) # out 10
Expand All @@ -34,12 +36,15 @@ def call(self, inputs, training=None, mask=None):
def train(self, optimizer, x, y):
with tf.GradientTape() as tape:
logits = self.__call__(x)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)
loss = tf.nn.softmax_cross_entropy_with_logits(labels=y,
logits=logits)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, self.variables)
optimizer.apply_gradients(zip(grads, self.variables))
return loss, tf.nn.softmax(logits)

def accuracy(self, prediction, y):
eq = tf.to_float(tf.equal(tf.argmax(prediction, axis=-1), tf.argmax(y, axis=-1)))
eq = tf.cast(
tf.equal(tf.argmax(prediction, axis=-1), tf.argmax(y, axis=-1)),
tf.float32)
return tf.reduce_mean(eq)
160 changes: 91 additions & 69 deletions nets/deformable_conv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,68 +33,68 @@ def __init__(self,
:param num_deformable_group: split output channels into groups, offset shared in each group. If
this parameter is None, then set num_deformable_group=filters.
"""
super().__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
super().__init__(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.kernel = None
self.bias = None
self.offset_layer_kernel = None
self.offset_layer_bias = None
if num_deformable_group is None:
num_deformable_group = filters
if filters % num_deformable_group != 0:
raise ValueError('"filters" mod "num_deformable_group" must be zero')
raise ValueError(
'"filters" mod "num_deformable_group" must be zero')
self.num_deformable_group = num_deformable_group

def build(self, input_shape):
input_dim = int(input_shape[-1])
# kernel_shape = self.kernel_size + (input_dim, self.filters)
# we want to use depth-wise conv
kernel_shape = self.kernel_size + (self.filters * input_dim, 1)
self.kernel = self.add_weight(
name='kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype)
self.kernel = self.add_weight(name='kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype)
if self.use_bias:
self.bias = self.add_weight(
name='bias',
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype)
self.bias = self.add_weight(name='bias',
shape=(self.filters, ),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype)

# create offset conv layer
offset_num = self.kernel_size[0] * self.kernel_size[1] * self.num_deformable_group
offset_num = self.kernel_size[0] * self.kernel_size[
1] * self.num_deformable_group
self.offset_layer_kernel = self.add_weight(
name='offset_layer_kernel',
shape=self.kernel_size + (input_dim, offset_num * 2), # 2 means x and y axis
shape=self.kernel_size +
(input_dim, offset_num * 2), # 2 means x and y axis
initializer=tf.zeros_initializer(),
regularizer=self.kernel_regularizer,
trainable=True,
dtype=self.dtype)
self.offset_layer_bias = self.add_weight(
name='offset_layer_bias',
shape=(offset_num * 2,),
shape=(offset_num * 2, ),
initializer=tf.zeros_initializer(),
# initializer=tf.random_uniform_initializer(-5, 5),
regularizer=self.bias_regularizer,
Expand All @@ -105,7 +105,7 @@ def build(self, input_shape):
def call(self, inputs, training=None, **kwargs):
# get offset, shape [batch_size, out_h, out_w, filter_h, * filter_w * channel_out * 2]
offset = tf.nn.conv2d(inputs,
filter=self.offset_layer_kernel,
filters=self.offset_layer_kernel,
strides=[1, *self.strides, 1],
padding=self.padding.upper(),
dilations=[1, *self.dilation_rate, 1])
Expand All @@ -117,8 +117,10 @@ def call(self, inputs, training=None, **kwargs):
# some length
batch_size = int(inputs.get_shape()[0])
channel_in = int(inputs.get_shape()[-1])
in_h, in_w = [int(i) for i in inputs.get_shape()[1: 3]] # input feature map size
out_h, out_w = [int(i) for i in offset.get_shape()[1: 3]] # output feature map size
in_h, in_w = [int(i) for i in inputs.get_shape()[1:3]
] # input feature map size
out_h, out_w = [int(i) for i in offset.get_shape()[1:3]
] # output feature map size
filter_h, filter_w = self.kernel_size

# get x, y axis offset
Expand All @@ -128,28 +130,34 @@ def call(self, inputs, training=None, **kwargs):
# input feature map gird coordinates
y, x = self._get_conv_indices([in_h, in_w])
y, x = [tf.expand_dims(i, axis=-1) for i in [y, x]]
y, x = [tf.tile(i, [batch_size, 1, 1, 1, self.num_deformable_group]) for i in [y, x]]
y, x = [tf.reshape(i, [*i.shape[0: 3], -1]) for i in [y, x]]
y, x = [tf.to_float(i) for i in [y, x]]
y, x = [
tf.tile(i, [batch_size, 1, 1, 1, self.num_deformable_group])
for i in [y, x]
]
y, x = [tf.reshape(i, [*i.shape[0:3], -1]) for i in [y, x]]
y, x = [tf.cast(i, tf.float32) for i in [y, x]]

# add offset
y, x = y + y_off, x + x_off
y = tf.clip_by_value(y, 0, in_h - 1)
x = tf.clip_by_value(x, 0, in_w - 1)

# get four coordinates of points around (x, y)
y0, x0 = [tf.to_int32(tf.floor(i)) for i in [y, x]]
y0, x0 = [tf.cast(tf.floor(i), tf.int32) for i in [y, x]]
y1, x1 = y0 + 1, x0 + 1
# clip
y0, y1 = [tf.clip_by_value(i, 0, in_h - 1) for i in [y0, y1]]
x0, x1 = [tf.clip_by_value(i, 0, in_w - 1) for i in [x0, x1]]

# get pixel values
indices = [[y0, x0], [y0, x1], [y1, x0], [y1, x1]]
p0, p1, p2, p3 = [DeformableConvLayer._get_pixel_values_at_point(inputs, i) for i in indices]
p0, p1, p2, p3 = [
DeformableConvLayer._get_pixel_values_at_point(inputs, i)
for i in indices
]

# cast to float
x0, x1, y0, y1 = [tf.to_float(i) for i in [x0, x1, y0, y1]]
x0, x1, y0, y1 = [tf.cast(i, tf.float32) for i in [x0, x1, y0, y1]]
# weights
w0 = (y1 - y) * (x1 - x)
w1 = (y1 - y) * (x - x0)
Expand All @@ -161,19 +169,28 @@ def call(self, inputs, training=None, **kwargs):
pixels = tf.add_n([w0 * p0, w1 * p1, w2 * p2, w3 * p3])

# reshape the "big" feature map
pixels = tf.reshape(pixels, [batch_size, out_h, out_w, filter_h, filter_w, self.num_deformable_group, channel_in])
pixels = tf.reshape(pixels, [
batch_size, out_h, out_w, filter_h, filter_w,
self.num_deformable_group, channel_in
])
pixels = tf.transpose(pixels, [0, 1, 3, 2, 4, 5, 6])
pixels = tf.reshape(pixels, [batch_size, out_h * filter_h, out_w * filter_w, self.num_deformable_group, channel_in])
pixels = tf.reshape(pixels, [
batch_size, out_h * filter_h, out_w * filter_w,
self.num_deformable_group, channel_in
])

# copy channels to same group
feat_in_group = self.filters // self.num_deformable_group
pixels = tf.tile(pixels, [1, 1, 1, 1, feat_in_group])
pixels = tf.reshape(pixels, [batch_size, out_h * filter_h, out_w * filter_w, -1])
pixels = tf.reshape(
pixels, [batch_size, out_h * filter_h, out_w * filter_w, -1])

# depth-wise conv
out = tf.nn.depthwise_conv2d(pixels, self.kernel, [1, filter_h, filter_w, 1], 'VALID')
out = tf.nn.depthwise_conv2d(pixels, self.kernel,
[1, filter_h, filter_w, 1], 'VALID')
# add the output feature maps in the same group
out = tf.reshape(out, [batch_size, out_h, out_w, self.filters, channel_in])
out = tf.reshape(out,
[batch_size, out_h, out_w, self.filters, channel_in])
out = tf.reduce_sum(out, axis=-1)
if self.use_bias:
out += self.bias
Expand All @@ -188,25 +205,30 @@ def _pad_input(self, inputs):
# When padding is 'same', we should pad the feature map.
# if padding == 'same', output size should be `ceil(input / stride)`
if self.padding == 'same':
in_shape = inputs.get_shape().as_list()[1: 3]
in_shape = inputs.get_shape().as_list()[1:3]
padding_list = []
for i in range(2):
filter_size = self.kernel_size[i]
dilation = self.dilation_rate[i]
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
same_output = (in_shape[i] + self.strides[i] - 1) // self.strides[i]
valid_output = (in_shape[i] - dilated_filter_size + self.strides[i]) // self.strides[i]
dilated_filter_size = filter_size + (filter_size -
1) * (dilation - 1)
same_output = (in_shape[i] + self.strides[i] -
1) // self.strides[i]
valid_output = (in_shape[i] - dilated_filter_size +
self.strides[i]) // self.strides[i]
if same_output == valid_output:
padding_list += [0, 0]
else:
p = dilated_filter_size - 1
p_0 = p // 2
padding_list += [p_0, p - p_0]
if sum(padding_list) != 0:
padding = [[0, 0],
[padding_list[0], padding_list[1]], # top, bottom padding
[padding_list[2], padding_list[3]], # left, right padding
[0, 0]]
padding = [
[0, 0],
[padding_list[0], padding_list[1]], # top, bottom padding
[padding_list[2], padding_list[3]], # left, right padding
[0, 0]
]
inputs = tf.pad(inputs, padding)
return inputs

Expand All @@ -216,16 +238,17 @@ def _get_conv_indices(self, feature_map_size):
:param feature_map_size:
:return: y, x with shape [1, out_h, out_w, filter_h * filter_w]
"""
feat_h, feat_w = [int(i) for i in feature_map_size[0: 2]]
feat_h, feat_w = [int(i) for i in feature_map_size[0:2]]

x, y = tf.meshgrid(tf.range(feat_w), tf.range(feat_h))
x, y = [tf.reshape(i, [1, *i.get_shape(), 1]) for i in [x, y]] # shape [1, h, w, 1]
x, y = [tf.image.extract_image_patches(i,
[1, *self.kernel_size, 1],
[1, *self.strides, 1],
[1, *self.dilation_rate, 1],
'VALID')
for i in [x, y]] # shape [1, out_h, out_w, filter_h * filter_w]
x, y = [tf.reshape(i, [1, *i.get_shape(), 1])
for i in [x, y]] # shape [1, h, w, 1]
x, y = [
tf.image.extract_patches(i, [1, *self.kernel_size, 1],
[1, *self.strides, 1],
[1, *self.dilation_rate, 1], 'VALID')
for i in [x, y]
] # shape [1, out_h, out_w, filter_h * filter_w]
return y, x

@staticmethod
Expand All @@ -237,10 +260,9 @@ def _get_pixel_values_at_point(inputs, indices):
:return:
"""
y, x = indices
batch, h, w, n = y.get_shape().as_list()[0: 4]
batch, h, w, n = y.get_shape().as_list()[0:4]

batch_idx = tf.reshape(tf.range(0, batch), (batch, 1, 1, 1))
b = tf.tile(batch_idx, (1, h, w, n))
pixel_idx = tf.stack([b, y, x], axis=-1)
return tf.gather_nd(inputs, pixel_idx)

Loading