diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 48014280a558..68900e107d7c 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /*! + * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. + * The layout of the cache will be the same as by the iterators of the block that reads/writes the + * buffer. It requires: + * 1) There is only one block who reads/writes the target buffer + * 2) There is only one buffer load/store of this buffer in the block + * \param block_rv The block operates on the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \return The reindex stage block. + */ + virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index f86228848b9d..4179088aa534 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1056,6 +1056,79 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + @type_checked + def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV: + """Create a block that read/write a buffer region into a read/write cache with reindexing. + The layout of the cache will be the same as by the iterators of the block that reads/writes + the buffer. It requires: + 1) There is only one block who reads/writes the target buffer + 2) There is only one buffer load/store of this buffer in the block + + Parameters + ---------- + block: BlockRV + The block that accesses the target buffer + buffer_index: int + The index of the buffer in block's read or write region + buffer_index_type : str + Type of the buffer index, "read" or "write" + + Returns + ------- + reindex_block : BlockRV + The block of the reindex stage + + Examples + -------- + + Before transform_layout, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] * 2.0 + + Create the schedule and do transform_layout: + + .. code-block:: python + + sch = tir.Schedule(before_reindex) + block = sch.get_block("B") + sch.reindex(block, 0, "read) + + After applying reindex, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_reindex( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + A_reindex = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("A_reindex"): + vi, vj = T.axis.remap("SS", [i, j]) + A_reindex[vi, vj] = A[vj, vi] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_reindex[vi, vj] * 2.0 + + """ + assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 + return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member + self, block, buffer_index, buffer_index_type_enum + ) + ########## Schedule: Compute location ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8066d85a8e7d..d54d7f6021e1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); + TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8e83aac2ce82..70c0265611c3 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 50dedf71ff52..f4dba69c6b15 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); +/*! + *! + * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. + * The layout of the cache will be the same as by the iterators of the block that reads/writes the + * buffer. It requires: + * 1) There is only one block who reads/writes the target buffer + * 2) There is only one buffer load/store of this buffer in the block + * \param self The state of the schedule + * \param block_rv The block operates on the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \return The reindex stage block. + */ +TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type); /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 1bba2ae4fc61..c96f88e1f633 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -160,6 +160,121 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, return block; } +/*! + * \brief Create the reindex block and generate the corresponding outer loops. + * \details The reindex block is a data copy block between the reindex buffer (the intermediate + * buffer), and the target buffer. + If buffer_index_type == kWrite, copy from the reindex buffer to the target buffer. + If buffer_index_type == kRead, copy from the target buffer to the reindex buffer. + The reindex block has the same block iters and the surrounding loops as the input block. + However, if a block iter is not used in the indices of the target buffer being reindexed, the + domain of the block iter, and the corresponding outer loop, will become constant value one, making + it a trivial iter. + * \param block The block to be reindexed + * \param info The cache info + * \param covered The set of block iter vars covered in the buffer access indices + * \param original_indices The original buffer access indices + * \param buffer_index The index of the target buffer + * \param buffer_index_type The type of buffer index + * \return The reindex block. + */ +Block MakeReIndexStage(const Block& block, CacheStageInfo* info, + const std::unordered_set& covered, + const Array& original_indices, int buffer_index, + BufferIndexType buffer_index_type) { + // iters of the reindex block + Array new_block_iters; + // the substition map from the original block iter to the iters of the reindex block + std::unordered_map block_var_replace_map; + // block access region of reindexed buffer and target buffer + Region reindex_region, target_region; + // indices to access the reindex buffer and the target buffer + Array reindex_indices, target_indices; + + // Step 1: Create block iters, access regions of the reindex block, and accessing indices to the + // reindex buffer. + for (const IterVar& iter : block->iter_vars) { + Var var("v" + std::to_string(new_block_iters.size())); + bool used = covered.count(iter->var); + new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom : Range::FromMinExtent(0, 1), + /*var=*/var, + /*IterVarType=*/kDataPar)); + if (used) { + reindex_indices.push_back(var); + reindex_region.push_back(Range::FromMinExtent(var, 1)); + } + block_var_replace_map[iter->var] = var; + } + + // Step 2: Replace the original block iters with the new block iters + BufferRegion buffer_region = buffer_index_type == BufferIndexType::kWrite + ? block->writes[buffer_index] + : block->reads[buffer_index]; + target_region = Substitute(buffer_region->region, block_var_replace_map); + for (const PrimExpr& index : original_indices) { + target_indices.push_back(Substitute(index, block_var_replace_map)); + } + + // Step 3: Create the reindex block + + // The src and the dst region and indices of the data copy + Region src_region{nullptr}; + Region dst_region{nullptr}; + Array src_indices{nullptr}; + Array dst_indices{nullptr}; + + if (buffer_index_type == BufferIndexType::kWrite) { + src_region = reindex_region; + dst_region = target_region; + src_indices = reindex_indices; + dst_indices = target_indices; + } else { + src_region = target_region; + dst_region = reindex_region; + src_indices = target_indices; + dst_indices = reindex_indices; + } + + // Create the body block + Block new_block( + /*iter_vars=*/new_block_iters, + /*reads=*/ + {BufferRegion(info->read_buffer, src_region)}, + /*writes=*/ + {BufferRegion(info->write_buffer, dst_region)}, + /*name_hint=*/buffer_region->buffer->name + "_reindex", + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, src_indices), dst_indices)); + + // Step 4: Create surrounding loops + + // Create loop vars and bindings for block iters + std::vector loop_vars; // loop variables + std::vector iter_values; // bindings in block realize + for (int i = 0; i < static_cast(block->iter_vars.size()); ++i) { + Var loop_var("ax" + std::to_string(loop_vars.size())); + loop_vars.push_back(loop_var); + iter_values.push_back(loop_var); + } + + // Create the block realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/new_block); + + // Create the chain of loops + for (int i = static_cast(new_block_iters.size()) - 1; i >= 0; --i) { + body = For(/*loop_var=*/loop_vars[i], + /*min=*/new_block_iters[i]->dom->min, + /*extent=*/new_block_iters[i]->dom->extent, + /*kind=*/ForKind::kSerial, + /*body=*/std::move(body)); + } + // Update cache info, which will be used in the later rewriting. + info->cache_stage = std::move(body); + return new_block; +} + /*! * \brief Recalculate the `affine_binding` flag of a specifc block * \param block_sref The sref to the specific block @@ -599,6 +714,252 @@ class CacheWriteRewriter : public StmtExprMutator { bool under_writer_block_{false}; }; +/*! + * \brief Create a new buffer by change the shape with block iters to be used as the reindex buffer + * \param buffer The given buffer. + * \param block_iters The block iters. + * \param covered Set of block iter vars covered by the buffer access indices + * \return The new buffer with target shape. + */ +Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_iters, + const std::unordered_set& covered) { + ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_var = make_object(*buffer->data.get()); + std::vector new_shape; + std::vector new_strides; + for (const auto& iter : block_iters) { + if (covered.count(iter->var)) { + new_shape.push_back(iter->dom->min + iter->dom->extent); + } + } + new_strides.clear(); + new_buffer->shape = new_shape; + new_buffer->strides = new_strides; + new_buffer->data = buffer->data.copy_with_suffix("_reindex"); + new_buffer->name = buffer->name + "_reindex"; + return Buffer(new_buffer); +} + +/*! \brief The schedule error that the target is not a leaf block. */ +class NotLeafBlockError : public ScheduleError { + public: + NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} + String FastErrorString() const final { + return "ScheduleError: The target block is not a leaf block."; + } + + String DetailRenderTemplate() const final { return "The target block {0} is not a leaf block."; } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; +}; + +/*! \brief The schedule error that the buffer access is invalid for reindex. */ +class InvalidBufferAccessError : public ScheduleError { + public: + enum class ErrorKind { + kNoAccess, // buffer access not found + kNonUniqueAccess, // multiple buffer accesses with different indices + kOpaqueAccess, // opaque access to the buffer + }; + + InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind) + : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {} + String FastErrorString() const final { + return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The " + "indices should be the same if there are multiple accesses to the target buffer."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_->name + << " should be accessed in the leaf block {0} via BufferLoad or BufferStore. The indices " + "should be the same if there are multiple accesses to the target buffer. "; + if (kind_ == ErrorKind::kNoAccess) { + os << "No buffer accesses found."; + } else if (kind_ == ErrorKind::kNonUniqueAccess) { + os << "Multiple buffer accesses have non-unique indices."; + } else if (kind_ == ErrorKind::kOpaqueAccess) { + os << "Opaque buffer accesses found."; + } + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; + ErrorKind kind_; +}; + +/*! \brief Collect the related Load/Store to reindex */ +class ReIndexCollector : public StmtExprVisitor { + public: + static Array Collect(const IRModule& mod, const Buffer& buffer, const Block& block) { + ReIndexCollector collector(mod, buffer, block); + collector(block->body); + if (!collector.buffer_access_indices_.defined()) { + throw InvalidBufferAccessError(mod, buffer, block, + InvalidBufferAccessError::ErrorKind::kNoAccess); + } + return collector.buffer_access_indices_.value(); + } + + private: + explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const Block& block) + : mod_(mod), buffer_(buffer), block_(block) {} + + void VisitExpr_(const BufferLoadNode* load) final { + StmtExprVisitor::VisitExpr_(load); + if (load->buffer.same_as(buffer_)) { + CheckAndUpdateBufferAccessIndices(load->indices); + } + } + + void VisitStmt_(const BlockNode* block) final { + // no sub-blocks under this block + throw NotLeafBlockError(mod_, block_); + } + + void VisitStmt_(const BufferStoreNode* store) final { + StmtExprVisitor::VisitStmt_(store); + if (store->buffer.same_as(buffer_)) { + CheckAndUpdateBufferAccessIndices(store->indices); + } + } + + void CheckAndUpdateBufferAccessIndices(const Array indices) { + if (!buffer_access_indices_.defined()) { + buffer_access_indices_ = indices; + return; + } else if (!std::equal(buffer_access_indices_.value().begin(), + buffer_access_indices_.value().end(), indices.begin(), indices.end(), + ExprDeepEqual())) { + throw InvalidBufferAccessError(mod_, buffer_, block_, + InvalidBufferAccessError::ErrorKind::kNonUniqueAccess); + } + } + + void VisitExpr_(const VarNode* var) final { + if (var == buffer_->data.get()) { + throw InvalidBufferAccessError(mod_, buffer_, block_, + InvalidBufferAccessError::ErrorKind::kOpaqueAccess); + } + } + /*! \brief The IR module */ + IRModule mod_; + /*! \brief The buffer to rewrite */ + Buffer buffer_; + /*! \brief The block to visit */ + Block block_; + /*! \brief The indices of buffer acess to rewrite */ + Optional> buffer_access_indices_; +}; + +/*! \brief Mutator of ReIndex */ +class ReIndexRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info, + const std::unordered_set& covered) { + ReIndexRewriter rewriter(block_sref, info, covered); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info, + const std::unordered_set& covered) + : block_sref_(block_sref), info_(info), covered_(covered) { + new_buffer_ = info->alloc; + old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + if (is_scope_) { + is_scope_ = false; + Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + // Insert cache stage into the loop + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + info_->block_reuse.Set(old_stmt, stmt); + return stmt; + } + + // Visiting the blokc being reindexed + if (block == block_sref_->stmt) { + // Collect the updated indices and regions + for (const IterVar& iter : block->iter_vars) { + if (covered_.count(iter->var)) { + indices_.push_back(iter->var); + region_.push_back(Range::FromMinExtent(iter->var, 1)); + } + } + Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + // Update block reads/writes to use the intermediate reindex buffer + auto writes = + ReplaceBufferRegion(block->writes, old_buffer_, BufferRegion{new_buffer_, region_}); + auto reads = + ReplaceBufferRegion(block->reads, old_buffer_, BufferRegion{new_buffer_, region_}); + auto match_buffers = ReplaceBufferRegion(block->match_buffers, old_buffer_, + BufferRegion{new_buffer_, region_}); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { + ObjectPtr n = make_object(*stmt.as()); + n->writes = std::move(writes); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + stmt = Block(n); + } + info_->block_reuse.Set(old_stmt, stmt); + return stmt; + } + return old_stmt; + } + + template + Node VisitBufferAccess(Node node) { + if (node->buffer.same_as(old_buffer_)) { + auto* n = node.CopyOnWrite(); + n->buffer = new_buffer_; + n->indices = indices_; + } + return node; + } + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore buffer_store = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(buffer_store)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad buffer_load = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(buffer_load)); + } + + private: + /*! \brief The parent scope of the insertion. */ + const StmtSRef& block_sref_; + /*! \brief The info for inserting reindex stage. */ + CacheStageInfo* info_; + /*! \brief Whether old block var is covered in the indices */ + const std::unordered_set& covered_; + /*! \brief Whether the current block is scope block */ + bool is_scope_{true}; + /*! \brief The buffer to be replaced */ + Buffer old_buffer_; + /*! \brief The reindex buffer */ + Buffer new_buffer_; + /*! \brief The new indices */ + Array indices_; + /*! \brief The new region */ + Region region_; +}; + /******** Implementation ********/ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, @@ -729,6 +1090,80 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Block block = GetRef(block_ptr); + Buffer buffer = + GetNthAccessBuffer(self, block, buffer_index, buffer_index_type == BufferIndexType::kWrite); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + arith::Analyzer analyzer; + + // Step 1. Collect the original indices and check there's only single pattern of related + // Load/Store and the buffer is not accessed opaquely + Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); + // Simplify the indices if possible + for (const IterVar& iter : block->iter_vars) { + analyzer.Bind(iter->var, iter->dom); + } + original_indices.MutateByApply( + [&analyzer](const PrimExpr& expr) { return analyzer.Simplify(expr); }); + + // Collect block iters appearing in the original_indices + std::unordered_set covered; + for (const PrimExpr& index : original_indices) { + PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { + if (const VarNode* var = obj.as()) { + covered.insert(GetRef(var)); + } + return true; + }); + } + + // Step 2. Creating CacheStageInfo + CacheStageInfo info; + // Create the corresponding buffer to be read(write), i.e. the result of reindex read(write) + if (buffer_index_type == BufferIndexType::kWrite) { + info.read_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered); + info.write_buffer = buffer; + info.alloc = info.read_buffer; + } else { + info.read_buffer = buffer; + info.write_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered); + info.alloc = info.write_buffer; + } + + // Step 3. Check the block belongs to a chain loop nesting under the scope, + // and get the insert location + const StmtSRefNode* loop; + for (loop = block_sref->parent; loop->parent != scope_sref.get();) { + const ForNode* outer = loop->parent->StmtAs(); + const ForNode* inner = loop->StmtAs(); + ICHECK(outer != nullptr && inner != nullptr); + ICHECK(outer->body.get() == inner); + loop = loop->parent; + } + + info.loc_pos = loop->seq_index == -1 ? 0 : loop->seq_index; + if (buffer_index_type == BufferIndexType::kWrite) { + info.loc_pos++; + } + + // Step 4. Making new reindex stage block and rewrite + Block reindex_stage = + MakeReIndexStage(block, &info, covered, original_indices, buffer_index, buffer_index_type); + Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info, covered); + + // Step 5. Replacing and updating flags + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + /******** Instruction Registration ********/ struct CacheReadTraits : public UnpackedInstTraits { @@ -787,7 +1222,40 @@ struct CacheWriteTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReIndexTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReIndex"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, + Integer buffer_index_type) { + return sch->ReIndex(block, buffer_index, + static_cast(buffer_index_type->value)); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, + Integer buffer_index_type) { + PythonAPICall py("reindex"); + py.Input("block", block); + py.Input("buffer_index", buffer_index); + py.Input("buffer_index_type", '"' + + std::string(BufferIndexType2Str( + static_cast(buffer_index_type->value))) + + '"'); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index fb884ce77f7b..3880d0b19eeb 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") + .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, + int buffer_index_type) { + return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); + }); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 8156480a4516..d2f627edfd11 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -265,6 +265,18 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) { + BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); + + static const InstructionKind& kind = InstructionKind::Get("ReIndex"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)}, + /*outputs=*/{result})); + return result; +} + /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d1860be9512d..ba4a4b99cbb2 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -73,6 +73,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 79802ecd65db..67d0f55f20b9 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -70,6 +70,32 @@ Array ReplaceBuffer(Array match_buffers, c return match_buffers; } +Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, + const BufferRegion& target) { + regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) -> BufferRegion { + if (region->buffer.same_as(source_buffer)) { + return target; + } + return region; + }); + return regions; +} + +Array ReplaceBufferRegion(Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target) { + match_buffers.MutateByApply([&source_buffer, &target]( + const MatchBufferRegion& match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source_buffer)) { + ObjectPtr n = make_object(*match_buffer.get()); + n->source = target; + return MatchBufferRegion(n); + } + return match_buffer; + }); + return match_buffers; +} + /******** ReplaceBufferMutator ********/ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, Map* block_sref_reuse) diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 192d44d9e9ad..908a823c2d86 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -73,6 +73,27 @@ Array ReplaceBuffer(Array regions, const Buffer& sou Array ReplaceBuffer(Array match_buffers, const Buffer& source, const Buffer& target); +/*! + * \brief Replaces the buffer region within the specific sequence of regions + * \param regions The regions to be replaced + * \param source_buffer The buffer to whose region is to be replaced + * \param target The buffer region to be replaced to + * \return The new sequence of regions after replacement + */ +Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, + const BufferRegion& target); + +/*! + * \brief Replaces the buffer region within the specific sequence of match_buffers + * \param regions The match_buffers to be replaced + * \param source_buffer The buffer to whose region is to be replaced + * \param target The buffer region to be replaced to + * \return The new sequence of match_buffers after replacement + */ +Array ReplaceBufferRegion(Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target); + /*! * \brief A helper mutator which recursively replaces the old buffer with the new buffer and * collects the block sref reuse information for the following replacement. diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py new file mode 100644 index 000000000000..9b2e37a19813 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reindex.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.schedule import ScheduleError +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def transpose_elementwise( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] * 2.0 + + +@T.prim_func +def transpose_elementwise_reindex_read( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + A_reindex = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("A_reindex"): + vi, vj = T.axis.remap("SS", [i, j]) + A_reindex[vi, vj] = A[vj, vi] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_reindex[vi, vj] * 2.0 + + +@T.prim_func +def conv2d_nhwc( + Input: T.Buffer[(1, 224, 224, 3), "float32"], + Weight: T.Buffer[(7, 7, 3, 64), "float32"], + Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), + Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + Conv2d_nhwc[n, h, w, co] = T.float32(0) + Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( + PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co, 64) * 3) + rc)] + * Weight[rh, rw, rc, co] + ) + + +@T.prim_func +def conv2d_nhwc_reindex_weight( + var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle +) -> None: + inputs = T.match_buffer(var_inputs, [1, 224, 224, 3], dtype="float32") + weight = T.match_buffer(var_weight, [7, 7, 3, 64], dtype="float32") + conv2d_nhwc = T.match_buffer(var_conv2d_nhwc, [1, 112, 112, 64], dtype="float32") + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, + inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], + T.float32(0), + dtype="float32", + ) + for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 64, 7, 7, 3): + with T.block("weight_reindex"): + v0, v1, v2, v3, v4, v5, v6 = T.axis.remap( + "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6] + ) + T.reads(weight[v4, v5, v6, v3]) + T.writes(weight_reindex[v3, v4, v5, v6]) + weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads( + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], + weight_reindex[co, rh, rw, rc], + ) + T.writes(conv2d_nhwc[n, h, w, co]) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float32(0) + conv2d_nhwc[n, h, w, co] = ( + conv2d_nhwc[n, h, w, co] + + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] + * weight_reindex[co, rh, rw, rc] + ) + + +@T.prim_func +def matmul( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], +) -> None: + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("matmul"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + +@T.prim_func +def matmul_reindex_write( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], +) -> None: + C_reindex = T.alloc_buffer([512, 512], dtype="float32") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("matmul"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C_reindex[i, j], A[i, k], B[k, j]) + T.writes(C_reindex[i, j]) + with T.init(): + C_reindex[i, j] = T.float32(0) + C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j] + for i0, i1, i2 in T.grid(512, 512, 1): + with T.block("C_reindex"): + v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(C_reindex[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_reindex[v0, v1] + + +@T.prim_func +def multiple_read(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + A[vi, vj] + + +def test_reindex_read_basic(): + sch = tir.Schedule(transpose_elementwise) + block = sch.get_block("B") + sch.reindex(block, 0, "read") + tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=transpose_elementwise) + + +def test_conv2d_reindex_read(): + sch = tir.Schedule(conv2d_nhwc) + block = sch.get_block("conv2d_nhwc") + sch.reindex(block, 1, "read") + tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) + + +def test_matmul_reindex_write(): + sch = tir.Schedule(matmul) + block = sch.get_block("matmul") + sch.reindex(block, 0, "write") + tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=matmul) + + +def test_reindex_fail_multiple_read(): + sch = tir.Schedule(multiple_read) + block = sch.get_block("B") + with pytest.raises(ScheduleError): + sch.reindex(block, 0, "read") + + +if __name__ == "__main__": + tvm.testing.main()