Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi_gpu API bug for CPU. Fix PEP. Fix bias_add #64

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 33 additions & 26 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import warnings
import mxnet as mx
import numpy as np
from subprocess import CalledProcessError
from numbers import Number
from functools import wraps
from collections import defaultdict

from .common import floatx, epsilon, set_image_data_format, image_data_format
from .common import floatx, epsilon, image_data_format

_UID_PREFIXES = defaultdict(int)
_LEARNING_PHASE = 1
Expand Down Expand Up @@ -195,6 +196,7 @@ def to_dense(tensor):
"""
raise NotImplementedError('MXNet Backend: Sparse operations are not supported yet.')


def variable(value, dtype=None, name=None, constraint=None):
"""Instantiates a variable and returns it.

Expand Down Expand Up @@ -2571,20 +2573,22 @@ def rnn(step_function, inputs, initial_states,
raise ValueError('MXNet Backend: Unrolling requires a fixed number of time-steps.')

if not unroll and dshape[1] is None:
raise NotImplementedError('MXNet Backend: unroll=False '
'is not supported yet in RNN.\n'
'MXNet Backend: Does not support Variable '
'Length input(Samples of different length). '
'Please pad your input to a constant length, '
'provide `input_shape` and set `unroll=True`'
'Ex: new_x_train = keras.preprocessing.sequence.pad_sequences(old_x_train, '
'maxlen=MAX_LEN_OF_INPUT_SAMPLE_TYPE_INT). '
'More Details - https://github.com/deep-learning-tools/keras/wiki/Limitations-and-workaround-of-RNN-layer-using-MXNet-backend')
raise NotImplementedError(
'MXNet Backend: unroll=False '
'is not supported yet in RNN.\n'
'MXNet Backend: Does not support Variable '
'Length input(Samples of different length). '
'Please pad your input to a constant length, '
'provide `input_shape` and set `unroll=True`'
'Ex: new_x_train = keras.preprocessing.sequence.pad_sequences(old_x_train, '
'maxlen=MAX_LEN_OF_INPUT_SAMPLE_TYPE_INT). '
'More Details - '
'https://github.com/awslabs/keras-apache-mxnet/wiki/Using-RNN-with-MXNet-backend')

if not unroll and dshape[1] is not None:
warnings.warn('MXNet Backend: `unroll=False` is not supported yet in RNN. Since the input_shape is known, '
'setting `unroll=True` and continuing the execution.'
'More Details - https://github.com/deep-learning-tools/keras/wiki/Limitations-and-workaround-of-RNN-layer-using-MXNet-backend',
'More Details - https://github.com/awslabs/keras-apache-mxnet/wiki/Using-RNN-with-MXNet-backend',
stacklevel=2)

# Split the inputs across time dimension and generate the list of inputs
Expand Down Expand Up @@ -3056,7 +3060,7 @@ def conv1d(x, kernel, strides=1, padding='valid',
kernel = expand_dims(kernel, axis=1)

output = _convnd(x, kernel, name='conv1d', strides=strides, filter_dilation=dilation_rate,
padding_mode=padding, data_format=data_format)
padding_mode=padding, data_format=data_format)

# Remove added extra dimension
# remove added dim
Expand Down Expand Up @@ -3241,7 +3245,7 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
"""
# MXNet only support Conv3D with GPU and CUDNN
gpus = mx.test_utils.list_gpus()
if gpus and len(gpus) > 0 :
if gpus and len(gpus) > 0:
if data_format is None:
data_format = image_data_format()
_validate_data_format(data_format)
Expand All @@ -3255,8 +3259,6 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
raise NotImplementedError('MXNet Backend: Conv3D Transpose is only supported on GPU with CUDNN')




@keras_mxnet_symbol
def pool2d(x, pool_size, strides=(1, 1),
padding='valid', data_format=None,
Expand Down Expand Up @@ -3331,7 +3333,6 @@ def bias_add(x, bias, data_format='channels_last'):
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('MXNet Backend: Unknown data_format ' + str(data_format))
bias_shape = int_shape(bias)
x_shape = int_shape(x)
x_dim = ndim(x)
if len(bias_shape) != 1 and len(bias_shape) != x_dim - 1:
raise ValueError('MXNet Backend: Unexpected bias dimensions %d, expect to be 1 or %d dimensions'
Expand Down Expand Up @@ -3674,10 +3675,10 @@ def bind(self, data):
else:
self.tensor = data
if self.name in self._bind_values:
assert (self._bind_values[self.name].shape == data.shape,
'Redefinition of variable %s' % self.name)
assert (self._bind_values[self.name].dtype == data.dtype,
'Redefinition of variable %s' % self.name)
assert self._bind_values[self.name].shape == data.shape, \
'Redefinition of variable %s' % self.name
assert self._bind_values[self.name].dtype == data.dtype, \
'Redefinition of variable %s' % self.name
if _MODEL is not None and self.name in _MODEL._args:
_MODEL._set_weights({self.name: data}, {})
if _MODEL is not None and self.name in _MODEL._auxs:
Expand Down Expand Up @@ -4196,6 +4197,7 @@ def _convnd(x, kernel, strides, filter_dilation, name=None, padding_mode='valid'
result = _postprocess_convnd_output(KerasSymbol(conv), data_format)
return result


@keras_mxnet_symbol
def _convnd_transpose(x, kernel, output_shape, strides, data_format, name=None):
# Handle Data Format
Expand Down Expand Up @@ -4224,6 +4226,7 @@ def _convnd_transpose(x, kernel, output_shape, strides, data_format, name=None):
result = _postprocess_convnd_output(KerasSymbol(deconv), data_format)
return result


# Pooling helpers
def _calculate_pool_output_size(input_length, filter_size, padding, stride,
dilation=1):
Expand Down Expand Up @@ -4266,7 +4269,7 @@ def _preprocess_pooling_padding_mode(padding_mode, input_shape, kernel, strides)
if padding_mode == 'same':
padding, is_slice, out_size = zip(
*[_calculate_pool_padding_requirement(input_shape[2 + i], kernel[i],
strides[i], padding_mode)
strides[i], padding_mode)
for i in range(nd)])
elif padding_mode == 'valid':
padding = (0,) * nd
Expand Down Expand Up @@ -4475,8 +4478,8 @@ def _adjust_module(self, inputs, phase):
# adjust module data shape
if inputs[0].shape[0] != self._module._curr_module._exec_group.batch_size:
self._module._curr_module.reshape(data_shapes, label_shapes)
assert (inputs[0].shape[0] == self._module._curr_module._exec_group.batch_size,
'Reshape failed')
assert inputs[0].shape[0] == self._module._curr_module._exec_group.batch_size, \
'Reshape failed'

return data, label, phase, data_shapes, label_shapes

Expand Down Expand Up @@ -4585,7 +4588,7 @@ def _create_predict_module(self):
trainable_weights = set([x.name for x in self.trainable_weights])
self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
self._weights_dirty = False

# set module for prediction only
Expand All @@ -4599,13 +4602,17 @@ def sym_gen(phase):
context=self._context,
fixed_param_names=self._fixed_weights)

def get_mxnet_context(self, context):
@staticmethod
def get_mxnet_context(context):
mxnet_context = []

if context is None:
# If user does not provide any context, if GPUs are detected, by default it runs on first available
# GPU device. If not GPUs are detected, then it falls back to CPU.
gpus = mx.test_utils.list_gpus()
try:
gpus = mx.test_utils.list_gpus()
except CalledProcessError:
gpus = []
if gpus and len(gpus) > 0:
mxnet_context.append(mx.gpu(gpus[0]))
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
num_samples, timesteps, embedding_dim, units = 2, 5, 4, 3
embedding_num = 12


@keras_test
def rnn_test(f):
"""
Expand Down Expand Up @@ -344,6 +345,7 @@ def test_specify_initial_state_non_keras_tensor(layer_class):
targets = np.random.random((num_samples, units))
model.fit(inputs, targets)


@rnn_test
def test_reset_states_with_values(layer_class):
num_states = 2 if layer_class is recurrent.LSTM else 1
Expand Down