Skip to content


[Relax][Transform] Handle identical PrimFunc with distinct VDevice
Browse files Browse the repository at this point in the history
Prior to this commit, if an `IRModule` contained two expressions,
where the types of the arguments differed only by the `VDevice`, these
would be legalized to produce a single PrimFunc.  This PrimFunc would
have the a `tvm::attr::kTarget` annotation specific to one of those
expressions, and would be incorrect for use in the other location.

This commit updates the `LegalizeOps` transform to handle this case,
producing multiple TIR PrimFuncs if required by the `VDevice`
  • Loading branch information
Lunderberg committed Apr 30, 2024
1 parent 6252fa5 commit 2d7e065
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 8 deletions.
95 changes: 87 additions & 8 deletions src/relax/transform/
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -83,7 +84,12 @@ class LegalizeMutator : public ExprMutator {
builder_->UpdateFunction(gv, f);
return builder_->GetContextIRModule();
IRModule output = builder_->GetContextIRModule();
if (requires_tir_convert_ssa_) {
output = tir::transform::ConvertSSA()(output);

return output;

Expand Down Expand Up @@ -129,7 +135,7 @@ class LegalizeMutator : public ExprMutator {
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);

Target GetTarget(const Array<StructInfo>& sinfos) {
Optional<Target> GetTarget(const Array<StructInfo>& sinfos) {
for (auto sinfo : sinfos) {
if (const auto* tinfo =<TensorStructInfoNode>()) {
if (tinfo->vdevice.defined()) {
Expand All @@ -142,20 +148,90 @@ class LegalizeMutator : public ExprMutator {
return GetTarget(tup_sinfo->fields);
return Target();
return NullOpt;

void SaveTarget(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
auto call = Downcast<Call>(expr);
auto target = GetTarget(call->sinfo_args);
const GlobalVarNode* gvar_node;
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);

if (auto target = GetTarget(call->sinfo_args)) {
if (auto gvar = call->args[0].as<GlobalVar>()) {
this->tmap_.Set(gvar.value(), target.value());

Expr BindTarget(Expr expr) {
if (!expr->IsInstance<CallNode>()) {
// FLegalize returned something other than a relax::Call. This
// post-processing only handles cases where legalization
// produces a lowered call node. In principle, this
// post-processing isn't necessary, and FLegalize should already
// have generated vdevice-aware kernels, so hopefully the
// FLegalize implementation did so.
return expr;

auto call = Downcast<Call>(expr);

auto vdevice_target = GetTarget(call->sinfo_args);
if (!vdevice_target.defined()) {
// No vdevice annotation is present, so we don't need to apply
// any updates.
return expr;

if (call->args.empty()) {
return expr;

auto gvar = call->args[0].as<GlobalVar>();
if (!gvar.defined()) {
// This is not a call into a legalized function within the
// current IRModule, so no post-processing is required.
return expr;

auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value());
auto opt_prim_func =<tir::PrimFunc>();
if (!opt_prim_func) {
// The call is to something other than a PrimFunc. It may be
// another Relax function, in which case the legalization of its
// body will handle any additional target annotations.
return expr;
auto prim_func = opt_prim_func.value();

auto func_target = prim_func->GetAttr<Target>(tvm::attr::kTarget);
if (func_target && func_target.value()->kind == vdevice_target.value()->kind) {
// The function already has compatible annotations for the
// target, so no modifications are required.
return expr;

// The FLegalize function generated a PrimFunc, but that PrimFunc
// doesn't have annotations compatible with the vdevice required
// by the Relax StructInfo. Update the call to instead call a
// `PrimFunc` with the appropriate target annotation. In the
// future, this may be treated as a bug in the FLegalize
// implementation, rather than expected output from it.
auto new_prim_func = WithAttr(prim_func, tvm::attr::kTarget, vdevice_target.value());
auto new_gvar_name = [&]() -> std::string {
std::stringstream ss;
ss << gvar.value()->name_hint;
ss << "_";
ss << vdevice_target.value()->kind->name;
return ss.str();
auto new_gvar = builder_->AddFunction(new_prim_func, new_gvar_name);
requires_tir_convert_ssa_ = true;

call.CopyOnWrite()->args.Set(0, new_gvar);
return call;

Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
Expand Down Expand Up @@ -268,8 +344,10 @@ class LegalizeMutator : public ExprMutator {
Expr legalized = legalization_func(builder_, visited_call);

legalized = BindTarget(legalized);

// Save the expected target info. into tmap_
// SaveTarget(legalized);

legalized = builder_->Normalize(legalized);

Expand Down Expand Up @@ -305,6 +383,7 @@ class LegalizeMutator : public ExprMutator {
Map<String, PackedFunc> cmap_;
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
Map<GlobalVar, Target> tmap_;
bool requires_tir_convert_ssa_{false};
* \brief A boolean value indicating if to print warnings for CallNode whose op's
* legalization function is not registered.
Expand Down
36 changes: 36 additions & 0 deletions src/tir/transforms/
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,42 @@ class IRConvertSSA final : public StmtExprMutator {
return std::move(decl);

Stmt VisitStmt_(const BlockNode* op) final {
Block block = GetRef<Block>(op);

// The BlockNode is the point of definition for the IterVar
// instances. These re-defines must be present before visiting
// the body of the BlockNode.
std::vector<ScopedRedefine> redefines;
Array<IterVar> iter_vars = op->iter_vars.Map([&](IterVar iter_var) {
if (defined_.count(iter_var->var.get())) {
redefines.emplace_back(this, iter_var->var);
iter_var.CopyOnWrite()->var = redefines.back().new_var;
} else {
return iter_var;
Array<BufferRegion> reads =
block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); });
Array<BufferRegion> writes =
block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); });

if (!reads.same_as(block->reads) || !writes.same_as(block->writes) ||
!iter_vars.same_as(op->iter_vars)) {
auto write_ptr = block.CopyOnWrite();
write_ptr->reads = reads;
write_ptr->writes = writes;
write_ptr->iter_vars = iter_vars;

Stmt output = Downcast<Block>(StmtExprMutator::VisitStmt_(block.get()));

while (redefines.size()) redefines.pop_back();

return output;

template <typename Node>
Node VisitBufferAccess(Node node) {
Buffer new_buf = GetRemappedBuffer(node->buffer);
Expand Down
113 changes: 113 additions & 0 deletions tests/python/relax/
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,118 @@ def main(, AfterSecondIter)

def test_legalize_with_vdevice():
"""Legalization may generate kernels for multiple targets
This is a regression test. In previous implementations, Relax
expressions whose argument types differed only by their `vdevice`
would be legalized to use the same `PrimFunc`.

class Before:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
C = R.add(A, B)
return C

def func_llvm(
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
C = R.add(A, B)
return C

class Expected:
"vdevice": [
"keys": ["cpu"],
"kind": "llvm",
"mtriple": "x86_64-pc-linux-gnu",
"tag": "",

def add(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
T_add: T.Buffer((T.int64(32), T.int64(32)), "float32"),
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

def add_llvm(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
T_add: T.Buffer((T.int64(32), T.int64(32)), "float32"),
"keys": ["cpu"],
"kind": "llvm",
"mtriple": "x86_64-pc-linux-gnu",
"tag": "",
"tir.noalias": T.bool(True),
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

def func_cuda(
A: R.Tensor((32, 32), dtype="float32"), B: R.Tensor((32, 32), dtype="float32")
) -> R.Tensor((32, 32), dtype="float32"):
cls = Expected
C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
return C

def func_llvm(
A: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"),
B: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"),
) -> R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"):
cls = Expected
C = R.call_tir(
(A, B),
out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"),
return C

After = tvm.relax.transform.LegalizeOps()(Before), After)

if __name__ == "__main__":

0 comments on commit 2d7e065

Please sign in to comment.