From 34f48f3f10c32ba802070fb1b7be411682dce0be Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 13 Jan 2022 16:14:44 -0500 Subject: [PATCH 01/18] [TIR] Add software pipelining Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai --- include/tvm/tir/stmt.h | 6 + include/tvm/tir/transform.h | 32 + python/tvm/tir/transform/transform.py | 11 + src/driver/driver_api.cc | 2 + .../transforms/inject_software_pipeline.cc | 776 ++++++++++++++++++ src/tir/transforms/ir_utils.h | 17 + .../transforms/tensorcore_infer_fragment.cc | 23 +- ..._tir_transform_inject_software_pipeline.py | 754 +++++++++++++++++ 8 files changed, 1607 insertions(+), 14 deletions(-) create mode 100644 src/tir/transforms/inject_software_pipeline.cc create mode 100644 tests/python/unittest/test_tir_transform_inject_software_pipeline.py diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0a05439b2341..697f89caae66 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1361,6 +1361,12 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ */ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; +/*! \brief Mark the stage of a statement in the software pipeline */ +constexpr const char* software_pipeline_stage = "software_pipeline_stage"; + +/*! \brief Mark the order of a statement in the software pipeline */ +constexpr const char* software_pipeline_order = "software_pipeline_order"; + /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 3a964eb77d1b..bcb1804ba551 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -492,6 +492,38 @@ TVM_DLL Pass ConvertForLoopsToSerial(); */ TVM_DLL Pass UnifiedStaticMemoryPlanner(); +/*! + * \brief Transform annotated loops into pipelined one that ovarlaps producers and consumers. + * + * This pass detects loops with the software pipeline annotations and rewrite them to pipelined + * ones. The behavior of such rewriting depending on two annotations on the loop, + * attr::software_pipeline_stage, and attr::software_pipeline_order, which defines the stage and the + * order, respectively, of the components of the software pipeline. The components of the software + * pipeline is the direct children (ignoring BlockRealize / Block / SeqStmt) of the annotated loop. + * The value of the both annotations should be array of integers, with its size the same as the + * number of the components. + * + * The result of the rewriting is a block that has three blocks as its direct children which + * represents the prologue, the body, and the epilogue of the software pipeline. In the prologue, + * only components whose stage is less than max_stage will be executed. In the epilogue, only + * components whose stage is greater than 0 will be executed. In the body, all the components will + * be executed. Such rewriting enables behavior like prefetching, the components are not necessarily + * executed in the original order. attr::software_pipeline_order defines the order of the each + * component. Components belong to different stages can be reordered. + * + * Buffer allocated inside the software pipeline may be resized to accommodate multiple versions + * of the original buffer. Block annotation attr::double_buffer_scope can be used to indicate that + * the block need to write in the double-buffering style. + * + * Annotations: + * attr::software_pipeline_stage: Array of non-negative integers, each element should be in range + * [0, max_stage], where max_stage is the maximum (inclusive) stage. + * attr::software_pipeline_order: Array of non-negative integers, should be a permutation of + * [0, 1, ..., num_components - 1]. + * + * \return The IR transform pass. + */ + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 834335766551..d734416570f8 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -749,3 +749,14 @@ def ConvertForLoopsToSerial(): The result pass """ return _ffi_api.ConvertForLoopsToSerial() # type: ignore + + +def InjectSoftwarePipeline(): + """Transform annotated loops into pipelined one that parallelize producers and consumers + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectSoftwarePipeline() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e750344f4f0c..cdc8902e393d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -247,6 +247,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(transform::PrintIR()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc new file mode 100644 index 000000000000..4d19128c220f --- /dev/null +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -0,0 +1,776 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inject_software_pipeline.cc + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers + */ +#include +#include +#include + +#include "../../support/utils.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" + +namespace tvm { +namespace tir { + +namespace software_pipeline { + +/*! + * \brief Create a block and infer the access region with the given body. + * + * The result is a opaque block that doesn't contain any block iter vars. In case the body is a + * block realize without predicate, it is unnecessary to create a new block, the block of the block + * realize will be returned. + * + * \param body The body of the block. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \return The result block. + */ +Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) { + if (const BlockRealizeNode* block_realize = body.as()) { + if (is_one(block_realize->predicate)) { + // no need to create a new block + return block_realize->block; + } + } + Block block = Block({}, {}, {}, "", body); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + auto* n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + return block; +} + +/*! Structure that represents the stage and order of the software pipeline component. */ +struct PipelineStageOrder { + int stage; + int order; + explicit PipelineStageOrder(int stage, int order) : stage(stage), order(order) {} +}; + +using PipelineInfo = std::unordered_map; + +struct BufferAccessInfo { + int def; // the defining stage of the buffer + int use; // the last using stage of the buffer + BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use){}; +}; + +/*! + * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices + * of the remapped buffer to select the version corresponding to the pipeline stage. + */ +class PipelineBodyRewriter : public StmtExprMutator { + public: + /*! + * \brief Constructor of PipelineBodyRewriter. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param buffer_remap The map from original buffer to the buffer with updated shape for + * multi-versioning in the sofeware pipeline. + * \param pipeline_loop The original loop to be software pipelined. + * \param access_all_versions Whether all versions the the buffers in the software pipeline are + * accessed. This will be used to update block access region. In the prologue and epilogue + * of a two-stage software pipeline, only one version of these buffers are accessed. + */ + PipelineBodyRewriter(const Map& buffer_data_to_buffer, + const Map& buffer_remap, For pipeline_loop, + bool access_all_versions) + : buffer_data_to_buffer_(buffer_data_to_buffer), + buffer_remap_(buffer_remap), + pipeline_loop_(pipeline_loop), + access_all_versions_(access_all_versions) {} + + private: + BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { + auto it = buffer_remap_.find(buffer_region->buffer); + if (it != buffer_remap_.end()) { + Region new_region = buffer_region->region; + const Buffer& new_buffer = (*it).second; + // For pipeline buffers, relax the access region of the first dimension to full extent + // if access_all_versions == true + Range accessed_version = + access_all_versions_ + ? Range::FromMinExtent(0, new_buffer->shape[0]) + : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), + new_buffer->shape[0]), + Integer(1)); + new_region.insert(new_region.begin(), accessed_version); + return BufferRegion(new_buffer, new_region); + } + return buffer_region; + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + BlockNode* n = block.CopyOnWrite(); + n->reads.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + n->writes.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(alloc_buffer->data); + } + return std::move(block); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return std::move(store); + } + const Buffer& new_buffer = (*it).second; + auto* n = store.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return std::move(load); + } + const Buffer& new_buffer = (*it).second; + auto* n = load.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(load); + } + + int GetWmmaFragmentSize(const Buffer& buffer) { + const FragmentInfo& info = fragment_info_.at(buffer->data.get()); + String scope = buffer.scope(); + if (scope == "wmma.matrix_a") { + return info.m * info.k; + } else if (scope == "wmma.matrix_b") { + return info.n * info.k; + } else if (scope == "wmma.accumulator") { + return info.m * info.n; + } else { + ICHECK(0); + throw; + } + } + + PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, + const PrimExpr& old_index) { + PrimExpr new_buffer_offset = old_index; + + int fragment_size = GetWmmaFragmentSize(old_buffer); + PrimExpr offset = + floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), old_buffer->shape), + fragment_size); + new_buffer_offset += + floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; + return new_buffer_offset; + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // Intrinsic calls should be handled explicitly here as they are opaque accesses to + // buffer. + static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); + static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); + static const auto& mma_sync = builtin::tvm_mma_sync(); + static const auto& access_ptr = builtin::tvm_access_ptr(); + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); + return Call(call->dtype, call->op, new_args, call->span); + } + } else if (call->op.same_as(mma_sync)) { + Array new_args = call->args; + for (int i = 0; i < 4; i++) { + const Var& buffer_var = Downcast(call->args[i * 2]); + const PrimExpr& index = call->args[i * 2 + 1]; + const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); + new_args.Set(i * 2 + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } else if (call->op.same_as(access_ptr)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[1])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[2]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + new_args.Set(2, new_index); + return Call(call->dtype, call->op, new_args, call->span); + } + } + return std::move(call); + } + + Map buffer_data_to_buffer_; + Map buffer_remap_; + For pipeline_loop_; + bool access_all_versions_; + std::unordered_map fragment_info_; +}; + +/*! + * \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one. + */ +class PipelineRewriter : public StmtExprMutator { + public: + static Stmt Rewrite( + Map buffer_data_to_buffer, + const std::unordered_set& double_buffers, + const Array pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info, + const std::unordered_map& fragment_info) { + PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, + pipeline_info, fragment_info); + return rewriter.BuildPipeline(); + } + + private: + PipelineRewriter(Map buffer_data_to_buffer, + const std::unordered_set& double_buffers, + const Array& pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info, + const std::unordered_map& fragment_info) + + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + double_buffers_(double_buffers), + pipeline_allocs_(pipeline_allocs), + pipeline_loop_(pipeline_loop), + pipeline_info_(pipeline_info), + fragment_info_(fragment_info) {} + + Stmt BuildPipeline() { + // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions + // need to maintain for each buffer. + RemapPipelineBuffers(pipeline_allocs_); + + ordered_stmts_.resize(pipeline_info_.size()); + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int order = pair.second.order; + ordered_stmts_.Set(order, block); + } + + // Step 2: Emit the pipeline prologue, body and epilogue. + Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); + Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false); + Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true); + + SeqStmt stmt = SeqStmt({prologue, body, epilogue}); + + // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. + Array alloc_buffers; + for (const auto& alloc : pipeline_allocs_) { + auto it = buffer_remap_.find(alloc); + if (it != buffer_remap_.end()) { + alloc_buffers.push_back((*it).second); + } else { + alloc_buffers.push_back(alloc); + } + buffer_data_to_buffer_.erase(alloc->data); + } + Block block = MakeBlock(stmt, buffer_data_to_buffer_); + auto* n = block.CopyOnWrite(); + n->alloc_buffers = std::move(alloc_buffers); + return BlockRealize({}, Bool(true), block); + } + + private: + /*! + * \brief Analyze accesses to the buffers in the software pipeline. + * + * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which + * can be used to compute the number of versions needed to maintain after rewriting. + */ + std::unordered_map + GetBufferAccessInfo() { + std::unordered_map infos; + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int stage = pair.second.stage; + max_stage_ = std::max(max_stage_, stage); + + for (const BufferRegion& write : block->writes) { + if (!infos.count(write->buffer)) { + infos.emplace(write->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(write->buffer); + if (info.def == -1) { + info.def = stage; + } else { + info.def = std::min(info.def, stage); + } + } + + for (const BufferRegion& read : block->reads) { + if (!infos.count(read->buffer)) { + infos.emplace(read->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(read->buffer); + info.use = std::max(info.use, stage); + } + } + return infos; + } + + /*! + * \brief Check whether two regions have intersections. + * \param region1 The first region. + * \param region2 The second region. + * \return Whether region1 and region2 have intersections. + */ + bool MayConflict(Region region1, Region region2) { + ICHECK(region1.size() == region2.size()); + for (size_t i = 0; i < region1.size(); i++) { + Range dim1 = region1[i]; + Range dim2 = region2[i]; + auto int_set1 = arith::IntSet::FromRange(dim1); + auto int_set2 = arith::IntSet::FromRange(dim2); + if (arith::Intersect({int_set1, int_set2}).IsNothing()) { + return false; + } + } + return true; + } + + /*! + * \brief Compute the number of versions need to maintain for buffer accessed in the software + * pipeline. + * + * This method applies liveness analysis to the target buffer to compute the number of versions + * need to maintain during the software pipeline. + * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the + * result of the analysis. Additional double buffering in the software pipeline can be useful + * to eliminate synchonizations in GPU devices. + * + * \param buffer The target buffer + * \param buffer_info The access information of the target buffer. + * \return The number of versions required for the target buffer. + */ + int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) { + if (buffer_info.def == -1) { + // Keep the original number of versions as buffers defined outside the software pipeline + // should not be mutated. + return 1; + } + + // `use - def + 1` is a upper bound of the needed versions + // We optimize a few case where the number of versions can be smaller than the upper bound + int num_versions = buffer_info.use - buffer_info.def + 1; + if (num_versions == 2) { + // A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when + // these exists a reader block_i and a writer block_j such that + // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions + // of block_i and block_j overlap. + bool need_multi_version = false; + for (const auto& pair1 : pipeline_info_) { + const Block& writer_block = pair1.first; + const auto& writer_info = pair1.second; + + auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), + [&](const BufferRegion& buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it1 == writer_block->writes.end()) { + continue; + } + + for (const auto& pair2 : pipeline_info_) { + const Block& reader_block = pair2.first; + const auto& reader_info = pair2.second; + auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), + [&](const BufferRegion& buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it2 == reader_block->reads.end()) { + continue; + } + if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage && + MayConflict((*it1)->region, (*it2)->region)) { + need_multi_version = true; + break; + } + } + } + if (!need_multi_version) { + num_versions = 1; + } + } + if (num_versions == 1 && double_buffers_.count(buffer)) { + num_versions = 2; + } + return num_versions; + } + + /*! + * \brief Rewrite buffer allocations to create new buffers with new shapes according to + * the software pipeline. + * \param pipeline_allocs The buffer allocations inside the software pipeline scope. + */ + void RemapPipelineBuffers(Array pipeline_allocs) { + std::unordered_map infos = + GetBufferAccessInfo(); + for (const Buffer& buffer : pipeline_allocs) { + const BufferAccessInfo access_info = infos.at(buffer); + int num_versions = ComputeBufferVersions(buffer, access_info); + if (num_versions > 1) { + Buffer new_buffer = RewriteAllocBuffer(buffer, num_versions); + buffer_remap_.Set(buffer, new_buffer); + } + } + } + + /*! + * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined + * accesses. + * \param buffer The buffer to be resized. + * \param num_versions The number of versions to keep. + * \return The resized buffer. + */ + Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { + ObjectPtr new_buffer = make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), num_versions); + if (new_buffer->strides.size()) { + ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); + } + return Buffer(new_buffer); + } + + /*! + * \brief Emit the pipeline loop in the given range. + * \param start The start of the range + * \param end The end of the range + * \param unroll_loop Whether the loop should be unrolled. + * \return The result loop. + */ + Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { + Array stmts; + PrimExpr new_loop_var; + PrimExpr extent = end - start; + if (!analyzer_.CanProve(extent > 0)) { + return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), buffer_data_to_buffer_)); + } + bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); + if (is_unit_loop) { + new_loop_var = start; // use constants as the loop var for unit loops + } else { + new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); + analyzer_.Bind(Downcast(new_loop_var), Range(start, end)); + } + + for (const Block block : ordered_stmts_) { + int stage = pipeline_info_.at(block).stage; + PrimExpr skewed_loop_var = new_loop_var - stage; + PrimExpr inbound = (skewed_loop_var >= pipeline_loop_->min) && + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); + inbound = analyzer_.Simplify(inbound); + if (analyzer_.CanProve(!inbound)) { + continue; + } + Block new_block = Downcast(PipelineBodyRewriter( + buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block)); + Map subst_map; + if (is_unit_loop) { + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); + } else { + // normalize loop range + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min)); + } + new_block = Downcast(Substitute(new_block, subst_map)); + stmts.push_back(BlockRealize({}, inbound, new_block)); + } + + Stmt new_loop{nullptr}; + + if (stmts.empty()) { + new_loop = Evaluate(0); + } else if (stmts.size() == 1) { + new_loop = stmts[0]; + } else { + new_loop = SeqStmt(stmts); + } + + if (!is_unit_loop) { + new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); + } + return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); + ; + } + + arith::Analyzer analyzer_; + Map buffer_data_to_buffer_; + const std::unordered_set& double_buffers_; + Array pipeline_allocs_; + For pipeline_loop_; + PipelineInfo pipeline_info_; + std::unordered_map fragment_info_; + int max_stage_ = -1; + Map buffer_remap_; + Array ordered_stmts_; +}; + +class PipelineInjector : private StmtExprMutator { + public: + static Stmt Inject(const PrimFunc& func) { + PipelineInjector injector; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + injector.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body); + return injector(func->body); + } + + private: + PipelineInjector() = default; + + /*! + * \brief Check the pipeline satisfies the following conditions: + * 1) No conflicting order: The order of each statement should be unique. + * 2) No reordering with the same stage: Statements in the same stage are not allowed to be + * reordered. + */ + void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { + std::unordered_set used_orders; + std::unordered_map stage_max_order; + for (const Block& block : original_order) { + const auto& stmt_info = pipeline_info.at(block); + int stage = stmt_info.stage; + int order = stmt_info.order; + CHECK(!used_orders.count(order)) + << "ValueError: Two statements in the software pipeline cannot have the same order"; + used_orders.insert(order); + CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order) + << "ValueError: Statements in the same stage of the software pipeline must have " + "increasing order."; + stage_max_order[stage] = order; + } + } + + Stmt VisitStmt_(const ForNode* op) final { + // Step 1: Recursively rewrite the children first. + For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); + bool is_pipeline = HasPipelineAnnotation(op); + if (!is_pipeline) { + return std::move(for_node); + } + // Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of + // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the + // child of the block. + Stmt pipeline_body{nullptr}; + Array pipeline_allocs; + if (const auto* realize = for_node->body.as()) { + const auto& block = realize->block; + for (const auto& buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body = block->body; + pipeline_allocs = block->alloc_buffers; + } else { + pipeline_body = for_node->body; + } + + const SeqStmtNode* pipeline_body_seq = pipeline_body.as(); + CHECK(pipeline_body_seq) + << "ValueError: The body of the software pipeline should be SeqStmt, got " + << pipeline_body->GetTypeKey(); + // The SeqStmt before recursive rewriting, which is used for validating the software pipeline. + const SeqStmtNode* original_seq = + op->body->IsInstance() + ? op->body.as()->block->body.as() + : op->body.as(); + ICHECK(original_seq); + + // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be + // converted into a block. + PipelineInfo pipeline_info; + Array original_order; // pipeline body blocks in the original order + + auto f_add_child = [&](const Stmt& child) { + const auto* block_realize = child.as(); + Block block = (block_realize && is_one(block_realize->predicate)) + ? block_realize->block + : MakeBlock(child, buffer_data_to_buffer_); + original_order.push_back(block); + }; + for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { + const auto* nested_block_realize = pipeline_body_seq->seq[i].as(); + if (nested_block_realize && is_one(nested_block_realize->predicate) && + nested_block_realize->block->body->IsInstance()) { + const Block& nested_pipeline_block = nested_block_realize->block; + ICHECK( + nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered + for (const auto& buffer : nested_pipeline_block->alloc_buffers) { + pipeline_allocs.push_back(buffer); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + const auto* nested_seq = nested_pipeline_block->body.as(); + for (size_t j = 0; j < nested_seq->seq.size(); j++) { + f_add_child(nested_seq->seq[j]); + } + } else { + f_add_child(pipeline_body_seq->seq[i]); + } + } + + auto pipeline_stages = + Downcast>(op->annotations.at(attr::software_pipeline_stage)); + auto pipeline_orders = + Downcast>(op->annotations.at(attr::software_pipeline_order)); + CHECK_EQ(pipeline_stages.size(), original_order.size()); + CHECK_EQ(pipeline_orders.size(), original_order.size()); + for (size_t i = 0; i < pipeline_stages.size(); i++) { + PipelineStageOrder stage_order(pipeline_stages[i]->value, pipeline_orders[i]->value); + pipeline_info.emplace(original_order[i], stage_order); + } + ValidatePipelineBody(pipeline_info, original_order); + + // Step 4: Rewrite the pipeline body. + Stmt pipeline = + PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, pipeline_allocs, + GetRef(op), pipeline_info, fragment_info_); + + if (const auto* realize = op->body.as()) { + const auto& block = realize->block; + for (const auto& buffer : block->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + } + return pipeline; + } + + /*! + * \brief Add buffer allocations to a block and update the write region of the block. + * \param n The block pointer to which the buffer allocations are added. + * \param alloc_buffers The buffer allocations to be added. + */ + void AddAllocBuffers(BlockNode* n, const Array alloc_buffers) { + for (const Buffer& alloc_buffer : alloc_buffers) { + n->alloc_buffers.push_back(alloc_buffer); + Region region; + region.reserve(alloc_buffer->shape.size()); + for (const PrimExpr& dim : alloc_buffer->shape) { + region.push_back(Range::FromMinExtent(0, dim)); + } + n->writes.push_back(BufferRegion(alloc_buffer, region)); + } + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const auto& buffer : op->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + auto it = op->annotations.find(attr::double_buffer_scope); + if (it != op->annotations.end()) { + int buffer_index = Downcast((*it).second); + CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size()) + << "ValueError: Index of the buffer exceeds the size of the write regions of the block. (" + << buffer_index << " vs. " << op->writes.size() << ")"; + double_buffers.insert(op->writes[buffer_index]->buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + return std::move(block); + } + + bool HasPipelineAnnotation(const ForNode* op) const { + auto it1 = op->annotations.find(attr::software_pipeline_stage); + auto it2 = op->annotations.find(attr::software_pipeline_order); + bool has_stage = it1 != op->annotations.end(); + bool has_order = it2 != op->annotations.end(); + if (has_stage && has_order) { + return true; + } + if (has_stage) { + LOG(FATAL) << "ValueError: Order of the software pipeline is not defined."; + } + if (has_order) { + LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined."; + } + return false; + } + + Map buffer_data_to_buffer_; + std::unordered_map fragment_info_; + std::unordered_set double_buffers; +}; + +} // namespace software_pipeline + +namespace transform { + +/*! + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. + * \return The IR transform pass. + */ +Pass InjectSoftwarePipeline() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* fptr = f.CopyOnWrite(); + fptr->body = software_pipeline::PipelineInjector::Inject(f); + fptr->body = ConvertSSA(std::move(fptr->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index da52a82a2f08..610270b5e7e9 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -267,6 +267,23 @@ class ConditionalBoundsContext { std::unordered_map origin_map_; }; +// Information of tensor core fragment. +struct FragmentInfo { + // fragment shape + int m, n, k; + // fragment layout (row-major or column-major) + std::string layout; + FragmentInfo() = default; + FragmentInfo(int _m, int _n, int _k, const std::string& _layout) + : m(_m), n(_n), k(_k), layout(_layout) {} +}; + +/*! + * \brief Extract information of tensor core fragment from the IR. + * \param stmt The stmt to visit. + * \return Map from buffer variables to the fragment info. + */ +std::unordered_map GetTensorCoreFragmentInfo(const Stmt& stmt); } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 1836b8ecec0d..b14eff5973c8 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -39,17 +39,6 @@ namespace tir { // Get fragment information from tensor intrinsics class FragmentGetter : public StmtExprVisitor { public: - // fragment metadata - struct FragmentInfo { - // fragment shape - int m, n, k; - // fragment layout (row-major or column-major) - std::string layout; - FragmentInfo() = default; - FragmentInfo(int _m, int _n, int _k, const std::string& _layout) - : m(_m), n(_n), k(_k), layout(_layout) {} - }; - void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); @@ -126,6 +115,12 @@ class FragmentGetter : public StmtExprVisitor { std::unordered_map fragments; }; +std::unordered_map GetTensorCoreFragmentInfo(const Stmt& stmt) { + FragmentGetter getter; + getter(stmt); + return std::move(getter.fragments); +} + // Check shape of fragment making sure it is a valid shape for tvm_mma_sync class FragmentChecker : public StmtExprVisitor { public: @@ -157,8 +152,8 @@ class FragmentChecker : public StmtExprVisitor { bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) { ICHECK(fragment_getter.fragments.count(buffer1)); ICHECK(fragment_getter.fragments.count(buffer2)); - FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1); - FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2); + FragmentInfo info1 = fragment_getter.fragments.at(buffer1); + FragmentInfo info2 = fragment_getter.fragments.at(buffer2); return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; } // Fragment infomation @@ -175,7 +170,7 @@ class InferFragmenter : public StmtMutator { const VarNode* buffer = op->buffer_var.get(); if (fragment_getter.fragments.count(buffer)) { // Add attribute to fragments allocation - FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); + FragmentInfo info = fragment_getter.fragments.at(buffer); // Add shape attribute to all fragments std::string shape = diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py new file mode 100644 index 000000000000..87d1143c9e34 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm import tir, te, TVMError +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod['main'].script()) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_error(func): + mod = tvm.IRModule.from_expr(func) + with pytest.raises(ValueError): + tvm.tir.transform.InjectSoftwarePipeline()(mod) + + +@T.prim_func +def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], 'software_pipeline_order': [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0]) + T.writes(C[tx, 0]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0]) + T.writes(B[0, tx, 0]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads() + T.writes() + T.evaluate(0) + with T.block(): + T.reads(B[0, tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[0, tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1], 'software_pipeline_order': [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16]]) + T.writes([C[tx, 0:16]]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0]]) + T.writes([B[0, tx, 0]]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16], B[0:2, tx, 0]]) + T.writes([B[0:2, tx, 0], C[tx, 0:15]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1]]) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads([B[i % 2, tx, 0]]) + T.writes([C[tx, i]]) + C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 0]]) + T.writes([C[tx, 15]]) + C[tx, 15] = B[1, tx, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, + annotations={"software_pipeline_stage": [0, 1, 1, 1], + "software_pipeline_order": [0, 1, 2, 3]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1] + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 0]]) + T.writes([B[0, tx, i, 0]]) + B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 0]]) + T.writes([B[0, tx, 15, 0]]) + B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_prefetch_inner(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 1, 1], "software_pipeline_order": [0, 2, 1, 3]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_prefetch_inner(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]]) + T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_interleaving(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_interleaving(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.block_attr({"double_buffer_scope": 0}) + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[0, 0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[0:2, tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0:2, 0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( + 2 + ) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[(i + 1) % 2, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_incorrect_reorder(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [0, 2, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_conflicting_order(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [ 0, 1, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + + +def test_simple_compute(): + _check(simple_compute, transformed_simple_compute) + + +def test_trivial_pipeline(): + _check(trivial_pipeline, transformed_trivial_pipeline) + + +def test_nest_pipeline_simple(): + _check(nested_pipeline_simple, transformed_nested_pipeline_simple) + + +def test_nest_pipeline_prefetch_inner(): + _check(nested_pipeline_prefetch_inner, transformed_nested_pipeline_prefetch_inner) + + +def test_nest_pipeline_interleaving(): + _check(nested_pipeline_interleaving, transformed_nested_pipeline_interleaving) + + +def test_nest_pipeline_double_buffer(): + _check(nested_pipeline_double_buffer, transformed_nested_pipeline_double_buffer) + + +def test_error_reorder(): + _check_error(simple_compute_incorrect_reorder) + + +def test_error_conflicting_order(): + _check_error(simple_compute_conflicting_order) + + +def test_error_missing_annotation(): + _check_error(simple_compute_missing_annotation) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 0355e49ab94fc6b6920b96d319124426f742151f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 25 Jan 2022 18:45:08 -0500 Subject: [PATCH 02/18] fix --- include/tvm/tir/transform.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index bcb1804ba551..a01259920629 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -523,6 +523,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * * \return The IR transform pass. */ +TVM_DLL Pass InjectSoftwarePipeline(); } // namespace transform } // namespace tir From 8b062153ca283f667fdf6c129841730d17b1cacd Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 25 Jan 2022 18:46:35 -0500 Subject: [PATCH 03/18] fix --- src/driver/driver_api.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cdc8902e393d..24b2bd3eb1c1 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -247,7 +247,6 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(transform::PrintIR()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); From d17e94c8b1f0854856ced4ab2e1099ae33a85719 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 25 Jan 2022 18:50:43 -0500 Subject: [PATCH 04/18] lint --- src/tir/transforms/ir_utils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 610270b5e7e9..61f2536df90a 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -284,6 +284,7 @@ struct FragmentInfo { * \return Map from buffer variables to the fragment info. */ std::unordered_map GetTensorCoreFragmentInfo(const Stmt& stmt); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ From 5c497a8c19622fc2e113474bbaf2e9edc1c000e3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 25 Jan 2022 20:35:51 -0500 Subject: [PATCH 05/18] fix --- .../transforms/inject_software_pipeline.cc | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 4d19128c220f..94a4b2b66f4d 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -90,14 +90,17 @@ class PipelineBodyRewriter : public StmtExprMutator { * \param access_all_versions Whether all versions the the buffers in the software pipeline are * accessed. This will be used to update block access region. In the prologue and epilogue * of a two-stage software pipeline, only one version of these buffers are accessed. + * \param fragment_info Information about tensor core fragment */ PipelineBodyRewriter(const Map& buffer_data_to_buffer, const Map& buffer_remap, For pipeline_loop, - bool access_all_versions) + bool access_all_versions, + const std::unordered_map& fragment_info) : buffer_data_to_buffer_(buffer_data_to_buffer), buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), - access_all_versions_(access_all_versions) {} + access_all_versions_(access_all_versions), + fragment_info_(fragment_info) {} private: BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { @@ -166,7 +169,9 @@ class PipelineBodyRewriter : public StmtExprMutator { } int GetWmmaFragmentSize(const Buffer& buffer) { - const FragmentInfo& info = fragment_info_.at(buffer->data.get()); + auto it = fragment_info_.find(buffer->data.get()); + ICHECK(it != fragment_info_.end()); + const FragmentInfo& info = (*it).second; String scope = buffer.scope(); if (scope == "wmma.matrix_a") { return info.m * info.k; @@ -250,7 +255,7 @@ class PipelineBodyRewriter : public StmtExprMutator { Map buffer_remap_; For pipeline_loop_; bool access_all_versions_; - std::unordered_map fragment_info_; + const std::unordered_map& fragment_info_; }; /*! @@ -495,8 +500,11 @@ class PipelineRewriter : public StmtExprMutator { Array stmts; PrimExpr new_loop_var; PrimExpr extent = end - start; + + auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; + if (!analyzer_.CanProve(extent > 0)) { - return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), buffer_data_to_buffer_)); + return make_nop(); } bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); if (is_unit_loop) { @@ -515,8 +523,9 @@ class PipelineRewriter : public StmtExprMutator { if (analyzer_.CanProve(!inbound)) { continue; } - Block new_block = Downcast(PipelineBodyRewriter( - buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block)); + Block new_block = Downcast(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, + pipeline_loop_, max_stage_ != 1, + fragment_info_)(block)); Map subst_map; if (is_unit_loop) { subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); @@ -531,8 +540,9 @@ class PipelineRewriter : public StmtExprMutator { Stmt new_loop{nullptr}; if (stmts.empty()) { - new_loop = Evaluate(0); - } else if (stmts.size() == 1) { + return make_nop(); + } + if (stmts.size() == 1) { new_loop = stmts[0]; } else { new_loop = SeqStmt(stmts); @@ -543,7 +553,6 @@ class PipelineRewriter : public StmtExprMutator { unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); } return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); - ; } arith::Analyzer analyzer_; @@ -552,7 +561,7 @@ class PipelineRewriter : public StmtExprMutator { Array pipeline_allocs_; For pipeline_loop_; PipelineInfo pipeline_info_; - std::unordered_map fragment_info_; + const std::unordered_map& fragment_info_; int max_stage_ = -1; Map buffer_remap_; Array ordered_stmts_; From 46f4df3807d2e7ef637960b228a997ca4b103773 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 26 Jan 2022 13:17:47 -0500 Subject: [PATCH 06/18] format --- ..._tir_transform_inject_software_pipeline.py | 123 ++++++++++++++---- 1 file changed, 97 insertions(+), 26 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 87d1143c9e34..a71ee8a00f28 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -27,7 +27,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) mod = tvm.tir.transform.Simplify()(mod) - print(mod['main'].script()) + print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], transformed, True) @@ -40,7 +40,9 @@ def _check_error(func): @T.prim_func def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], 'software_pipeline_order': [0, 1]}): + for i in T.serial( + 0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]} + ): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) @@ -56,7 +58,9 @@ def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "floa @T.prim_func -def transformed_trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]) -> None: +def transformed_trivial_pipeline( + A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"] +) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0]) @@ -79,7 +83,11 @@ def transformed_trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(1 @T.prim_func def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1], 'software_pipeline_order': [0, 1]}): + for i in T.serial( + 0, + 16, + annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}, + ): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) @@ -95,7 +103,9 @@ def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "floa @T.prim_func -def transformed_simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: +def transformed_simple_compute( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16]]) @@ -124,11 +134,18 @@ def transformed_simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16 @T.prim_func -def nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): +def nested_pipeline_simple( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, - annotations={"software_pipeline_stage": [0, 1, 1, 1], - "software_pipeline_order": [0, 1, 2, 3]}): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1, 1, 1], + "software_pipeline_order": [0, 1, 2, 3], + }, + ): with T.block(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) @@ -143,7 +160,7 @@ def nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16 16, annotations={ "software_pipeline_stage": [0, 1], - "software_pipeline_order": [0, 1] + "software_pipeline_order": [0, 1], }, ): with T.block(): @@ -161,7 +178,9 @@ def nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16 @T.prim_func -def transformed_nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: +def transformed_nested_pipeline_simple( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16, 0:16]]) @@ -236,9 +255,18 @@ def transformed_nested_pipeline_simple(A: T.Buffer[(16, 16, 16), "float32"], C: @T.prim_func -def nested_pipeline_prefetch_inner(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): +def nested_pipeline_prefetch_inner( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 1, 1], "software_pipeline_order": [0, 2, 1, 3]}): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 1, 1], + "software_pipeline_order": [0, 2, 1, 3], + }, + ): with T.block(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) @@ -271,7 +299,9 @@ def nested_pipeline_prefetch_inner(A: T.Buffer[(16, 16, 16), "float32"], C: T.Bu @T.prim_func -def transformed_nested_pipeline_prefetch_inner(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: +def transformed_nested_pipeline_prefetch_inner( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16, 0:16]]) @@ -349,9 +379,18 @@ def transformed_nested_pipeline_prefetch_inner(A: T.Buffer[(16, 16, 16), "float3 @T.prim_func -def nested_pipeline_interleaving(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): +def nested_pipeline_interleaving( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 0, 1, 1], + "software_pipeline_order": [0, 2, 3, 1, 4], + }, + ): with T.block(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) @@ -390,7 +429,9 @@ def nested_pipeline_interleaving(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buff @T.prim_func -def transformed_nested_pipeline_interleaving(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: +def transformed_nested_pipeline_interleaving( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16, 0:16]]) @@ -497,9 +538,18 @@ def transformed_nested_pipeline_interleaving(A: T.Buffer[(16, 16, 16), "float32" @T.prim_func -def nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): +def nested_pipeline_double_buffer( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 0, 1, 1], + "software_pipeline_order": [0, 2, 3, 1, 4], + }, + ): with T.block(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) @@ -539,7 +589,9 @@ def nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buf @T.prim_func -def transformed_nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: +def transformed_nested_pipeline_double_buffer( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16, 0:16]]) @@ -650,9 +702,18 @@ def transformed_nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32 @T.prim_func -def simple_compute_incorrect_reorder(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): +def simple_compute_incorrect_reorder( + A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [0, 2, 1]}): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1, 1], + "software_pipeline_order": [0, 2, 1], + }, + ): with T.block(): T.reads(A[tx, i]) T.writes(D[tx, i]) @@ -673,9 +734,18 @@ def simple_compute_incorrect_reorder(A: T.Buffer[(16, 16), "float32"], D: T.Buff @T.prim_func -def simple_compute_conflicting_order(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): +def simple_compute_conflicting_order( + A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [ 0, 1, 1]}): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1, 1], + "software_pipeline_order": [0, 1, 1], + }, + ): with T.block(): T.reads(A[tx, i]) T.writes(D[tx, i]) @@ -696,7 +766,9 @@ def simple_compute_conflicting_order(A: T.Buffer[(16, 16), "float32"], D: T.Buff @T.prim_func -def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): +def simple_compute_missing_annotation( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): with T.block(): @@ -713,7 +785,6 @@ def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"], C: T.Buf C[tx, i] = B[tx, 0] + T.float32(1) - def test_simple_compute(): _check(simple_compute, transformed_simple_compute) From dc5e81ec01ea61a84f9bb5632200a0e42da9d530 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 26 Jan 2022 13:41:07 -0500 Subject: [PATCH 07/18] doc --- include/tvm/tir/transform.h | 79 ++++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a01259920629..8e017ba8a27f 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -511,15 +511,84 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * executed in the original order. attr::software_pipeline_order defines the order of the each * component. Components belong to different stages can be reordered. * + * Nested software pipelines are allowed. In this case, the inner software pipeline will be + * generated first. As a result, this may affect the number of components, i.e. the number of the + * direct children of the outer loop. In this case, the annotations for the outer software + * pipeline should include the result of the inner software pipeline, which is three blocks as + * discussed above. + * * Buffer allocated inside the software pipeline may be resized to accommodate multiple versions * of the original buffer. Block annotation attr::double_buffer_scope can be used to indicate that * the block need to write in the double-buffering style. * - * Annotations: - * attr::software_pipeline_stage: Array of non-negative integers, each element should be in range - * [0, max_stage], where max_stage is the maximum (inclusive) stage. - * attr::software_pipeline_order: Array of non-negative integers, should be a permutation of - * [0, 1, ..., num_components - 1]. + * The following annotations are used to specify the behavior of this pass: + * attr::software_pipeline_stage: Array of non-negative integers, each element should be in + * range [0, max_stage], where max_stage is the maximum + * (inclusive) stage. + * attr::software_pipeline_order: Array of non-negative integers, should be a permutation of + * [0, 1, ..., num_components - 1]. + * attr::double_buffer_scope: Integer index of the write regions of the block. Mark a buffer + * should be double-buffered during the software pipelining. + * + * Example: + * + * Before this pass, the TIR is: + * + * @T.prim_func + * def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + * for i in T.serial(0, 16, + * annotations={"software_pipeline_stage": [0, 1], + * "software_pipeline_order": [0, 1]} + * ): + * with T.block(): + * T.reads(A[tx, i]) + * T.writes(C[tx, i]) + * B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + * with T.block("B"): + * T.reads(A[tx, i]) + * T.writes(B[tx, 0]) + * B[tx, 0] = A[tx, i] * T.float32(2) + * with T.block("C"): + * T.reads(B[tx, 0]) + * T.writes(C[tx, i]) + * C[tx, i] = B[tx, 0] + T.float32(1) + * + * The TIR above annotate the loop as a two-stage pipeline, the components are not reordered. + * After this pass, the TIR is: + * + * @T.prim_func + * def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + * with T.block(): + * T.reads([A[tx, 0:16]]) + * T.writes([C[tx, 0:16]]) + * B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + * with T.block("prologue"): + * T.reads([A[tx, 0]]) + * T.writes([B[0, tx, 0]]) + * B[0, tx, 0] = A[tx, 0] * T.float32(2) + * with T.block("body"): + * T.reads([A[tx, 1:16], B[0:2, tx, 0]]) + * T.writes([B[0:2, tx, 0], C[tx, 0:15]]) + * for i in T.serial(0, 15): + * with T.block("B"): + * T.reads([A[tx, i + 1]]) + * T.writes([B[(i + 1) % 2, tx, 0]]) + * B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + * with T.block("C"): + * T.reads([B[i % 2, tx, 0]]) + * T.writes([C[tx, i]]) + * C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + * with T.block("epilogue"): + * T.reads([B[1, tx, 0]]) + * T.writes([C[tx, 15]]) + * C[tx, 15] = B[1, tx, 0] + T.float32(1) + * + * The original loop has two blocks, B and C, as its direct children. The loop annotations indicate + * that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B + * should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order + * of block B and C inside the body block inside the result TIR. * * \return The IR transform pass. */ From 0e309870a50ad1794f2cc8aa0714e5fdbc5f1c95 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 26 Jan 2022 13:42:51 -0500 Subject: [PATCH 08/18] remove print --- .../unittest/test_tir_transform_inject_software_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index a71ee8a00f28..1432be4efbe1 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -27,7 +27,6 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) mod = tvm.tir.transform.Simplify()(mod) - print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], transformed, True) From 0390b776e0bd8900c9a953bea37b3ef7fbee3504 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 26 Jan 2022 14:11:22 -0500 Subject: [PATCH 09/18] lint --- src/tir/transforms/inject_software_pipeline.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 94a4b2b66f4d..e62ee8ab288a 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -72,7 +72,7 @@ using PipelineInfo = std::unordered_map Date: Wed, 26 Jan 2022 14:15:41 -0500 Subject: [PATCH 10/18] lint --- src/tir/transforms/inject_software_pipeline.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index e62ee8ab288a..7575fc2c22dc 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -72,7 +72,7 @@ using PipelineInfo = std::unordered_map Date: Wed, 26 Jan 2022 14:28:17 -0500 Subject: [PATCH 11/18] doc --- include/tvm/tir/transform.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 8e017ba8a27f..1011081b22d1 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -534,6 +534,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * * Before this pass, the TIR is: * + * \code{.py} * @T.prim_func * def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): @@ -553,10 +554,12 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * T.reads(B[tx, 0]) * T.writes(C[tx, i]) * C[tx, i] = B[tx, 0] + T.float32(1) + * \endcode * * The TIR above annotate the loop as a two-stage pipeline, the components are not reordered. * After this pass, the TIR is: * + * \code{.py} * @T.prim_func * def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): @@ -584,6 +587,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * T.reads([B[1, tx, 0]]) * T.writes([C[tx, 15]]) * C[tx, 15] = B[1, tx, 0] + T.float32(1) + * \endcode * * The original loop has two blocks, B and C, as its direct children. The loop annotations indicate * that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B From 5537afdaff7cc57a87a8e8fc95fca6170fc854dc Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 8 Feb 2022 10:48:41 -0800 Subject: [PATCH 12/18] Apply suggestions from code review Co-authored-by: Junru Shao --- include/tvm/tir/transform.h | 71 +++++++++---------- .../transforms/inject_software_pipeline.cc | 41 ++++++----- 2 files changed, 53 insertions(+), 59 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 1011081b22d1..12af0f7c275c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -493,43 +493,38 @@ TVM_DLL Pass ConvertForLoopsToSerial(); TVM_DLL Pass UnifiedStaticMemoryPlanner(); /*! - * \brief Transform annotated loops into pipelined one that ovarlaps producers and consumers. - * - * This pass detects loops with the software pipeline annotations and rewrite them to pipelined - * ones. The behavior of such rewriting depending on two annotations on the loop, - * attr::software_pipeline_stage, and attr::software_pipeline_order, which defines the stage and the - * order, respectively, of the components of the software pipeline. The components of the software - * pipeline is the direct children (ignoring BlockRealize / Block / SeqStmt) of the annotated loop. - * The value of the both annotations should be array of integers, with its size the same as the - * number of the components. - * - * The result of the rewriting is a block that has three blocks as its direct children which - * represents the prologue, the body, and the epilogue of the software pipeline. In the prologue, - * only components whose stage is less than max_stage will be executed. In the epilogue, only - * components whose stage is greater than 0 will be executed. In the body, all the components will - * be executed. Such rewriting enables behavior like prefetching, the components are not necessarily - * executed in the original order. attr::software_pipeline_order defines the order of the each - * component. Components belong to different stages can be reordered. - * - * Nested software pipelines are allowed. In this case, the inner software pipeline will be - * generated first. As a result, this may affect the number of components, i.e. the number of the - * direct children of the outer loop. In this case, the annotations for the outer software - * pipeline should include the result of the inner software pipeline, which is three blocks as - * discussed above. - * - * Buffer allocated inside the software pipeline may be resized to accommodate multiple versions - * of the original buffer. Block annotation attr::double_buffer_scope can be used to indicate that - * the block need to write in the double-buffering style. - * - * The following annotations are used to specify the behavior of this pass: - * attr::software_pipeline_stage: Array of non-negative integers, each element should be in - * range [0, max_stage], where max_stage is the maximum - * (inclusive) stage. - * attr::software_pipeline_order: Array of non-negative integers, should be a permutation of - * [0, 1, ..., num_components - 1]. - * attr::double_buffer_scope: Integer index of the write regions of the block. Mark a buffer - * should be double-buffered during the software pipelining. + * \brief This pass transforms annotated loops into pipelined ones where producers and consumers + * are overlapped with the information provided in loop annotations, which enables optimization + * techniques like prefetching and pipeline parallelism. * + * The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize, + * Block, SeqStmt), and the number of children is denoted by `n` in the documentation. + * + * The following annotations are used to guide the loop transformation: + * + * 1) Loop annotation `software_pipeline_stage` defines the pipeline stage. + * An array of `n` integers, and each element should be in range [0, max_stage], + * where max_stage is the maximum (inclusive) stage. + * 2) Loop annotation `software_pipeline_order` defines the pipeline order. + * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1]; + * 3) Block annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of + * read/write dependency. It's an integer index of the write regions of the block. + * + * Every annotated loop is transformed into a loop with three blocks as its direct children: + * + * 1) Prologue block, where components whose stage is less than `max_stage` is executed; + * + * 2) Body block, where all the components are executed; + * + * 3) Epilogue block, where only components whose stage is greater than 0 will be executed. + * The execution order is controlled by the annotation `software_pipeline_order`, + * and thus could be different than the original order. + * + * Note: For nested software pipelines, the inner software pipeline will be generated first, + * which may affect the number of the direct children of the outer loop. + * In this case, the annotations for the outer software + * pipeline should include the result of the inner software pipeline, + * which is the three blocks as discussed above. * Example: * * Before this pass, the TIR is: @@ -556,8 +551,8 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * C[tx, i] = B[tx, 0] + T.float32(1) * \endcode * - * The TIR above annotate the loop as a two-stage pipeline, the components are not reordered. - * After this pass, the TIR is: + * The TIR above annotates the loop as a two-stage pipeline with no reordering. + * After applying this pass, the TIR is transformed into: * * \code{.py} * @T.prim_func diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 7575fc2c22dc..4a367b298d0b 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -52,9 +52,9 @@ Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) return block_realize->block; } } - Block block = Block({}, {}, {}, "", body); - auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); - auto* n = block.CopyOnWrite(); + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/body); + Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + BlockNode* n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; return block; @@ -85,7 +85,7 @@ class PipelineBodyRewriter : public StmtExprMutator { * \brief Constructor of PipelineBodyRewriter. * \param buffer_data_to_buffer The map from buffer data to buffer. * \param buffer_remap The map from original buffer to the buffer with updated shape for - * multi-versioning in the sofeware pipeline. + * multi-versioning in the software pipeline. * \param pipeline_loop The original loop to be software pipelined. * \param access_all_versions Whether all versions the the buffers in the software pipeline are * accessed. This will be used to update block access region. In the prologue and epilogue @@ -291,7 +291,14 @@ class PipelineRewriter : public StmtExprMutator { Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions // need to maintain for each buffer. - RemapPipelineBuffers(pipeline_allocs_); + std::unordered_map infos = + GetBufferAccessInfo(); + for (const Buffer& buffer : pipeline_allocs_) { + int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); + if (num_versions > 1) { + buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); + } + } ordered_stmts_.resize(pipeline_info_.size()); for (const auto& pair : pipeline_info_) { @@ -312,17 +319,11 @@ class PipelineRewriter : public StmtExprMutator { // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. Array alloc_buffers; for (const auto& alloc : pipeline_allocs_) { - auto it = buffer_remap_.find(alloc); - if (it != buffer_remap_.end()) { - alloc_buffers.push_back((*it).second); - } else { - alloc_buffers.push_back(alloc); - } + alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); buffer_data_to_buffer_.erase(alloc->data); } Block block = MakeBlock(stmt, buffer_data_to_buffer_); - auto* n = block.CopyOnWrite(); - n->alloc_buffers = std::move(alloc_buffers); + block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); return BlockRealize({}, Bool(true), block); } @@ -392,7 +393,7 @@ class PipelineRewriter : public StmtExprMutator { * need to maintain during the software pipeline. * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the * result of the analysis. Additional double buffering in the software pipeline can be useful - * to eliminate synchonizations in GPU devices. + * to eliminate synchronizations in GPU devices. * * \param buffer The target buffer * \param buffer_info The access information of the target buffer. @@ -462,7 +463,7 @@ class PipelineRewriter : public StmtExprMutator { std::unordered_map infos = GetBufferAccessInfo(); for (const Buffer& buffer : pipeline_allocs) { - const BufferAccessInfo access_info = infos.at(buffer); + const BufferAccessInfo& access_info = infos.at(buffer); int num_versions = ComputeBufferVersions(buffer, access_info); if (num_versions > 1) { Buffer new_buffer = RewriteAllocBuffer(buffer, num_versions); @@ -480,7 +481,7 @@ class PipelineRewriter : public StmtExprMutator { */ Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { ObjectPtr new_buffer = make_object(*(buffer.get())); - new_buffer->shape.insert(new_buffer->shape.begin(), num_versions); + new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); if (new_buffer->strides.size()) { ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; @@ -514,12 +515,11 @@ class PipelineRewriter : public StmtExprMutator { analyzer_.Bind(Downcast(new_loop_var), Range(start, end)); } - for (const Block block : ordered_stmts_) { + for (const Block& block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; PrimExpr skewed_loop_var = new_loop_var - stage; - PrimExpr inbound = (skewed_loop_var >= pipeline_loop_->min) && + PrimExpr inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); - inbound = analyzer_.Simplify(inbound); if (analyzer_.CanProve(!inbound)) { continue; } @@ -608,8 +608,7 @@ class PipelineInjector : private StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { // Step 1: Recursively rewrite the children first. For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); - bool is_pipeline = HasPipelineAnnotation(op); - if (!is_pipeline) { + if (!HasPipelineAnnotation(op)) { return std::move(for_node); } // Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of From 95666d1ea732e3576cc5b1ea38bb5725ecc97e7d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 8 Feb 2022 14:00:40 -0500 Subject: [PATCH 13/18] address comments --- .../transforms/inject_software_pipeline.cc | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 4a367b298d0b..153a8fc18136 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -52,7 +52,7 @@ Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) return block_realize->block; } } - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/body); + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); BlockNode* n = block.CopyOnWrite(); n->reads = access[0]; @@ -64,15 +64,13 @@ Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) struct PipelineStageOrder { int stage; int order; - explicit PipelineStageOrder(int stage, int order) : stage(stage), order(order) {} }; using PipelineInfo = std::unordered_map; struct BufferAccessInfo { - int def; // the defining stage of the buffer - int use; // the last using stage of the buffer - explicit BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use) {} + int def = -1; // the defining stage of the buffer + int use = -1; // the last using stage of the buffer }; /*! @@ -128,10 +126,12 @@ class PipelineBodyRewriter : public StmtExprMutator { } Block block = Downcast(StmtExprMutator::VisitStmt_(op)); BlockNode* n = block.CopyOnWrite(); - n->reads.MutateByApply( - std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); - n->writes.MutateByApply( - std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + n->reads.MutateByApply([this](const BufferRegion& buffer_region) { + return RewritePipelineBufferRegion(buffer_region); + }); + n->writes.MutateByApply([this](const BufferRegion& buffer_region) { + return RewritePipelineBufferRegion(buffer_region); + }); for (const Buffer& alloc_buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(alloc_buffer->data); } @@ -678,7 +678,8 @@ class PipelineInjector : private StmtExprMutator { CHECK_EQ(pipeline_stages.size(), original_order.size()); CHECK_EQ(pipeline_orders.size(), original_order.size()); for (size_t i = 0; i < pipeline_stages.size(); i++) { - PipelineStageOrder stage_order(pipeline_stages[i]->value, pipeline_orders[i]->value); + PipelineStageOrder stage_order{/*stage=*/static_cast(pipeline_stages[i]->value), + /*order=*/static_cast(pipeline_orders[i]->value)}; pipeline_info.emplace(original_order[i], stage_order); } ValidatePipelineBody(pipeline_info, original_order); @@ -716,7 +717,6 @@ class PipelineInjector : private StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { for (const auto& buffer : op->alloc_buffers) { - ICHECK(buffer->IsInstance()); buffer_data_to_buffer_.Set(buffer->data, buffer); } From 55a2fdc640ead2edd86b488639196a0a7311fdd1 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 8 Feb 2022 15:36:07 -0500 Subject: [PATCH 14/18] address comments --- src/tir/transforms/inject_software_pipeline.cc | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 153a8fc18136..2da388d8a7e6 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -632,12 +632,6 @@ class PipelineInjector : private StmtExprMutator { CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline should be SeqStmt, got " << pipeline_body->GetTypeKey(); - // The SeqStmt before recursive rewriting, which is used for validating the software pipeline. - const SeqStmtNode* original_seq = - op->body->IsInstance() - ? op->body.as()->block->body.as() - : op->body.as(); - ICHECK(original_seq); // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be // converted into a block. @@ -645,11 +639,7 @@ class PipelineInjector : private StmtExprMutator { Array original_order; // pipeline body blocks in the original order auto f_add_child = [&](const Stmt& child) { - const auto* block_realize = child.as(); - Block block = (block_realize && is_one(block_realize->predicate)) - ? block_realize->block - : MakeBlock(child, buffer_data_to_buffer_); - original_order.push_back(block); + original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); }; for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { const auto* nested_block_realize = pipeline_body_seq->seq[i].as(); From e1ce7cff91d3e7d13c327edd1fca9b2d3c7a8aad Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 10 Feb 2022 15:55:37 -0500 Subject: [PATCH 15/18] refactor FragmentInfo::GetSize --- .../transforms/inject_software_pipeline.cc | 12 +----------- src/tir/transforms/ir_utils.h | 19 +++++++++++++++++-- .../transforms/tensorcore_infer_fragment.cc | 6 +++--- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 2da388d8a7e6..8df472b4aaf6 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -172,17 +172,7 @@ class PipelineBodyRewriter : public StmtExprMutator { auto it = fragment_info_.find(buffer->data.get()); ICHECK(it != fragment_info_.end()); const FragmentInfo& info = (*it).second; - String scope = buffer.scope(); - if (scope == "wmma.matrix_a") { - return info.m * info.k; - } else if (scope == "wmma.matrix_b") { - return info.n * info.k; - } else if (scope == "wmma.accumulator") { - return info.m * info.n; - } else { - ICHECK(0); - throw; - } + return info.GetSize(); } PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 61f2536df90a..d7ae362b64d4 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -273,9 +273,24 @@ struct FragmentInfo { int m, n, k; // fragment layout (row-major or column-major) std::string layout; + // scope of the fragment (wmma.matrix_a, wmma.matrix_b, or wmma.accumulator) + std::string scope; FragmentInfo() = default; - FragmentInfo(int _m, int _n, int _k, const std::string& _layout) - : m(_m), n(_n), k(_k), layout(_layout) {} + FragmentInfo(int _m, int _n, int _k, const std::string& _layout, const std::string& _scope) + : m(_m), n(_n), k(_k), layout(_layout), scope(_scope) {} + + int GetSize() const { + if (scope == "wmma.matrix_a") { + return m * k; + } else if (scope == "wmma.matrix_b") { + return n * k; + } else if (scope == "wmma.accumulator") { + return m * n; + } else { + ICHECK(0); + throw; + } + } }; /*! diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index b14eff5973c8..89b9307198f6 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -72,9 +72,9 @@ class FragmentGetter : public StmtExprVisitor { // store metadata FragmentInfo info; if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - info = FragmentInfo(m->value, n->value, k->value, layout->value); + info = FragmentInfo(m->value, n->value, k->value, layout->value, scope); } else if (scope == "wmma.accumulator") { - info = FragmentInfo(m->value, n->value, k->value, ""); + info = FragmentInfo(m->value, n->value, k->value, "", scope); } fragments[buffer_var] = info; } @@ -100,7 +100,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK_EQ(n->value, info.n); ICHECK_EQ(k->value, info.k); } else { - FragmentInfo info(m->value, n->value, k->value, ""); + FragmentInfo info(m->value, n->value, k->value, "", scope); fragments[buffer_var] = info; } } From dccc94cbf34872c3b4fc9ea574310094e42f6cec Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 14 Feb 2022 13:54:27 -0500 Subject: [PATCH 16/18] remove unused --- src/tir/transforms/inject_software_pipeline.cc | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 8df472b4aaf6..b893219ca62a 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -444,24 +444,6 @@ class PipelineRewriter : public StmtExprMutator { return num_versions; } - /*! - * \brief Rewrite buffer allocations to create new buffers with new shapes according to - * the software pipeline. - * \param pipeline_allocs The buffer allocations inside the software pipeline scope. - */ - void RemapPipelineBuffers(Array pipeline_allocs) { - std::unordered_map infos = - GetBufferAccessInfo(); - for (const Buffer& buffer : pipeline_allocs) { - const BufferAccessInfo& access_info = infos.at(buffer); - int num_versions = ComputeBufferVersions(buffer, access_info); - if (num_versions > 1) { - Buffer new_buffer = RewriteAllocBuffer(buffer, num_versions); - buffer_remap_.Set(buffer, new_buffer); - } - } - } - /*! * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined * accesses. From 0e38874195531bb556b873af99ceac1c8ad59cb5 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 14 Feb 2022 16:12:37 -0500 Subject: [PATCH 17/18] refactor --- src/tir/transforms/inject_software_pipeline.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index b893219ca62a..e6469b9aefb9 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -189,14 +189,13 @@ class PipelineBodyRewriter : public StmtExprMutator { return new_buffer_offset; } - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr RewriteOpaqueAccesses(const Call& call) { // Intrinsic calls should be handled explicitly here as they are opaque accesses to // buffer. static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); static const auto& mma_sync = builtin::tvm_mma_sync(); static const auto& access_ptr = builtin::tvm_access_ptr(); - Call call = Downcast(StmtExprMutator::VisitExpr_(op)); if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); auto it = buffer_remap_.find(buffer); @@ -238,7 +237,12 @@ class PipelineBodyRewriter : public StmtExprMutator { return Call(call->dtype, call->op, new_args, call->span); } } - return std::move(call); + return call; + } + + PrimExpr VisitExpr_(const CallNode* op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + return RewriteOpaqueAccesses(call); } Map buffer_data_to_buffer_; From d0af0dd33b1e97243a238ba1acfe269fc87b8421 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 17 Feb 2022 14:25:14 -0500 Subject: [PATCH 18/18] address comments --- .../transforms/inject_software_pipeline.cc | 177 ++++++++++-------- 1 file changed, 102 insertions(+), 75 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index e6469b9aefb9..b607ba485a6a 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -73,6 +73,104 @@ struct BufferAccessInfo { int use = -1; // the last using stage of the buffer }; +class PipelineOpaqueAccessRewriter { + public: + /*! + * \brief Constructor + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param buffer_remap The map from original buffer to the buffer with updated shape for + * multi-versioning in the software pipeline. + * \param pipeline_loop The original loop to be software pipelined. + * \param fragment_info Information about tensor core fragment + */ + PipelineOpaqueAccessRewriter( + const Map& buffer_data_to_buffer, const Map& buffer_remap, + const For& pipeline_loop, + const std::unordered_map& fragment_info) + : buffer_data_to_buffer_(buffer_data_to_buffer), + buffer_remap_(buffer_remap), + pipeline_loop_(pipeline_loop), + fragment_info_(fragment_info) {} + + PrimExpr Rewrite(const Call& call) { + // Intrinsic calls should be handled explicitly here as they are opaque accesses to + // buffer. + static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); + static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); + static const auto& mma_sync = builtin::tvm_mma_sync(); + static const auto& access_ptr = builtin::tvm_access_ptr(); + if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); + return Call(call->dtype, call->op, new_args, call->span); + } + } else if (call->op.same_as(mma_sync)) { + Array new_args = call->args; + for (int i = 0; i < 4; i++) { + const Var& buffer_var = Downcast(call->args[i * 2]); + const PrimExpr& index = call->args[i * 2 + 1]; + const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); + new_args.Set(i * 2 + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } else if (call->op.same_as(access_ptr)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[1])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[2]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + new_args.Set(2, new_index); + return Call(call->dtype, call->op, new_args, call->span); + } + } + return call; + } + + private: + int GetWmmaFragmentSize(const Buffer& buffer) { + auto it = fragment_info_.find(buffer->data.get()); + ICHECK(it != fragment_info_.end()); + const FragmentInfo& info = (*it).second; + return info.GetSize(); + } + + PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, + const PrimExpr& old_index) { + PrimExpr new_buffer_offset = old_index; + + int fragment_size = GetWmmaFragmentSize(old_buffer); + PrimExpr offset = + floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), old_buffer->shape), + fragment_size); + new_buffer_offset += + floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; + return new_buffer_offset; + } + + const Map& buffer_data_to_buffer_; + const Map& buffer_remap_; + const For& pipeline_loop_; + const std::unordered_map& fragment_info_; +}; + /*! * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices * of the remapped buffer to select the version corresponding to the pipeline stage. @@ -98,7 +196,8 @@ class PipelineBodyRewriter : public StmtExprMutator { buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), access_all_versions_(access_all_versions), - fragment_info_(fragment_info) {} + opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, + fragment_info) {} private: BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { @@ -168,88 +267,16 @@ class PipelineBodyRewriter : public StmtExprMutator { return std::move(load); } - int GetWmmaFragmentSize(const Buffer& buffer) { - auto it = fragment_info_.find(buffer->data.get()); - ICHECK(it != fragment_info_.end()); - const FragmentInfo& info = (*it).second; - return info.GetSize(); - } - - PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, - const PrimExpr& old_index) { - PrimExpr new_buffer_offset = old_index; - - int fragment_size = GetWmmaFragmentSize(old_buffer); - PrimExpr offset = - floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), old_buffer->shape), - fragment_size); - new_buffer_offset += - floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; - return new_buffer_offset; - } - - PrimExpr RewriteOpaqueAccesses(const Call& call) { - // Intrinsic calls should be handled explicitly here as they are opaque accesses to - // buffer. - static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); - static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); - static const auto& mma_sync = builtin::tvm_mma_sync(); - static const auto& access_ptr = builtin::tvm_access_ptr(); - if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { - const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - Array new_args = call->args; - const Buffer& new_buffer = (*it).second; - new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); - return Call(call->dtype, call->op, new_args, call->span); - } - } else if (call->op.same_as(mma_sync)) { - Array new_args = call->args; - for (int i = 0; i < 4; i++) { - const Var& buffer_var = Downcast(call->args[i * 2]); - const PrimExpr& index = call->args[i * 2 + 1]; - const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); - new_args.Set(i * 2 + 1, new_index); - } - } - return Call(call->dtype, call->op, new_args, call->span); - } else if (call->op.same_as(access_ptr)) { - const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[1])); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - Array new_args = call->args; - const Buffer& new_buffer = (*it).second; - const PrimExpr& old_index = call->args[2]; - PrimExpr offset; - if (new_buffer->strides.empty()) { - offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), buffer->shape); - } else { - offset = new_buffer->strides[0]; - } - PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; - new_args.Set(2, new_index); - return Call(call->dtype, call->op, new_args, call->span); - } - } - return call; - } - PrimExpr VisitExpr_(const CallNode* op) final { Call call = Downcast(StmtExprMutator::VisitExpr_(op)); - return RewriteOpaqueAccesses(call); + return opaque_access_rewriter_.Rewrite(call); } Map buffer_data_to_buffer_; Map buffer_remap_; For pipeline_loop_; bool access_all_versions_; - const std::unordered_map& fragment_info_; + PipelineOpaqueAccessRewriter opaque_access_rewriter_; }; /*!