Skip to content

Commit

Permalink
Add ConvGroup; use keyword params eerywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Tsimfer authored and Sergey Tsimfer committed Jan 28, 2020
1 parent 8ca3922 commit ec151fd
Show file tree
Hide file tree
Showing 16 changed files with 235 additions and 141 deletions.
6 changes: 3 additions & 3 deletions batchflow/models/tf/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def encoder(cls, inputs, name='encoder', **kwargs):

for letter in order:
if letter == 'b':
x = base_block(x, name='block', **args)
x = conv_block(x, base_block=base_block, name='block', **args)
elif letter == 's':
encoder_outputs.append(x)
elif letter in ['d', 'p']:
Expand Down Expand Up @@ -318,7 +318,7 @@ def embedding(cls, inputs, name='embedding', **kwargs):
tf.Tensor
"""
base_block = kwargs.get('base', cls.block)
return base_block(inputs, name=name, **kwargs)
return conv_block(inputs, base_block=base_block, name=name, **kwargs)

@classmethod
def decoder(cls, inputs, name='decoder', return_all=False, **kwargs):
Expand Down Expand Up @@ -419,7 +419,7 @@ def decoder(cls, inputs, name='decoder', return_all=False, **kwargs):

for letter in order:
if letter == 'b':
x = base_block(x, name='block', **args)
x = conv_block(x, base_block=base_block, name='block', **args)
elif letter in ['u']:
if upsample.get('layout') is not None:
x = cls.upsample(x, name='upsample', **upsample_args)
Expand Down
4 changes: 2 additions & 2 deletions batchflow/models/tf/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def body(cls, inputs, num_classes, name='body', **kwargs):

x = cls.upsample(x, factor=2, filters=num_classes, name='fcn32_upsample', **upsample_args, **kwargs)

skip = conv_block(skip, 'c', filters=num_classes, kernel_size=1, name='pool4', **kwargs)
skip = conv_block(skip, layout='c', filters=num_classes, kernel_size=1, name='pool4', **kwargs)
x = cls.crop(x, skip, kwargs.get('data_format'))
output = tf.add(x, skip, name='output')
return output
Expand Down Expand Up @@ -287,7 +287,7 @@ def body(cls, inputs, num_classes, name='body', **kwargs):
x = FCN16.body((x, skip1), filters=filters, num_classes=num_classes, name='fcn16', **kwargs)
x = cls.upsample(x, factor=2, filters=num_classes, name='fcn16_upsample', **upsample_args, **kwargs)

skip2 = conv_block(skip2, 'c', num_classes, 1, name='pool3', **kwargs)
skip2 = conv_block(skip2, layout='c', filters=num_classes, kernel_size=1, name='pool3', **kwargs)

x = cls.crop(x, skip2, kwargs.get('data_format'))
output = tf.add(x, skip2, name='output')
Expand Down
10 changes: 6 additions & 4 deletions batchflow/models/tf/inception_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ def reduction_block(cls, inputs, filters, layout='cna', name='reduction_block',
tf.Tensor
"""
with tf.variable_scope(name):
branch_3 = conv_block(inputs, layout, filters[3], 3, name='conv_3', strides=2, padding='valid', **kwargs)
branch_1_3 = conv_block(inputs, layout*2, [filters[0]]+[filters[1]], [1, 3], name='conv_1_3', **kwargs)
branch_1_3_3 = conv_block(branch_1_3, layout, filters[2], 3, name='conv_1_3_3', strides=2,
padding='valid', **kwargs)
branch_3 = conv_block(inputs, layout=layout, filters=filters[3], kernel_size=3, strides=2, padding='valid',
name='conv_3', **kwargs)
branch_1_3 = conv_block(inputs, layout=layout*2, filters=[filters[0]]+[filters[1]], kernel_size=[1, 3],
name='conv_1_3', **kwargs)
branch_1_3_3 = conv_block(branch_1_3, layout=layout, filters=filters[2], kernel_size=3, strides=2,
padding='valid', name='conv_1_3_3', **kwargs)

