Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor compile_engine to introduce TETranslator #6888

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Expand Up @@ -255,7 +255,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)


@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
def lower_call(call, inputs, target, no_trace=False):
"""Lower the call expression to op implementation and tensor outputs."""
assert isinstance(call.op, tvm.ir.Op)
op = call.op
Expand Down Expand Up @@ -283,7 +283,7 @@ def lower_call(call, inputs, target):
env = autotvm.task.TaskExtractEnv.current
reenable_tracing = False
if env is not None and env.tracing:
if env.wanted_relay_ops is not None and op not in env.wanted_relay_ops:
if (env.wanted_relay_ops is not None and op not in env.wanted_relay_ops) or no_trace:
env.tracing = False
reenable_tracing = True

Expand Down Expand Up @@ -410,3 +410,7 @@ def get():
The compile engine.
"""
return _backend._CompileEngineGlobal()


def translate_to_te(prim_func, target):
return _backend._TranslateToTE(prim_func, target)
289 changes: 178 additions & 111 deletions src/relay/backend/compile_engine.cc
Expand Up @@ -52,6 +52,7 @@ namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
TVM_REGISTER_NODE_TYPE(TEGraphNode);
TVM_REGISTER_NODE_TYPE(CachedFuncNode);
TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
TVM_REGISTER_NODE_TYPE(CCacheValueNode);
Expand Down Expand Up @@ -94,26 +95,44 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
return res;
}

// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
te::Tensor GetScalar(const ConstantNode* op) {
using tir::make_const;
ICHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
auto value = te::compute(
{},
[&](const Array<tvm::tir::Var>&) {
if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == DataType::Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == DataType::Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == DataType::Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::PrimExpr();
}
},
"compile_engine_const", topi::kBroadcast);
return value;
}

class TETranslator : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = transform::PassContext::Current()
->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
.value();
}
explicit TETranslator(Target target) : target_(target), device_copy_op_(Op::Get("device_copy")) {}

CachedFunc Create(const Function& prim_func) {
auto cache_node = make_object<CachedFuncNode>();
cache_node->target = target_;
TEGraph Translate(const Function& prim_func) {
auto graph_node = make_object<TEGraphNode>();
for (Var param : prim_func->params) {
Array<tvm::te::Tensor> inputs;
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
graph_node->inputs.push_back(tensor);
inputs.push_back(tensor);
} else {
// flatten tuple of tensor type.
Expand All @@ -123,14 +142,123 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
// TODO(@icemelon): Allow recursive tuple
ICHECK(ttype != nullptr);
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
graph_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
}
memo_[param] = inputs;
}
graph_node->outputs = this->VisitExpr(prim_func->body);
return TEGraph(graph_node);
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
}

Array<te::Tensor> VisitExpr_(const ConstantNode* op) final { return {GetScalar(op)}; }

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";

Array<te::Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
for (te::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
}
if (count_tuple) {
ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
}

ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

Array<te::Tensor> outputs;
OpImplementation impl;
// Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
} else {
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_, true);
outputs = lowered_out->outputs;
}

if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
ICHECK(tuple_type) << "Expect output to be a tuple type";
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
}
return outputs;
}

Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
return Array<te::Tensor>();
}

Array<te::Tensor> VisitExpr_(const LetNode* op) final {
Array<te::Tensor> val = VisitExpr(op->value);
ICHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}

Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
ICHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}

Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
Array<te::Tensor> tuple = VisitExpr(op->tuple);
ICHECK_EQ(tuple_type->fields.size(), tuple.size());
ICHECK_GE(op->index, 0);
ICHECK_LT(static_cast<size_t>(op->index), tuple.size());
return {tuple[op->index]};
}

private:
tvm::Target target_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};

// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter : public ExprVisitor {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = transform::PassContext::Current()
->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
.value();
}

CachedFunc Create(const Function& prim_func) {
auto translator = TETranslator(target_);
auto te_graph = translator.Translate(prim_func);
auto cache_node = make_object<CachedFuncNode>();
cache_node->target = target_;
cache_node->inputs = te_graph->inputs;
cache_node->outputs = te_graph->outputs;
readable_name_stream_ << "fused";
cache_node->outputs = this->VisitExpr(prim_func->body);
this->VisitExpr(prim_func->body);
auto candidate_name = readable_name_stream_.str();
constexpr static size_t kMaxFuncNameLength = 80;
if (candidate_name.size() > kMaxFuncNameLength) {
Expand Down Expand Up @@ -166,7 +294,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
}
}

// Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
// Use TOPI schedule if user specified, or the function has no auto_scheduler schedule.
if (!schedule.defined()) {
ICHECK(anchor_implementation_.defined());
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
Expand All @@ -181,72 +309,50 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
return CachedFunc(cache_node);
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
}

Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
auto value = te::compute(
{},
[&](const Array<tvm::tir::Var>&) {
if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == DataType::Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == DataType::Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == DataType::Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::PrimExpr();
}
},
"compile_engine_const", topi::kBroadcast);
void VisitExpr_(const ConstantNode* op) final {
auto value = GetScalar(op);
scalars_.push_back(value->op);
return {value};
}

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
void VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";

Array<te::Tensor> inputs;
int count_tuple = 0;
Array<te::Tensor> inputs;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
for (te::Tensor tensor : VisitExpr(arg)) {
VisitExpr(arg);
if (const auto* ttype = arg->checked_type().as<TensorTypeNode>()) {
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
inputs.push_back(tensor);
} else {
ICHECK_EQ(count_tuple, 0) << "Only allow function with a single tuple input";
// flatten tuple of tensor type.
const auto* tuple_type = arg->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
// TODO(@icemelon): Allow recursive tuple
ICHECK(ttype != nullptr);
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
inputs.push_back(tensor);
++count_tuple;
}
}
}
if (count_tuple) {
ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
}

ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

Array<te::Tensor> outputs;
OpImplementation impl;
// Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
} else {
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
impl = lowered_out->implementation;
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
return;
}
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_, false);
OpImplementation impl = lowered_out->implementation;

int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
Expand All @@ -260,52 +366,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
anchor_op_pattern_ = op_pattern;
anchor_implementation_ = impl;
}
if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
ICHECK(tuple_type) << "Expect output to be a tuple type";
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (op == device_copy_op_) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
readable_name_stream_ << '_' << op->name;
}
return outputs;
}

Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
return Array<te::Tensor>();
}

Array<te::Tensor> VisitExpr_(const LetNode* op) final {
Array<te::Tensor> val = VisitExpr(op->value);
ICHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}

Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
ICHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}

Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
Array<te::Tensor> tuple = VisitExpr(op->tuple);
ICHECK_EQ(tuple_type->fields.size(), tuple.size());
ICHECK_GE(op->index, 0);
ICHECK_LT(static_cast<size_t>(op->index), tuple.size());
return {tuple[op->index]};
readable_name_stream_ << '_' << op->name;
}

private:
Expand Down Expand Up @@ -836,6 +897,12 @@ CompileEngine& CompileEngine::Global() {

TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);

TVM_REGISTER_GLOBAL("relay.backend._TranslateToTE")
.set_body_typed([](Function prim_func, Target target) {
auto translator = TETranslator(target);
return translator.Translate(prim_func);
});

TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
return LoweredOutput(outputs, impl);
Expand Down