Skip to content

Commit

Permalink
[AUTOTVM] Fix a bug in generating the search space (#4779)
Browse files Browse the repository at this point in the history
- Do not use numpy.prod which ignores integer (64 bits) overflows.
  This leads to an incorrect number of points in the search space.
  • Loading branch information
wpan11nv committed Jan 29, 2020
1 parent 3827ccb commit 1b8522e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__(self, axes, policy, **kwargs):
def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
"""Generate space by DFS"""
if now == self.num_output - 1:
prod = np.prod(tmp_stack, dtype=np.int64)
prod = functools.reduce(lambda x, y: x * y, tmp_stack)
if prod > self.product:
return
if self.product % prod == 0 or (not enforce_no_tail and prod < self.product):
self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
else:
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_autotvm_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def test_split():
cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3)
assert len(cfg.space_map['tile_c']) == 84

# Count the number of non-negative integer solutions of a + b + c + d = n
def count4(n):
cnt = 0
for a in range(0, n + 1):
for b in range(0, n - a + 1):
cnt += n - a - b + 1
return cnt

# test overflow
n = 25
cfg = ConfigSpace()
cfg.define_split('x', cfg.axis(2**n), policy='factors', num_outputs=4)
# count4(25) is 3276.
assert len(cfg.space_map['x']) == count4(n)

# test fallback
cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
Expand Down

0 comments on commit 1b8522e

Please sign in to comment.