branch_pool = conv_block(inputs, layout='p', pool_size=3, pool_strides=2, name='max_pooling',
padding='valid', **kwargs)
Expand Down
78 changes: 40 additions & 38 deletions batchflow/models/tf/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,28 @@ def initial_block(cls, inputs, name='initial_block', **kwargs):
layout, filters = cls.pop(['layout', 'filters'], kwargs)
axis = cls.channels_axis(kwargs['data_format'])

x = conv_block(inputs, layout*2, filters[0]*2, 3, name='conv_3_3', padding='valid',
strides=[2, 1], **kwargs)
x = conv_block(x, layout, filters[1], 3, name='conv_3_3_3', **kwargs)
x = conv_block(inputs, layout=layout*2, filters=filters[0]*2, kernel_size=3, strides=[2, 1],
padding='valid', name='conv_3_3', **kwargs)
x = conv_block(x, layout=layout, filters=filters[1], kernel_size=3,
name='conv_3_3_3', **kwargs)

branch_3 = conv_block(x, layout, filters[2], 3, name='conv_3', strides=2, padding='valid', **kwargs)
branch_3 = conv_block(x, layout=layout, filters=filters[2], kernel_size=3, strides=2,
padding='valid', name='conv_3', **kwargs)
branch_pool = conv_block(x, layout='p', name='max_pool', padding='valid', **kwargs)
x = tf.concat([branch_3, branch_pool], axis, name='concat_3_and_pool')

branch_1 = conv_block(x, layout, filters[1], 1, name='conv_1', **kwargs)
branch_1_3 = conv_block(branch_1, layout, filters[2], 3, name='conv_1_3', padding='valid', **kwargs)
branch_1 = conv_block(x, layout=layout, filters=filters[1], kernel_size=1, name='conv_1', **kwargs)
branch_1_3 = conv_block(branch_1, layout=layout, filters=filters[2], kernel_size=3, padding='valid',
name='conv_1_3', **kwargs)

branch_1_7 = conv_block(x, layout*3, [filters[1]]*3, [1, [7, 1], [1, 7]], name='conv_1_7', **kwargs)
branch_1_7_3 = conv_block(branch_1_7, layout, filters[2], 3, name='conv_1_7_3', padding='valid', **kwargs)
branch_1_7 = conv_block(x, layout=layout*3, filters=[filters[1]]*3, kernel_size=[1, [7, 1], [1, 7]],
name='conv_1_7', **kwargs)
branch_1_7_3 = conv_block(branch_1_7, layout=layout, filters=filters[2], kernel_size=3,
padding='valid', name='conv_1_7_3', **kwargs)
x = tf.concat([branch_1_3, branch_1_7_3], axis, name='concat_1_3_and_1_7_3')

branch_out_3 = conv_block(x, layout, filters[3], 3, name='conv_out_3', strides=2,
padding='valid', **kwargs)
branch_out_3 = conv_block(x, layout=layout, filters=filters[3], kernel_size=3, strides=2,
padding='valid', name='conv_out_3', **kwargs)
branch_out_pool = conv_block(x, layout='p', name='out_max_pooling', padding='valid', **kwargs)

output = tf.concat([branch_out_3, branch_out_pool], axis, name='output')
Expand Down Expand Up @@ -149,18 +154,18 @@ def block_a(cls, inputs, filters, layout='cna', name='block_a', **kwargs):
with tf.variable_scope(name):
x = tf.nn.relu(inputs)

