Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Call with TIRCallAttrs to call_lowered op #9312

Merged
merged 5 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 0 additions & 11 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,6 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/attrs/call.h
* \brief Attribute for call_lowered operator.
*/
#ifndef TVM_RELAY_ATTRS_CALL_H_
#define TVM_RELAY_ATTRS_CALL_H_

#include <tvm/ir/attrs.h>

#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_CALL_H_
77 changes: 55 additions & 22 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
Expand All @@ -40,6 +41,7 @@
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
Expand Down Expand Up @@ -72,14 +74,34 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
AssignReturnSid(GetRef<Expr>(op));
}

void DeviceAwareVisitExpr_(const CallNode* op) final {
// create token for the call node.
VisitExpr(op->op);
CreateStorage(op);
for (Expr arg : op->args) {
void DeviceAwareVisitExpr_(const CallNode* call_node) final {
// AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case
// where the op of the call is a generic function

Expr func;
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
func = call_lowered_props.lowered_func;
args = call_lowered_props.arguments;
} else { // Relay functions that have not been lowered and lowered extern functions
func = call_node->op;
args = call_node->args;
if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function
ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
} else { // Relay function which has not been lowered yet
ICHECK(call_node->op.as<FunctionNode>())
<< "Expected the call to be to a lowered primfunc, a lowered extern function or a "
"unlowered Relay function.";
}
}
VisitExpr(func);
CreateStorage(call_node);
for (const Expr& arg : args) {
GetStorage(arg);
}
AssignReturnSid(GetRef<Expr>(op));
AssignReturnSid(GetRef<Expr>(call_node));
}

void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef<Expr>(op)); }
Expand Down Expand Up @@ -287,13 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Call a function with a given name
* brief Create a function call
* \param call_lowered_props The lowered function and the arguments to call it with
* \param call The call we got func and args from
*/
void CreateFuncCall(Call call, std::string func_name) {
void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
std::string func_name = call_lowered_props.lowered_func->name_hint;

tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;

// Pack the inputs
for (Expr arg : call->args) {
for (const Expr& arg : call_lowered_props.arguments) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[arg])});
Expand Down Expand Up @@ -371,21 +398,25 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return ss.str();
}

void VisitExpr_(const CallNode* op) override {
void VisitExpr_(const CallNode* call_node) override {
// Descend the call tree
for (auto arg : op->args) {
VisitExpr(arg);
}

if (op->op.as<OpNode>()) {
LOG(FATAL) << "Operators should be transformed away; try applying"
<< "the fuse_ops transformation to the expression.";
} else if (op->op.as<GlobalVarNode>()) {
GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
CreateFuncCall(GetRef<Call>(op), node->name_hint);
CallLoweredProps call_lowered_props;
if (const auto* gvn = call_node->op.as<GlobalVarNode>()) { // Lowered extern function
ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
for (const auto& arg : call_node->args) {
VisitExpr(arg);
}
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
} else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
call_lowered_props = GetCallLoweredProps(call_node);
for (const auto& arg : call_lowered_props.arguments) {
VisitExpr(arg);
}
}
CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
}

void VisitExpr_(const VarNode* op) override {
Expand Down Expand Up @@ -443,7 +474,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); }
void VisitExpr_(const OpNode* op) override {
LOG(FATAL) << "All OpNodes should have been expanded";
if (GetRef<Op>(op) != CallLoweredOp()) {
LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded";
}
}
void VisitExpr_(const IfNode* op) override {
LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called";
Expand Down
12 changes: 11 additions & 1 deletion src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/memory.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include "../../../op/call/call.h"

namespace tvm {
namespace relay {
namespace contrib {
Expand Down Expand Up @@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator {
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);

// Since we are replacing the Relay function with a call to a TIR function, we must use the
// call_lowered op.
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
return CallLowered(std::move(new_global_var), call->args,
std::move(Attrs(call_lowered_attrs)), call->type_args, call->span);
}
}

Expand Down
89 changes: 51 additions & 38 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/json.h>
#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
Expand All @@ -37,6 +38,7 @@
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
Expand Down Expand Up @@ -403,64 +405,75 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return lhs_storage_id == rhs_storage_id;
}

std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& func_name,
GraphAttrs attrs) {
std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) {
Call call = GetRef<Call>(call_node);
std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
auto res = VisitExpr(arg);
for (auto nr : res) {
inputs.push_back(nr);
}
}
std::string func_name;

/// An adapted version of the storage optimization for the time being.
bool reshape_only = false;
if (op->attrs.defined()) {
if (auto tir_call_attrs = op->attrs.as<TIRCallAttrs>()) {
Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
if (metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
reshape_only = true;
}
if (call->op == CallLoweredOp()) {
// Extract function and arguments from the call_lowered op
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

auto relay_attrs = Downcast<DictAttrs>(tir_call_attrs->metadata["relay_attrs"]);
func_name = call_lowered_props.lowered_func->name_hint;

for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
for (const Expr& arg : call_lowered_props.arguments) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}
if (call_lowered_props.attrs.metadata.count("relay_attrs")) {
if (auto relay_attrs =
call_lowered_props.attrs.metadata["relay_attrs"].as<DictAttrsNode>()) {
for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
}
}
}
}
}

if (reshape_only && ShareSameStorage(GetRef<Expr>(op), op->args[0])) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, GetRef<Expr>(op));
bool reshape_only = false;
if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value ==
1) {
reshape_only = true;
}
if (reshape_only &&
ShareSameStorage(GetRef<Expr>(call_node), call_lowered_props.arguments[0])) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, call);
}
} else if (!call_node->attrs.defined()) { // Call is an extern function
std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
const auto* func = call_node->op.as<GlobalVarNode>();
ICHECK(func) << "Expected the operator to be a global var, but got "
<< call_node->op->GetTypeKey(); // getting a relay fn here, not sure why.
func_name = func->name_hint;

for (const Expr& arg : call_node->args) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}
} else {
LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to be call_lowered, "
<< "but found: " << std::endl
<< PrettyPrint(call);
}

// Compute the operator name, because we used the get unique name when generating the kernel.
auto op_name = _GetUniqueName(func_name);
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
return AddNode(node, GetRef<Expr>(op));
return AddNode(node, call);
}

std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
relay::Call call = GetRef<Call>(call_node);
auto props = GetOnDeviceProps(call_node);
if (props.body.defined()) {
// See through "on_device" calls.
return VisitExpr(props.body);
}

const auto* global_node = call->op.as<GlobalVarNode>();
ICHECK(global_node)
<< "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to have their callee "
"normalized to a GlobalVar, but found:"
<< std::endl
<< PrettyPrint(call);
auto prim_fn_name = global_node->name_hint;
return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs());
return GraphAddCallNode(call_node, GraphAttrs());
}

std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
Expand Down