-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Applying the transformations LiftTransformParams(), there is an inconsistency in the model structure between the sequential transformation (mod_seq) and the individual transformations (mod). And build the module after transformation, it will crash.
The error may relate to how m is handled as a dynamic shape or a required computed value, which may not be properly resolved during the transformation and build processes.
Actual behavior
File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 463
InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed
Steps to reproduce
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def tir_acos(var_x: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
m = T.int64()
x = T.match_buffer(var_x, (T.int64(16), m, T.int64(3), T.int64(3)))
compute = T.match_buffer(var_compute, (T.int64(16), m, T.int64(3), T.int64(3)))
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(16), m, T.int64(3), T.int64(3)):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(x[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.acos(x[v_i0, v_i1, v_i2, v_i3])
@R.function
def main(x: R.Tensor((1, 16, 224, "n"), dtype="float32"), w1: R.Tensor((16, "m", 3, 3), dtype="float32"), w2: R.Tensor((16, "m", 3, 3), dtype="float32")) -> R.Tensor((16, "m", 3, 3), dtype="float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
gv = R.call_tir(cls.tir_acos, (w1,), out_sinfo=R.Tensor((16, m, 3, 3), dtype="float32"))
R.output(gv)
return gv
mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.LiftTransformParams(), ])(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')
tvm.ir.assert_structural_equal(mod_seq, mod)Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug