Skip to content

Commit

Permalink
Use multi-tensor sumSQ in clip_global_norm (apache#17652)
Browse files Browse the repository at this point in the history
* Use multi-tensor sumSQ in clip_global_norm

* fix pylint
  • Loading branch information
MoisesHer committed Apr 10, 2020
1 parent 27f8ba6 commit f9b6ed6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
24 changes: 16 additions & 8 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,23 @@ def clip_global_norm(arrays, max_norm, check_isfinite=True):
False. Otherwise a float is returned.
"""
def _norm(array):
if array.stype == 'default':
x = array.reshape((-1,))
return ndarray.dot(x, x)
return array.norm().square()
assert len(arrays) > 0
# group arrays by ctx
def group_by_ctx(arr_list):
groups = collections.defaultdict(list)
for arr in arr_list:
ctx = arr.context
groups[ctx].append(arr)
return groups
arrays_groups = group_by_ctx(arrays)
all_ctx_sum = []
ctx = arrays[0].context
total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
total_norm = ndarray.sqrt(total_norm)
for group in arrays_groups:
sum_sq = ndarray.multi_sum_sq(*arrays_groups[group],
num_arrays=len(arrays_groups[group]))
sum_sq = ndarray.add_n(*sum_sq)
all_ctx_sum.append(sum_sq.as_in_context(ctx))
# global reduce
total_norm = ndarray.add_n(*all_ctx_sum).sqrt()
if check_isfinite:
if not np.isfinite(total_norm.asscalar()):
warnings.warn(
Expand Down
14 changes: 9 additions & 5 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,18 @@ def test_global_norm_clip_multi_device():
for check_isfinite in [True, False]:
x1 = mx.nd.ones((3, 3), ctx=mx.gpu(0))
x2 = mx.nd.ones((4, 4), ctx=mx.cpu(0))
x3 = mx.nd.ones((7, 4), ctx=mx.gpu(0))
x4 = mx.nd.ones((7, 4), ctx=mx.cpu(0))
norm = gluon.utils.clip_global_norm(
[x1, x2], 1.0, check_isfinite=check_isfinite)
[x1, x2, x3, x4], 1.0, check_isfinite=check_isfinite)
if check_isfinite:
assert norm == 5.0
assert norm == 9.0
else:
assert norm.asscalar() == 5.0
assert_almost_equal(x1, np.ones((3, 3)) / 5)
assert_almost_equal(x2, np.ones((4, 4)) / 5)
assert norm.asscalar() == 9.0
assert_almost_equal(x1, np.ones((3, 3)) / 9)
assert_almost_equal(x2, np.ones((4, 4)) / 9)
assert_almost_equal(x3, np.ones((7, 4)) / 9)
assert_almost_equal(x4, np.ones((7, 4)) / 9)


def _check_batchnorm_result(input, num_devices=1, cuda=False):
Expand Down

0 comments on commit f9b6ed6

Please sign in to comment.