Skip to content

Commit

Permalink
Introduce call_lowered op
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Nov 9, 2021
1 parent bd00c66 commit 5861dc6
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 238 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
*/

/*!
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
* \file tvm/relay/attrs/call.h
* \brief Attribute for call_lowered operator.
*/
#ifndef TVM_RELAY_ATTRS_CALL_H_
#define TVM_RELAY_ATTRS_CALL_H_
Expand All @@ -39,7 +39,7 @@ struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call.");
}
};

Expand Down
52 changes: 25 additions & 27 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,23 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
func = func_and_args.first;
args = func_and_args.second;

} else {
ICHECK(call_node->op.as<FunctionNode>())
<< "Expect call to be call_lowered op or function node. ";
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
func = call_lowered_props.lowered_func;
args = call_lowered_props.arguments;
} else { // Relay functions that have not been lowered and lowered extern functions
func = call_node->op;
args = call_node->args;
if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function
ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
} else { // Relay function which has not been lowered yet
ICHECK(call_node->op.as<FunctionNode>())
<< "Expected the call to be to a lowered primfunc, a lowered extern function or a "
"unlowered Relay function.";
}
}
VisitExpr(func);
CreateStorage(call_node);
for (Expr arg : args) {
for (const Expr& arg : args) {
GetStorage(arg);
}
AssignReturnSid(GetRef<Expr>(call_node));
Expand Down Expand Up @@ -306,23 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Call a function with a given name
* brief Create a function call
* \param call_lowered_props The lowered function and the arguments to call it with
* \param call The call we got func and args from
*/
void CreateFuncCall(const CallNode* call_node) {
Call call = GetRef<Call>(call_node);

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;

std::string func_name = func->name_hint;
void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
std::string func_name = call_lowered_props.lowered_func->name_hint;

tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;
// Pack the inputs

for (Expr arg : call_args) {
// Pack the inputs
for (const Expr& arg : call_lowered_props.arguments) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[arg])});
Expand Down Expand Up @@ -402,18 +400,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {

void VisitExpr_(const CallNode* call_node) override {
// Descend the call tree
ICHECK(call_node->op == CallLoweredOp()) << "Only expect call_lowered op at this point";
ICHECK(call_node->op == CallLoweredOp())
<< "Operators should be transformed away; Try applying the fuse_ops transformation to the "
"expression.";

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

for (auto arg : call_args) {
for (auto arg : call_lowered_props.arguments) {
VisitExpr(arg);
}

CreateFuncCall(call_node);
CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
}

void VisitExpr_(const VarNode* op) override {
Expand Down
12 changes: 11 additions & 1 deletion src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/memory.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include "../../../op/call/call.h"

namespace tvm {
namespace relay {
namespace contrib {
Expand Down Expand Up @@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator {
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);

// Since we are replacing the Relay function with a call to a TIR function, we must use the
// call_lowered op.
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
return CallLowered(std::move(new_global_var), call->args,
std::move(Attrs(call_lowered_attrs)), call->type_args, call->span);
}
}

Expand Down
92 changes: 46 additions & 46 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,58 +407,58 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<

std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) {
Call call = GetRef<Call>(call_node);
ICHECK(call->op == CallLoweredOp())
<< "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to be call_lowered, "
<< "but found: " << std::endl
<< PrettyPrint(call);
std::vector<GraphNodeRef> inputs;
std::string func_name;

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;
if (call->op == CallLoweredOp()) {
// Extract function and arguments from the call_lowered op
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

std::string func_name = func->name_hint;
func_name = call_lowered_props.lowered_func->name_hint;

std::vector<GraphNodeRef> inputs;
// Visit all the arguments to call_lowered
for (Expr arg : call_args) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
for (const Expr& arg : call_lowered_props.arguments) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}
}

/// An adapted version of the storage optimization for the time being.
bool reshape_only = false;
ICHECK(call_node->attrs.defined()) << "Attrs should be defined!";
auto call_lowered_attrs = call_node->attrs.as<CallLoweredAttrs>();
ICHECK(call_lowered_attrs) << "Expected call_lowered to have CallLoweredAttrs";

// Need to check if this is an extern or not
Map<String, ObjectRef> metadata = call_lowered_attrs->metadata;
if (metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
reshape_only = true;
}

if (!call_lowered_attrs->metadata.count(
"extern_func")) { // Extern funcs won't have relay attrs
// In main, I don't understand why this was running properly unless this was not actually
// called on the function?? Looks like maybe something is messed up with how call nodes are
// getting passed around. IDK tho
ICHECK(call_lowered_attrs->metadata.count("relay_attrs"))
<< "Expected there to be relay attrs stored in the metadata. ";
auto relay_attrs = Downcast<DictAttrs>(call_lowered_attrs->metadata["relay_attrs"]);
for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
if (call_lowered_props.attrs.metadata.count("relay_attrs")) {
if (auto relay_attrs =
call_lowered_props.attrs.metadata["relay_attrs"].as<DictAttrsNode>()) {
for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
}
}
}
}
}

if (reshape_only && ShareSameStorage(GetRef<Expr>(call_node), func)) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, call);
bool reshape_only = false;
if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value ==
1) {
reshape_only = true;
}
if (reshape_only &&
ShareSameStorage(GetRef<Expr>(call_node), call_lowered_props.arguments[0])) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, call);
}
} else if (!call_node->attrs.defined()) { // Call is an extern function
std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
const auto* func = call_node->op.as<GlobalVarNode>();
ICHECK(func) << "Expected the operator to be a global var, but got "
<< call_node->op->GetTypeKey(); // getting a relay fn here, not sure why.
func_name = func->name_hint;

for (const Expr& arg : call_node->args) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}
} else {
LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to be call_lowered, "
<< "but found: " << std::endl
<< PrettyPrint(call);
}

