Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Sep 22, 2022
1 parent 14958bb commit 21dbad9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def postprocs() -> List[Postproc]:
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
# TODO(masahi): Fix RewriteLayout for link-params=True case
M.RewriteLayout(),
# M.RewriteLayout(),
]


Expand Down
48 changes: 27 additions & 21 deletions tests/python/contrib/test_hexagon/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,20 @@ def schedule_dense_for_tune(sch):
verify_dense(sch, target, M, N, K, session)


# This is an example of a schedule found by vrmpy auto tensorization.
# It gets 440 GFLOPS on SD888.
@tvm.script.ir_module
class Module_vrmpy_auto_tensorize:
@T.prim_func
def main(X: T.Buffer[(128, 768), "uint8"], packedW: T.Buffer[(24, 192, 32, 4), "uint8"], compute: T.Buffer[(128, 768), "int32"]) -> None:
# function attr dict
def main(
X: T.Buffer[(128, 768), "uint8"],
packedW: T.Buffer[(24, 192, 32, 4), "uint8"],
compute: T.Buffer[(128, 768), "int32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0_0_i1_0_0_fused in T.parallel(512, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}):
for i0_0_i1_0_0_fused in T.parallel(
512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}
):
for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1):
with T.block("compute_o_init"):
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init)
Expand All @@ -241,15 +246,27 @@ def main(X: T.Buffer[(128, 768), "uint8"], packedW: T.Buffer[(24, 192, 32, 4), "
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2)
j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1)
k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1)
T.reads(compute[i, j_o * 32 : j_o * 32 + 32], X[i, k_o * 4 : k_o * 4 + 4], packedW[j_o, k_o, 0 : 32, 0 : 4])
T.reads(
compute[i, j_o * 32 : j_o * 32 + 32],
X[i, k_o * 4 : k_o * 4 + 4],
packedW[j_o, k_o, 0:32, 0:4],
)
T.writes(compute[i, j_o * 32 : j_o * 32 + 32])
A = T.match_buffer(X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1)
B = T.match_buffer(packedW[j_o, k_o, 0 : 32, 0 : 4], [32, 4], dtype="uint8", offset_factor=1)
C = T.match_buffer(compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1)
A = T.match_buffer(
X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1
)
B = T.match_buffer(
packedW[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1
)
C = T.match_buffer(
compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1
)
A_u8x4: T.uint8x4 = A[0:4]
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32")
C[0:32] = T.call_llvm_pure_intrin(4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32")
C[0:32] = T.call_llvm_pure_intrin(
4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32"
)


@tvm.testing.requires_hexagon
Expand All @@ -264,16 +281,6 @@ def test_vrmpy_dense_auto_tensorize(hexagon_launcher):
workload = te.create_prim_func(dense(M, N, K))

sch_rules = [
schedule_rule.AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
),
schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
schedule_rule.MultiLevelTilingWithIntrin(
VRMPY_u8u8i32_INTRIN,
structure="SRSRS",
Expand All @@ -296,7 +303,6 @@ def test_vrmpy_dense_auto_tensorize(hexagon_launcher):
]

postprocs = [
postproc.DisallowDynamicLoop(),
postproc.RewriteParallelVectorizeUnroll(),
postproc.RewriteReductionBlock(),
postproc.RewriteTensorize(vectorize_init_loop=True),
Expand Down

0 comments on commit 21dbad9

Please sign in to comment.