Skip to content

Commit

Permalink
Integrate tir constant nodes in compilation pipeline
Browse files Browse the repository at this point in the history
This PR integrates tir.allocate_const to the compilation pipeline to support --link-params.

Change-Id: Ic8d0cb75d596299fcae7078b304598afbf0c5494

Co-authored-by: Giuseppe Rossini <giuseros85@gmail.com>
Change-Id: Id98cc682bbfacfe75c4d8b260fd41658f1f196b2
  • Loading branch information
2 people authored and d-smirnov committed Sep 7, 2021
1 parent 2a3def8 commit a4d809e
Show file tree
Hide file tree
Showing 26 changed files with 452 additions and 210 deletions.
36 changes: 0 additions & 36 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,42 +151,6 @@ class PrimFunc : public BaseFunc {
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
};

/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
Expand Down
36 changes: 36 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,42 @@ class Allocate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
#define TVM_TIR_TRANSFORM_H_

#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>

#include <string>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -463,6 +465,8 @@ TVM_DLL Pass UnifyThreadBinding();
*/
TVM_DLL Pass MergeDynamicSharedMemoryAllocations();

TVM_DLL Pass BindParams(const std::vector<const relay::ConstantNode*>& constants);

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .tensor import TensorSlice, Tensor
from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import placeholder, compute, scan, extern, var, size_var, const
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func

Expand Down
22 changes: 22 additions & 0 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,28 @@ def var(name="tindex", dtype="int32", span=None):
return tvm.tir.Var(name, dtype, span)


def const(name="tindex", dtype="int32", span=None):
"""Create a new constant with specified name and dtype
Parameters
----------
name : str
The name
dtype : str
The data type
span : Optional[Span]
The location of this variable in the source.
Returns
-------
var : Var
The result symbolic variable.
"""
return tvm.tir.Const(name, dtype, span)


def size_var(name="size", dtype="int32", span=None):
"""Create a new variable represents a tensor shape size, which is non-negative.
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
auto ndarray_str = ss.str();

Doc doc;
var_not_in_headers.insert(alloc->buffer_var.get());
var_not_in_headers_.insert(alloc->buffer_var.get());
if (current_num_ != num_child_ - 1) {
doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", "
<< Print(alloc->extents) << ")";
Expand Down
31 changes: 22 additions & 9 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,22 +401,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {

void VisitExpr_(const ConstantNode* op) override {
Expr expr = GetRef<Expr>(op);
size_t index = params_.size();
std::string name = "p" + std::to_string(index);
StorageInfo& sinfo = storage_device_map_[expr];
param_storage_ids_[name] = sinfo->storage_ids[0];
params_[name] = op->data;
params_by_expr_.Set(expr, name);
std::stringstream ss;
ss << "constant_" << constant_map_.size();

tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype))));
constant_map_[constant.operator->()] = op;

// If the Constant node is an output node we need to copy the content of the parameter to the
// output A Var node can only produce a single output
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});
CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false,
sinfo->storage_sizes_in_bytes[0]);
CopyToOutput(main_signature_[input_vars_.size() + output_index], constant,
/* pack_input */ false, sinfo->storage_sizes_in_bytes[0]);
}
}

Expand Down Expand Up @@ -490,6 +488,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

for (auto kv : constant_map_) {
auto buffer_var = GetRef<tir::Var>(kv.first);
auto dtype = DataType(kv.second->data->dtype);

int ndim = kv.second->data->ndim;
Array<PrimExpr> extents;

for (int i = 0; i < ndim; i++) {
int shape = kv.second->data->shape[i];
extents.push_back(tir::make_const(DataType::Int(32), shape));
}
body = tir::AllocateConst(buffer_var, kv.second->data, dtype, extents, body);
}

// Define the attributes
body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body);
body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
Expand Down Expand Up @@ -538,6 +550,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Map<Expr, String> params_by_expr_;
/*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/
std::unordered_map<std::string, int64_t> param_storage_ids_;
std::unordered_map<const tir::VarNode*, const ConstantNode*> constant_map_;

/*! \brief plan memory of device result */
StorageMap storage_device_map_;
Expand Down
31 changes: 7 additions & 24 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,13 @@ class RelayBuildModule : public runtime::ModuleNode {
}

// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
if (targets.size() == 1) {
const auto& it = targets.begin();
With<Target> tctx((*it).second);
relay_module = transform::FuseOps()(relay_module);
} else {
relay_module = transform::FuseOps()(relay_module);
}

// Do layout rewrite for auto-scheduler.
if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
Expand Down Expand Up @@ -495,29 +501,6 @@ class RelayBuildModule : public runtime::ModuleNode {
if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
lowered_funcs.Set(ext_dev, IRModule());
}

// Generate a placeholder function that attaches linked params as its arguments.
if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {
CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen.";
auto param_ids = executor_codegen_->GetParamIds();
auto link_params = Map<String, tir::LinkedParam>();
for (auto param : ret_.params) {
link_params.Set(param.first, tir::LinkedParam(param_ids[param.first], param.second));
}

Map<String, ObjectRef> dict;
dict.Set(tvm::tir::attr::kLinkedParams, link_params);
dict.Set(tvm::attr::kGlobalSymbol, String(::tvm::runtime::symbol::tvm_lookup_linked_param));
DictAttrs attrs{dict};
auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
Map<tir::Var, tir::Buffer>(), attrs);
if (lowered_funcs.find(target_host) == lowered_funcs.end()) {
lowered_funcs.Set(target_host, IRModule(Map<GlobalVar, BaseFunc>({})));
}
lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param),
prim);
}

// When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) {
if (target_host.defined() && target_host->kind->name == "llvm") {
Expand Down
8 changes: 8 additions & 0 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/topi/tags.h>

#include <functional>
Expand Down Expand Up @@ -221,10 +222,17 @@ class CompileEngineImpl : public CompileEngineNode {
for (te::Tensor arg : cfunc->outputs) {
all_args.push_back(arg);
}
std::vector<const ConstantNode*> all_consts;
for (auto kv : cfunc->constant_tensors) {
all_args.push_back(kv.second);
all_consts.push_back(kv.first);
}

// lower the function
std::unordered_map<te::Tensor, tir::Buffer> binds;
auto func_name = cfunc->prim_fn_var->name_hint;
cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
cfunc->funcs->Update(tir::transform::BindParams(all_consts)(cfunc->funcs));
value->cached_func = cfunc;

return value;
Expand Down
7 changes: 7 additions & 0 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/topi/tags.h>

#include <functional>
Expand Down Expand Up @@ -246,10 +247,16 @@ class TECompilerImpl : public TECompilerNode {
for (te::Tensor arg : cfunc->outputs) {
all_args.push_back(arg);
}
std::vector<const ConstantNode*> all_consts;
for (auto kv : cfunc->constant_tensors) {
all_args.push_back(kv.second);
all_consts.push_back(kv.first);
}

std::unordered_map<te::Tensor, tir::Buffer> binds;
auto func_name = cfunc->prim_fn_var->name_hint;
cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
cfunc->funcs->Update(tir::transform::BindParams(all_consts)(cfunc->funcs));
value->cached_func = cfunc;
return value;
}
Expand Down
Loading

0 comments on commit a4d809e

Please sign in to comment.