From 971db0e6503206d40400ea0faa9b3646a53f2d3f Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 22 Jun 2022 21:17:19 +0800 Subject: [PATCH] do not bind non-index type value of lets in compact buffer --- src/tir/transforms/compact_buffer_region.cc | 20 +++++++++++++------ ...est_tir_transform_compact_buffer_region.py | 15 ++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index e0efec79b052..46f64d4edf09 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -150,18 +150,26 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const LetStmtNode* op) final { StmtExprVisitor::VisitExpr(op->value); - dom_analyzer_.Bind(op->var, op->value); - dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); + if (arith::IsIndexType(op->value->dtype)) { + dom_analyzer_.Bind(op->var, op->value); + dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); + } StmtExprVisitor::VisitStmt(op->body); - dom_map_.erase(op->var.get()); + if (arith::IsIndexType(op->value->dtype)) { + dom_map_.erase(op->var.get()); + } } void VisitExpr_(const LetNode* op) final { StmtExprVisitor::VisitExpr(op->value); - dom_analyzer_.Bind(op->var, op->value); - dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); + if (arith::IsIndexType(op->value->dtype)) { + dom_analyzer_.Bind(op->var, op->value); + dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); + } StmtExprVisitor::VisitExpr(op->body); - dom_map_.erase(op->var.get()); + if (arith::IsIndexType(op->value->dtype)) { + dom_map_.erase(op->var.get()); + } } void VisitStmt_(const IfThenElseNode* op) final { diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 974e59356326..af206ef1862c 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -737,6 +737,21 @@ def func_with_let_binding(): _check(func_with_let_binding, func_with_let_binding) + @T.prim_func + def func_with_non_index_let_binding(): + x1 = T.call_extern("get", dtype="float16") + x2 = T.call_extern("get", dtype="float32") + x3 = T.call_extern("get", dtype="float64") + x4 = T.call_extern("get", dtype="uint8") + x5 = T.call_extern("get", dtype="int32x16") + x6 = T.call_extern("get", dtype="handle") + x7 = T.call_extern("get", dtype="") + A = T.alloc_buffer((64), "float32") + for rk in range(64): + A[rk] = T.call_extern("load_ptr", x1, x2, x3, x4, x5, x6, x7, dtype="float32") + + _check(func_with_non_index_let_binding, func_with_non_index_let_binding) + def test_compact_spatial_tiled_pad_and_pooling(): @T.prim_func