Skip to content

Commit

Permalink
MXNet: support broadcasting deferred initialization parameters in Glu…
Browse files Browse the repository at this point in the history
…on (horovod#915)

* Create DistributedInitializer to broadcast deferred-init param

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* inject broadcast to init_impl

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* add unit test

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
Signed-off-by: Yana Shchyokotova <yana.shchyokotova@intel.com>
  • Loading branch information
yuxihu authored and shirosankaku committed May 30, 2019
1 parent 2b35e92 commit 3646e38
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
5 changes: 2 additions & 3 deletions examples/mxnet_mnist.py
Expand Up @@ -112,13 +112,12 @@ def evaluate(model, data_iter, context):
model.cast(args.dtype)
model.hybridize()

# Define hyper parameters
# Create optimizer
optimizer_params = {'momentum': args.momentum,
'learning_rate': args.lr * hvd.size(),
'rescale_grad': 1.0 / args.batch_size}

# Add Horovod Distributed Optimizer
opt = mx.optimizer.create('sgd', **optimizer_params)
# Horovod: wrap optimizer with DistributedOptimizer
opt = hvd.DistributedOptimizer(opt)

# Initialize parameters
Expand Down
17 changes: 15 additions & 2 deletions horovod/mxnet/__init__.py
Expand Up @@ -30,6 +30,7 @@
from horovod.mxnet.mpi_ops import mpi_threads_supported

import mxnet as mx
import types


# This is where Horovod's DistributedOptimizer wrapper for MXNet goes
Expand Down Expand Up @@ -68,6 +69,16 @@ def set_wd_mult(self, args_wd_mult):
self._optimizer.set_wd_mult(args_wd_mult)


# Wrapper to inject Horovod broadcast after parameter initialization
def _append_broadcast_init(param, root_rank):
init_impl = getattr(param, '_init_impl')
def wrapped_init_impl(self, *args, **kwargs):
init_impl(*args, **kwargs)
broadcast_(self.data(), root_rank=root_rank)
self.data().wait_to_read()
return wrapped_init_impl


def broadcast_parameters(params, root_rank=0):
"""
Broadcasts the parameters from root rank to all other processes.
Expand All @@ -89,8 +100,10 @@ def broadcast_parameters(params, root_rank=0):
try:
tensors.append(p.data())
except mx.gluon.parameter.DeferredInitializationError:
# skip broadcasting deferred init param
pass
# Inject wrapper method with post-initialization broadcast to
# handle parameters with deferred initialization
new_init = _append_broadcast_init(p, root_rank)
p._init_impl = types.MethodType(new_init, p)
else:
raise ValueError('invalid params of type: %s' % type(params))

Expand Down
28 changes: 27 additions & 1 deletion test/test_mxnet.py
Expand Up @@ -355,7 +355,6 @@ def test_horovod_broadcast_grad(self):
shapes = [(), (17), (17, 17), (17, 17, 17)]
root_rank = 1
tensor_dict = {}
broadcast_dict = {}
root_dict = {}
for dtype, dim, in itertools.product(dtypes, dims):
tensor_dict[count] = mx.nd.ones(shapes[dim], ctx=ctx) * rank
Expand Down Expand Up @@ -445,5 +444,32 @@ def test_horovod_broadcast_rank_error(self):
except (MXNetError, RuntimeError):
pass

def test_horovod_broadcast_deferred_init_parameters(self):
"""Test that the deferred initialized parameters are broadcasted."""
hvd.init()
root_rank = 0
rank = hvd.rank()

# This test does not apply if there is only one worker.
if hvd.size() == 1:
return

mx.random.seed(rank)
layer = mx.gluon.nn.Conv2D(10, 2)
layer.initialize()
hvd.broadcast_parameters(layer.collect_params(), root_rank=root_rank)

x = mx.nd.ones((5, 4, 10, 10))
layer(x)
tensors = [p.data() for _, p in sorted(layer.collect_params().items())]
root_tensors = []
for tensor in tensors:
root_tensors.append(hvd.broadcast(tensor, root_rank=root_rank))

for tensor, root_tensor in zip(tensors, root_tensors):
assert same(tensor.asnumpy(), root_tensor.asnumpy()), \
'horovod did not broadcast deferred initialized parameter correctly'


if __name__ == '__main__':
unittest.main()

0 comments on commit 3646e38

Please sign in to comment.