branch_1 = conv_block(x, layout, filters[0], 1, name='conv_1', **kwargs)
branch_2 = conv_block(x, layout*2, [filters[1], filters[2]], [1, 3], name='conv_2', **kwargs)
branch_3 = conv_block(x, layout*3, [filters[3], filters[4], filters[5]], [1, 3, 3], name='conv_3', **kwargs)
branch_1 = conv_block(x, layout=layout, filters=filters[0], kernel_size=1, name='conv_1', **kwargs)
branch_2 = conv_block(x, layout=layout*2, filters=[filters[1], filters[2]], kernel_size=[1, 3],
name='conv_2', **kwargs)
branch_3 = conv_block(x, layout=layout*3, filters=[filters[3], filters[4], filters[5]],
kernel_size=[1, 3, 3], name='conv_3', **kwargs)

axis = cls.channels_axis(kwargs['data_format'])
branch_1 = tf.concat([branch_1, branch_2, branch_3], axis=axis)
branch_1 = conv_block(branch_1, 'c', filters[6], 1, name='conv_1x1', **kwargs)
branch_1 = conv_block(branch_1, layout='c', filters=filters[6], kernel_size=1, name='conv_1x1', **kwargs)

x = x + branch_1

x = tf.nn.relu(x)

return x

@classmethod
Expand All @@ -185,18 +190,16 @@ def block_b(cls, inputs, filters, layout='cna', name='block_b', **kwargs):
with tf.variable_scope(name):
x = tf.nn.relu(inputs)

branch_1 = conv_block(x, layout, filters[0], 1, name='conv_1', **kwargs)
branch_2 = conv_block(x, layout*3, [filters[1], filters[2], filters[3]], [1, (1, 7), (7, 1)],
name='conv_2', **kwargs)
branch_1 = conv_block(x, layout=layout, filters=filters[0], kernel_size=1, name='conv_1', **kwargs)
branch_2 = conv_block(x, layout=layout*3, filters=[filters[1], filters[2], filters[3]],
kernel_size=[1, (1, 7), (7, 1)], name='conv_2', **kwargs)

axis = cls.channels_axis(kwargs['data_format'])
branch_1 = tf.concat([branch_1, branch_2], axis=axis)
branch_1 = conv_block(branch_1, 'c', filters[4], 1, name='conv_1x1', **kwargs)
branch_1 = conv_block(branch_1, layout='c', filters=filters[4], kernel_size=1, name='conv_1x1', **kwargs)

x = x + branch_1

x = tf.nn.relu(x)

return x

@classmethod
Expand All @@ -221,18 +224,16 @@ def block_c(cls, inputs, filters, layout='cna', name='block_c', **kwargs):
with tf.variable_scope(name):
x = tf.nn.relu(inputs)

branch_1 = conv_block(x, layout, filters[0], 1, name='conv_1', **kwargs)
branch_2 = conv_block(x, layout*3, [filters[1], filters[2], filters[3]], [1, (1, 3), (3, 1)],
name='conv_2', **kwargs)
branch_1 = conv_block(x, layout=layout, filters=filters[0], kernel_size=1, name='conv_1', **kwargs)
branch_2 = conv_block(x, layout=layout*3, filters=[filters[1], filters[2], filters[3]],
kernel_size=[1, (1, 3), (3, 1)], name='conv_2', **kwargs)

axis = cls.channels_axis(kwargs['data_format'])
branch_1 = tf.concat([branch_1, branch_2], axis=axis)
branch_1 = conv_block(branch_1, 'c', filters[4], 1, name='conv_1x1', **kwargs)
branch_1 = conv_block(branch_1, layout='c', filters=filters[4], kernel_size=1, name='conv_1x1', **kwargs)

x = x + branch_1

x = tf.nn.relu(x)

return x

