Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM] Re-enable ref_input #8113

Merged
merged 8 commits into from Jul 17, 2021
53 changes: 41 additions & 12 deletions python/tvm/autotvm/measure/measure_methods.py
Expand Up @@ -32,6 +32,7 @@
import typing
from collections import namedtuple
from random import getrandbits
import warnings

import tvm._ffi
import tvm.ir.transform
Expand Down Expand Up @@ -235,13 +236,33 @@ def __init__(
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms
self._ref_input = None

self.enable_cpu_cache_flush = enable_cpu_cache_flush
self.cooldown_interval = cooldown_interval
self.module_loader = module_loader

self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1))

@property
def ref_input(self):
"""
Fixed input for tuning special operators, e.g., sparse operators
requiring indices as input.
"""
return self._ref_input

@ref_input.setter
def ref_input(self, val):
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
self._ref_input = val

def set_task(self, task):
self.task = task

Expand Down Expand Up @@ -308,6 +329,7 @@ def run(self, measure_inputs, build_results):
self.min_repeat_ms,
self.cooldown_interval,
remote_kwargs,
self.ref_input,
self.enable_cpu_cache_flush,
module_loader,
)
Expand Down Expand Up @@ -508,6 +530,7 @@ def run_through_rpc(
min_repeat_ms,
cooldown_interval,
remote_kwargs,
ref_input,
enable_cpu_cache_flush=False,
module_loader=None,
):
Expand Down Expand Up @@ -539,6 +562,8 @@ def run_through_rpc(
The cool down interval between two measurements
remote_kwargs: dict
Passed to module_loader(). Ultimately, keyword args to request_remote().
ref_input: List of np.ndarray
The reference input used for tuning. Empty for randomly filled input.
enable_cpu_cache_flush: bool
Whether to flush cache on CPU between repeated measurements.
Flushing cache can make the measured latency of one operator closer to
Expand Down Expand Up @@ -573,18 +598,22 @@ def run_through_rpc(
f_preproc=f_prepare,
)

try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
dev.sync()
if ref_input:
args = [nd.array(x, device=dev) for x in ref_input]
else:
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake "
"on the remote devices"
)
args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
dev.sync()
areusch marked this conversation as resolved.
Show resolved Hide resolved

costs = time_f(*args).results

Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_autotvm_measure.py
Expand Up @@ -26,6 +26,8 @@
from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task
from tvm import autotvm
from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult
from tvm.autotvm import measure
from inspect import Signature


def test_task_tuner_without_measurement():
Expand Down Expand Up @@ -60,8 +62,30 @@ def test_task_tuner_without_measurement_spawn():
p.join()


def test_task_runner_with_ref_input():
"""test runner ref_input without measurement"""
refinp = [np.random.rand(128, 128) for i in range(3)]
runner = measure.LocalRunner()
runner.ref_input = refinp

class DummyExecutor(measure.executor.Executor):
def __init__(self):
self.ran_dummy_executor = False

def submit(self, func, *args, **kwargs):
self.ran_dummy_executor = True
sig = Signature.from_callable(func)
assert sig.bind(*args, **kwargs).arguments["ref_input"] == refinp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you set a variable here in the outer test_task_runner_with_ref_input (e.g. ran_dummy_executor = True) and then assert on that in the test to ensure that this function and assert ran?

return measure.local_executor.LocalFutureNoFork(None)

runner.executor = DummyExecutor()
runner.run([None], [None])
assert runner.executor.ran_dummy_executor


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

test_task_tuner_without_measurement()
test_task_tuner_without_measurement_spawn()
test_task_runner_with_ref_input()