Skip to content

Commit

Permalink
Conv3d() operator implementation for Keras2.0 using MXNet backend (#40)
Browse files Browse the repository at this point in the history
* conv3d implementation for keras2.0 as MXNet backend

* conv3d implementation/testing for keras2.0 using MXNet backend

* keeping -n option in pytest.ini file

* fixed comments given by Sandeep
  • Loading branch information
karan6181 authored and sandeep-krishnamurthy committed Feb 12, 2018
1 parent 84bfae4 commit 8325604
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
23 changes: 19 additions & 4 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,7 +3037,15 @@ def conv3d(x, kernel, strides=(1, 1, 1), padding='valid',
# Raises
ValueError: if `data_format` is neither `channels_last` or `channels_first`.
"""
raise NotImplementedError('MXNet Backend: conv3d operator is not supported yet.')
if data_format is None:
data_format = image_data_format()
_validate_data_format(data_format)

if padding not in {'same', 'valid'}:
raise ValueError('`padding` should be either `same` or `valid`.')

return _convnd(x, kernel, name='conv3d', strides=strides, filter_dilation=dilation_rate,
padding_mode=padding, data_format=data_format)


def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
Expand Down Expand Up @@ -3923,9 +3931,16 @@ def _postprocess_convnd_output(x, data_format):

@keras_mxnet_symbol
def _preprocess_convnd_kernel(kernel, data_format):
# Kernel is always provided in TF kernel shape: (rows, cols, input_depth, depth)
# Convert it to MXNet kernel shape: (depth, input_depth, rows, cols)
if len(kernel.shape) > 3:
# Kernel is always provided in TF kernel shape:
# 2-D: (rows, cols, input_depth, depth)
# 3-D: (kernel_depth, kernel_rows, kernel_cols, input_depth, depth)
# Convert it to MXNet kernel shape:
# 2-D: (depth, input_depth, rows, cols)
# 3-D: (depth, input_depth, kernel_depth, kernel_rows, kernel_cols)
#
if len(kernel.shape) > 4:
kernel = KerasSymbol(mx.sym.transpose(data=kernel.symbol, axes=(4, 3, 0, 1, 2)))
elif len(kernel.shape) > 3:
kernel = KerasSymbol(mx.sym.transpose(data=kernel.symbol, axes=(3, 2, 0, 1)))

return kernel
Expand Down
6 changes: 2 additions & 4 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,6 @@ def test_conv3d(self):
# TH kernel shape: (depth, input_depth, x, y, z)
# TF kernel shape: (x, y, z, input_depth, depth)

# MXNet backend does not support conv2d yet.

# test in data_format = channels_first
for input_shape in [(2, 3, 4, 5, 4), (2, 3, 5, 4, 6)]:
for kernel_shape in [(2, 2, 2, 3, 4), (3, 2, 4, 3, 4)]:
Expand All @@ -861,13 +859,13 @@ def test_conv3d(self):
input_shape = (1, 2, 2, 2, 1)
kernel_shape = (2, 2, 2, 1, 1)
check_two_tensor_operation('conv3d', input_shape, kernel_shape,
BACKENDS_WITHOUT_MXNET, cntk_dynamicity=True,
BACKENDS, cntk_dynamicity=True,
data_format='channels_last')

xval = np.random.random(input_shape)
kernel_val = np.random.random(kernel_shape) - 0.5
# Test invalid use cases
for k in BACKENDS_WITHOUT_MXNET:
for k in BACKENDS:
with pytest.raises(ValueError):
k.conv3d(k.variable(xval), k.variable(kernel_val), data_format='channels_middle')

Expand Down
8 changes: 2 additions & 6 deletions tests/keras/layers/convolutional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,6 @@ def test_averagepooling_2d():
input_shape=(3, 4, 5, 6))


@pytest.mark.skipif((K.backend() == 'mxnet'),
reason='MXNet backend does not support conv3d yet.')
@keras_test
def test_convolution_3d():
num_samples = 2
Expand All @@ -442,9 +440,8 @@ def test_convolution_3d():
'kernel_size': 3,
'padding': padding,
'strides': strides},
input_shape=(num_samples,
input_len_dim1, input_len_dim2, input_len_dim3,
stack_size))
input_shape=(num_samples, stack_size,
input_len_dim1, input_len_dim2, input_len_dim3))

layer_test(convolutional.Convolution3D,
kwargs={'filters': filters,
Expand All @@ -461,7 +458,6 @@ def test_convolution_3d():
input_len_dim1, input_len_dim2, input_len_dim3,
stack_size))


@pytest.mark.skipif((K.backend() == 'mxnet'),
reason='MXNet backend does not support conv3d_transpose yet.')
@keras_test
Expand Down

0 comments on commit 8325604

Please sign in to comment.