Skip to content

Commit

Permalink
[TIR] Fix RenewDef for symbolic input shapes
Browse files Browse the repository at this point in the history
There are cases where the shapes of input buffers are symbolic, but the
first symbol is a composite PrimExpr rather than a TIR Var, which the
original implementation does not take this into account.

Example:

```python
@T.prim_func
def main(a: T.handle, b: T.handle):
    m = T.int64()
    A = T.match_buffer(a, (m * 2,))  // `m` first appears as composite
    B = T.match_buffer(b, (m, 2))
```
  • Loading branch information
junrushao committed Jun 27, 2023
1 parent 28aead9 commit dc138ac
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/tir/transforms/renew_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ class RenewDefMutator : public StmtExprMutator {
for (const auto& param : func->params) {
params.push_back(generator.ReDefineVar(param));
}
for (const auto& param : func->params) {
if (param->dtype.is_handle()) {
const Buffer& buffer = func->buffer_map.at(param);
for (const PrimExpr& e : buffer->shape) {
if (const auto* v = e.as<VarNode>()) {
if (generator.remap_.count(GetRef<Var>(v)) == 0) {
generator.ReDefineVar(GetRef<Var>(v));
}
}
}
}
}
// Redefine buffers in order
// TODO(Siyuan Feng): checking var is used after define
Map<tir::Var, Buffer> buffer_map;
Expand Down
19 changes: 18 additions & 1 deletion tests/python/unittest/test_tir_renew_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import pytest
import sys

import pytest
import tvm
import tvm.testing
from tvm.script import tir as T
Expand Down Expand Up @@ -76,6 +76,7 @@ def elementwise(A: T.Buffer((128, 128), "float32")):
assert f1.body.block.body.loop_var != f2.body.block.body.loop_var
# check remap of j
assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var

# check inner block
def _get_block(f):
return f.body.block.body.body.body.block
Expand Down Expand Up @@ -169,5 +170,21 @@ def symbolic_func(a: T.handle, b: T.handle, n: T.int32):
tvm.ir.assert_structural_equal(f1, f2)


def test_buffer_map():
@T.prim_func
def main(a: T.handle, b: T.handle):
m = T.int64()
A = T.match_buffer(a, (m * 2,))
B = T.match_buffer(b, (m, 2))
for i, j in T.grid(m, 2):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi * 2 + vj]

f1 = main
f2 = tvm.tir.stmt_functor.renew_defs(main)
tvm.ir.assert_structural_equal(f1, f2)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit dc138ac

Please sign in to comment.