Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Support template-free meta schedule tuning #12854

Merged
merged 4 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 55 additions & 2 deletions python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,12 @@ def schedule_rules( # pylint: disable=redefined-outer-name
return sch_rules()
if sch_rules is not None:
raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}")
if target.kind.name in ["llvm", "hexagon"]:
if target.kind.name == "llvm":
return _DefaultLLVM.schedule_rules()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return _DefaultCUDA.schedule_rules()
if target.kind.name == "hexagon":
return _DefaultHexagon.schedule_rules()
raise ValueError(f"Unsupported target: {target}")


Expand All @@ -190,10 +192,12 @@ def postproc( # pylint: disable=redefined-outer-name
return postproc()
if postproc is not None:
raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}")
if target.kind.name in ["llvm", "hexagon"]:
if target.kind.name == "llvm":
return _DefaultLLVM.postprocs()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return _DefaultCUDA.postprocs()
if target.kind.name == "hexagon":
return _DefaultHexagon.postprocs()
raise ValueError(f"Unsupported target: {target}")


Expand Down Expand Up @@ -277,6 +281,55 @@ def mutator_probs() -> Dict[Mutator, float]:
}


class _DefaultHexagon:
"""Default tuning configuration for Hexagon."""

@staticmethod
def schedule_rules() -> List[ScheduleRule]:
from tvm.meta_schedule import schedule_rule as M

return [
M.AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
),
M.MultiLevelTilingWideVector(
structure="SRSRS",
vector_length_in_bits=1024,
max_innermost_factor=128,
reuse_read=None,
reuse_write=M.ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=128,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
),
]

@staticmethod
def postprocs() -> List[Postproc]:
from tvm.meta_schedule import postproc as M

