From 1937176f2bea4ad5031dafd34af57de6433183f7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 6 Mar 2023 22:11:21 -0500 Subject: [PATCH] [Unity][Transform] Memory plan across the IRModule Previously the static memory planning pass only works at single function level - each function inside the an IRModule will be independently planned. This is not perfect for the VM to reuse allocated memory across different functions. Therefore, this PR turns the static memory planning pass into a module pass. Now the plan is done across the IRModule, so that memory alloation in different functions can share the same storage token when planning. With this PR, it is hopeful that the VM will find more opportunities of memory reuse. --- .../transform/static_plan_block_memory.cc | 101 +++++++++----- ...test_transform_static_plan_block_memory.py | 123 +++++++++++++++++- 2 files changed, 191 insertions(+), 33 deletions(-) diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 8b7adae246eb..ba5177fec065 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -86,11 +86,6 @@ class StorageTokenNode : public Object { DataType dtype; /*! \brief The storage id, reserved for debug and demo use. */ int storage_id{-1}; - /*! - * \brief The variable corresponding to the allocated storage, which is NullOpt - * before definition. - */ - Optional storage{NullOpt}; static constexpr const char* _type_key = "relax.transform.StorageToken"; TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object); @@ -287,23 +282,36 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { */ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { public: - explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {} - /*! * \brief The entry of the initialization. + * \param mod The IRModule to be planned * \return The mapping from each Expr to the token it uses. */ - std::unordered_map Initialize(const Function& func) { + static std::unordered_map Initialize(const IRModule& mod) { + StorageAllocatorInit initializer(mod); + + for (auto it : mod->functions) { + const auto* func = it.second.as(); + if (func == nullptr) { + continue; + } + initializer(GetRef(func)); + } + return initializer.token_map_; + } + + private: + using ExprVisitor::VisitExpr_; + + explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {} + + void VisitExpr_(const FunctionNode* func) final { // Recurse into the function to get its tokens. Tokens body_tokens = GetTokens(func->body); // Discard the tokens used by the function return value, as they are external referenced. DiscardTokensIn(body_tokens); - return this->token_map_; } - private: - using ExprVisitor::VisitExpr_; - void VisitExpr_(const CallNode* call) final { static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); if (call->op == alloc_tensor_op) { @@ -501,6 +509,16 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { this->token_map_ = std::move(token_map); } + void Allocate(const IRModule& mod) { + for (auto it : mod->functions) { + const auto* func = it.second.as(); + if (func == nullptr) { + continue; + } + this->VisitExpr_(func); + } + } + /*! * \brief The mapping from each `builtin.alloc_tensor` to its corresponding * underlying storage token that it is using. @@ -629,14 +647,29 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { class StorageAllocationRewriter : public ExprMutator { public: explicit StorageAllocationRewriter( - std::unordered_map alloc_tensor2token, + IRModule mod, std::unordered_map alloc_tensor2token, std::unordered_map> expr2killed_tensors, std::unordered_map> block2tokens) - : alloc_tensor2token_(std::move(alloc_tensor2token)), + : ExprMutator(std::move(mod)), + alloc_tensor2token_(std::move(alloc_tensor2token)), expr2killed_tensors_(std::move(expr2killed_tensors)), block2tokens_(std::move(block2tokens)) {} + IRModule Rewrite() { + const IRModule& mod = builder_->GetContextIRModule(); + for (const auto& [gv, base_func] : mod->functions) { + const auto* func_ = base_func.as(); + if (func_ == nullptr) { + continue; + } + token2storage_var_.clear(); + Function func = Downcast(this->VisitExpr_(func_)); + builder_->UpdateFunction(gv, func); + } + return builder_->GetContextIRModule(); + } + private: using ExprMutator::VisitExpr_; @@ -648,9 +681,10 @@ class StorageAllocationRewriter : public ExprMutator { // Insert `memory.kill_storage` for the storage tokens allocated inside this block. for (const StorageTokenNode* token : block2tokens_[block]) { - ICHECK(token->storage.defined()); + auto it_token = token2storage_var_.find(token); + ICHECK(it_token != token2storage_var_.end()); static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); - this->builder_->Emit(Call(mem_kill_storage, {token->storage.value()}), /*name_hint=*/"_"); + this->builder_->Emit(Call(mem_kill_storage, {it_token->second}), /*name_hint=*/"_"); } BindingBlock new_block = builder_->EndBlock(); @@ -682,7 +716,9 @@ class StorageAllocationRewriter : public ExprMutator { // If the token is visited for the first time, create a storage variable using // `memory.alloc_storage` for it. StorageToken token = it->second; - if (!token->storage.defined()) { + Var storage_var{nullptr}; + auto it_token = token2storage_var_.find(token.get()); + if (it_token == token2storage_var_.end()) { static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage"); ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)}); PrimValue virtual_device_index = runtime_device_index; @@ -692,15 +728,17 @@ class StorageAllocationRewriter : public ExprMutator { mem_alloc_storage, {std::move(size), virtual_device_index, StringImm(storage_scope), DataTypeImm(dtype)}, Attrs()); - token->storage = builder_->Emit(alloc_storage, "storage"); + storage_var = builder_->Emit(alloc_storage, "storage"); + token2storage_var_[token.get()] = storage_var; + } else { + storage_var = it_token->second; } // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor"); PrimValue offset = PrimValue::Int64(0); DataType dtype = sinfo->dtype; - return Call(mem_alloc_tensor, - {token->storage.value(), offset, sinfo->shape.value(), DataTypeImm(dtype)}, + return Call(mem_alloc_tensor, {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype)}, Attrs()); } @@ -716,31 +754,30 @@ class StorageAllocationRewriter : public ExprMutator { std::unordered_map> expr2killed_tensors_; /*! \brief The mapping from each binding block to the storage tokens that are create inside. */ std::unordered_map> block2tokens_; + /*! \brief The mapping from each token to its corresponding storage var in each function. */ + std::unordered_map token2storage_var_; }; -Expr StaticPlanBlockMemory(Function func, const IRModule& ctx_mod) { +IRModule StaticPlanBlockMemory(IRModule mod) { // Step 1. Initialize. - StorageAllocatorInit initializer(ctx_mod); - std::unordered_map token_map = initializer.Initialize(func); + std::unordered_map token_map = StorageAllocatorInit::Initialize(mod); // Step 2. Collect the memory allocation info. StorageAllocator allocator(std::move(token_map)); - allocator(func); + allocator.Allocate(mod); // Step 3. Rewrite the function. - StorageAllocationRewriter rewriter(std::move(allocator.alloc_tensor2token), + StorageAllocationRewriter rewriter(std::move(mod), // + std::move(allocator.alloc_tensor2token), std::move(allocator.expr2killed_tensors), std::move(allocator.block2tokens)); - func = Downcast(rewriter(func)); - return func; + return rewriter.Rewrite(); } namespace transform { Pass StaticPlanBlockMemory() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(StaticPlanBlockMemory(std::move(f), m)); - }; - return CreateFunctionPass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::StaticPlanBlockMemory(std::move(m)); }; + return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); } TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory); diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 1b556139ccc9..2f04e74062af 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax -from tvm.script import relax as R, tir as T +from tvm.script import ir as I, relax as R, tir as T def test_basic(): @@ -608,5 +608,126 @@ def main( tvm.ir.assert_structural_equal(mod, Module) +def test_multiple_functions(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def func1( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="int32", runtime_device_index=0 + ) + _1: R.Tuple() = add1(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="int32") = alloc1 + return x + + @R.function + def func2( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = add(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="float32") = alloc1 + return x + + @I.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def func1( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="int32" + ) + alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor( + storage1, 0, R.shape([2, 3]), dtype="int32" + ) + _2: R.Tuple() = add1(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="int32") = alloc1 + _5: R.Tuple() = R.memory.kill_storage(storage) + _4: R.Tuple() = R.memory.kill_storage(storage1) + return x + + @R.function + def func2( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _2: R.Tuple() = add(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="float32") = alloc1 + _4: R.Tuple() = R.memory.kill_storage(storage) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()