Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 69 additions & 32 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ class StorageTokenNode : public Object {
DataType dtype;
/*! \brief The storage id, reserved for debug and demo use. */
int storage_id{-1};
/*!
* \brief The variable corresponding to the allocated storage, which is NullOpt
* before definition.
*/
Optional<Var> storage{NullOpt};

static constexpr const char* _type_key = "relax.transform.StorageToken";
TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
Expand Down Expand Up @@ -287,23 +282,36 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
*/
class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
public:
explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}

/*!
* \brief The entry of the initialization.
* \param mod The IRModule to be planned
* \return The mapping from each Expr to the token it uses.
*/
std::unordered_map<const ExprNode*, Tokens> Initialize(const Function& func) {
static std::unordered_map<const ExprNode*, Tokens> Initialize(const IRModule& mod) {
StorageAllocatorInit initializer(mod);

for (auto it : mod->functions) {
const auto* func = it.second.as<FunctionNode>();
if (func == nullptr) {
continue;
}
initializer(GetRef<Function>(func));
}
return initializer.token_map_;
}

private:
using ExprVisitor::VisitExpr_;

explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}

void VisitExpr_(const FunctionNode* func) final {
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are external referenced.
DiscardTokensIn(body_tokens);
return this->token_map_;
}

private:
using ExprVisitor::VisitExpr_;

void VisitExpr_(const CallNode* call) final {
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
if (call->op == alloc_tensor_op) {
Expand Down Expand Up @@ -501,6 +509,16 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
this->token_map_ = std::move(token_map);
}

void Allocate(const IRModule& mod) {
for (auto it : mod->functions) {
const auto* func = it.second.as<FunctionNode>();
if (func == nullptr) {
continue;
}
this->VisitExpr_(func);
}
}

/*!
* \brief The mapping from each `builtin.alloc_tensor` to its corresponding
* underlying storage token that it is using.
Expand Down Expand Up @@ -629,14 +647,29 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
class StorageAllocationRewriter : public ExprMutator {
public:
explicit StorageAllocationRewriter(
std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token,
IRModule mod, std::unordered_map<const ExprNode*, StorageToken> alloc_tensor2token,
std::unordered_map<const ExprNode*, std::vector<Var>> expr2killed_tensors,
std::unordered_map<const BindingBlockNode*, std::vector<const StorageTokenNode*>>
block2tokens)
: alloc_tensor2token_(std::move(alloc_tensor2token)),
: ExprMutator(std::move(mod)),
alloc_tensor2token_(std::move(alloc_tensor2token)),
expr2killed_tensors_(std::move(expr2killed_tensors)),
block2tokens_(std::move(block2tokens)) {}

IRModule Rewrite() {
const IRModule& mod = builder_->GetContextIRModule();
for (const auto& [gv, base_func] : mod->functions) {
const auto* func_ = base_func.as<FunctionNode>();
if (func_ == nullptr) {
continue;
}
token2storage_var_.clear();
Function func = Downcast<Function>(this->VisitExpr_(func_));
builder_->UpdateFunction(gv, func);
}
return builder_->GetContextIRModule();
}

private:
using ExprMutator::VisitExpr_;

Expand All @@ -648,9 +681,10 @@ class StorageAllocationRewriter : public ExprMutator {

// Insert `memory.kill_storage` for the storage tokens allocated inside this block.
for (const StorageTokenNode* token : block2tokens_[block]) {
ICHECK(token->storage.defined());
auto it_token = token2storage_var_.find(token);
ICHECK(it_token != token2storage_var_.end());
static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage");
this->builder_->Emit(Call(mem_kill_storage, {token->storage.value()}), /*name_hint=*/"_");
this->builder_->Emit(Call(mem_kill_storage, {it_token->second}), /*name_hint=*/"_");
}

