Skip to content

Commit

Permalink
[Bugfix][Transform] Keep private non-primitive functions in FuseTIR (#…
Browse files Browse the repository at this point in the history
…16565)

Prior to this commit, private non-primitive relax functions would be
discarded by `FuseTIR`.  If any calls to these functions exist, the
resulting `IRModule` would be ill-formed.  This commit updates
`FuseTIR` so that it only applies updates to functions with
`attr::kPrimitive`, and calls into those functions.

To retain backwards compatibility, `DeadCodeElimination` is applied as
a post-processing step.
  • Loading branch information
Lunderberg committed Feb 14, 2024
1 parent 2b813ec commit 274c368
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 113 deletions.
5 changes: 4 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,12 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
*
* Any binding blocks that are left empty will be removed by the normalizer.
*
* \param entry_functions Names of functions that should be considered
* as entry points, in addition to any externally exposed functions.
*
* \return The Pass.
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions = {});

/*!
* \brief Pass that changes calls to operators that can be done in-place
Expand Down
235 changes: 123 additions & 112 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,57 +961,73 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
*/
class TIRFuseMutator : public ExprMutator {
public:
static IRModule Transform(const IRModule& mod) {
Map<GlobalVar, BaseFunc> funcs_to_keep;
for (const auto& [gv, func] : mod->functions) {
// 1. If a TIR function has global symbol, we keep the function.
// 2. Always keep ExternFunc.
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
if (prim_func->GetAttr<String>("global_symbol").defined()) {
funcs_to_keep.Set(gv, func);
static IRModule Transform(IRModule mod) {
// Collect all primitive relax functions
Map<GlobalVar, Function> primitive_relax;
for (const auto& [gvar, base_func] : mod->functions) {
// Only fuse primitive relax functions
if (base_func->HasNonzeroAttr(attr::kPrimitive)) {
if (auto func = base_func.as<relax::Function>()) {
primitive_relax.Set(gvar, func.value());
}
} else if (func->IsInstance<ExternFuncNode>()) {
funcs_to_keep.Set(gv, func);
}
}

if (primitive_relax.empty()) {
return mod;
}

mod.CopyOnWrite();

IRModule updates;
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements;

// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
TIRFuseMutator mutator(mod);

// Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
for (const auto& [gv, func] : mod->functions) {
// Only fuse primitive relax functions
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv);
mutator.fused_tir_funcs_.Set(gv, prim_func);
if (!indices.empty()) {
mutator.inplace_indices_.Set(gv, indices);
}
}
for (const auto& [old_gvar, func] : primitive_relax) {
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar);

GlobalVar new_gvar(old_gvar->name_hint);
UpdateStructInfo(new_gvar,
FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)));

mod->Remove(old_gvar);
updates->Add(new_gvar, prim_func);
replacements[old_gvar] = Replacement{new_gvar, func, indices};
}

TIRFuseMutator mutator(replacements);

// Step 2. Update all non-primitive relax functions and add it, with the dependent function,
// into the new IRModule

for (const auto& [gv, func] : mod->functions) {
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
if (func->IsInstance<relax::FunctionNode>()) {
ICHECK(!func->HasNonzeroAttr(attr::kPrimitive))
<< "Module should not contain any primitive relax functions at this point";
relax::Function update_func = Downcast<Function>(mutator.VisitExpr(func));
mutator.builder_->AddFunction(update_func, gv->name_hint);
}
}

// Step 3. Add all functions that need to be kept.
auto modified_mod = mutator.builder_->GetContextIRModule();
for (const auto& [gv, func] : funcs_to_keep) {
if (!modified_mod->ContainGlobalVar(gv->name_hint)) {
modified_mod->Add(gv, func);
if (!update_func.same_as(func)) {
updates->Add(gv, update_func);
}
}
}

// Step 4. Copy over module attributes and return.
if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
return modified_mod;
// Step 4. Copy over updated functions and return.
mod->Update(updates);
return mod;
}

private:
explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {}
struct Replacement {
GlobalVar fused_tir_gvar;
Function original_function;
Array<Integer> inplace_indices;
};

explicit TIRFuseMutator(
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements)
: replacements_(replacements) {}

using ExprMutator::VisitExpr_;

Expand All @@ -1035,92 +1051,86 @@ class TIRFuseMutator : public ExprMutator {

Call call = Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));

