Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Support storage reuse for dynamic shapes #16500

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
99 changes: 69 additions & 30 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,23 @@ class StorageTokenNode : public Object {
/*! \brief Reference counter. */
int ref_counter{0};
/*! \brief Number of bytes that this token requires. */
int64_t bytes;
PrimExpr bytes;
/*! \brief The dtype of this token. */
DataType dtype;
/*! \brief The storage id, reserved for debug and demo use. */
int storage_id{-1};

/*! \brief Get the constant number of bytes that this token requires, or -1 if the number of bytes
* is symbolic */
int64_t const_bytes() const {
const int64_t* const_val = tir::as_const_int(bytes);
if (const_val) {
return *const_val;
} else {
return -1;
}
}

static constexpr const char* _type_key = "relax.transform.StorageToken";
TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
};
Expand All @@ -117,19 +128,22 @@ class StorageToken : public ObjectRef {
public:
explicit StorageToken(Array<PrimExpr> shape, DataType dtype) {
// Compute the tensor size from the shape.
int64_t size = 1;
int64_t const_coeff = dtype.bytes() * dtype.lanes();
PrimExpr size = tir::make_const(DataType::Int(64), 1);
for (const PrimExpr& dim_len : shape) {
const auto* int_len = dim_len.as<IntImmNode>();
ICHECK_NOTNULL(int_len);
size *= int_len->value;
if (const IntImmNode* const_dim_len = dim_len.as<IntImmNode>()) {
const_coeff *= const_dim_len->value;
} else {
size *= dim_len;
}
}
size = tir::make_const(DataType::Int(64), const_coeff) * size;

ObjectPtr<StorageTokenNode> n = make_object<StorageTokenNode>();
n->bytes = size * dtype.bytes() * dtype.lanes();
n->bytes = size;
n->dtype = dtype;
data_ = std::move(n);
}

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef, StorageTokenNode);
};