// Compute the operator name, because we used the get unique name when generating the kernel.
Expand Down
16 changes: 11 additions & 5 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
*/
const std::vector<StorageToken*>& GetToken(const Expr& expr) {
this->VisitExpr(expr);
// Return empty if called on a Function
// Functions don't require data storage, represented by the empty token
if (expr->checked_type().as<FuncTypeNode>()) {
return no_tokens_;
}
Expand All @@ -173,6 +173,13 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
can_realloc);
}

/*!
* \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding
* the result of evaluating \p op on \p device_type.
*/
virtual void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type,
bool can_realloc) = 0;
};

/*! \brief Associate storage with every expression without any concern for sharing. */
class StorageAllocaInit : protected StorageAllocaBaseVisitor {
Expand Down Expand Up @@ -326,7 +333,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
std::vector<StorageToken*> args;
// for each input, visit argument token.

for (Expr arg : call_node->args) {
for (const Expr& arg : call_node->args) {
// Note: GetToken skips GlobalVars and handles tuples properly, so we don't need to treat
// call_lowered specially.
for (StorageToken* tok : GetToken(arg)) {
Expand Down Expand Up @@ -379,9 +386,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
}

if (call->op == CallLoweredOp()) {
auto call_lowered_attrs = call->attrs.as<CallLoweredAttrs>();
ICHECK(call_lowered_attrs) << "Expected call_lowered to have CallLoweredAttrs";
Map<String, ObjectRef> metadata = call_lowered_attrs->metadata;
CallLoweredProps call_lowered_props = GetCallLoweredProps(call);
Map<String, ObjectRef> metadata = call_lowered_props.attrs.metadata;
return metadata.count(attr::kReshapeOnly) &&
(Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1);
}
Expand Down
48 changes: 23 additions & 25 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,54 +680,52 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

ObjectRef VisitExpr_(const CallNode* call_node) final {
if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function.
const CallLoweredAttrs* attrs = call_node->attrs.as<CallLoweredAttrs>();
ICHECK(attrs) << "Expected call_lowered to have CallLoweredAttrs";

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

// Evaluate only function args
std::vector<ObjectRef> args;
for (auto arg : call_args) {
for (auto arg : call_lowered_props.arguments) {
args.push_back(Eval(arg));
}

// TODO(mbs): Make calling convention first-class in Relay.
Array<GlobalVar> all_prim_fn_vars;
if (attrs->metadata.count("all_prim_fn_vars")) {
all_prim_fn_vars = Downcast<Array<GlobalVar>>(attrs->metadata.at("all_prim_fn_vars"));
if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars")) {
all_prim_fn_vars =
Downcast<Array<GlobalVar>>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars"));
}
GlobalVar prim_shape_fn_var;
if (attrs->metadata.count("prim_shape_fn_var")) {
prim_shape_fn_var = Downcast<GlobalVar>(attrs->metadata.at("prim_shape_fn_var"));
if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var")) {
prim_shape_fn_var =
Downcast<GlobalVar>(call_lowered_props.attrs.metadata.at("prim_shape_fn_var"));
}
Array<GlobalVar> all_prim_shape_fn_vars;
if (attrs->metadata.count("all_prim_shape_fn_vars")) {
all_prim_shape_fn_vars =
Downcast<Array<GlobalVar>>(attrs->metadata.at("all_prim_shape_fn_vars"));
if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars")) {
all_prim_shape_fn_vars = Downcast<Array<GlobalVar>>(
call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars"));
}
Array<Integer> prim_shape_fn_states;
if (attrs->metadata.count("prim_shape_fn_states")) {
prim_shape_fn_states = Downcast<Array<Integer>>(attrs->metadata.at("prim_shape_fn_states"));
if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states")) {
prim_shape_fn_states =
Downcast<Array<Integer>>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states"));
}

size_t num_shape_inputs = 0;
if (attrs->metadata.count("prim_shape_fn_num_inputs")) {
if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs")) {
num_shape_inputs = static_cast<size_t>(
Downcast<Integer>(attrs->metadata.at("prim_shape_fn_num_inputs"))->value);
Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs"))
->value);
}
size_t num_shape_outputs = 0;
if (attrs->metadata.count("prim_shape_fn_num_outputs")) {
if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs")) {
num_shape_outputs = static_cast<size_t>(
Downcast<Integer>(attrs->metadata.at("prim_shape_fn_num_outputs"))->value);
Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs"))
->value);
}

return InvokePrimitiveOp(func, all_prim_fn_vars, target_, prim_shape_fn_var,
all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs,
num_shape_outputs, cpu_target_, args);

return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars, target_,
prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states,
num_shape_inputs, num_shape_outputs, cpu_target_, args);
} else { // All other calls
// Evaluate all arguments
std::vector<ObjectRef> args;
Expand Down

0 comments on commit 5861dc6

Please sign in to comment.