-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hexagon] Support template-free meta schedule tuning #12854
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,15 +21,20 @@ | |
import tempfile | ||
|
||
import tvm.testing | ||
from tvm import te | ||
import tvm.topi.testing | ||
from tvm import te, relay | ||
from tvm import meta_schedule as ms | ||
from tvm.meta_schedule.arg_info import TensorInfo | ||
from tvm.meta_schedule.builder import BuilderInput | ||
from tvm.meta_schedule import postproc, schedule_rule | ||
from tvm.script import tir as T | ||
from tvm.tir import FloatImm | ||
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN | ||
from tvm.meta_schedule.runner import RunnerInput | ||
from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner | ||
from tvm.relay.backend import Executor | ||
from tvm.topi.utils import get_const_tuple | ||
from tvm.meta_schedule.testing import te_workload | ||
|
||
MATMUL_N = 16 | ||
MATMUL_M = 32 | ||
|
@@ -166,7 +171,6 @@ def verify_dense(sch, target, M, N, K, hexagon_session): | |
print("%f ms, %f GOPS" % (time_ms, gflops / (time_ms / 1e3))) | ||
|
||
|
||
@pytest.mark.skip(reason="xgboost not installed on CI") | ||
@tvm.testing.requires_hexagon | ||
def test_vrmpy_dense(hexagon_launcher): | ||
if hexagon_launcher._serial_number == "simulator": | ||
|
@@ -209,3 +213,207 @@ def schedule_dense_for_tune(sch): | |
|
||
with hexagon_launcher.start_session() as session: | ||
verify_dense(sch, target, M, N, K, session) | ||
|
||
|
||
# This is an example of a schedule found by vrmpy auto tensorization. | ||
# It gets 440 GFLOPS on SD888. | ||
@tvm.script.ir_module | ||
class Module_vrmpy_auto_tensorize: | ||
@T.prim_func | ||
def main( | ||
X: T.Buffer[(128, 768), "uint8"], | ||
packedW: T.Buffer[(24, 192, 32, 4), "uint8"], | ||
compute: T.Buffer[(128, 768), "int32"], | ||
) -> None: | ||
T.func_attr({"global_symbol": "main", "tir.noalias": True}) | ||
for i0_0_i1_0_0_fused in T.parallel( | ||
512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} | ||
): | ||
for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1): | ||
with T.block("compute_o_init"): | ||
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init) | ||
j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init) | ||
T.reads() | ||
T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) | ||
for i1_1 in T.vectorized(32): | ||
with T.block("compute_init"): | ||
j_i_init = T.axis.spatial(32, i1_1) | ||
T.reads() | ||
T.writes(compute[i, j_o * 32 + j_i_init]) | ||
compute[i, j_o * 32 + j_i_init] = 0 | ||
for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1): | ||
with T.block("compute_o_update"): | ||
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2) | ||
j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1) | ||
k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1) | ||
T.reads( | ||
compute[i, j_o * 32 : j_o * 32 + 32], | ||
X[i, k_o * 4 : k_o * 4 + 4], | ||
packedW[j_o, k_o, 0:32, 0:4], | ||
) | ||
T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) | ||
A = T.match_buffer( | ||
X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1 | ||
) | ||
B = T.match_buffer( | ||
packedW[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1 | ||
) | ||
C = T.match_buffer( | ||
compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1 | ||
) | ||
A_u8x4: T.uint8x4 = A[0:4] | ||
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") | ||
B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32") | ||
C[0:32] = T.call_llvm_pure_intrin( | ||
4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32" | ||
) | ||
|
||
|
||
@tvm.testing.requires_hexagon | ||
def test_vrmpy_dense_auto_tensorize(hexagon_launcher): | ||
if hexagon_launcher._serial_number == "simulator": | ||
pytest.skip(msg="Tuning on simulator not supported.") | ||
|
||
target_hexagon = tvm.target.hexagon("v68") | ||
target = tvm.target.Target(target_hexagon, host=target_hexagon) | ||
|
||
M, N, K = 128, 768, 768 | ||
workload = te.create_prim_func(dense(M, N, K)) | ||
|
||
sch_rules = [ | ||
schedule_rule.MultiLevelTilingWithIntrin( | ||
VRMPY_u8u8i32_INTRIN, | ||
structure="SRSRS", | ||
tile_binds=None, | ||
max_innermost_factor=64, | ||
vector_load_lens=None, | ||
reuse_read=None, | ||
reuse_write=schedule_rule.ReuseType( | ||
req="may", | ||
levels=[1, 2], | ||
scope="global", | ||
), | ||
), | ||
schedule_rule.ParallelizeVectorizeUnroll( | ||
max_jobs_per_core=16, | ||
max_vectorize_extent=128, | ||
unroll_max_steps=[0, 16, 64, 512], | ||
unroll_explicit=True, | ||
), | ||
] | ||
|
||
postprocs = [ | ||
postproc.RewriteParallelVectorizeUnroll(), | ||
postproc.RewriteReductionBlock(), | ||
postproc.RewriteTensorize(vectorize_init_loop=True), | ||
] | ||
|
||
if True: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a leftover from something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I intentionally left it so that people can experiment with both |
||
with tempfile.TemporaryDirectory() as work_dir: | ||
config = ms.TuneConfig( | ||
strategy="replay_trace", | ||
num_trials_per_iter=8, | ||
max_trials_per_task=8, | ||
max_trials_global=8, | ||
) | ||
|
||
sch = ms.tune_tir( | ||
mod=workload, | ||
target=target, | ||
config=config, | ||
work_dir=work_dir, | ||
sch_rules=lambda: sch_rules, | ||
postprocs=lambda: postprocs, | ||
builder=get_hexagon_local_builder(), | ||
runner=get_hexagon_rpc_runner(hexagon_launcher, number=10), | ||
) | ||
else: | ||
sch = tvm.tir.Schedule(Module_vrmpy_auto_tensorize, debug_mask="all") | ||
|
||
with hexagon_launcher.start_session() as session: | ||
verify_dense(sch, target, M, N, K, session) | ||
|
||
|
||
@tvm.testing.requires_hexagon | ||
def test_conv2d_relay_auto_schedule(hexagon_launcher): | ||
if hexagon_launcher._serial_number == "simulator": | ||
pytest.skip(msg="Tuning on simulator not supported.") | ||
|
||
target_hexagon = tvm.target.hexagon("v69") | ||
target = tvm.target.Target(target_hexagon, host=target_hexagon) | ||
I, O, H, W = 64, 64, 56, 56 | ||
kH = kW = 3 | ||
|
||
strides = (1, 1) | ||
padding = (1, 1) | ||
|
||
d_shape = (1, H, W, I) | ||
w_shape = (kH, kW, I, O) | ||
bias_shape = (1, 1, 1, w_shape[3]) | ||
out_channel = w_shape[3] | ||
|
||
data = relay.var("data", shape=d_shape, dtype="float16") | ||
weight = relay.var("weight", shape=w_shape, dtype="float16") | ||
bias = relay.var("bias", shape=bias_shape, dtype="float16") | ||
conv2d = relay.nn.conv2d( | ||
data=data, | ||
weight=weight, | ||
kernel_size=(kH, kW), | ||
channels=out_channel, | ||
padding=padding, | ||
strides=strides, | ||
out_dtype="float16", | ||
data_layout="NHWC", | ||
kernel_layout="HWIO", | ||
) | ||
mod = tvm.IRModule.from_expr(conv2d + bias) | ||
|
||
data_np = np.random.randn(*d_shape).astype("float16") | ||
weight_np = np.random.randn(*w_shape).astype("float16") | ||
bias_np = np.random.randn(*bias_shape).astype("float16") | ||
params = {"weight": weight_np, "bias": bias_np} | ||
|
||
target_llvm = tvm.target.Target("llvm") | ||
|
||
with tvm.transform.PassContext( | ||
opt_level=3, | ||
): | ||
lib_ref = relay.build(mod, target=target_llvm, params=params) | ||
|
||
rt_mod_ref = tvm.contrib.graph_executor.GraphModule(lib_ref["default"](tvm.cpu(0))) | ||
|
||
rt_mod_ref.set_input("data", data_np) | ||
|
||
rt_mod_ref.run() | ||
|
||
ref = rt_mod_ref.get_output(0).numpy() | ||
|
||
config = ms.TuneConfig( | ||
strategy="replay_trace", | ||
num_trials_per_iter=8, | ||
max_trials_per_task=8, | ||
max_trials_global=8, | ||
) | ||
|
||
with tempfile.TemporaryDirectory() as work_dir: | ||
executor = Executor("graph", {"link-params": True}) | ||
lib = ms.tune_relay( | ||
mod=mod, | ||
params=params, | ||
target=target, | ||
config=config, | ||
work_dir=work_dir, | ||
builder=get_hexagon_local_builder(), | ||
runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), | ||
executor=executor, | ||
) | ||
|
||
with hexagon_launcher.start_session() as session: | ||
rt_mod = session.get_executor_from_factory(lib) | ||
|
||
rt_mod.set_input("data", data_np) | ||
|
||
rt_mod.run() | ||
|
||
out = rt_mod.get_output(0).numpy() | ||
print(np.max(np.abs(ref - out)), np.mean(np.abs(ref - out))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See
tvm/python/tvm/meta_schedule/relay_integration.py
Lines 87 to 91 in 370abe6
relay.FuseOps.link_params
config, others are for compatibility with the existing code.