diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index c4d49e82c908..028e66075100 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -194,18 +194,14 @@ def _init_kvstore(self): if config['update_on_kvstore'] is not None: update_on_kvstore = config['update_on_kvstore'] - if kvstore: if self._compression_params: kvstore.set_gradient_compression(self._compression_params) self._distributed = 'dist' in kvstore.type if self._distributed: # kv.pull(row_sparse_grad) is not supported for dist kvstore - # Captures condition for dist_async, dist_device_sync or based on config for - # update_on_kvstore update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \ - or 'device' in kvstore.type or 'async' in kvstore.type \ - or config['update_on_kvstore'] + or 'async' in kvstore.type if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision kvstore.set_optimizer(self._optimizer) @@ -273,20 +269,13 @@ def step(self, batch_size, ignore_stale_grad=False): If true, ignores Parameters with stale gradient (gradient that has not been updated by `backward` after last step) and skip update. """ - rescale_grad = self._scale / batch_size - if self._update_on_kvstore and self._distributed and \ - self._optimizer.rescale_grad != rescale_grad: - raise UserWarning('Possible change in the `batch_size` from previous `step` detected.' \ - 'Optimizer gradient normalizing factor will not change w.r.t new batch_size when ' \ - 'update_on_kvstore=True and when distributed `kvstore` is used.') - - self._optimizer.rescale_grad = rescale_grad - if not self._kv_initialized: self._init_kvstore() if self._params_to_init: self._init_params() + self._optimizer.rescale_grad = self._scale / batch_size + self._allreduce_grads() self._update(ignore_stale_grad) diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py index 7fd0333aea79..75b48f42c5e8 100644 --- a/tests/nightly/dist_device_sync_kvstore.py +++ b/tests/nightly/dist_device_sync_kvstore.py @@ -90,25 +90,6 @@ def check_init(kv, cur_keys, cur_shape, device=False): my_rank = kv.rank print('worker ' + str(my_rank) + ' is initialized') -def test_gluon_trainer_type(): - def check_trainer_kv_update(update_on_kv): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(10,1), lr_mult=1.0) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - try: - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) - trainer._init_kvstore() - assert trainer._kv_initialized - assert trainer._update_on_kvstore is True - except ValueError: - assert update_on_kv is False - - check_trainer_kv_update(False) - check_trainer_kv_update(True) - check_trainer_kv_update(None) - my_rank = kv.rank - print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type') - if __name__ == "__main__": test_sync_init() test_sync_push_pull()