Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Prevent duplicate outputs in subgraph Tuple #5320

Merged
merged 10 commits into from
Apr 15, 2020
226 changes: 125 additions & 101 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
// region_function_calls is map that maintains
// (each annotated regions) --> created function

if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the
// region Thus, the function is always created and at the end there
// would be a tuple node Therefore, we insert a tuple get item node.

// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);

auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function

Array<Expr> fields;

for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);

Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;

for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}

Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}

std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}

std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);

// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;

if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
if (region_function_calls.find(region) == region_function_calls.end()) {
// First time this region is encountered in the traversal.
// Creating the function.
CreateFunction(region, call);
}
// Retrieve this particular output of function.
return GetFunctionOutput(region, GetRef<Call>(call));
}
}

Expand Down Expand Up @@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
}

/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
* \brief This function is called first time that we encounter a compiler_end
* node to create the function for the subgraph.
*/
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
return idx;
void CreateFunction(AnnotatedRegion region, const CallNode* call) {
zhiics marked this conversation as resolved.
Show resolved Hide resolved
// Create fields which is a unique list of outputs. Also populate
// region_return_indices_ map which maps parent of compiler_end node to
// corresponding index in fields.
Array<Expr> fields;
int i = 0;
for (auto ret : region->GetOutputs()) {
auto ret_node = Downcast<Call>(ret)->args[0];
// Don't duplicate outputs.
if (!region_return_indices_.count(region) ||
!region_return_indices_[region].count(ret_node)) {
auto ret_expr = VisitExpr(ret_node);
fields.push_back(ret_expr);
region_return_indices_[region][ret_node] = i;
i++;
}
idx++;
}
return -1;

Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;

for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}

Function global_region_func;
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}

std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}

std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);

// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
}

/*!
* \brief Get the return(output) of the function for compiler end node "end_arg".
* This will return either a Call (for a function with a single output) or a
* TupleGetItem (for a function with multiple outputs).
*/
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
Expr arg = Downcast<Call>(end_arg)->args[0];
// Function has one output.
if (region_return_indices_[region].size() == 1) {
return region_function_calls[region];
}
// Function has multiple outputs.
// Use already made TupleGetItem.
if (region_return_tuplegetitem_.count(region) &&
region_return_tuplegetitem_[region].count(arg)) {
return region_return_tuplegetitem_[region][arg];
}
// Create new TupleGetItem.
CHECK(region_return_indices_.count(region) &&
region_return_indices_[region].count(arg));
int index = region_return_indices_[region][arg];

auto func_call = region_function_calls[region];
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
return std::move(tuple_get_item_);
}

/*!
Expand All @@ -485,6 +492,23 @@ class Partitioner : public ExprMutator {
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
region_args;

/*!
* \brief This map maintains the index of an output in the subgraph function
* for a given region. If there are multiple entries for a region, then the
* function has a tuple of multiple outputs for its return.
*/
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
region_return_indices_;

/*!
* \brief This map holds already created TupleGetItem nodes for accessing
* outputs of a function.
*/
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
region_return_tuplegetitem_;

/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
Expand Down