Skip to content

Commit

Permalink
Store function result's virtual device in the virtual device field
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Jan 11, 2022
1 parent e7f3648 commit 800db32
Show file tree
Hide file tree
Showing 36 changed files with 319 additions and 167 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/expr.h
Expand Up @@ -180,6 +180,12 @@ class RelayExprNode : public BaseExprNode {
* the call to the function or closure is stored (instead of where the function itself is stored).
* The VirtualDevice's Target field describes how the body of the function should be compiled.
*
* Set to VirtualDevice::FullyUnconstrained by default.
*
* Set to VirtualDevice::FullyUnconstrained by default.
*
* Set to VirtualDevice::FullyUnconstrained by default.
*
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
* import.
*/
Expand Down Expand Up @@ -221,6 +227,7 @@ class GlobalVarNode : public RelayExprNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
10 changes: 0 additions & 10 deletions include/tvm/ir/function.h
Expand Up @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol";
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

/*!
* \brief The \p VirtualDevice which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: VirtualDevice
*/
constexpr const char* kResultVirtualDevice = "result_virtual_device";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
1 change: 1 addition & 0 deletions include/tvm/relay/adt.h
Expand Up @@ -299,6 +299,7 @@ class MatchNode : public ExprNode {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/expr.h
Expand Up @@ -72,6 +72,7 @@ class ConstantNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -108,6 +109,7 @@ class TupleNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -196,6 +198,7 @@ class VarNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -319,6 +322,7 @@ class CallNode : public ExprNode {
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -425,6 +429,7 @@ class LetNode : public ExprNode {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -516,6 +521,7 @@ class IfNode : public ExprNode {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -589,6 +595,7 @@ class TupleGetItemNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -652,6 +659,7 @@ class RefCreateNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -713,6 +721,7 @@ class RefReadNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -776,6 +785,7 @@ class RefWriteNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/function.h
Expand Up @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode {
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
128 changes: 90 additions & 38 deletions python/tvm/ir/json_compact.py
Expand Up @@ -57,6 +57,53 @@ def _updater(data):
return _updater


def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9
Returns
-------
fupdater : function
The updater function
"""

def _initialize_virtual_device(item, _):
if "virtual_device_" not in item["attrs"].keys():
item["attrs"]["virtual_device_"] = "0"
return item

node_map = {
# Base IR
"GlobalVar": _initialize_virtual_device,
"relay.Var": _initialize_virtual_device,
"relay.Function": _initialize_virtual_device,
"relay.Tuple": _initialize_virtual_device,
"relay.Call": _initialize_virtual_device,
"relay.Let": _initialize_virtual_device,
"relay.If": _initialize_virtual_device,
"relay.TupleGetItem": _initialize_virtual_device,
"relay.RefCreate": _initialize_virtual_device,
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")


def create_updater_07_to_08():
"""Create an update to upgrade json from v0.7 to v0.8"""

def _initialize_module_attributes(item, _):
assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules"
if "attrs" not in item["attrs"].keys():
item["attrs"]["attrs"] = "0"
return item

node_map = {"IRModule": _initialize_module_attributes}
return create_updater(node_map, "0.7", "0.8")


def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7
Expand Down Expand Up @@ -127,7 +174,7 @@ def _convert(item, nodes):
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Constructor": [_update_from_std_str("name_hint")],
"relay.Constructor": _update_from_std_str("name_hint"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
Expand All @@ -143,43 +190,43 @@ def _convert(item, nodes):
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
"StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
"Cast": [_rename("tir.Cast")],
"Add": [_rename("tir.Add")],
"Sub": [_rename("tir.Sub")],
"Mul": [_rename("tir.Mul")],
"Div": [_rename("tir.Div")],
"Mod": [_rename("tir.Mod")],
"FloorDiv": [_rename("tir.FloorDiv")],
"FloorMod": [_rename("tir.FloorMod")],
"Min": [_rename("tir.Min")],
"Max": [_rename("tir.Max")],
"EQ": [_rename("tir.EQ")],
"NE": [_rename("tir.NE")],
"LT": [_rename("tir.LT")],
"LE": [_rename("tir.LE")],
"GT": [_rename("tir.GT")],
"GE": [_rename("tir.GE")],
"And": [_rename("tir.And")],
"Or": [_rename("tir.Or")],
"Not": [_rename("tir.Not")],
"Select": [_rename("tir.Select")],
"Load": [_rename("tir.Load")],
"BufferLoad": [_rename("tir.BufferLoad")],
"Ramp": [_rename("tir.Ramp")],
"Broadcast": [_rename("tir.Broadcast")],
"Shuffle": [_rename("tir.Shuffle")],
"Cast": _rename("tir.Cast"),
"Add": _rename("tir.Add"),
"Sub": _rename("tir.Sub"),
"Mul": _rename("tir.Mul"),
"Div": _rename("tir.Div"),
"Mod": _rename("tir.Mod"),
"FloorDiv": _rename("tir.FloorDiv"),
"FloorMod": _rename("tir.FloorMod"),
"Min": _rename("tir.Min"),
"Max": _rename("tir.Max"),
"EQ": _rename("tir.EQ"),
"NE": _rename("tir.NE"),
"LT": _rename("tir.LT"),
"LE": _rename("tir.LE"),
"GT": _rename("tir.GT"),
"GE": _rename("tir.GE"),
"And": _rename("tir.And"),
"Or": _rename("tir.Or"),
"Not": _rename("tir.Not"),
"Select": _rename("tir.Select"),
"Load": _rename("tir.Load"),
"BufferLoad": _rename("tir.BufferLoad"),
"Ramp": _rename("tir.Ramp"),
"Broadcast": _rename("tir.Broadcast"),
"Shuffle": _rename("tir.Shuffle"),
"Call": [_rename("tir.Call"), _update_from_std_str("name")],
"Let": [_rename("tir.Let")],
"Any": [_rename("tir.Any")],
"LetStmt": [_rename("tir.LetStmt")],
"AssertStmt": [_rename("tir.AssertStmt")],
"Store": [_rename("tir.Store")],
"BufferStore": [_rename("tir.BufferStore")],
"BufferRealize": [_rename("tir.BufferRealize")],
"Allocate": [_rename("tir.Allocate")],
"IfThenElse": [_rename("tir.IfThenElse")],
"Evaluate": [_rename("tir.Evaluate")],
"Prefetch": [_rename("tir.Prefetch")],
"Let": _rename("tir.Let"),
"Any": _rename("tir.Any"),
"LetStmt": _rename("tir.LetStmt"),
"AssertStmt": _rename("tir.AssertStmt"),
"Store": _rename("tir.Store"),
"BufferStore": _rename("tir.BufferStore"),
"BufferRealize": _rename("tir.BufferRealize"),
"Allocate": _rename("tir.Allocate"),
"IfThenElse": _rename("tir.IfThenElse"),
"Evaluate": _rename("tir.Evaluate"),
"Prefetch": _rename("tir.Prefetch"),
"AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
"Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
"Buffer": [
Expand All @@ -206,8 +253,13 @@ def upgrade_json(json_str):
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]

if from_version.startswith("0.6"):
data = create_updater_06_to_07()(data)
data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data)))
elif from_version.startswith("0.7"):
data = create_updater_08_to_09()(create_updater_07_to_08()(data))
elif from_version.startswith("0.8"):
data = create_updater_08_to_09()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)
1 change: 1 addition & 0 deletions src/printer/relay_text_printer.cc
Expand Up @@ -458,6 +458,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n";
doc << PrintBody(fn->body);
return doc;
}
Expand Down
7 changes: 3 additions & 4 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Expand Up @@ -55,8 +55,8 @@ class ExtractConstantsMutator : public MixedModeMutator {
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
func = WithFields(std::move(func), FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
}
return std::move(func);
}
Expand Down Expand Up @@ -147,8 +147,7 @@ IRModule ExtractConstants(const IRModule& mod) {
auto new_main_body = extract_constants.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
Function new_main_func = WithFields(std::move(main_func), main_func->params, new_main_body);
mod->Update(main_var, new_main_func);
}
return mod;
Expand Down
9 changes: 2 additions & 7 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Expand Up @@ -44,13 +44,8 @@ class RelayToTIRVisitor : public MixedModeMutator {

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
Function mutated_main = WithFields(std::move(main), main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);

Expand Down
8 changes: 2 additions & 6 deletions src/relay/backend/contrib/ethosu/codegen.cc
Expand Up @@ -56,12 +56,8 @@ class RelayToTIRMutator : public MixedModeMutator {

IRModule operator()() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Function main_func = Downcast<Function>(ir_module_->Lookup(main_global_var));

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
Function mutated_main = WithFields(std::move(main), main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);
ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_);
Expand Down
9 changes: 6 additions & 3 deletions src/relay/backend/contrib/ethosu/preprocess.cc
Expand Up @@ -177,11 +177,14 @@ class ExternalFuncIOHandler : public ExprRewriter {
reshaped_outputs.push_back(CreateFlattenTensor(out));
}
auto concat_out = CreateConcatTensor(reshaped_outputs);
auto f = Function(params, concat_out, concat_out->checked_type_, {}, func->attrs);
Function f = WithFields(std::move(func), std::move(params), std::move(concat_out),
std::move(concat_out->checked_type_),
Array<TypeVar>() /* erase type params */);
return InferType(f, this->module_);
} else {
auto f =
Function(params, core_compute_expr, core_compute_expr->checked_type_, {}, func->attrs);
Function f = WithFields(std::move(func), std::move(params), std::move(core_compute_expr),
std::move(core_compute_expr->checked_type_),
Array<TypeVar>() /* erase type params */);
return InferType(f, this->module_);
}
}
Expand Down
Expand Up @@ -43,13 +43,8 @@ class ConvertAddToSubtract : public MixedModeMutator {

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = GetRef<Function>(ir_module_->Lookup(main_global_var).as<FunctionNode>());
Function mutated_main = WithFields(std::move(main), main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);

Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/graph_plan_memory.cc
Expand Up @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
// Either all or none of the nodes should be annotated.
VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
<< std::endl;
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
Expand Down

0 comments on commit 800db32

Please sign in to comment.