diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index f7face272de5..59548634fc4a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -522,6 +522,11 @@ def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: if isinstance(dom, ir.Range): return dom if isinstance(dom, (list, tuple)): + from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel + + extent = Analyzer().simplify(dom[1] - dom[0]) + if isinstance(extent, tir.IntImm): + return ir.Range.from_min_extent(dom[0], extent) return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): return ir.Range(IntImm(dom.dtype, 0), dom) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 31815fc71060..2709bd2f94ca 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include #include +#include "../../support/array.h" #include "../../tir/ir/functor_common.h" #include "../../tir/transforms/ir_utils.h" #include "../schedule/graph.h" @@ -180,30 +182,97 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { "workload"}; }; -BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, - const Array& tensors, Array bindings, - PrimExpr expr_body, CreateFuncInfo* info, - arith::Analyzer* analyzer) { - // Step 1. Push_back data_par axis and reduce_axis into block_vars. - Array iter_vars; - std::unordered_map var_map; - iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size()); - auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array& iters) { - for (IterVar iter_var : iters) { - // Create new var - Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype); - var_map[iter_var->var.get()] = new_var; - - PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min); - PrimExpr dom_extent = analyzer->Simplify(iter_var->dom->extent); - iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var, - iter_var->iter_type, iter_var->thread_tag, iter_var->span)); +/**! + * \brief The iter levels specify nested structure wrt iteration domain dependencies. + * (1) Each iter should reside in exactly one level. + * (2) The domain of low level iter should be either free or ony depend on iters in high level. + **/ +using NestedIterLevels = std::vector>; + +NestedIterLevels GenerateNestedIterLevels(const Array& axes, arith::Analyzer* analyzer) { + int global_max_depth = 0; + std::unordered_map depth; + std::unordered_map var2iter; + for (const auto& axis : axes) { + var2iter[axis->var] = axis; + } + + std::function traverse = [&](const IterVar& axis) -> int { + auto depth_it = depth.find(axis->var); + if (depth_it != depth.end()) { // cache + return depth_it->second; + } + std::vector dep_vars; + for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->min))) { + dep_vars.push_back(v); } + for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->extent))) { + dep_vars.push_back(v); + } + int cur_depth = 0; + for (const Var& v : dep_vars) { + auto it = var2iter.find(v); + if (it == var2iter.end()) { + // not axis var dependency, maybe a symbolic shape var or others. + continue; + } + int depth = traverse(it->second); + cur_depth = std::max(cur_depth, depth + 1); + } + depth.emplace_hint(depth_it, axis->var, cur_depth); + global_max_depth = std::max(global_max_depth, cur_depth); + return cur_depth; }; - f_push_block_vars(compute_op->axis); - f_push_block_vars(compute_op->reduce_axis); - // Step 2. + for (const auto& axis : axes) { + traverse(axis); + } + NestedIterLevels levels; + levels.resize(global_max_depth + 1); + for (const auto& axis : axes) { + const Var& var = axis->var; + levels[depth[var]].push_back(axis); + } + return levels; +} + +/*! + * \brief Generate output buffers from compute op's output tensors, and bind to context func info. + * \param compute_op The target compute op. + * \param info Generation context info. + * \returns The output buffer objects, ordered by compute op's outputs. + **/ +Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncInfo* info) { + // Step 1. Collect output tensors in TE operation. + Array tensors; + if (compute_op->body[0]->IsInstance()) { + auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { + StructuralEqual eq; + return eq(a->combiner, b->combiner) && // + eq(a->source, b->source) && // + eq(a->axis, b->axis) && // + eq(a->condition, b->condition) && // + eq(a->init, b->init); + }; + PrimExpr expr_body = compute_op->body[0]; + tensors.push_back(compute_op.output(0)); + const tir::ReduceNode* reduce = expr_body.as(); + // specially handle reduction inline for multiplre reductions. + for (size_t k = 1; k < compute_op->body.size(); ++k) { + const tir::ReduceNode* reduce_ = compute_op->body[k].as(); + ICHECK(reduce_); + ICHECK(f_reducer_equal(reduce_, reduce)) + << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " + << "but the first argument has body " << GetRef(reduce_) << ", while the " << k + << "-th argument has body " << GetRef(reduce); + tensors.push_back(compute_op.output(k)); + } + } else { + for (size_t k = 0; k < compute_op->body.size(); ++k) { + tensors.push_back(compute_op.output(k)); + } + } + // Step 2. Prepare buffers for compute outputs // - Declare buffers // - Update `op2buffers` // - Add the non-argument tensors to `alloc_buffer` of the root block @@ -212,32 +281,94 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; buffers.push_back(buffer); - if (!info->IsArg(tensor)) { info->root_alloc.push_back(info->tensor2buffers[tensor]); } } + return buffers; +} - // Step 3. Calculate indices for BufferStore - Array indices; - indices.reserve(compute_op->axis.size()); - for (const IterVar& iter_var : compute_op->axis) { - auto it = var_map.find(iter_var->var.get()); - ICHECK(it != var_map.end()); - indices.push_back(it->second); +/*! + * \brief Generate block annotation dict from compute op attrs. + * \param compute_op The target compute op. + * \param info Generation context info. + * \returns The block annotation dict. + **/ +Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, + CreateFuncInfo* info) { + Map annotations; + auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef { + if (auto tensor_value = value.as()) { + return info->tensor2buffers.at(tensor_value.value()); + } else { + return value; + } + }; + for (const auto& pair : compute_op->attrs) { + const String& key = pair.first; + const ObjectRef& value = pair.second; + // TensorIR will not allow Tensor data structure + if (value->IsInstance()) { + const auto array_value = Downcast>(value); + annotations.Set(key, array_value.Map(mutate_attr)); + } else { + annotations.Set(key, mutate_attr(value)); + } } + // Set script_parsing_detect_access + annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); + return annotations; +} - // Step 4. Create block body. +/*! + * \brief Generate init stmt for reduction. + * \param indices Target store indices for the block. + * \param buffers Target store buffers for the block. + * \param reduce Reduce description node. + * \param var_map Var re-mapping for TE compute axes. + * \param info Generation context info. + * \returns Init stmt. + **/ +Stmt GenerateInitStmt(const Array& indices, const Array& buffers, + const ReduceNode* reduce, const Map& var_map, + CreateFuncInfo* info) { // helper to transform the expr and remap iters to the block domain auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); }; - String block_name{nullptr}; Optional init = NullOpt; Stmt body; + int n_buffers = buffers.size(); + Array init_stmts; + init_stmts.reserve(n_buffers); + for (int i = 0; i < n_buffers; ++i) { + const Buffer& buffer = buffers[i]; + PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]); + init_stmts.push_back(BufferStore(buffer, identity, indices)); + } + return SeqStmt::Flatten(init_stmts); +} + +/*! + * \brief Generate body execution stmt. + * \param indices Target store indices for the block. + * \param buffers Target store buffers for the block. + * \param var_map Var re-mapping for TE compute axes. + * \param expr_body Target computation expression. + * \param info Generation context info. + * \param analyzer Arithmetic analyzer in context. + * \returns Init stmt. + **/ +Stmt GenerateBodyStmt(const Array& indices, const Array& buffers, + const Map& var_map, PrimExpr expr_body, CreateFuncInfo* info, + arith::Analyzer* analyzer) { + // helper to transform the expr and remap iters to the block domain + auto f_transform_and_remap = [&](const PrimExpr& e) { + return Substitute(info->transformer(e), var_map); + }; + Stmt body; if (const auto* reduce = expr_body.as()) { // Case 1. Reduce compute - block_name = info->FreshName(compute_op->name); int n_buffers = buffers.size(); Array lhs; @@ -258,10 +389,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, Array temp_vars; Array body_stmts; - Array init_stmts; temp_vars.reserve(n_buffers); body_stmts.reserve(n_buffers); - init_stmts.reserve(n_buffers); // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs, // rhs)" into the target buffer position. @@ -270,8 +399,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // then store the value of the variables into the target buffer positions. for (int i = 0; i < n_buffers; ++i) { const Buffer& buffer = buffers[i]; - PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]); - init_stmts.push_back(BufferStore(buffer, identity, indices)); PrimExpr value{nullptr}; if (n_buffers > 1) { temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype()))); @@ -282,8 +409,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, } body_stmts.push_back(BufferStore(buffer, value, indices)); } - - init = SeqStmt::Flatten(init_stmts); body = SeqStmt::Flatten(body_stmts); if (n_buffers > 1) { // When there are multiple buffers, we wrap the body with LetStmts. @@ -294,116 +419,198 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, } } else { // Case 2. Data parallel compute - ICHECK_EQ(tensors.size(), 1); - block_name = info->FreshName(tensors[0]->GetNameHint()); + ICHECK_EQ(buffers.size(), 1); const PrimExpr& compute_body = f_transform_and_remap(expr_body); - body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices); + body = BufferStore(buffers[0], analyzer->Simplify(compute_body), indices); } + return std::move(body); +} - // Step 5. Add script_parsing_detect_access attr for auto complete the whole IR. - Map annotations; - auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef { - if (auto tensor_value = value.as()) { - return info->tensor2buffers.at(tensor_value.value()); - } else { - return value; +/*! \brief Record loops, block vars and binding in the single level scope. */ +struct NestedScopeInfo { + // loop var and range in the scope. + std::vector> loop_vars; + // block iters for current level's block. + Array block_iters; + // block bindings for current level's block. + Array bindings; + // store indices for current level's block. + Array store_indices; + // mapping from original TE compute axes to new block vars. + Map axes_remap; + + // helper to add new block var + void AddBlockIter(const Optional& origin_axis, const IterVar& iter, + const PrimExpr& value) { + block_iters.push_back(iter); + bindings.push_back(value); + if (origin_axis.defined()) { + if (iter->iter_type != IterVarType::kCommReduce) { + store_indices.push_back(iter->var); + } + axes_remap.Set(origin_axis.value()->var, iter->var); } - }; + } - for (const auto& pair : compute_op->attrs) { - const String& key = pair.first; - const ObjectRef& value = pair.second; - // TensorIR will not allow Tensor data structure - if (value->IsInstance()) { - const auto array_value = Downcast>(value); - annotations.Set(key, array_value.Map(mutate_attr)); - } else { - annotations.Set(key, mutate_attr(value)); + // helper to renew leaf block var defs to ensure SSA. + void Renew(const Array& origin_axes) { + block_iters.MutateByApply([](const IterVar& itervar) { + auto n = make_object(*itervar.get()); + n->var = n->var.copy_with_suffix(""); + return IterVar(n); + }); + for (size_t i = 0; i < origin_axes.size(); ++i) { + Var block_var = block_iters[i]->var; + if (origin_axes[i]->iter_type != IterVarType::kCommReduce) { + store_indices.Set(i, block_var); + } + axes_remap.Set(origin_axes[i]->var, block_var); } } - // Set script_parsing_detect_access - annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); - if (iter_vars.empty()) { - IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), IterVarType::kDataPar); - PrimExpr binding(0); - iter_vars.push_back(iter); - bindings.push_back(binding); - } - - // Step 6. Create Block and BlockRealize. - return BlockRealize(/*iter_values=*/std::move(bindings), - /*predicate=*/Bool(true), - /*block=*/ - Block(/*iter_vars=*/std::move(iter_vars), - /*reads=*/{}, - /*writes=*/{}, - /*name_hint=*/block_name, - /*body=*/std::move(body), - /*init=*/std::move(init), - /*alloc_buffers=*/{}, - /*match_buffers=*/{}, - /*annotations=*/std::move(annotations))); -} +}; Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, arith::Analyzer* analyzer) { - // Step 1. Creating loop vars for block bindings. + // Step 1. Collect all iter axes in original TE compute op Array axes = compute_op->axis; axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); - Array bindings = axes.Map([&](IterVar iter_var) -> PrimExpr { - int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits()); - return Var(iter_var->var->name_hint, runtime::DataType::Int(bits)); - }); + // Step 2. Prepare nested iteration scopes. + // For each axis, we generate loop and the first block binding at the level it belongs to. + // In lower levels, we just create new block var and bind it to the previous level block var. + auto axes_levels = GenerateNestedIterLevels(axes, analyzer); + ICHECK(!axes_levels.empty()); + std::vector scopes; + scopes.reserve(axes_levels.size()); + std::unordered_set defined_axes; + for (size_t i = 0; i < axes_levels.size(); ++i) { + NestedScopeInfo cur_scope; + for (size_t j = 0; j < axes.size(); ++j) { + const IterVar& axis = axes[j]; + DataType index_type = + DataType::Int(std::max(axis->dom->min.dtype().bits(), axis->dom->extent.dtype().bits())); + bool first_times_define = + std::find(axes_levels[i].begin(), axes_levels[i].end(), axis) != axes_levels[i].end(); + if (first_times_define) { + Var loop_var = Var(axis->var->name_hint, index_type); + Var block_var("v_" + axis->var->name_hint, index_type); + PrimExpr min = axis->dom->min; + PrimExpr extent = axis->dom->extent; + if (i > 0) { + const auto& scope_repl = scopes[i - 1].axes_remap; + min = Substitute(min, scope_repl); + extent = Substitute(extent, scope_repl); + } + Range dom = Range::FromMinExtent(analyzer->Simplify(min), analyzer->Simplify(extent)); + IterVar new_block_iter(dom, block_var, axis->iter_type, axis->thread_tag, axis->span); + cur_scope.loop_vars.emplace_back(loop_var, dom); + cur_scope.AddBlockIter(axis, new_block_iter, loop_var); + defined_axes.insert(axis->var); + } else if (defined_axes.count(axis->var)) { + ICHECK_GT(i, 0); + ICHECK(scopes[i - 1].axes_remap.count(axis->var)); + PrimExpr prev_binding = scopes[i - 1].axes_remap.at(axis->var); + Var block_var("v_" + axis->var->name_hint, index_type); + Range dom = Range::FromMinExtent(prev_binding, make_const(index_type, 1)); + IterVar new_block_iter(dom, block_var, axis->iter_type, axis->thread_tag, axis->span); + cur_scope.AddBlockIter(axis, new_block_iter, prev_binding); + } + } + if (i == axes_levels.size() - 1 && cur_scope.block_iters.empty()) { + // for the leaf scope, we ensure at least one block var exists + IterVar dummy(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), + IterVarType::kDataPar); + cur_scope.AddBlockIter(NullOpt, dummy, 0); + } + scopes.push_back(cur_scope); + } - // Step 2. Generate block bodies. - Array seq_stmt; - if (compute_op->body[0]->IsInstance()) { - auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { - StructuralEqual eq; - return eq(a->combiner, b->combiner) && // - eq(a->source, b->source) && // - eq(a->axis, b->axis) && // - eq(a->condition, b->condition) && // - eq(a->init, b->init); - }; + // Step 3. Generate output buffers for each output tensor + Array buffers = GenerateOutputBuffers(compute_op, info); + // Step 4. Generate leaf block stmts. + Array seq_stmt; + auto leaf = scopes.back(); + Map annotations = GenerateBlockAnnotations(compute_op, info); + const ReduceNode* reduce = compute_op->body[0].as(); + if (reduce) { PrimExpr expr_body = compute_op->body[0]; - Array tensors = {compute_op.output(0)}; - const tir::ReduceNode* reduce = expr_body.as(); - // specially handle reduction inline for multiplre reductions. - for (size_t k = 1; k < compute_op->body.size(); ++k) { - const tir::ReduceNode* reduce_ = compute_op->body[k].as(); - ICHECK(reduce_); - ICHECK(f_reducer_equal(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " - << "but the first argument has body " << GetRef(reduce_) << ", while the " << k - << "-th argument has body " << GetRef(reduce); - tensors.push_back(compute_op.output(k)); - } + Stmt init = GenerateInitStmt(leaf.store_indices, buffers, reduce, leaf.axes_remap, info); + Stmt body = + GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap, expr_body, info, analyzer); + seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings, + /*predicate=*/Bool(true), + /*block=*/ + Block(/*iter_vars=*/leaf.block_iters, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/info->FreshName(compute_op->name), + /*body=*/body, + /*init=*/init, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations))); - seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body), - info, analyzer)); } else { for (int i = 0; i < compute_op->num_outputs(); ++i) { - const te::Tensor& tensor = compute_op.output(i); + if (i > 0) { + // Renew block var defs to ensure SSA + leaf.Renew(axes); + } PrimExpr expr_body = compute_op->body[i]; - seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings, - std::move(expr_body), info, analyzer)); + Stmt body = GenerateBodyStmt(leaf.store_indices, {buffers[i]}, leaf.axes_remap, expr_body, + info, analyzer); + seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings, + /*predicate=*/Bool(true), + /*block=*/ + Block(/*iter_vars=*/leaf.block_iters, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/info->FreshName(buffers[i]->name), + /*body=*/body, + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations))); } } - Stmt body = SeqStmt::Flatten(seq_stmt); - // Step 3. Generate loop nesting. - for (size_t i = axes.size(); i > 0; --i) { - const IterVar& axis = axes[i - 1]; - PrimExpr dom_min = analyzer->Simplify(axis->dom->min); - PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent); - const Var& loop_var = Downcast(bindings[i - 1]); - body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body); - } + // Step 4. Generate nested parent scopes. + for (size_t i = scopes.size(); i > 0; --i) { + const auto& cur = scopes[i - 1]; + if (i < scopes.size()) { + auto block_name = info->FreshName(compute_op->name + "_l" + std::to_string(i)); + const auto& block_iters = cur.block_iters; + + Optional init{NullOpt}; + if (reduce && std::any_of(block_iters.begin(), block_iters.end(), [](const IterVar& iter) { + return iter->iter_type == IterVarType::kCommReduce; + })) { + // if the reduce axis defined in non-leaf scopes, the nested block is also + // a reduction block, thus we should also insert init stmt in the parent level. + init = GenerateInitStmt(cur.store_indices, buffers, reduce, cur.axes_remap, info); + } + // wrap nested block + body = BlockRealize(/*iter_values=*/cur.bindings, + /*predicate=*/Bool(true), + /*block=*/ + Block(/*iter_vars=*/block_iters, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/block_name, + /*body=*/body, + /*init=*/init, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations)); + } + for (size_t j = cur.loop_vars.size(); j > 0; --j) { + const auto& [loop_var, dom] = cur.loop_vars[j - 1]; + body = For(loop_var, dom->min, dom->extent, ForKind::kSerial, body); + } + } return body; } diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 1a7e03188a25..0fb64e8d0f32 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -45,8 +45,12 @@ def test_unique_name_reduction_block(): assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef) -def _check_workload(te_workload, tir_workload, index_dtype_override=None): +def _check_workload(te_workload, tir_workload, index_dtype_override=None, do_simplify=False): func = te.create_prim_func(te_workload(), index_dtype_override) + if do_simplify: + simplify = tir.transform.Simplify() + func = simplify(tvm.IRModule.from_expr(func))["main"] + tir_workload = simplify(tvm.IRModule.from_expr(tir_workload))["main"] tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func s = tir.Schedule(func, debug_mask="all") @@ -887,5 +891,102 @@ def te_workload(): _check_workload(te_workload, tir_workload) +def test_adaptive_pooling_window(): + @T.prim_func + def tir_workload( + x: T.Buffer((1, 1024, 16, 40), "float32"), + adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"}) + # fmt: off + adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30)) + for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): + with T.block("adaptive_pool_sum_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)]) + T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) + for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30): + with T.block("adaptive_pool_sum"): + v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0) + v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1) + v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2) + v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3) + v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1]) + T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]) + T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1]) + with T.init(): + adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0) + adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1] + for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): + with T.block("adaptive_pool_avg"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) + adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30)) + # fmt: on + + def te_workload(): + x = te.placeholder([1, 1024, 16, 40], "float32", "x") + y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg") + f = te.create_prim_func([x, y]) + return [x, y] + + _check_workload(te_workload, tir_workload) + + +def test_nested_reduce_domain_dependency(): + @T.prim_func + def tir_workload( + x: T.Buffer((8, 8, 8, 8, 8), "float32"), compute: T.Buffer((8, 8, 8), "float32") + ): + T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"}) + for i0, i1, i2 in T.grid(8, 8, 8): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(x[v_i0, v_i1, v_i2, 0:v_i1, 0 : v_i1 - 1]) + T.writes(compute[v_i0, v_i1, v_i2]) + for rv in range(v_i1): + with T.block("compute_1"): + v_i0_1 = T.axis.spatial((v_i0, v_i0 + 1), v_i0) + v_i1_1 = T.axis.spatial((v_i1, v_i1 + 1), v_i1) + v_i2_1 = T.axis.spatial((v_i2, v_i2 + 1), v_i2) + v_rv = T.axis.reduce(v_i1, rv) + T.reads(x[v_i0_1, v_i1_1, v_i2_1, v_rv, 0:v_rv]) + T.writes(compute[v_i0_1, v_i1_1, v_i2_1]) + with T.init(): + compute[v_i0_1, v_i1_1, v_i2_1] = T.float32(0.0) + for rv_1 in range(v_rv): + with T.block("compute"): + v_i0_2 = T.axis.spatial((v_i0_1, v_i0_1 + 1), v_i0_1) + v_i1_2 = T.axis.spatial((v_i1_1, v_i1_1 + 1), v_i1_1) + v_i2_2 = T.axis.spatial((v_i2_1, v_i2_1 + 1), v_i2_1) + v_rv_1 = T.axis.reduce((v_rv, v_rv + 1), v_rv) + v_rv_2 = T.axis.reduce(v_rv, rv_1) + T.reads(x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, v_rv_2]) + T.writes(compute[v_i0_2, v_i1_2, v_i2_2]) + with T.init(): + compute[v_i0_2, v_i1_2, v_i2_2] = T.float32(0.0) + compute[v_i0_2, v_i1_2, v_i2_2] = ( + compute[v_i0_2, v_i1_2, v_i2_2] + + x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, v_rv_2] + ) + + def te_workload(): + x = te.placeholder([8, 8, 8, 8, 8], "float32", "x") + + def fcompute(*axes): + r1 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, axes[1])) + r2 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, r1)) + all_axes = [*axes, r1, r2] + return te.sum(x(*all_axes), [r1, r2]) + + y = te.compute([8, 8, 8], fcompute) + f = te.create_prim_func([x, y]) + return [x, y] + + _check_workload(te_workload, tir_workload) + + if __name__ == "__main__": tvm.testing.main()