Skip to content

Commit

Permalink
[Relax] Fix EliminiateCommonSubexpr removing alloc tensor (#16852)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Apr 6, 2024
1 parent 3e802d1 commit a156181
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,8 @@ RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
"The storage scope of the storage to allocate. Default is global.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllocateTensor)
// memory allocation isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<Bool>("TAllocator", Bool(true));

Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index,
StringImm storage_scope) {
Expand All @@ -875,7 +876,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
// memory allocation isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<Bool>("TAllocator", Bool(true));

Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm storage_scope,
DataTypeImm dtype) {
Expand Down Expand Up @@ -906,7 +908,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_tensor")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMemAllocTensor)
// memory allocation isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<Bool>("TAllocator", Bool(true));

Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) {
static const Op& op = Op::Get("relax.memory.alloc_tensor");
Expand Down Expand Up @@ -960,7 +963,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_storage")
"The storage scope of the storage to allocate. Default is global.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
// memory allocation isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<Bool>("TAllocator", Bool(true));

Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm dtype,
StringImm storage_scope) {
Expand Down Expand Up @@ -998,7 +1002,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_tensor")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoVMAllocTensor)
// memory allocation isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<Bool>("TAllocator", Bool(true));

Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) {
static const Op& op = Op::Get("relax.vm.alloc_tensor");
Expand Down
15 changes: 15 additions & 0 deletions src/relax/transform/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class CommonSubexprEliminator : public ExprMutator {
} else if (ContainsImpureCall(bound_value)) {
VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value;

} else if (IsAllocatorCall(bound_value)) {
VLOG(1) << "Skip allocator calls";
} else if (auto it = expr_replacements_.find(lookup_key);
it != expr_replacements_.end() && it->second.size()) {
VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second[0]
Expand Down Expand Up @@ -186,6 +188,19 @@ class CommonSubexprEliminator : public ExprMutator {
return clean_mutator.VisitExpr(expr);
}

bool IsAllocatorCall(const Expr& expr) {
static const auto& allocator_attr_map = Op::GetAttrMap<Bool>("TAllocator");
if (const auto* call = expr.as<CallNode>()) {
if (const auto* op = call->op.as<OpNode>()) {
bool is_allocator = allocator_attr_map.get(GetRef<Op>(op), Bool(false))->value;
if (is_allocator) {
return true;
}
}
}
return false;
}

bool call_only_{false};
std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_;
};
Expand Down
32 changes: 32 additions & 0 deletions tests/python/relax/test_transform_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,5 +627,37 @@ def foo(
verify(Before, Expected)


def test_keep_alloc_tensor():
@I.ir_module
class Before:
@R.function
def foo(x: R.Tensor((2, 3), dtype="float32")):
tmp_buf1 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"), R.prim_value(0))
tmp_buf2 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"), R.prim_value(0))
out = R.add(tmp_buf1, tmp_buf2)
return out

Expected = Before

verify(Before, Expected)


def test_keep_alloc_storage():
@I.ir_module
class Before:
@R.function
def foo(x: R.Tensor((2, 3), dtype="float32")):
tmp_storage1 = R.vm.alloc_storage(R.shape([64]), runtime_device_index=0, dtype="uint8")
tmp_buf1 = R.vm.alloc_tensor(tmp_storage1, offset=0, shape=R.shape([64]), dtype="int32")
tmp_storage2 = R.vm.alloc_storage(R.shape([64]), runtime_device_index=0, dtype="uint8")
tmp_buf2 = R.vm.alloc_tensor(tmp_storage2, offset=0, shape=R.shape([64]), dtype="int32")
out = R.add(tmp_buf1, tmp_buf2)
return out

Expected = Before

verify(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a156181

Please sign in to comment.