if (call->op->IsInstance<GlobalVarNode>()) {
// Case 1. It is a relax cross function call
GlobalVar old_gv = Downcast<GlobalVar>(call->op);
auto relax_func = Downcast<Function>(mod_->Lookup(old_gv));
auto it = fused_tir_funcs_.find(old_gv);
if (it != fused_tir_funcs_.end()) {
const tir::PrimFunc& fused_tir = (*it).second;
// Case 1.1. It calls a primitive relax function, update the call into a call_tir
GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint);
// Step a. Flatten all args since call_tir does not support Tuple value.
Array<Expr> arg_list;
Array<PrimExpr> tir_vars;
for (size_t i = 0; i < call->args.size(); ++i) {
auto arg = call->args[i];
auto sinfo = GetStructInfo(arg);

ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
!sinfo.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, argument " << arg << " with struct info " << arg->struct_info_
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
<< relax_func->params[i]->struct_info_;

if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined())
<< "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
"arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

} else {
arg_list.push_back(arg);
}
}
// Step b. Create call_tir or call_tir_inplace
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
if (!tir_vars.empty()) {
call_args.push_back(ShapeExpr(tir_vars));
}
Op call_op = call_tir_op_;
Attrs call_attrs = call->attrs;
if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) {
call_op = call_tir_inplace_op_;
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
inplace_attrs->inplace_indices = (*it).second;
call_attrs = Attrs(inplace_attrs);
auto opt_gvar = call->op.as<GlobalVar>();
if (!opt_gvar) {
// Case 1. The Call isn't a relax-to-relax function call, no need to update.
return call;
}
GlobalVar old_gvar = opt_gvar.value();

auto it = replacements_.find(old_gvar);
if (it == replacements_.end()) {
// Case 2. The callee function is not a primitive relax
// function, no need to update.
return call;
}
const Replacement& replacement = it->second;
const GlobalVar& fused_tir_gv = replacement.fused_tir_gvar;
const Function& relax_func = replacement.original_function;

// Case 3. It calls a primitive relax function, update the call
// into a call_tir or call_tir_inplace.

// Step a. Collect all relax/symbolic arguments. Tuple arguments
// are not supported by PrimFunc, so this step verifies that
// ExpandTupleArguments has already removed them.
Array<Expr> arg_list;
Array<PrimExpr> tir_vars;
for (size_t i = 0; i < call->args.size(); ++i) {
auto arg = call->args[i];
auto sinfo = GetStructInfo(arg);

ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
!sinfo.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, argument " << arg << " with struct info " << arg->struct_info_
<< " is passed as argument " << i << " to Primitive Relax function " << old_gvar
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
<< relax_func->params[i]->struct_info_;

if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined()) << "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
"arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

} else {
// Case 1.2. The callee function is not primitive, nothing to do.
return call;
}
} else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
// Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
Array<Expr> new_args = call->args;
new_args.Set(0, new_gv);
return Call(call->op, new_args, call->attrs, call->sinfo_args, call->span);
arg_list.push_back(arg);
}
}

// Case 3. CallNode in other types. Leave it as it is.
return call;
// Step b. Create call_tir or call_tir_inplace
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
if (!tir_vars.empty()) {
call_args.push_back(ShapeExpr(tir_vars));
}
Op call_op = call_tir_op_;
Attrs call_attrs = call->attrs;
if (replacement.inplace_indices.size()) {
call_op = call_tir_inplace_op_;
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
inplace_attrs->inplace_indices = replacement.inplace_indices;
call_attrs = Attrs(inplace_attrs);
}
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
}

private:
/*! \brief The IRModule */
const IRModule& mod_;
/*! \brief The map from global var of primitive relax function to generated prim func. */
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
/*! \brief The map from global var of primitive relax function to in-place indices
* (if there are any). */
Map<GlobalVar, Array<Integer>> inplace_indices_;
/*! \brief The map from global var to how it should be replaced
*
* Has one entry for each primitive relax function in the IRModule.
*/
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements_;
};

IRModule FuseTIR(IRModule mod) {
Expand All @@ -1142,6 +1152,7 @@ Pass FuseTIR() {
ExpandTupleArguments(),
RemoveUnusedParameters(),
inner_pass,
DeadCodeElimination(),
},
"FuseTIR");
}
Expand Down
60 changes: 60 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,5 +2254,65 @@ def main(
_check(Module, Expected)


def test_private_nonprimitive_func():
"""Input IRModule may contain calls to non-primitive functions
This is a regression test. Prior implementations did not preserve
relax-to-relax function calls.
"""

@I.ir_module
class Before:
@R.function
def main(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
gv = cls.fused_func(input_ids, input_embeds)
R.output(gv)
return gv

@R.function(private=True)
def fused_func(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
lv = R.call_tir(
cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
gv = R.call_tir(
cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16")
)
R.output(gv)
return gv

@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out[vi, vj] = A[vi, vj] + T.float16(1.0)

@T.prim_func(private=True)
def take(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
B: T.Buffer((T.int64(1),), "int32"),
T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
):
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]

_check(Before, Before)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 274c368

Please sign in to comment.