Expand All @@ -143,6 +157,8 @@ using Tokens = NestedMsg<StorageToken>;
*/
class TokenAllocator1D {
public:
explicit TokenAllocator1D(arith::Analyzer* analyzer) : analyzer_(analyzer) {}

/*!
* \brief Request a storage token from the available token pool for a
* given prototype, or report no appropriate available token in the pool.
Expand All @@ -162,8 +178,24 @@ class TokenAllocator1D {
// Step 1. Get the available pool of the token dtype.
std::multimap<int64_t, StorageToken>& pool = available_pool_[prototype->dtype];

int64_t size = prototype->const_bytes();
if (size == -1) {
// Handle the case where the prototype token has dynamic size. Currently it requires the
// symbolic size to be the same as the prototype token in order to reuse the storage.
auto [begin, end] = pool.equal_range(size);
for (; begin != end; ++begin) {
StorageToken available_token = begin->second;
if (analyzer_->CanProveEqual(prototype->bytes, available_token->bytes)) {
ICHECK_EQ(available_token->ref_counter, 0)
<< "Available tokens are expected to have 0 reference.";
available_token->ref_counter = prototype->ref_counter;
pool.erase(begin);
return available_token;
}
}
return NullOpt;
}
// Step 2. Get the range of memory blocks in [size / match_range_, size * match_range_)
int64_t size = prototype->bytes;
auto begin = pool.lower_bound(size / match_range_);
auto mid = pool.lower_bound(size);
auto end = pool.upper_bound(size * match_range_);
Expand All @@ -172,7 +204,7 @@ class TokenAllocator1D {
StorageToken available_token = mid->second;
ICHECK_EQ(available_token->ref_counter, 0)
<< "Available tokens are expected to have 0 reference.";
ICHECK_LE(size, available_token->bytes);
ICHECK_LE(size, available_token->const_bytes());
available_token->ref_counter = prototype->ref_counter;
pool.erase(mid);
return available_token;
Expand All @@ -181,11 +213,13 @@ class TokenAllocator1D {
if (mid != begin) {
--mid;
StorageToken available_token = mid->second;
int64_t available_size = available_token->const_bytes();
ICHECK_EQ(available_token->ref_counter, 0)
<< "Available tokens are expected to have 0 reference.";
ICHECK_GE(size, available_token->bytes);
ICHECK_GE(available_size, 0);
ICHECK_GE(size, available_size);
// Enlarge the token size.
available_token->bytes = size;
available_token->bytes = tir::make_const(DataType::Int(64), size);
available_token->ref_counter = prototype->ref_counter;
pool.erase(mid);
return available_token;
Expand Down Expand Up @@ -216,7 +250,7 @@ class TokenAllocator1D {
ICHECK_GE(token->storage_id, 0)
<< "The token to be released is expected to be allocated before";
ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected to have 0 reference.";
available_pool_[token->dtype].insert({token->bytes, token});
available_pool_[token->dtype].insert({token->const_bytes(), token});
}

/*! \brief Clear the allocator. */
Expand All @@ -226,6 +260,8 @@ class TokenAllocator1D {
}

private:
/*! \brief The arithmetic analyzer. */
arith::Analyzer* analyzer_;
/*! \brief A constant scale representing the token search range. */
const int match_range_{16};
/*! \brief The pool of available storage tokens for each dtype. */
Expand Down Expand Up @@ -385,10 +421,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
/*!
* \brief The entry of the initialization.
* \param mod The IRModule to be planned
* \param analyzer The arithmetic analyzer.
* \return The mapping from each Expr to the token it uses.
*/
static std::unordered_map<const ExprNode*, Tokens> Initialize(const IRModule& mod) {
StorageAllocatorInit initializer(mod);
static std::unordered_map<const ExprNode*, Tokens> Initialize(const IRModule& mod,
arith::Analyzer* analyzer) {
StorageAllocatorInit initializer(mod, analyzer);

for (auto it : mod->functions) {
const auto* func = it.second.as<FunctionNode>();
Expand All @@ -403,11 +441,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
private:
using ExprVisitor::VisitExpr_;

explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
explicit StorageAllocatorInit(const IRModule& ctx_mod, arith::Analyzer* analyzer)
: ctx_mod_(ctx_mod), analyzer_(analyzer) {}

void VisitExpr_(const FunctionNode* func) final {
// Set the upper bound of TIR variables in the analyzer.
SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
SetTIRVarUpperBound(GetRef<Function>(func), analyzer_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are external referenced.
Expand Down Expand Up @@ -508,14 +547,9 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
ICHECK(!token_map_.count(call));

// Use the upper bounds of TIR vars as their values.
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);

// No support for TIR vars that are not bounded.
if (!IsStaticShape(upper_bounded_shape)) {
token_map_[call] = Tokens();
return Tokens();
}
// Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic
// if the upper bounds of some variables are not provided.
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_);

// Create and set token.
StorageToken token(upper_bounded_shape, sinfo->dtype);
Expand Down Expand Up @@ -583,13 +617,13 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
token2block_.erase(token_to_discard.get());
}

/*! \brief The arithmetic analyzer. */
arith::Analyzer ana_;
/*!
* \brief The context IRModule, used for checking if a callee function is
* a PrimFunc inside the IRModule.
*/
const IRModule& ctx_mod_;
/*! \brief The arithmetic analyzer. */
arith::Analyzer* analyzer_;
/*! \brief The mapping from each token to the binding block where it is created. */
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
/*! \brief The mapping from each token to the Exprs that are using this token. */
Expand All @@ -612,7 +646,9 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
*/
class StorageAllocator : public StorageAllocatorBaseVisitor {
public:
explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens> token_map) {
explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens> token_map,
arith::Analyzer* analyzer)
: allocator_(analyzer) {
this->token_map_ = std::move(token_map);
}

Expand Down Expand Up @@ -797,7 +833,7 @@ class StorageAllocationRewriter : public ExprMutator {
Var storage_var{nullptr};
auto it_token = token2storage_var_.find(token.get());
if (it_token == token2storage_var_.end()) {
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
ShapeExpr size({token->bytes});
PrimValue virtual_device_index = runtime_device_index;
std::string storage_scope = "global";
DataType dtype = token->dtype;
Expand Down Expand Up @@ -868,10 +904,13 @@ class StorageAllocationRewriter : public ExprMutator {
};

IRModule StaticPlanBlockMemory(IRModule mod) {
arith::Analyzer ana;

// Step 1. Initialize.
std::unordered_map<const ExprNode*, Tokens> token_map = StorageAllocatorInit::Initialize(mod);
std::unordered_map<const ExprNode*, Tokens> token_map =
StorageAllocatorInit::Initialize(mod, &ana);
// Step 2. Collect the memory allocation info.
StorageAllocator allocator(std::move(token_map));
StorageAllocator allocator(std::move(token_map), &ana);
allocator.Allocate(mod);
// Step 3. Rewrite the function.
StorageAllocationRewriter rewriter(std::move(mod), //
Expand Down
35 changes: 17 additions & 18 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,29 +1226,27 @@ def expected(
lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, axes=None)
lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, lv1, out_dtype="void")
lv3: R.Tuple(
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 0, 1280), dtype="float32"),
) = R.split(lv2, indices_or_sections=[640, 1280], axis=1)
lv0: R.Tensor((2, 640, 1280), dtype="float32") = lv3[0]
lv1_1: R.Tensor((2, 384, 1280), dtype="float32") = lv3[1]
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = R.split(lv2, indices_or_sections=[640], axis=-1)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
lv_1: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), axis=0)
lv1_2: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv_1, axes=None)
lv2_1: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(
x2, lv1_2, out_dtype="void"
)
lv3_1: R.Tuple(
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 0, 1280), dtype="float32"),
) = R.split(lv2_1, indices_or_sections=[640, 1280], axis=1)
lv2_1_1: R.Tensor((2, 640, 1280), dtype="float32") = lv3_1[0]
lv3_1_1: R.Tensor((2, 384, 1280), dtype="float32") = lv3_1[1]
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = R.split(lv2_1, indices_or_sections=[640], axis=-1)
lv2_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[0]
lv3_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[1]
out: R.Tuple(
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = (lv0, lv1_1, lv2_1_1, lv3_1_1)
R.output(out)
return out
Expand All @@ -1267,9 +1265,9 @@ def rewriter(matchings, _):

concat = R.concat([w1, w2], axis=0)
matmul = R.matmul(inp, R.permute_dims(concat))
sections = [w1.struct_info.shape[0], w1.struct_info.shape[0] + w2.struct_info.shape[0]]
sections = [w1.struct_info.shape[0]]

chunks = R.split(matmul, sections, 1)
chunks = R.split(matmul, sections, -1)

return {
matchings[matmul1]: chunks[0],
Expand All @@ -1282,6 +1280,7 @@ def rewriter(matchings, _):
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
print(mod)

rx.build(mod, target="llvm")

Expand Down
47 changes: 39 additions & 8 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,34 @@ def main(x: R.Tensor(("m", "n"), "float32")):
y: R.Tensor((m, n), dtype="float32") = alloc
return x

# The pass does no change.
@tvm.script.ir_module
class Expected:
@T.prim_func
def exp(var_A: T.handle, var_B: T.handle):
m = T.int64()
n = T.int64()
A = T.match_buffer(var_A, (m, n), "float32")
B = T.match_buffer(var_B, (m, n), "float32")
T.evaluate(0)

@R.function
def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"):
m = T.int64()
n = T.int64()
R.func_attr({"relax.force_pure": True})
cls = Expected
storage: R.Object = R.memory.alloc_storage(
R.shape([4 * (m * n)]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
alloc: R.Tensor((m, n), dtype="float32") = R.memory.alloc_tensor(
storage, R.prim_value(0), R.shape([m, n]), R.dtype("float32")
)
_: R.Tuple = cls.exp(x, alloc)
y: R.Tensor((m, n), dtype="float32") = alloc
return x

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Module)
tvm.ir.assert_structural_equal(mod, Expected)


def test_zero_reference():
Expand Down Expand Up @@ -1198,7 +1223,10 @@ def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
alloc2: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n, m), dtype="float32") = alloc2
return lv3
alloc3: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_3: R.Tuple = cls.tir_exp(lv3, alloc3)
lv4: R.Tensor((n, m), dtype="float32") = alloc3
return lv4

@I.ir_module
class Expected:
Expand All @@ -1216,19 +1244,22 @@ def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
m = T.int64()
R.func_attr({"relax.force_pure": True, "tir_var_upper_bound": {"n": 20}})
cls = Expected
storage: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
storage: R.Object = R.memory.alloc_storage(R.shape([80 * m]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, m])))
full: R.Tensor((n, m), dtype="float32") = alloc
storage1: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
storage1: R.Object = R.memory.alloc_storage(R.shape([80 * m]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_1: R.Tuple = cls.tir_exp(full, alloc1)
lv2: R.Tensor((n, m), dtype="float32") = alloc1
storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc2: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
alloc2: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n, m), dtype="float32") = alloc2
return lv3
storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc3: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_3: R.Tuple = cls.tir_exp(lv3, alloc3)
lv4 = alloc3
return lv4
# fmt: on

mod = relax.transform.StaticPlanBlockMemory()(Module)
Expand Down