BindingBlock new_block = builder_->EndBlock();
Expand Down Expand Up @@ -682,7 +716,9 @@ class StorageAllocationRewriter : public ExprMutator {
// If the token is visited for the first time, create a storage variable using
// `memory.alloc_storage` for it.
StorageToken token = it->second;
if (!token->storage.defined()) {
Var storage_var{nullptr};
auto it_token = token2storage_var_.find(token.get());
if (it_token == token2storage_var_.end()) {
static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
PrimValue virtual_device_index = runtime_device_index;
Expand All @@ -692,15 +728,17 @@ class StorageAllocationRewriter : public ExprMutator {
mem_alloc_storage,
{std::move(size), virtual_device_index, StringImm(storage_scope), DataTypeImm(dtype)},
Attrs());
token->storage = builder_->Emit(alloc_storage, "storage");
storage_var = builder_->Emit(alloc_storage, "storage");
token2storage_var_[token.get()] = storage_var;
} else {
storage_var = it_token->second;
}

// And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`.
static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
PrimValue offset = PrimValue::Int64(0);
DataType dtype = sinfo->dtype;
return Call(mem_alloc_tensor,
{token->storage.value(), offset, sinfo->shape.value(), DataTypeImm(dtype)},
return Call(mem_alloc_tensor, {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype)},
Attrs());
}

Expand All @@ -716,31 +754,30 @@ class StorageAllocationRewriter : public ExprMutator {
std::unordered_map<const ExprNode*, std::vector<Var>> expr2killed_tensors_;
/*! \brief The mapping from each binding block to the storage tokens that are create inside. */
std::unordered_map<const BindingBlockNode*, std::vector<const StorageTokenNode*>> block2tokens_;
/*! \brief The mapping from each token to its corresponding storage var in each function. */
std::unordered_map<const StorageTokenNode*, Var> token2storage_var_;
};

Expr StaticPlanBlockMemory(Function func, const IRModule& ctx_mod) {
IRModule StaticPlanBlockMemory(IRModule mod) {
// Step 1. Initialize.
StorageAllocatorInit initializer(ctx_mod);
std::unordered_map<const ExprNode*, Tokens> token_map = initializer.Initialize(func);
std::unordered_map<const ExprNode*, Tokens> token_map = StorageAllocatorInit::Initialize(mod);
// Step 2. Collect the memory allocation info.
StorageAllocator allocator(std::move(token_map));
allocator(func);
allocator.Allocate(mod);
// Step 3. Rewrite the function.
StorageAllocationRewriter rewriter(std::move(allocator.alloc_tensor2token),
StorageAllocationRewriter rewriter(std::move(mod), //
std::move(allocator.alloc_tensor2token),
std::move(allocator.expr2killed_tensors),
std::move(allocator.block2tokens));
func = Downcast<Function>(rewriter(func));
return func;
return rewriter.Rewrite();
}

namespace transform {

Pass StaticPlanBlockMemory() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(StaticPlanBlockMemory(std::move(f), m));
};
return CreateFunctionPass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {});
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relax::StaticPlanBlockMemory(std::move(m)); };
return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {});
}

TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory);
Expand Down
123 changes: 122 additions & 1 deletion tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T
from tvm.script import ir as I, relax as R, tir as T


def test_basic():
Expand Down Expand Up @@ -608,5 +608,126 @@ def main(
tvm.ir.assert_structural_equal(mod, Module)


def test_multiple_functions():
@tvm.script.ir_module
class Module:
@T.prim_func
def add(
A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
):
T.evaluate(0)

@T.prim_func
def add1(
A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
T.evaluate(0)

@R.function
def func1(
x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32")
) -> R.Tensor((2, 3), dtype="float32"):
alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="float32", runtime_device_index=0
)
_: R.Tuple() = add(x, x, alloc)
gv: R.Tensor((2, 3), dtype="float32") = alloc
alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="int32", runtime_device_index=0
)
_1: R.Tuple() = add1(y, y, alloc1)
gv1: R.Tensor((2, 3), dtype="int32") = alloc1
return x

@R.function
def func2(
x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")
) -> R.Tensor((2, 3), dtype="float32"):
alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="float32", runtime_device_index=0
)
_: R.Tuple() = add(x, x, alloc)
gv: R.Tensor((2, 3), dtype="float32") = alloc
alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="float32", runtime_device_index=0
)
_1: R.Tuple() = add(y, y, alloc1)
gv1: R.Tensor((2, 3), dtype="float32") = alloc1
return x

@I.ir_module
class Expected:
@T.prim_func
def add(
A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
C: T.Buffer((T.int64(2), T.int64(3)), "float32"),
):
T.evaluate(0)

@T.prim_func
def add1(
A: T.Buffer((T.int64(2), T.int64(3)), "int32"),
B: T.Buffer((T.int64(2), T.int64(3)), "int32"),
C: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
T.evaluate(0)

@R.function
def func1(
x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32")
) -> R.Tensor((2, 3), dtype="float32"):
storage: R.Object = R.memory.alloc_storage(
R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32"
)
alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
storage, 0, R.shape([2, 3]), dtype="float32"
)
_: R.Tuple() = add(x, x, alloc)
_1: R.Tuple() = R.memory.kill_tensor(alloc)
gv1: R.Tensor((2, 3), dtype="float32") = alloc
storage1: R.Object = R.memory.alloc_storage(
R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="int32"
)
alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor(
storage1, 0, R.shape([2, 3]), dtype="int32"
)
_2: R.Tuple() = add1(y, y, alloc1)
_3: R.Tuple() = R.memory.kill_tensor(alloc1)
gv12: R.Tensor((2, 3), dtype="int32") = alloc1
_5: R.Tuple() = R.memory.kill_storage(storage)
_4: R.Tuple() = R.memory.kill_storage(storage1)
return x

@R.function
def func2(
x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")
) -> R.Tensor((2, 3), dtype="float32"):
storage: R.Object = R.memory.alloc_storage(
R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32"
)
alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
storage, 0, R.shape([2, 3]), dtype="float32"
)
_: R.Tuple() = add(x, x, alloc)
_1: R.Tuple() = R.memory.kill_tensor(alloc)
gv1: R.Tensor((2, 3), dtype="float32") = alloc
alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
storage, 0, R.shape([2, 3]), dtype="float32"
)
_2: R.Tuple() = add(y, y, alloc1)
_3: R.Tuple() = R.memory.kill_tensor(alloc1)
gv12: R.Tensor((2, 3), dtype="float32") = alloc1
_4: R.Tuple() = R.memory.kill_storage(storage)
return x

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()