Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Make SearchTask and ComputeDAG serializable #6842

Merged
merged 7 commits into from Nov 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 18 additions & 5 deletions python/tvm/auto_scheduler/compute_dag.py
Expand Up @@ -21,14 +21,14 @@

import tvm._ffi
from tvm.runtime import Object
from tvm.te import PlaceholderOp, ComputeOp
from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
from tvm.te import ComputeOp, PlaceholderOp

from . import _ffi_api
from .loop_state import State, StateObject
from .utils import get_const_tuple
from .workload_registry import workload_key_to_tensors

from . import _ffi_api


@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
Expand Down Expand Up @@ -63,7 +63,10 @@ def __init__(self, compute_or_sche):
elif isinstance(compute_or_sche, list):
for item in compute_or_sche:
if not isinstance(item, tvm.te.Tensor):
raise ValueError("The input of ComputeDAG should be a list of Tensor")
raise ValueError(
"The input of ComputeDAG should be a list of Tensor, but got %s"
% type(item)
)
compute = compute_or_sche
sche = None
elif isinstance(compute_or_sche, tvm.te.Schedule):
Expand All @@ -72,8 +75,10 @@ def __init__(self, compute_or_sche):
else:
raise ValueError(
"Invalid compute type: %s. ComputeDAG expects string, list of Tensor, or Schedule"
% type(compute)
% type(compute_or_sche)
)
self.compute = compute
self.sche = sche
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)

def get_init_state(self):
Expand Down Expand Up @@ -182,3 +187,11 @@ def hash_key(self):

str_key = str_key.encode(encoding="utf-8")
return hashlib.md5(str_key).hexdigest()

def __getstate__(self):
return {"compute": SaveJSON(self.compute), "sche": SaveJSON(self.sche)}

def __setstate__(self, state):
self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return
Copy link
Member

@merrymercy merrymercy Nov 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac Do we need to explicitly call SaveJSON, LoadJSON here?
I find that you don't call these functions in the SearchTask::__getstate__

Copy link
Contributor Author

@comaniac comaniac Nov 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't call them in SearchTask because all its members have setstate and getstate implemented correctly. This function provides the correct implementation for ConouteDAG and it will be called when processing ComputeDAG in SearchTask.

For Load/SaveJSON here, alternatively we can use pickle.loads/dumps and let them call Load/SaveJSON via Objects. However, this will introduce the dependency of pickle in this function, and will somehow trigger a bug in the unit test. You can test it by replacing Load/SaveJSON with pickle calls. In the unit test I add, you will find that after loading the DAG back, the placeholder A appears twice.

self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche)
29 changes: 29 additions & 0 deletions python/tvm/auto_scheduler/search_task.py
Expand Up @@ -42,6 +42,35 @@ class SearchTask(Object):
"""

def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
self.dag = dag
self.workload_key = workload_key
self.target = target
self.target_host = target_host
self.hardware_params = hardware_params
self.__init_handle_by_constructor__(
_ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
)

def __getstate__(self):
return {
"dag": self.dag,
"workload_key": self.workload_key,
"target": self.target,
"target_host": self.target_host,
"hardware_params": self.hardware_params,
}

def __setstate__(self, state):
self.dag = state["dag"]
self.workload_key = state["workload_key"]
self.target = state["target"]
self.target_host = state["target_host"]
self.hardware_params = state["hardware_params"]
self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
self.dag,
self.workload_key,
self.target,
self.target_host,
self.hardware_params,
)
4 changes: 1 addition & 3 deletions tests/python/unittest/test_auto_scheduler_common.py
Expand Up @@ -161,14 +161,12 @@ def conv2d_winograd_nhwc_auto_scheduler_test(
r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, "float32")
A, B, _ = winograd_transform_matrices(m, r, "float32")

H = (H + 2 * HPAD - KH) // HSTR + 1
W = (W + 2 * WPAD - KW) // WSTR + 1
nH, nW = (H + m - 1) // m, (W + m - 1) // m
P = N * nH * nW
r_kh = te.reduce_axis((0, KH), name="r_kh")
r_kw = te.reduce_axis((0, KW), name="r_kw")
kshape = (alpha, alpha, CI, CO)
kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")

Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_auto_scheduler_compute_dag.py
Expand Up @@ -16,6 +16,7 @@
# under the License.

"""Test ComputeDAG (replay, infer bound)"""
import pickle

import tvm
from tvm import topi
Expand All @@ -32,7 +33,7 @@ def test_apply_steps():
dag, s = get_tiled_matmul()
dag.print_python_code_from_state(s)
sch, tensors = dag.apply_steps_from_state(s)
stmt = tvm.lower(sch, tensors, simple_mode=True)
tvm.lower(sch, tensors, simple_mode=True)


def test_infer_bound():
Expand Down Expand Up @@ -61,6 +62,7 @@ def test_estimate_flop():


def test_stage_order():
"""Test if the stage order is preserved when recovering a DAG."""
N = 512
A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
sch = te.create_schedule([D.op, E.op])
Expand All @@ -87,6 +89,11 @@ def test_stage_order():
elif op.name in ["B", "C"]:
assert stage_ops_1[idx + 1].name == "%s.shared" % op.name

# Serialize and deserialize the ComputeDAG constructed by a schedule.
loaded_dag = pickle.loads(pickle.dumps(dag))
assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)

# Apply the same schedule to Ansor state and it should have the same stage order
dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
state = dag.get_init_state()
Expand All @@ -105,6 +112,27 @@ def test_stage_order():
for op1, op2 in zip(stage_ops_1, stage_ops_2):
assert op1.name == op2.name

# Serialize and deserialize the ComputeDAG constructed by a list of tensor ops.
loaded_dag = pickle.loads(pickle.dumps(dag))
assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)

# Serialize and deserialize the search task.
task = auto_scheduler.SearchTask(
dag,
"test1",
tvm.target.Target("llvm"),
hardware_params=auto_scheduler.HardwareParams(100000, 16, 64),
)
task2 = pickle.loads(pickle.dumps(task))
assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state())
assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops)
assert task.workload_key == task2.workload_key
assert str(task.target) == str(task2.target)
assert task.hardware_params.num_cores == task2.hardware_params.num_cores
assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes
assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes


if __name__ == "__main__":
test_apply_steps()
Expand Down