[TIR] Fix buffer shape and IndexMap indices dtype mismatch#13463
[TIR] Fix buffer shape and IndexMap indices dtype mismatch#13463vinx13 merged 5 commits intoapache:mainfrom
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Lunderberg
left a comment
There was a problem hiding this comment.
Good catch, and the change looks reasonable. The MapIndices method already uses SubstituteWithDataTypeLegalization, so it makes sense for the transform layout to use it as well.
|
Regarding the readability of the unit test, I did some poking around, and the following PrimFunc triggers the same error on main when the layout is transformed using @T.prim_func
def func(A: T.Buffer[T.int64(58), "int32"]):
for i in T.serial(T.int64(58)):
with T.block("block"):
vi = T.axis.remap("S", [i])
A[vi] = 0 |
|
@Lunderberg Thanks for finding the simple test case! I hit this error when running this Hexagon test . The original test case in this PR, with itscache_read, index map etc, was directly taken from this Hexagon test.
The bug is triggered when we hit the code path I was not sure what makes this code path hit and why existing tests didn't hit it. |
0bc7436 to
aa7f08e
Compare
|
@Lunderberg @vinx13 PTAL, thanks. |
Lunderberg
left a comment
There was a problem hiding this comment.
Thank you for making the changes, and LGTM!
After the PR #13327, I'm getting a dtype-mismatch error at
tvm/src/tir/schedule/primitive/layout_transformation.cc
Line 310 in b6fae9b
dim.dtype()is now int64 whilevirtual_var.dtype()is int32. The dtypes ofinitial_indicesinIndexMapare fixed to int32 (see below), this is in conflict with the above PR which made the dtypes of buffer shapes int64.tvm/src/tir/ir/index_map.cc
Line 51 in 458ca81
tvm/python/tvm/tir/function.py
Line 395 in 78b5322
Since
initial_indicesis used everywhere to constructloop_var/iter_var/iter_valuesofForandBlocketc, I'm adding a dtype legalization at the beginning ofTransformLayout, when the dtypes of the input buffer andinitial_indicesdo not match.@vinx13 @Lunderberg