@classmethod
Expand All @@ -257,10 +258,11 @@ def reduction_a(cls, inputs, filters, layout='cna', name='reduction_a', **kwargs
with tf.variable_scope(name):
x = tf.nn.relu(inputs)

branch_1 = conv_block(x, 'p', pool_strides=2, name='max-pool', **kwargs)
branch_2 = conv_block(x, layout, filters[0], 3, strides=2, name='conv_2', **kwargs)
branch_3 = conv_block(x, layout*3, [filters[1], filters[2], filters[3]], [1, 3, 3],
strides=[1, 1, 2], name='conv_3', **kwargs)
branch_1 = conv_block(x, layout='p', pool_strides=2, name='max-pool', **kwargs)
branch_2 = conv_block(x, layout=layout, filters=filters[0], kernel_size=3, strides=2,
name='conv_2', **kwargs)
branch_3 = conv_block(x, layout=layout*3, filters=[filters[1], filters[2], filters[3]],
kernel_size=[1, 3, 3], strides=[1, 1, 2], name='conv_3', **kwargs)

axis = cls.channels_axis(kwargs['data_format'])
x = tf.concat([branch_1, branch_2, branch_3], axis=axis)
Expand Down Expand Up @@ -289,12 +291,12 @@ def reduction_b(cls, inputs, filters, layout='cna', name='reduction_a', **kwargs
with tf.variable_scope(name):
x = inputs
branch_1 = conv_block(x, 'p', pool_size=3, pool_strides=2, name='max-pool', **kwargs)
branch_2 = conv_block(x, layout*2, [filters[0], filters[1]], [1, 3], strides=[1, 2],
name='conv_2', **kwargs)
branch_3 = conv_block(x, layout*2, [filters[2], filters[3]], [1, 3], strides=[1, 2],
name='conv_3', **kwargs)
branch_4 = conv_block(x, layout*3, [filters[4], filters[5], filters[6]], [1, 3, 3], strides=[1, 1, 2],
name='conv_4', **kwargs)
branch_2 = conv_block(x, layout=layout*2, filters=[filters[0], filters[1]], kernel_size=[1, 3],
strides=[1, 2], name='conv_2', **kwargs)
branch_3 = conv_block(x, layout=layout*2, filters=[filters[2], filters[3]], kernel_size=[1, 3],
strides=[1, 2], name='conv_3', **kwargs)
branch_4 = conv_block(x, layout=layout*3, filters=[filters[4], filters[5], filters[6]],
kernel_size=[1, 3, 3], strides=[1, 1, 2], name='conv_4', **kwargs)

axis = cls.channels_axis(kwargs['data_format'])
x = tf.concat([branch_1, branch_2, branch_3, branch_4], axis)
Expand Down
12 changes: 7 additions & 5 deletions batchflow/models/tf/inception_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ def block(cls, inputs, filters, layout='cn', name=None, **kwargs):
tf.Tensor
"""
with tf.variable_scope(name):
branch_1 = conv_block(inputs, layout, filters[0], 1, name='conv_1', **kwargs)
branch_1 = conv_block(inputs, layout=layout, filters=filters[0], kernel_size=1, name='conv_1', **kwargs)

branch_3 = conv_block(inputs, layout*2, [filters[1], filters[2]], [1, 3], name='conv_3', **kwargs)
branch_3 = conv_block(inputs, layout=layout*2, filters=[filters[1], filters[2]], kernel_size=[1, 3],
name='conv_3', **kwargs)

branch_5 = conv_block(inputs, layout*2, [filters[3], filters[4]], [1, 5], name='conv_5', **kwargs)
branch_5 = conv_block(inputs, layout=layout*2, filters=[filters[3], filters[4]], kernel_size=[1, 5],
name='conv_5', **kwargs)

branch_pool = conv_block(inputs, 'p'+layout, filters[5], 1,
branch_pool = conv_block(inputs, layout='p'+layout, filters=filters[5], kernel_size=1,
name='conv_pool', **{**kwargs, 'pool_strides': 1})

axis = cls.channels_axis(kwargs['data_format'])
Expand All @@ -102,5 +104,5 @@ def reduction_block(cls, inputs, layout='p', filters=None, name='reduction_block
-------
tf.Tensor
"""
output = conv_block(inputs, layout, filters=filters, name=name, **kwargs)
output = conv_block(inputs, layout=layout, filters=filters, name=name, **kwargs)
return output

0 comments on commit ec151fd

Please sign in to comment.