From 2d7e065ea0891f98c1b752292618a04fbb48c896 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Apr 2024 11:26:13 -0500 Subject: [PATCH] [Relax][Transform] Handle identical PrimFunc with distinct VDevice Prior to this commit, if an `IRModule` contained two expressions, where the types of the arguments differed only by the `VDevice`, these would be legalized to produce a single PrimFunc. This PrimFunc would have the a `tvm::attr::kTarget` annotation specific to one of those expressions, and would be incorrect for use in the other location. This commit updates the `LegalizeOps` transform to handle this case, producing multiple TIR PrimFuncs if required by the `VDevice` annotations. --- src/relax/transform/legalize_ops.cc | 95 +++++++++++++-- src/tir/transforms/ir_utils.cc | 36 ++++++ .../relax/test_transform_legalize_ops.py | 113 ++++++++++++++++++ 3 files changed, 236 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index e2e463ff2b2f8..d31284b0fe356 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -28,6 +28,7 @@ #include #include #include +#include namespace tvm { namespace relax { @@ -83,7 +84,12 @@ class LegalizeMutator : public ExprMutator { builder_->UpdateFunction(gv, f); } } - return builder_->GetContextIRModule(); + IRModule output = builder_->GetContextIRModule(); + if (requires_tir_convert_ssa_) { + output = tir::transform::ConvertSSA()(output); + } + + return output; } private: @@ -129,7 +135,7 @@ class LegalizeMutator : public ExprMutator { return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args); } - Target GetTarget(const Array& sinfos) { + Optional GetTarget(const Array& sinfos) { for (auto sinfo : sinfos) { if (const auto* tinfo = sinfo.as()) { if (tinfo->vdevice.defined()) { @@ -142,20 +148,90 @@ class LegalizeMutator : public ExprMutator { return GetTarget(tup_sinfo->fields); } } - return Target(); + return NullOpt; } void SaveTarget(const Expr& expr) { if (expr->IsInstance()) { auto call = Downcast(expr); - auto target = GetTarget(call->sinfo_args); - const GlobalVarNode* gvar_node; - if (target.defined() && (gvar_node = call->args[0].as())) { - this->tmap_.Set(GetRef(gvar_node), target); + + if (auto target = GetTarget(call->sinfo_args)) { + if (auto gvar = call->args[0].as()) { + this->tmap_.Set(gvar.value(), target.value()); + } } } } + Expr BindTarget(Expr expr) { + if (!expr->IsInstance()) { + // FLegalize returned something other than a relax::Call. This + // post-processing only handles cases where legalization + // produces a lowered call node. In principle, this + // post-processing isn't necessary, and FLegalize should already + // have generated vdevice-aware kernels, so hopefully the + // FLegalize implementation did so. + return expr; + } + + auto call = Downcast(expr); + + auto vdevice_target = GetTarget(call->sinfo_args); + if (!vdevice_target.defined()) { + // No vdevice annotation is present, so we don't need to apply + // any updates. + return expr; + } + + if (call->args.empty()) { + return expr; + } + + auto gvar = call->args[0].as(); + if (!gvar.defined()) { + // This is not a call into a legalized function within the + // current IRModule, so no post-processing is required. + return expr; + } + + auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value()); + auto opt_prim_func = base_func.as(); + if (!opt_prim_func) { + // The call is to something other than a PrimFunc. It may be + // another Relax function, in which case the legalization of its + // body will handle any additional target annotations. + return expr; + } + auto prim_func = opt_prim_func.value(); + + auto func_target = prim_func->GetAttr(tvm::attr::kTarget); + if (func_target && func_target.value()->kind == vdevice_target.value()->kind) { + // The function already has compatible annotations for the + // target, so no modifications are required. + return expr; + } + + // The FLegalize function generated a PrimFunc, but that PrimFunc + // doesn't have annotations compatible with the vdevice required + // by the Relax StructInfo. Update the call to instead call a + // `PrimFunc` with the appropriate target annotation. In the + // future, this may be treated as a bug in the FLegalize + // implementation, rather than expected output from it. + auto new_prim_func = WithAttr(prim_func, tvm::attr::kTarget, vdevice_target.value()); + auto new_gvar_name = [&]() -> std::string { + std::stringstream ss; + ss << gvar.value()->name_hint; + ss << "_"; + ss << vdevice_target.value()->kind->name; + return ss.str(); + }(); + auto new_gvar = builder_->AddFunction(new_prim_func, new_gvar_name); + requires_tir_convert_ssa_ = true; + + call.CopyOnWrite()->args.Set(0, new_gvar); + return call; + } + Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); @@ -268,8 +344,10 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); + legalized = BindTarget(legalized); + // Save the expected target info. into tmap_ - SaveTarget(legalized); + // SaveTarget(legalized); legalized = builder_->Normalize(legalized); @@ -305,6 +383,7 @@ class LegalizeMutator : public ExprMutator { Map cmap_; /*! \brief The map from GlobalVar of PrimFunc to compilation Target. */ Map tmap_; + bool requires_tir_convert_ssa_{false}; /*! * \brief A boolean value indicating if to print warnings for CallNode whose op's * legalization function is not registered. diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index c52027acba134..50edc676bd70d 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -246,6 +246,42 @@ class IRConvertSSA final : public StmtExprMutator { return std::move(decl); } + Stmt VisitStmt_(const BlockNode* op) final { + Block block = GetRef(op); + + // The BlockNode is the point of definition for the IterVar + // instances. These re-defines must be present before visiting + // the body of the BlockNode. + std::vector redefines; + Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { + if (defined_.count(iter_var->var.get())) { + redefines.emplace_back(this, iter_var->var); + iter_var.CopyOnWrite()->var = redefines.back().new_var; + } else { + defined_.insert(iter_var->var.get()); + } + return iter_var; + }); + Array reads = + block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); + Array writes = + block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); + + if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || + !iter_vars.same_as(op->iter_vars)) { + auto write_ptr = block.CopyOnWrite(); + write_ptr->reads = reads; + write_ptr->writes = writes; + write_ptr->iter_vars = iter_vars; + } + + Stmt output = Downcast(StmtExprMutator::VisitStmt_(block.get())); + + while (redefines.size()) redefines.pop_back(); + + return output; + } + template Node VisitBufferAccess(Node node) { Buffer new_buf = GetRemappedBuffer(node->buffer); diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py index 47eeb68341b35..3846d40ea49cc 100644 --- a/tests/python/relax/test_transform_legalize_ops.py +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -356,5 +356,118 @@ def main( tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter) +def test_legalize_with_vdevice(): + """Legalization may generate kernels for multiple targets + + This is a regression test. In previous implementations, Relax + expressions whose argument types differed only by their `vdevice` + would be legalized to use the same `PrimFunc`. + + """ + + @I.ir_module + class Before: + I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) + + @R.function + def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")): + C = R.add(A, B) + return C + + @R.function + def func_llvm( + A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm") + ): + C = R.add(A, B) + return C + + @I.ir_module + class Expected: + I.module_global_infos( + { + "vdevice": [ + I.vdevice( + { + "keys": ["cpu"], + "kind": "llvm", + "mtriple": "x86_64-pc-linux-gnu", + "tag": "", + }, + 0, + "global", + ) + ] + } + ) + + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + T_add: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(32), T.int64(32)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @T.prim_func(private=True) + def add_llvm( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + T_add: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr( + { + "target": T.target( + { + "keys": ["cpu"], + "kind": "llvm", + "mtriple": "x86_64-pc-linux-gnu", + "tag": "", + } + ), + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(32), T.int64(32)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @R.function + def func_cuda( + A: R.Tensor((32, 32), dtype="float32"), B: R.Tensor((32, 32), dtype="float32") + ) -> R.Tensor((32, 32), dtype="float32"): + cls = Expected + C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32")) + return C + + @R.function + def func_llvm( + A: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"), + B: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"), + ) -> R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"): + cls = Expected + C = R.call_tir( + cls.add_llvm, + (A, B), + out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"), + ) + return C + + with tvm.target.Target("cuda"): + After = tvm.relax.transform.LegalizeOps()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()