Skip to content

Commit

Permalink
[TVMScript] Create loop var with min_val dtype in for frame (#15547)
Browse files Browse the repository at this point in the history
adapt tir for loop var dtype
  • Loading branch information
Lucien0 committed Aug 15, 2023
1 parent 1422c18 commit 208e01f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
PrimExpr extent = arith::Analyzer().Simplify(stop - start); \
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
n->vars = {Var("v", DataType::Int(bits))}; \
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \
n->doms = {Range::FromMinExtent(min, extent)}; \
n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
ICHECK_EQ(vars.size(), 1); \
Expand All @@ -344,7 +344,7 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
PrimExpr extent = arith::Analyzer().Simplify(stop - start);
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->vars = {Var("v", DataType::Int(bits))};
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))};
n->doms = {Range::FromMinExtent(min, extent)};
n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_tvmscript_ir_builder_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ def test_ir_builder_tir_for():
assert_structural_equal(for_actual, for_expected, map_free_vars=True)


def test_ir_builder_tir_for_uint():
with IRBuilder() as ib:
with T.serial(tir.const(128, "uint32")) as a:
T.evaluate(0)

# the for generated by IRBuilder
for_actual = ib.get()

for_expected = tir.For(
loop_var=tir.Var("", "uint32"),
min_val=tir.const(0, "uint32"),
extent=tir.const(128, "uint32"),
kind=tir.ForKind.SERIAL,
body=tir.Evaluate(0),
)

# Check if the generated ir is expected
assert_structural_equal(for_actual, for_expected, map_free_vars=True)


def test_ir_builder_tir_assert():
with IRBuilder() as ib:
with T.Assert(T.int32() == 0, message="a is 0"):
Expand Down

0 comments on commit 208e01f

Please sign in to comment.