Skip to content

Commit

Permalink
[TIR] Fix RenewDef for symbolic input shapes (#15163)
Browse files Browse the repository at this point in the history
There are cases where the shapes of input buffers are symbolic, but the
first symbol is a composite PrimExpr rather than a TIR Var, which the
original implementation does not take this into account.

Example:

```python
@T.prim_func
def main(a: T.handle, b: T.handle):
    m = T.int64()
    A = T.match_buffer(a, (m * 2,))  // `m` first appears as composite
    B = T.match_buffer(b, (m, 2))
```
  • Loading branch information
junrushao committed Jun 27, 2023
1 parent 48f295f commit bca7ebf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/tir/transforms/renew_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ class RenewDefMutator : public StmtExprMutator {
for (const auto& param : func->params) {
params.push_back(generator.ReDefineVar(param));
}
for (const auto& param : func->params) {
if (param->dtype.is_handle()) {
const Buffer& buffer = func->buffer_map.at(param);
for (const PrimExpr& e : buffer->shape) {
if (const auto* v = e.as<VarNode>()) {
if (generator.remap_.count(GetRef<Var>(v)) == 0) {
generator.ReDefineVar(GetRef<Var>(v));
}
}
}
}
}
// Redefine buffers in order
// TODO(Siyuan Feng): checking var is used after define
Map<tir::Var, Buffer> buffer_map;
Expand Down
20 changes: 19 additions & 1 deletion tests/python/unittest/test_tir_renew_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import pytest
import sys

import pytest
import tvm
import tvm.testing
from tvm.script import tir as T
Expand Down Expand Up @@ -76,6 +76,7 @@ def elementwise(A: T.Buffer((128, 128), "float32")):
assert f1.body.block.body.loop_var != f2.body.block.body.loop_var
# check remap of j
assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var

# check inner block
def _get_block(f):
return f.body.block.body.body.body.block
Expand Down Expand Up @@ -169,5 +170,22 @@ def symbolic_func(a: T.handle, b: T.handle, n: T.int32):
tvm.ir.assert_structural_equal(f1, f2)


def test_buffer_map():
@T.prim_func
def main(a: T.handle, b: T.handle):
m = T.int64()
A = T.match_buffer(a, (m * 2,))
B = T.match_buffer(b, (m, 2))
for i, j in T.grid(m, 2):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi * 2 + vj]

f1 = main
f2 = tvm.tir.stmt_functor.renew_defs(main)
tvm.ir.assert_structural_equal(f1, f2)
assert f1.buffer_map[f1.params[1]].shape[0] != f2.buffer_map[f2.params[1]].shape[0]


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit bca7ebf

Please sign in to comment.