From 5a3523d383e99025db293bdf2ff285091b7a1e56 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 24 Jun 2023 16:58:31 -0700 Subject: [PATCH] [TIR][Schedule] Enhance `compute-inline` for fusion (#15142) This PR enhances `compute-inline` and `reverse-compute-inline` to handle more complicated fusion patterns where the RHS of the equation could be generic expressions rather than strict buffer indexing. This could be used in multiple cases in LLM inference. --- .../meta_schedule/testing/space_generation.py | 2 +- src/tir/schedule/primitive/compute_inline.cc | 262 +++++++------- ..._meta_schedule_schedule_rule_mlt_intrin.py | 3 +- ...test_meta_schedule_schedule_rule_mlt_tc.py | 25 +- .../test_meta_schedule_space_cuda_winograd.py | 57 ++-- .../test_tir_schedule_compute_inline.py | 320 +++++++++++++++++- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 24 +- 7 files changed, 510 insertions(+), 183 deletions(-) diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 45cd6659b6e0..6689e45245e8 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -127,7 +127,7 @@ def check_sketches( def print_sketches(sketches: List[Schedule]): for i, sch in enumerate(sketches): print(f"###### {i}") - sch.mod.show() + sch.mod.show(black_format=False) for inst in sch.trace.insts: if inst in sch.trace.decisions: print(f'("{inst.kind.name}", {sch.trace.decisions[inst]}),') diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index b64351186ac5..31e31294948e 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -22,9 +22,8 @@ namespace tvm { namespace tir { static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of - 'A[i, j, k, ...] = f(i, j, k, ...)', -where the indices on the left are distinct atomic variables, -and there should be no variables other than the index variables)"; + 'A[f(i, j, k, ...)] = g(i, j, k, ...)', +where the store indices mapping f on the left are bijective affine.)"; static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`, @@ -172,20 +171,53 @@ class NonSingleProducerError : public ScheduleError { */ static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, const StmtSRef& scope_root_sref) { - BlockScope scope = self->GetBlockScope(scope_root_sref); - Array producers = scope->GetDepsByDst(consumer_block_sref); - StmtSRef producer_block_sref{nullptr}; - if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) { - producer_block_sref = producers[0]->src; - if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) { - Array consumers = scope->GetDepsBySrc(producer_block_sref); - if (consumers.size() == 1) { - return producer_block_sref; + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); + const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); + Buffer consumer_buffer = NotSingleReadWriteBuffer::GetSingleRead( + self, GetRef(consumer_block), scope_root_sref); + class ProducerFinder : public StmtVisitor { + public: + static std::vector GetProducer(const Buffer& buffer, const Block& scope_block) { + ProducerFinder finder(buffer); + finder(scope_block); + return finder.producer_across_scope_.back(); + } + + private: + explicit ProducerFinder(const Buffer& buffer) : buffer_(buffer) { + producer_across_scope_.push_back({}); + } + + void VisitStmt_(const BlockNode* node) final { + producer_across_scope_.push_back({}); + StmtVisitor::VisitStmt_(node); + // not a leaf block + if (!producer_across_scope_.back().empty()) { + auto producer_under_block = producer_across_scope_.back(); + producer_across_scope_.pop_back(); + producer_across_scope_.back().insert(producer_across_scope_.back().end(), + producer_under_block.begin(), + producer_under_block.end()); + return; + } + // leaf block + producer_across_scope_.pop_back(); + for (const auto& write : node->writes) { + if (write->buffer.same_as(buffer_)) { + producer_across_scope_.back().push_back(GetRef(node)); + break; + } } } + Buffer buffer_; + std::vector> producer_across_scope_; + }; + std::vector producer_across_scope = + ProducerFinder::GetProducer(consumer_buffer, GetRef(scope_block)); + if (producer_across_scope.size() != 1) { + throw NonSingleProducerError(self->mod, GetRef(consumer_block)); } - const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref); - throw NonSingleProducerError(self->mod, GetRef(block)); + return self->stmt2ref.at(producer_across_scope[0].get()); } }; @@ -269,7 +301,7 @@ class BaseInliner : public StmtExprMutator { return StmtExprMutator::VisitStmt_(loop); } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) { CheckMatchBufferRegion(block); AddBuffersInBlockSignature(block); Block src_block = GetRef(block); @@ -284,31 +316,6 @@ class BaseInliner : public StmtExprMutator { return std::move(tgt_block); } - /*! - * \brief Count the number of undefined variables that are not used - * as buffer objects. - * - * This is used to determine whether inlining or reverse inlining is - * possible. The only undefined variables present should be the - * load/store indices, or buffer access based on those indices. - * - * \param stmt The statement in which to count undefined variables - */ - static int GetNumUndefinedNonpointerVars(const Stmt& stmt) { - auto undefined_vars = UndefinedVars(stmt, {}); - // Buffer pointers and the inlined indices are allowed, but no - // other variables may appear in the inlined block. - int num_nonpointer_vars = 0; - for (const auto& var : undefined_vars) { - bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() && - var->type_annotation.as(); - if (!is_pointer) { - num_nonpointer_vars++; - } - } - return num_nonpointer_vars; - } - private: /*! * \brief Add the buffers in the block signature to the `buffer_var_map_`, @@ -406,7 +413,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief Maps a buffer's data field to itself */ Map buffer_var_map_; /*! \brief The indices used for indexing the buffer to be inlined */ - std::vector idx_vars_; + std::vector idx_vars_; /*! \brief The mapping to substitute index variables to PrimExprs */ std::unordered_map idx_sub_; @@ -443,10 +450,62 @@ class ComputeInliner : public BaseInliner { return false; } - int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); - if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { + // Fast path on trivial case: + // Check the store indices are same with the block iters; + store_value_ = inlined_store_->value; + size_t num_iters = producer_block->iter_vars.size(); + size_t buffer_ndim = inlined_store_->indices.size(); + if (num_iters == buffer_ndim) { + std::vector idx_vars; + idx_vars.reserve(num_iters); + for (size_t i = 0; i < num_iters; ++i) { + const IterVar& iter = producer_block->iter_vars[i]; + const PrimExpr& e = inlined_store_->indices[i]; + if (e.same_as(iter->var) || + (analyzer_.CanProveEqual(e, 0) && analyzer_.CanProveEqual(iter->dom->min, 0) && + analyzer_.CanProveEqual(iter->dom->extent, 1))) { + idx_vars.push_back(iter->var); + } else { + break; + } + } + if (idx_vars.size() == num_iters) { + // match success + idx_vars_ = std::move(idx_vars); + return true; + } + } + + // If the mapping for store indices is non-trivial + // check bijective mapping from producer iter var to store indices + Map producer_iter_doms; + for (const auto& iter : producer_block->iter_vars) { + producer_iter_doms.Set(iter->var, iter->dom); + } + arith::IterMapResult res = arith::DetectIterMap( + /*indices=*/inlined_store_->indices, + /*input_iters=*/producer_iter_doms, + /*predicate=*/true, + /*check_level=*/arith::IterMapLevel::Bijective, + /*analyzer=*/&analyzer_, + /*simplify_trivial_iterators=*/false); + if (!res->errors.empty()) { + // Failure: indices of BufferStore are not bijective affine return false; } + idx_vars_.resize(buffer_ndim); + for (size_t i = 0; i < idx_vars_.size(); ++i) { + idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype()); + } + auto inverse_iter_map = arith::InverseAffineIterMap( + res->indices, Array(idx_vars_.begin(), idx_vars_.end())); + for (const auto& iter : producer_block->iter_vars) { + if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) { + // fallback mapping for constant iters + inverse_iter_map.Set(iter->var, iter->dom->min); + } + } + store_value_ = Substitute(store_value_, inverse_iter_map); return true; } @@ -464,45 +523,7 @@ class ComputeInliner : public BaseInliner { PrimExpr ReplaceInlinedBuffer(BufferLoad load) { SetIndexSubstitution(load->indices); - return Substitute(inlined_store_->value, idx_sub_); - } - - /*! - * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. - * If so, set `self->idx_vars_` properly. - * \param indices The indices to be extracted - * \param expected_ndim The expected ndim of the access - * \return A boolean flag indicating if the check is successful - */ - bool UpdateAndCheckIndexVars(const Array& indices, int expected_ndim) { - int n = indices.size(); - if (n != expected_ndim) { - // Failure: dimension mismatch - return false; - } - std::vector result; - result.reserve(n); - for (const PrimExpr& i : indices) { - if (const auto* var = i.as()) { - result.push_back(var); - } else { - // Failure: indexing expression is not a variable - return false; - } - } - using DistinctSet = std::unordered_set; - int n_distinct = DistinctSet(result.begin(), result.end()).size(); - if (n != n_distinct) { - // Failure: indexing variables are not distinct - return false; - } - if (idx_vars_.empty()) { - idx_vars_ = std::move(result); - } else if (!support::ArrayWithSameContent(idx_vars_, result)) { - // Failure: indexing variables are not consitent in different BufferLoads - return false; - } - return true; + return Substitute(store_value_, idx_sub_); } /*! @@ -512,11 +533,17 @@ class ComputeInliner : public BaseInliner { void SetIndexSubstitution(const Array& indices) { ICHECK_EQ(indices.size(), idx_vars_.size()); int n = idx_vars_.size(); - idx_sub_.reserve(n); for (int i = 0; i < n; ++i) { - idx_sub_[idx_vars_[i]] = indices[i]; + idx_sub_[idx_vars_[i].get()] = indices[i]; } } + + /*! \brief The arithmetic analyzer */ + arith::Analyzer analyzer_; + /*! \brief The store value for inlinement. If the producer + store indices are trivial, it is wrt the producer block iter var, + otherwise it is wrt to the placeholder vars of store indices. */ + PrimExpr store_value_; }; /*! @@ -534,7 +561,9 @@ class ReverseComputeInliner : public BaseInliner { private: PrimExpr VisitExpr_(const VarNode* var) final { auto it = self_->idx_sub_.find(var); - ICHECK(it != self_->idx_sub_.end()); + if (it == self_->idx_sub_.end()) { + return GetRef(var); + } return (*it).second; } @@ -552,8 +581,7 @@ class ReverseComputeInliner : public BaseInliner { const StmtSRef& scope_root_sref, const IRModule& mod) : BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref), producer_block_(producer_block), - consumer_block_(consumer_block_realize->block.get()), - mod_(mod) { + consumer_block_(consumer_block_realize->block.get()) { // Initialize the predicates to ensure consumer block iters are in-bound consumer_iter_in_bound_ = Bool(true); for (const IterVar& iter : consumer_block_realize->block->iter_vars) { @@ -596,7 +624,7 @@ class ReverseComputeInliner : public BaseInliner { } } - auto res = arith::DetectIterMap( + arith::IterMapResult res = arith::DetectIterMap( /*indices=*/buffer_load_indices_, /*input_iters=*/consumer_iter_doms, /*predicate=*/true, @@ -609,7 +637,15 @@ class ReverseComputeInliner : public BaseInliner { return false; } - const BufferStoreNode* producer_store = producer_block_->body.as(); + const BufferStoreNode* producer_store = nullptr; + if (const auto* producer_if = producer_block_->body.as()) { + if (producer_if->else_case.defined()) { + return false; + } + producer_store = producer_if->then_case.as(); + } else { + producer_store = producer_block_->body.as(); + } if (producer_store == nullptr) { // Failure: producer block body is not BufferStore return false; @@ -628,39 +664,41 @@ class ReverseComputeInliner : public BaseInliner { using BaseInliner::VisitStmt_; /*! \brief Generate the predicate after inlining based on the consumer predicate */ - PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) { + Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) { // Bind the producer block iter domains for simplification Map subst_map; - for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) { - const IterVar& iter = producer_block_realize->block->iter_vars[i]; + for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { + const IterVar& iter = producer_block->iter_vars[i]; analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); - subst_map.Set(iter->var, producer_block_realize->iter_values[i]); } // Substitute the consumer block iters with the corresponding iters in the producer blocks PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_); // Simplify the predicate using the producer block iter domains predicate = analyzer_.Simplify(predicate); - // Substitute the producer block iters with the its bindings since the predicate in BlockRealize - // should not contain the block iters - predicate = Substitute(predicate, subst_map); - predicate = analyzer_.Simplify(predicate); - return predicate; + ObjectPtr block = make_object(*producer_block); + if (is_one(predicate)) { + return Block(block); + } + if (const auto* if_ = producer_block->body.as()) { + PrimExpr if_predicate = analyzer_.Simplify(if_->condition); + if (!StructuralEqual()(predicate, if_predicate)) { + predicate = analyzer_.Simplify(predicate && if_->condition); + } + block->body = IfThenElse(predicate, if_->then_case); + return Block(block); + } + block->body = IfThenElse(predicate, block->body); + return Block(block); } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize new_block_realize = Downcast(StmtMutator::VisitStmt_(op)); - if (op->block.get() == producer_block_) { - auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get()); - - With ctx(&analyzer_, new_predicate); - if (!analyzer_.CanProve(op->predicate)) { - // We do not allow cases where the new predicate for the inlined block cannot - // imply the original predicate in the producer block. - throw ProducerHasNonTrivialPredicateError(mod_, GetRef(op), new_predicate); - } - new_block_realize.CopyOnWrite()->predicate = new_predicate; + Stmt VisitStmt_(const BlockNode* op) final { + Block src_block = GetRef(op); + Block tgt_block = Downcast(BaseInliner::VisitStmt_(op)); + if (op == producer_block_) { + tgt_block = BuildInlinedConsumerPredicate(tgt_block.get()); + block_reuse.Set(src_block, tgt_block); } - return std::move(new_block_realize); + return std::move(tgt_block); } Stmt VisitStmt_(const BufferStoreNode* _store) final { @@ -774,8 +812,6 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr consumer_iter_in_bound_{nullptr}; /*! \brief The arithmetic analyzer */ arith::Analyzer analyzer_; - /*! \brief The target module, only used for error reporting. */ - const IRModule& mod_; }; void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py index 1fadce6957a3..1f682d8018bc 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -21,6 +21,7 @@ from tvm.meta_schedule.testing.space_generation import ( check_sketches, generate_design_space, + print_sketches, ) from tvm.script import tir as T from tvm.target import Target @@ -276,7 +277,7 @@ def _dense(m, n, k, in_dtype, out_dtype): actual = generate_design_space( kind="cuda", mod=mod, - target=Target("cuda"), + target=Target("cuda --arch=sm_70"), types=None, sch_rules=[ ms.schedule_rule.MultiLevelTilingWithIntrin( diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index e101a63d138b..7cf06b54cac7 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -26,6 +26,7 @@ check_sketches, generate_design_space, get_rules, + print_sketches, ) from tvm.script import tir as T from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group @@ -211,7 +212,7 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[ multi_level_tiling_tensor_core( @@ -362,7 +363,7 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[ multi_level_tiling_tensor_core(), @@ -525,7 +526,7 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[ multi_level_tiling_tensor_core( @@ -545,7 +546,7 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[ multi_level_tiling_tensor_core( @@ -709,7 +710,7 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[ multi_level_tiling_tensor_core( @@ -739,7 +740,7 @@ def test_matmul_relu_non_tensorizable(): (sch,) = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), @@ -848,17 +849,17 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): with T.block("C_reindex_shared"): - v0 = T.axis.spatial(4, T.Add(ax0_0_0_ax1_0_0_fused // 2, 0)) - v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) v2 = T.axis.spatial(2, ax2) v3 = T.axis.spatial(1, 0) v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) - T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) - compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 < 127: + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ @@ -882,7 +883,7 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), @@ -1039,7 +1040,7 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( actual = generate_design_space( kind="cuda", mod=mod, - target=tvm.target.Target("cuda"), + target=tvm.target.Target("cuda --arch=sm_70"), types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py index 27fe47ab8699..a8a50b6f129c 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py @@ -534,33 +534,36 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu T.reads(bgemm_local[v0, v1, v2, v3]) T.writes(bgemm[v0, v1, v2, v3]) bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] - for i0, i1, i2_0, i3_0, ax0, ax1 in T.grid(2, 2048, 25, 38, 1, 1): - for ax2 in T.unroll(2): - for ax3 in T.unroll(2): - for ax4 in T.unroll(4): - for ax5 in T.unroll(4): - with T.block("inverse"): - co = T.axis.spatial(2048, i1 + ax0) - p = T.axis.spatial(1900, i0 * 950 + i2_0 * 38 + i3_0 + ax1) - vh, vw, r_a, r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) - T.reads(bgemm[r_a, r_b, co, p]) - T.writes(inverse_local[co, p, vh, vw]) - T.block_attr({"schedule_rule": "conv2d_nchw_winograd_inverse"}) - with T.init(): - inverse_local[co, p, vh, vw] = T.float32(0) - inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * T.Select(r_a % 4 == 3 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 3 and vh % 2 == 0, T.float32(0), T.Select(r_a % 4 == 2 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 2 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 1 and vh % 2 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 0 and vh % 2 == 1, T.float32(0), T.Select(r_a % 4 == 0 and vh % 2 == 0, T.float32(1), T.float32(0))))))))) * T.Select(r_b % 4 == 3 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 3 and vw % 2 == 0, T.float32(0), T.Select(r_b % 4 == 2 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 2 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 1 and vw % 2 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 0 and vw % 2 == 1, T.float32(0), T.Select(r_b % 4 == 0 and vw % 2 == 0, T.float32(1), T.float32(0))))))))) - for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): - for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): - for i0_i1_i2_i3_fused_0 in range(59): - with T.block("T_add"): - ax0 = T.axis.spatial(2, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) // 7680000) - ax1 = T.axis.spatial(2048, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 7680000 // 3750) - ax2 = T.axis.spatial(50, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 3750 // 75) - ax3 = T.axis.spatial(75, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 75) - T.where((i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + i0_i1_i2_i3_fused_2 < 15360000) - T.reads(inverse_local[ax1, ax0 * 950 + ax2 // 2 * 38 + ax3 // 2, ax2 % 2, ax3 % 2], p2[0, ax1, 0, 0]) - T.writes(T_relu[ax0, ax1, ax2, ax3]) - T_relu[ax0, ax1, ax2, ax3] = T.max(inverse_local[ax1, ax0 * 950 + ax2 // 2 * 38 + ax3 // 2, ax2 % 2, ax3 % 2] + p2[0, ax1, 0, 0], T.float32(0)) + for i0_i1_i2_0_i3_0_fused_1 in T.thread_binding(256, thread="blockIdx.x"): + for i0_i1_i2_0_i3_0_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): + for i0_i1_i2_0_i3_0_fused_0 in range(15): + for ax0, ax1 in T.grid(1, 1): + for ax2 in T.unroll(2): + for ax3 in T.unroll(2): + for ax4 in T.unroll(4): + for ax5 in T.unroll(4): + with T.block("inverse"): + co = T.axis.spatial(2048, (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) % 1945600 // 950 + ax0) + p = T.axis.spatial(1900, (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) // 1945600 * 950 + (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) % 950 + ax1) + vh, vw, r_a, r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) + T.where((i0_i1_i2_0_i3_0_fused_0 * 256 + i0_i1_i2_0_i3_0_fused_1) * 1024 + i0_i1_i2_0_i3_0_fused_2 < 3891200) + T.reads(bgemm[r_a, r_b, co, p]) + T.writes(inverse_local[co, p, vh, vw]) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_inverse"}) + with T.init(): + inverse_local[co, p, vh, vw] = T.float32(0) + inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * T.Select(r_a % 4 == 3 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 3 and vh % 2 == 0, T.float32(0), T.Select(r_a % 4 == 2 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 2 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 1 and vh % 2 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 0 and vh % 2 == 1, T.float32(0), T.Select(r_a % 4 == 0 and vh % 2 == 0, T.float32(1), T.float32(0))))))))) * T.Select(r_b % 4 == 3 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 3 and vw % 2 == 0, T.float32(0), T.Select(r_b % 4 == 2 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 2 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 1 and vw % 2 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 0 and vw % 2 == 1, T.float32(0), T.Select(r_b % 4 == 0 and vw % 2 == 0, T.float32(1), T.float32(0))))))))) + for i2_1, i3_1 in T.grid(2, 2): + with T.block("conv2d_winograd"): + n = T.axis.spatial(2, (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) // 1945600) + co = T.axis.spatial(2048, (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) % 1945600 // 950) + h = T.axis.spatial(50, (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) % 950 // 38 * 2 + i2_1) + w = T.axis.spatial(75, (i0_i1_i2_0_i3_0_fused_0 * 262144 + i0_i1_i2_0_i3_0_fused_1 * 1024 + i0_i1_i2_0_i3_0_fused_2) % 38 * 2 + i3_1) + T.where(((i0_i1_i2_0_i3_0_fused_0 * 256 + i0_i1_i2_0_i3_0_fused_1) * 1024 + i0_i1_i2_0_i3_0_fused_2) % 38 * 2 + i3_1 < 75 and (i0_i1_i2_0_i3_0_fused_0 * 256 + i0_i1_i2_0_i3_0_fused_1) * 1024 + i0_i1_i2_0_i3_0_fused_2 < 3891200) + T.reads(inverse_local[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2], p2[0, co, 0, 0]) + T.writes(T_relu[n, co, h, w]) + T_relu[n, co, h, w] = T.max(inverse_local[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2] + p2[0, co, 0, 0], T.float32(0)) + # fmt: on decision_0 = [ ("SamplePerfectTile", [2, 1, 2, 1, 1]), diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 63df2de23129..8d90189507d7 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys - import pytest import tvm import tvm.testing +import tvm.tir.tensor_intrin from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -622,8 +621,8 @@ def elementwise_overcomputed_producer_reverse_inlined( for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - T.where(i < 127 and j < 127) - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + if vi < 127 and vj < 127: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -650,8 +649,8 @@ def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined( with T.block("B"): vi = T.axis.spatial(128, i // 128) vj = T.axis.spatial(128, i % 128) - T.where(i < 16255 and i % 128 < 127) - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + if vi < 127 and vj < 127: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -676,8 +675,8 @@ def elementwise_overcomputed_producer_injective_load_reverse_inlined( for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): with T.block("B"): vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) - T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127) - C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 + if vi * 16 + vm < 127 and vj * 16 + vn < 127: + C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 @T.prim_func @@ -721,12 +720,13 @@ def elementwise_predicate_producer_inlined(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(C[vi, vj]) - C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1) + if vi < 127: + C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1) # fmt: off @tvm.script.ir_module -class Conv2dInt8_TensorCore_with_predicate: +class Conv2dInt8_TensorCore_with_predicate_before: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer(256, "int32"), p5: T.Buffer(256, "int32"), p6: T.Buffer(256, "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): # function attr dict @@ -845,6 +845,118 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " T.reads(compute_3[i0_13, i1_13, i2_13, i3_13], p9[i0_13, i1_13, i2_13, i3_13]) T.writes(compute[i0_13, i1_13, i2_13, i3_13]) compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(compute_3[i0_13, i1_13, i2_13, i3_13] + T.q_multiply_shift(p9[i0_13, i1_13, i2_13, i3_13], 2101000910, 31, 0, dtype="int32"), 255), 0) + +@tvm.script.ir_module +class Conv2dInt8_TensorCore_with_predicate_after: + @T.prim_func + def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") + pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") + p1_reindex_shared = T.alloc_buffer((1, 1, 256, 64), "int8", scope="shared") + pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer((50176, 64), "int8", scope="wmma.matrix_a") + p1_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 256, 64), "int8", scope="wmma.matrix_b") + for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): + for ax0_ax1_fused in range(1024): + with T.block("pad_temp_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(pad_temp_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) + pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] + for ax0_ax1_ax2_ax3_fused in range(2048): + with T.block("p1_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) + v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(p1[v2, v0, v1, v3]) + T.writes(p1_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) + p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): + for ax0_0_1, ax1_0_1 in T.grid(1, 1): + with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) + T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) + for ax0_1_1, ax1_1_1 in T.grid(16, 16): + with T.block("pad_temp_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) + T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): + with T.block("p1_reindex_shared_wmma.matrix_b_o"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) + v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("p1_reindex_shared_wmma.matrix_b"): + v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): + with T.block("conv2d_nhwc_o"): + v0 = T.axis.reduce(1, 0) + v1 = T.axis.reduce(1, 0) + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) + T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + with T.init(): + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 + for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) + v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) + T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) + T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, T.bool(False), T.bool(True)), 255), 0) - p8[0], 1457846997, 31, 0) + T.q_multiply_shift(p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 2101000910, 31, 0), 255), 0) # fmt: on # pylint: enable=no-member,invalid-name,unused-variable @@ -1148,15 +1260,189 @@ def test_reverse_compute_inline_producer_predicate_disallowed(): implied by the synthesized predicate of the new inlined block. """ - sch = tir.Schedule(Conv2dInt8_TensorCore_with_predicate, debug_mask="all") + sch = tir.Schedule(Conv2dInt8_TensorCore_with_predicate_before, debug_mask="all") + sch.reverse_compute_inline(sch.get_block("compute_4")) + tvm.ir.assert_structural_equal( + Conv2dInt8_TensorCore_with_predicate_after["main"], sch.mod["main"] + ) + - with pytest.raises(tvm.tir.ScheduleError) as e: - sch.reverse_compute_inline(sch.get_block("compute_4")) +def test_compute_inline_softmax(): + # fmt: off + @T.prim_func + def before(p_lv44: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv44[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - assert ( - "that cannot be implied by the synthesized predicate T.bool(True) of the new inlined block" - in str(e) - ) + @T.prim_func + def after(p_lv44: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, m = T.int64(), T.int64() + lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) + var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv44[v_i0, v_i1, v_i2, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv44[v_i0, v_i1, v_i2, v_k], T_softmax_maxelem[v_i0, v_i1, v_i2]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T.exp(lv44[v_i0, v_i1, v_i2, v_k] - T_softmax_maxelem[v_i0, v_i1, v_i2]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1, v_i2]) + T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 3}) + var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) + # fmt: on + + sch = tir.Schedule(before) + sch.compute_inline(sch.get_block("T_softmax_exp")) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + + +def test_reverse_compute_inline_layer_norm(): + # fmt: off + @T.prim_func + def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), scope="shared") + A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), scope="shared") + var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(10)): + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("A_red_temp"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused + ax1) + v_k2 = T.axis.reduce(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) + T.reads(lv6[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax2_0 in range(T.int64(10)): + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_layer_norm"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) + T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) + T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int64() + lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) + var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") + # with T.block("root"): + A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), scope="shared") + A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), scope="shared") + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(10)): + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("A_red_temp"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused + ax1) + v_k2 = T.axis.reduce(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) + T.reads(lv6[v_ax0, v_ax1, v_k2]) + T.writes(A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1]) + with T.init(): + A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0) + A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0 + A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 + for ax2_0 in range(T.int64(10)): + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + with T.block("T_layer_norm"): + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) + T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) + T.writes(var_compute_intermediate[v_ax0, v_ax1, v_ax2]) + var_compute_intermediate[v_ax0, v_ax1, v_ax2] = T.Cast("float16", (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2]) + # fmt: on + + sch = tir.Schedule(before) + sch.reverse_compute_inline(sch.get_block("compute")) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 2eda2b9ec458..2a853a24318c 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -15,35 +15,35 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring +import numpy as np +import pytest import tvm +import tvm.testing from tvm import te +from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_INTRIN, LDMATRIX_16x16_B_TRANS_INTRIN, LDMATRIX_16x32_A_INTRIN, - LDMATRIX_32x16_B_INTRIN, LDMATRIX_16x32_B_TRANS_INTRIN, - MMA_f16f16f32_INTRIN, - MMA_f16f16f32_TRANS_INTRIN, + LDMATRIX_32x16_B_INTRIN, MMA_f16f16f16_INTRIN, MMA_f16f16f16_TRANS_INTRIN, - MMA_i8i8i32_INTRIN, - MMA_i8i8i32_TRANS_INTRIN, - MMA_fill_16x16_f32_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_f16f16f32_TRANS_INTRIN, MMA_fill_16x16_f16_INTRIN, + MMA_fill_16x16_f32_INTRIN, MMA_fill_16x16_i32_INTRIN, - MMA_store_16x16_f32_global_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_i8i8i32_TRANS_INTRIN, MMA_store_16x16_f16_global_INTRIN, + MMA_store_16x16_f32_global_INTRIN, MMA_store_16x16_i32_global_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, - shared_32x16_to_ldmatrix_32x16_layout, shared_16x32_to_ldmatrix_32x16_layout, + shared_32x16_to_ldmatrix_32x16_layout, ) -import tvm.testing -import numpy as np -from tvm.testing.tir import mma_schedule - M = 4096 N = 4096