return [
M.DisallowDynamicLoop(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
# TODO(masahi): Fix RewriteLayout for link-params=True case
# M.RewriteLayout(),
]


class _DefaultCUDA:
"""Default tuning configuration for CUDA."""

Expand Down
29 changes: 26 additions & 3 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def tune_relay(
postprocs: Optional[FnPostproc] = None,
mutator_probs: Optional[FnMutatorProb] = None,
num_threads: Optional[int] = None,
executor=None,
) -> Union[Module, vm.Executable]:
"""Tune a Relay IRModule with a given target.

Expand Down Expand Up @@ -581,6 +582,9 @@ def tune_relay(
The callbacks used during tuning.
backend : str = "graph"
The backend to use for relay compilation(graph / vm).
executor : relay.backend.Executor
The executor to be passed to relay.build(...). In particular, its link-params
attribute affects task extration and workload database look up.

Returns
-------
Expand All @@ -596,8 +600,23 @@ def tune_relay(
target = default_config.target(target)
# pylint: enable=protected-access,
# parse the tuning contexts

if executor is None:
executor = relay.backend.Executor("graph")

if "link-params" in executor.attrs:
link_params = executor.attrs["link-params"]
else:
link_params = False

with Profiler.timeit("TaskExtraction"):
extracted_tasks = extract_task_from_relay(mod, target, params)
pass_config = {
"relay.FuseOps.link_params": link_params,
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": "default",
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See

if pass_config is None:
pass_config = {
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": tir_converter,
}
for why this change is necessary. We only need to pass relay.FuseOps.link_params config, others are for compatibility with the existing code.

extracted_tasks = extract_task_from_relay(mod, target, params, pass_config=pass_config)

database = tune_extracted_tasks(
extracted_tasks,
config,
Expand All @@ -613,7 +632,7 @@ def tune_relay(
mutator_probs=mutator_probs,
num_threads=num_threads,
)
relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend]

with Profiler.timeit("PostTuningCompilation"):
with target, autotvm_silencer(), database:
with PassContext(
Expand All @@ -624,4 +643,8 @@ def tune_relay(
"relay.backend.tir_converter": "default",
},
):
return relay_build(mod, target=target, params=params)
if backend == "graph":
return relay.build(mod, target=target, params=params, executor=executor)

# Executor is not supported by VM
return relay.vm.compile(mod, target=target, params=params)
212 changes: 210 additions & 2 deletions tests/python/contrib/test_hexagon/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@
import tempfile

import tvm.testing
from tvm import te
import tvm.topi.testing
from tvm import te, relay
from tvm import meta_schedule as ms
from tvm.meta_schedule.arg_info import TensorInfo
from tvm.meta_schedule.builder import BuilderInput
from tvm.meta_schedule import postproc, schedule_rule
from tvm.script import tir as T
from tvm.tir import FloatImm
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN
from tvm.meta_schedule.runner import RunnerInput
from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner
from tvm.relay.backend import Executor
from tvm.topi.utils import get_const_tuple
from tvm.meta_schedule.testing import te_workload

MATMUL_N = 16
MATMUL_M = 32
Expand Down Expand Up @@ -166,7 +171,6 @@ def verify_dense(sch, target, M, N, K, hexagon_session):
print("%f ms, %f GOPS" % (time_ms, gflops / (time_ms / 1e3)))


@pytest.mark.skip(reason="xgboost not installed on CI")
@tvm.testing.requires_hexagon
def test_vrmpy_dense(hexagon_launcher):
if hexagon_launcher._serial_number == "simulator":
Expand Down Expand Up @@ -209,3 +213,207 @@ def schedule_dense_for_tune(sch):

with hexagon_launcher.start_session() as session:
verify_dense(sch, target, M, N, K, session)


# This is an example of a schedule found by vrmpy auto tensorization.
# It gets 440 GFLOPS on SD888.
@tvm.script.ir_module
class Module_vrmpy_auto_tensorize:
@T.prim_func
def main(
X: T.Buffer[(128, 768), "uint8"],
packedW: T.Buffer[(24, 192, 32, 4), "uint8"],
compute: T.Buffer[(128, 768), "int32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0_0_i1_0_0_fused in T.parallel(
512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}
):
for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1):
with T.block("compute_o_init"):
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init)
j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init)
T.reads()
T.writes(compute[i, j_o * 32 : j_o * 32 + 32])
for i1_1 in T.vectorized(32):
with T.block("compute_init"):
j_i_init = T.axis.spatial(32, i1_1)
T.reads()
T.writes(compute[i, j_o * 32 + j_i_init])
compute[i, j_o * 32 + j_i_init] = 0
for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1):
with T.block("compute_o_update"):
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2)
j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1)
k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1)
T.reads(
compute[i, j_o * 32 : j_o * 32 + 32],
X[i, k_o * 4 : k_o * 4 + 4],
packedW[j_o, k_o, 0:32, 0:4],
)
T.writes(compute[i, j_o * 32 : j_o * 32 + 32])
A = T.match_buffer(
X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1
)
B = T.match_buffer(
packedW[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1
)
C = T.match_buffer(
compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1
)
A_u8x4: T.uint8x4 = A[0:4]
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32")
C[0:32] = T.call_llvm_pure_intrin(
4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32"
)


@tvm.testing.requires_hexagon
def test_vrmpy_dense_auto_tensorize(hexagon_launcher):
if hexagon_launcher._serial_number == "simulator":
pytest.skip(msg="Tuning on simulator not supported.")

target_hexagon = tvm.target.hexagon("v68")
target = tvm.target.Target(target_hexagon, host=target_hexagon)

M, N, K = 128, 768, 768
workload = te.create_prim_func(dense(M, N, K))

sch_rules = [
schedule_rule.MultiLevelTilingWithIntrin(
VRMPY_u8u8i32_INTRIN,
structure="SRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_lens=None,
reuse_read=None,
reuse_write=schedule_rule.ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
),
schedule_rule.ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=128,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
),
]

postprocs = [
postproc.RewriteParallelVectorizeUnroll(),
postproc.RewriteReductionBlock(),
postproc.RewriteTensorize(vectorize_init_loop=True),
]

if True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a leftover from something?

Copy link
Member Author

@masahi masahi Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally left it so that people can experiment with both then and else paths. The else path just compiles and runs the best schedule found in my experiment, which reproduces 440 GOPs performance.

with tempfile.TemporaryDirectory() as work_dir:
config = ms.TuneConfig(
strategy="replay_trace",
num_trials_per_iter=8,
max_trials_per_task=8,
max_trials_global=8,
)

sch = ms.tune_tir(
mod=workload,
target=target,
config=config,
work_dir=work_dir,
sch_rules=lambda: sch_rules,
postprocs=lambda: postprocs,
builder=get_hexagon_local_builder(),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=10),
)
else:
sch = tvm.tir.Schedule(Module_vrmpy_auto_tensorize, debug_mask="all")

with hexagon_launcher.start_session() as session:
verify_dense(sch, target, M, N, K, session)


@tvm.testing.requires_hexagon
def test_conv2d_relay_auto_schedule(hexagon_launcher):
if hexagon_launcher._serial_number == "simulator":
pytest.skip(msg="Tuning on simulator not supported.")

target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
I, O, H, W = 64, 64, 56, 56
kH = kW = 3

strides = (1, 1)
padding = (1, 1)

d_shape = (1, H, W, I)
w_shape = (kH, kW, I, O)
bias_shape = (1, 1, 1, w_shape[3])
out_channel = w_shape[3]

data = relay.var("data", shape=d_shape, dtype="float16")
weight = relay.var("weight", shape=w_shape, dtype="float16")
bias = relay.var("bias", shape=bias_shape, dtype="float16")
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(kH, kW),
channels=out_channel,
padding=padding,
strides=strides,
out_dtype="float16",
data_layout="NHWC",
kernel_layout="HWIO",
)
mod = tvm.IRModule.from_expr(conv2d + bias)

data_np = np.random.randn(*d_shape).astype("float16")
weight_np = np.random.randn(*w_shape).astype("float16")
bias_np = np.random.randn(*bias_shape).astype("float16")
params = {"weight": weight_np, "bias": bias_np}

target_llvm = tvm.target.Target("llvm")

with tvm.transform.PassContext(
opt_level=3,
):
lib_ref = relay.build(mod, target=target_llvm, params=params)

rt_mod_ref = tvm.contrib.graph_executor.GraphModule(lib_ref["default"](tvm.cpu(0)))

rt_mod_ref.set_input("data", data_np)

rt_mod_ref.run()

ref = rt_mod_ref.get_output(0).numpy()

config = ms.TuneConfig(
strategy="replay_trace",
num_trials_per_iter=8,
max_trials_per_task=8,
max_trials_global=8,
)

with tempfile.TemporaryDirectory() as work_dir:
executor = Executor("graph", {"link-params": True})
lib = ms.tune_relay(
mod=mod,
params=params,
target=target,
config=config,
work_dir=work_dir,
builder=get_hexagon_local_builder(),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
executor=executor,
)

with hexagon_launcher.start_session() as session:
rt_mod = session.get_executor_from_factory(lib)

rt_mod.set_input("data", data_np)

rt_mod.run()

out = rt_mod.get_output(0).numpy()
print(np.max(np.abs(ref - out)), np.mean(np.abs(ref - out)))