Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr& othe
/*!
* \brief Substitute a given source buffer with a given target buffer in statements or expressions.
*/
class FuseTIRBufferSubstitor : private StmtExprMutator {
class FuseTIRBufferSubstitutor : private StmtExprMutator {
public:
explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map,
const Map<Var, Var>& var_map) {
explicit FuseTIRBufferSubstitutor(const Map<Buffer, Buffer>& buffer_map,
const Map<Var, Var>& var_map) {
buffer_remap_ = buffer_map;
var_remap_ = var_map;
for (const auto& kv : buffer_map) {
const Buffer& src = kv.first;
const Buffer& tgt = kv.second;
for (const auto& [src, tgt] : buffer_map) {
var_remap_.Set(src->data, tgt->data);
}
}
Expand Down Expand Up @@ -246,8 +244,6 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
Map<tir::Buffer, tir::Buffer> buffer_remap_;
/*! \brief Mapping from src tir var to tgt var. */
Map<tir::Var, tir::Var> var_remap_;
/*! \brief The structural equality checker */
StructuralEqual structural_equal_;

Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>& regions) const {
// For now we only allow Buffer access the same elements.
Expand All @@ -262,8 +258,6 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
if (it == buffer_region_set.end()) {
ret.push_back(region);
buffer_region_set[region->buffer.get()] = region->region;
} else {
ICHECK(structural_equal_(region->region, it->second));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry the diff is noisy due to the unrelated style updates, but this is the only important diff.

}
}

Expand Down Expand Up @@ -351,10 +345,8 @@ class FusedTIRConstructor : public ExprVisitor {
// It's a symbolic shape var, no need to alloc Buffers.
continue;
}
auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), //
relax_param->name_hint());
const Array<tir::Var>& params = ret.first;
const Array<tir::Buffer>& buffers = ret.second;
auto [params, buffers] = CreateParamsAndBuffers(GetStructInfo(relax_param), //
relax_param->name_hint());
ICHECK_EQ(params.size(), buffers.size());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.buffer_map.Set(params[i], buffers[i]);
Expand Down Expand Up @@ -384,10 +376,8 @@ class FusedTIRConstructor : public ExprVisitor {
// Step 4. Append symbolic vars
const relax::Var& last_relax_param = func->params.back();
if (GetStructInfo(last_relax_param)->IsInstance<ShapeStructInfoNode>()) {
auto ret =
auto [params, buffers] =
CreateParamsAndBuffers(GetStructInfo(last_relax_param), last_relax_param->name_hint());
const Array<tir::Var>& params = ret.first;
const Array<tir::Buffer>& buffers = ret.second;
ICHECK(buffers.empty());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.params.push_back(params[i]);
Expand Down Expand Up @@ -682,9 +672,7 @@ class FusedTIRConstructor : public ExprVisitor {
"list.";
if (index == -1) index = 0;
for (size_t i = 0; i < tuple->fields.size(); ++i) {
auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
const Array<tir::Var>& ret_params = ret.first;
const Array<tir::Buffer>& ret_buffers = ret.second;
auto [ret_params, ret_buffers] = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
ICHECK_EQ(ret_params.size(), ret_buffers.size());
// Adding tuple field results to the end of params and buffers.
params.insert(params.end(), ret_params.begin(), ret_params.end());
Expand Down Expand Up @@ -714,19 +702,18 @@ class FusedTIRConstructor : public ExprVisitor {
tir::PrimFunc ConstructFunc() {
Map<String, ObjectRef> attr_map;
attr_map.Set("tir.noalias", tir::const_true());
tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map,
func_info_.symbolic_var_remap);
tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap);
ICHECK(func_info_.global_name != "fused");
// Remove output buffers from func_info_.alloc_buffers
Array<tir::Buffer> alloc_buffers;
for (const tir::Buffer& buf : func_info_.alloc_buffers) {
if (func_info_.output_buffers.count(buf.get()) == 0) {
alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf));
alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf));
}
}
tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));

body = substitor.Substitute(body);
body = subst.Substitute(body);
body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers);
body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map,
Expand Down Expand Up @@ -804,9 +791,7 @@ class TIRFuseMutator : public ExprMutator {
// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
TIRFuseMutator mutator(mod);
// Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
for (const auto& kv : mod->functions) {
const GlobalVar& gv = kv.first;
const BaseFunc& func = kv.second;
for (const auto& [gv, func] : mod->functions) {
// Only fuse primitive relax functions
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
Expand All @@ -816,9 +801,7 @@ class TIRFuseMutator : public ExprMutator {

// Step 2. Update all non-primitive relax functions and add it, with the dependent function,
// into the new IRModule
for (const auto& kv : mod->functions) {
const GlobalVar& gv = kv.first;
const BaseFunc& func = kv.second;
for (const auto& [gv, func] : mod->functions) {
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
relax::Function update_func = Downcast<Function>(mutator.VisitExpr(func));
mutator.builder_->AddFunction(update_func, gv->name_hint);
Expand Down
120 changes: 120 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,5 +883,125 @@ def main(
_check(Before, Expected)


def test_same_buffer_multiple_read():
@I.ir_module
class Module:
@T.prim_func
def concatenate(
rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"),
rxplaceholder_1: T.Buffer(
(T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"
),
T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"),
):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
with T.block("T_concat"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in this test case, rxplaceholder_1 and rxplaceholder refer to the same buffer inp_0. But they are used with different access patterns.

)
T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
T.int64(1) <= v_ax0,
rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
)

@T.prim_func
def transpose2(
rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"),
T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"),
):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[
v_ax0, v_ax3, v_ax1, v_ax2
]

@R.function
def fused_concatenate_transpose2(
inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(
cls.concatenate,
(inp_0, inp_0),
out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"),
)
gv = R.call_tir(
cls.transpose2, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32")
)
R.output(gv)
return gv

@R.function
def main(
inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"num_input": 3})
cls = Module
with R.dataflow():
lv = cls.fused_concatenate_transpose2(inp_0)
R.output(lv)
return lv

@I.ir_module
class Expected:
@T.prim_func
def fused_concatenate_transpose2(
inp_0: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"),
T_transpose_handle_intermediate: T.Buffer(
(T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
T_concat_handle_intermediate = T.alloc_buffer(
(T.int64(2), T.int64(4), T.int64(64), T.int64(64))
)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
with T.block("T_concat"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3])
T.writes(T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
T.int64(1) <= v_ax0,
inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
inp_0[v_ax0, v_ax1, v_ax2, v_ax3],
)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2])
T.writes(T_transpose_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_handle_intermediate[
v_ax0, v_ax1, v_ax2, v_ax3
] = T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2]

@R.function
def main(
inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"num_input": 3})
cls = Expected
with R.dataflow():
lv = R.call_tir(
cls.fused_concatenate_transpose2,
(inp_0,),
out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"),
)
R.output(lv)
return lv

_check(Module, Expected)


if __name__ == "__main__":
tvm.testing.main()