From 76da06da4d853f6dfc66c7bd9713fb9873dea739 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 14:37:34 +0000 Subject: [PATCH 1/2] [REFACTOR][IR] Inline ReplaceGlobalVars into AttachGlobalSymbol `ReplaceGlobalVars` was declared in `include/tvm/ir/replace_global_vars.h` and dispatched per-function-type via a NodeFunctor vtable populated by relax and tirx static-init blocks. The only in-tree C++ caller was `AttachGlobalSymbol` in `src/relax/transform/attach_global_symbol.cc`, and the only python user was an `IRModule.replace_global_vars` method exercised only by its own tests. Move the dispatch into `attach_global_symbol.cc` as a file-local helper. The helper branches directly on `tirx::PrimFuncNode` / `relax::FunctionNode` / `relax::ExternFuncNode` instead of going through a NodeFunctor vtable, since `attach_global_symbol.cc` already includes the relax + tirx headers. Delete the IR-layer driver, the per-dialect dispatch registration files, the public header, the python wrapper, and its tests. The remaining behavior is still covered by `tests/python/relax/test_transform_attach_global_symbol.py` and by every relax pipeline that runs `AttachGlobalSymbol`. Removes the runtime registration-based coupling from the IR layer to relax/tirx dialects for this specific transform. --- include/tvm/ir/replace_global_vars.h | 57 ---- python/tvm/ir/module.py | 27 -- src/ir/replace_global_vars.cc | 110 ------- src/relax/transform/attach_global_symbol.cc | 118 ++++++- src/relax/transform/replace_global_vars.cc | 83 ----- src/tirx/transform/replace_global_vars.cc | 84 ----- .../ir/test_transform_replace_global_var.py | 308 ------------------ 7 files changed, 116 insertions(+), 671 deletions(-) delete mode 100644 include/tvm/ir/replace_global_vars.h delete mode 100644 src/ir/replace_global_vars.cc delete mode 100644 src/relax/transform/replace_global_vars.cc delete mode 100644 src/tirx/transform/replace_global_vars.cc delete mode 100644 tests/python/ir/test_transform_replace_global_var.py diff --git a/include/tvm/ir/replace_global_vars.h b/include/tvm/ir/replace_global_vars.h deleted file mode 100644 index 0a9b38529637..000000000000 --- a/include/tvm/ir/replace_global_vars.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir/replace_global_vars.h - * - * \brief A utility to replace GlobalVar instances across all TVM IR - * types in an IRMdoule. - */ -#ifndef TVM_IR_REPLACE_GLOBAL_VARS_H_ -#define TVM_IR_REPLACE_GLOBAL_VARS_H_ - -#include - -namespace tvm { -namespace transform { - -/*! - * \brief Replace GlobalVar instances across any IR type. - * - * \param mod The module to update - * - * \param replacements The map, where each entry maps from an old - * `GlobalVar` to the new `GlobalVar` that should replace it. - * - * \return The updated IRModule - */ -TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, ffi::Map replacements); - -struct GlobalVarReplacer { - using FType = NodeFunctor)>; - TVM_DLL static FType& vtable() { - static FType inst; - return inst; - } -}; - -} // namespace transform -} // namespace tvm - -#endif // TVM_IR_REPLACE_GLOBAL_VARS_H_ diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index a9f43e09bd57..95b9d940ecb4 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -195,33 +195,6 @@ def get_global_vars(self): """ return _ffi_api.Module_GetGlobalVars(self) - def replace_global_vars( - self, - replacements: dict[str | _expr.GlobalVar, str | _expr.GlobalVar], - ) -> "IRModule": - """Replace GlobalVar instances within the module - - Replace GlobalVars within the IRModule. Since the IRModule - may contain internal references to a GlobalVar, either in TIR - or in Relax, this method should be used whenever replacing or - renaming a GlobalVar. - - Parameters - ---------- - replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]] - - A dictionary where each key is a GlobalVar to be replaced, - and the corresponding value is the GlobalVar with which to - replace it. - - Returns - ------- - IRModule - The updated module - - """ - return _ffi_api.Module_ReplaceGlobalVars(self, replacements) - @staticmethod def from_expr(expr, functions=None): """Construct a module from a standalone expression. diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc deleted file mode 100644 index 2a3517b4d815..000000000000 --- a/src/ir/replace_global_vars.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/ir/replace_global_vars.cc - * \brief IRModule transform to replace GlobalVar instances across any IR type. - */ - -#include -#include -#include - -#include - -namespace tvm { -namespace transform { - -IRModule ReplaceGlobalVars(IRModule mod, ffi::Map replacements) { - if (replacements.empty()) { - return mod; - } - - std::vector to_remove; - IRModule updates; - - const auto& vtable = GlobalVarReplacer::vtable(); - - for (const auto& [old_gvar, old_func] : mod->functions) { - auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar); - auto new_func = vtable(old_func, replacements); - - if (!new_gvar.same_as(old_gvar)) { - to_remove.push_back(old_gvar); - } - if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) { - updates->Add(new_gvar, new_func); - } - } - - if (to_remove.size() || updates->functions.size()) { - auto write_ptr = mod.CopyOnWrite(); - for (const auto& old_gvar : to_remove) { - write_ptr->Remove(old_gvar); - } - write_ptr->Update(updates); - } - return mod; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("transform.ReplaceGlobalVars", ReplaceGlobalVars); -} - -IRModule ModuleReplaceGlobalVars( - IRModule mod, - ffi::Map, ffi::Variant> - replacements) { - ffi::Map gvar_replacements; - for (const auto& [before, after] : replacements) { - GlobalVar gvar_before; - if (auto gvar = before.as()) { - gvar_before = gvar.value(); - } else if (auto str = before.as()) { - gvar_before = mod->GetGlobalVar(str.value()); - } else { - TVM_FFI_THROW(InternalError) - << "ffi::Variant must contain either ffi::String or GlobalVar"; - } - - GlobalVar gvar_after; - if (auto gvar = after.as()) { - gvar_after = gvar.value(); - } else if (auto str = after.as()) { - gvar_after = gvar_before; - gvar_after.CopyOnWrite()->name_hint = str.value(); - } else { - TVM_FFI_THROW(InternalError) - << "ffi::Variant must contain either ffi::String or GlobalVar"; - } - - gvar_replacements.Set(gvar_before, gvar_after); - } - - return ReplaceGlobalVars(mod, gvar_replacements); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.Module_ReplaceGlobalVars", ModuleReplaceGlobalVars); -} - -} // namespace transform -} // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index d22b6eb40a52..7b3cd85aec0e 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -24,15 +24,129 @@ #include #include #include -#include +#include #include #include #include +#include + +#include namespace tvm { namespace relax { namespace transform { +namespace { + +// File-local mutator: replace GlobalVar references inside a relax::Function. +struct RelaxGvarMutator : ExprMutator { + ffi::Map replacements; + explicit RelaxGvarMutator(ffi::Map replacements) + : replacements(replacements) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* node) override { + auto gvar = ffi::GetRef(node); + return replacements.Get(gvar).value_or(gvar); + } +}; + +// File-local mutator: replace GlobalVar references inside a tirx::PrimFunc. +struct TirxGvarMutator : tirx::StmtExprMutator { + ffi::Map replacements; + explicit TirxGvarMutator(ffi::Map replacements) + : replacements(replacements) {} + + PrimExpr VisitExpr_(const tirx::CallNode* node) override { + auto call = Downcast(tirx::StmtExprMutator::VisitExpr_(node)); + if (auto old_gvar = call->op.as()) { + if (auto new_gvar = replacements.Get(old_gvar.value())) { + call.CopyOnWrite()->op = new_gvar.value(); + } + } + return call; + } +}; + +// Replace GlobalVar references across all functions in the module. +// Direct dispatch on function type — no NodeFunctor indirection needed +// since this file already includes the relax + tirx headers. +IRModule ReplaceGlobalVarsInModule(IRModule mod, ffi::Map replacements) { + if (replacements.empty()) { + return mod; + } + + std::vector to_remove; + IRModule updates; + + for (const auto& [old_gvar, old_func] : mod->functions) { + auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar); + BaseFunc new_func; + + if (auto* prim_func_node = old_func.as()) { + auto func = ffi::GetRef(prim_func_node); + TirxGvarMutator mutator(replacements); + auto new_body = mutator(func->body); + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } + // Update kGlobalSymbol if the function is externally exposed and being renamed. + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = opt.value(); + for (const auto& [before, after] : replacements) { + if (before->name_hint == name) { + if (after->name_hint != name) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint); + } + break; + } + } + } + new_func = func; + } else if (auto* relax_func_node = old_func.as()) { + RelaxGvarMutator mutator(replacements); + auto new_relax_func = + Downcast(mutator(Downcast(ffi::GetRef(relax_func_node)))); + // Update kGlobalSymbol if the function is externally exposed and being renamed. + if (auto opt = new_relax_func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = opt.value(); + for (const auto& [before, after] : replacements) { + if (before->name_hint == name) { + if (after->name_hint != name) { + new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, after->name_hint); + } + break; + } + } + } + new_func = new_relax_func; + } else if (old_func.as()) { + // ExternFunc: no internal GlobalVar references to update. + new_func = old_func; + } else { + new_func = old_func; + } + + if (!new_gvar.same_as(old_gvar)) { + to_remove.push_back(old_gvar); + } + if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) { + updates->Add(new_gvar, new_func); + } + } + + if (to_remove.size() || updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + for (const auto& old_gvar : to_remove) { + write_ptr->Remove(old_gvar); + } + write_ptr->Update(updates); + } + return mod; +} + +} // namespace + Pass AttachGlobalSymbol() { auto pass_func = [=](IRModule mod, PassContext pc) { ffi::String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); @@ -74,7 +188,7 @@ Pass AttachGlobalSymbol() { mod.CopyOnWrite()->Update(updates); if (gvar_updates.size()) { - mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates); + mod = ReplaceGlobalVarsInModule(mod, gvar_updates); } } return mod; diff --git a/src/relax/transform/replace_global_vars.cc b/src/relax/transform/replace_global_vars.cc deleted file mode 100644 index f895cd50eb54..000000000000 --- a/src/relax/transform/replace_global_vars.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * - * \file src/relax/transform/replace_global_vars.cc - * - * \brief GlobalVar replacement across IR types - */ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace relax { - -namespace { -using tvm::transform::GlobalVarReplacer; - -struct Mutator : ExprMutator { - ffi::Map replacements; - explicit Mutator(ffi::Map replacements) : replacements(replacements) {} - - using ExprMutator::VisitExpr_; - Expr VisitExpr_(const GlobalVarNode* node) override { - auto gvar = ffi::GetRef(node); - return replacements.Get(gvar).value_or(gvar); - } -}; - -} // namespace - -TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) - .set_dispatch([](const ffi::ObjectRef& func, - ffi::Map replacements) -> BaseFunc { - Mutator mutator(replacements); - auto new_func = Downcast(mutator(Downcast(func))); - - // If the function is externally exposed, and is being replaced - // by a GlobalVar with a new name, then the function's - // kGlobalSymbol must be updated to match. - if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { - auto name = opt.value(); - for (const auto& [before, after] : replacements) { - if (before->name_hint == name) { - if (after->name_hint != name) { - new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, after->name_hint); - } - break; - } - } - } - - return new_func; - }); - -TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) - .set_dispatch([](const ffi::ObjectRef& func, - ffi::Map) -> BaseFunc { - return Downcast(func); - }); - -} // namespace relax -} // namespace tvm diff --git a/src/tirx/transform/replace_global_vars.cc b/src/tirx/transform/replace_global_vars.cc deleted file mode 100644 index 289d219b6b1a..000000000000 --- a/src/tirx/transform/replace_global_vars.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * - * \file src/tirx/transform/replace_global_vars.cc - * - * \brief GlobalVar replacement across IR types - */ - -#include -#include -#include - -namespace tvm { -namespace tirx { - -namespace { -using tvm::transform::GlobalVarReplacer; - -struct Mutator : StmtExprMutator { - ffi::Map replacements; - explicit Mutator(ffi::Map replacements) : replacements(replacements) {} - - PrimExpr VisitExpr_(const CallNode* node) override { - auto call = Downcast(StmtExprMutator::VisitExpr_(node)); - if (auto old_gvar = call->op.as()) { - if (auto new_gvar = replacements.Get(old_gvar.value())) { - call.CopyOnWrite()->op = new_gvar.value(); - } - } - return call; - } -}; - -} // namespace - -TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) - .set_dispatch([](const ffi::ObjectRef& obj, - ffi::Map replacements) -> BaseFunc { - Mutator mutator(replacements); - auto func = Downcast(obj); - auto new_body = mutator(func->body); - - if (!new_body.same_as(func->body)) { - func.CopyOnWrite()->body = new_body; - } - - // If the function is externally exposed, and is being replaced - // by a GlobalVar with a new name, then the function's - // kGlobalSymbol must be updated to match. - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { - auto name = opt.value(); - for (const auto& [before, after] : replacements) { - if (before->name_hint == name) { - if (after->name_hint != name) { - func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint); - } - break; - } - } - } - - return func; - }); - -} // namespace tirx -} // namespace tvm diff --git a/tests/python/ir/test_transform_replace_global_var.py b/tests/python/ir/test_transform_replace_global_var.py deleted file mode 100644 index 70a693c06e3e..000000000000 --- a/tests/python/ir/test_transform_replace_global_var.py +++ /dev/null @@ -1,308 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm.testing -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tirx as T - - -def _get_before_module(): - @I.ir_module - class Module: - @R.function - def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - R.func_attr({"relax.force_pure": True}) - - B = Module.relax_subroutine(A) - C = R.call_tir(Module.tir_main, B, out_sinfo=R.Tensor([16], "float32")) - - D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) - Module.tir_main(C, D) - - return D - - @R.function(private=True) - def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - B = R.add(A, R.prim_value(T.float32(1.0))) - return B - - @T.prim_func(s_tir=True) - def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - Module.tir_subroutine(A.data, B.data) - - @T.prim_func(private=True, s_tir=True) - def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): - A = T.decl_buffer(16, "float32", data=A_data) - B = T.decl_buffer(16, "float32", data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - return Module - - -def test_no_op_if_no_replacements(): - """If no replacements are performed, the IRModule is unmodified""" - - before = _get_before_module() - expected = before - - after = before.replace_global_vars({}) - - tvm.ir.assert_structural_equal(expected, after) - assert before.same_as(after) - - -def test_replace_relax_main(): - """An externally-exposed Relax function may be replaced - - In this example, the "relax_main" function is renamed. This - requires changing both the GlobalVar used to refer to the - function, and the "global_symbol" attribute of the - externally-exposed function. - - """ - - before = _get_before_module() - after = before.replace_global_vars({"relax_main": "relax_main_with_new_name"}) - - @I.ir_module - class Expected: - @R.function - def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - R.func_attr({"relax.force_pure": True}) - - B = Expected.relax_subroutine(A) - C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) - - D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) - Expected.tir_main(C, D) - - return D - - @R.function(private=True) - def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - B = R.add(A, R.prim_value(T.float32(1.0))) - return B - - @T.prim_func(s_tir=True) - def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - Expected.tir_subroutine(A.data, B.data) - - @T.prim_func(private=True, s_tir=True) - def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): - A = T.decl_buffer(16, "float32", data=A_data) - B = T.decl_buffer(16, "float32", data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - tvm.ir.assert_structural_equal(Expected, after) - - -def test_replace_relax_subroutine(): - """An internal Relax function may be replaced - - In this example, the "relax_subroutine" function is renamed. This - requires changing both the GlobalVar used to refer to the - function, and the GlobalVar used to call the subroutine within - "relax_main". The "global_symbol" attribute does not need to be - updated, because internal functions do not have this attribute. - - """ - - before = _get_before_module() - after = before.replace_global_vars({"relax_subroutine": "relax_subroutine_with_new_name"}) - - @I.ir_module - class Expected: - @R.function - def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - R.func_attr({"relax.force_pure": True}) - - B = Expected.relax_subroutine_with_new_name(A) - C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) - - D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) - Expected.tir_main(C, D) - - return D - - @R.function(private=True) - def relax_subroutine_with_new_name( - A: R.Tensor([16], "float32"), - ) -> R.Tensor([16], "float32"): - B = R.add(A, R.prim_value(T.float32(1.0))) - return B - - @T.prim_func(s_tir=True) - def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - Expected.tir_subroutine(A.data, B.data) - - @T.prim_func(private=True, s_tir=True) - def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): - A = T.decl_buffer(16, "float32", data=A_data) - B = T.decl_buffer(16, "float32", data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - tvm.ir.assert_structural_equal(Expected, after) - - -def test_replace_tir_main(): - """An externally-exposed TIR function may be replaced - - In this example, the "tir_main" function is renamed. This - requires changing both the GlobalVar used to refer to the - function, the "global_symbol" attribute of the externally-exposed - function. In addition, calls to the TIR function should be - updated to use the new GlobalVar. - - """ - - before = _get_before_module() - after = before.replace_global_vars({"tir_main": "tir_main_with_new_name"}) - - @I.ir_module - class Expected: - @R.function - def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - R.func_attr({"relax.force_pure": True}) - - B = Expected.relax_subroutine(A) - C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32")) - - D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) - Expected.tir_main_with_new_name(C, D) - - return D - - @R.function(private=True) - def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - B = R.add(A, R.prim_value(T.float32(1.0))) - return B - - @T.prim_func(s_tir=True) - def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - Expected.tir_subroutine(A.data, B.data) - - @T.prim_func(private=True, s_tir=True) - def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): - A = T.decl_buffer(16, "float32", data=A_data) - B = T.decl_buffer(16, "float32", data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - tvm.ir.assert_structural_equal(Expected, after) - - -def test_replace_tir_subroutine(): - """An internally-exposed TIR function may be replaced - - In this example, the "tir_subroutine" function is renamed. This - requires changing both the GlobalVar used to refer to the - function, and the GlobalVar used to refer to it. Internal - functions do not have the "global_symbol" attribute, so it does - not need to be updated. - - """ - - before = _get_before_module() - after = before.replace_global_vars({"tir_subroutine": "tir_subroutine_with_new_name"}) - - @I.ir_module - class Expected: - @R.function - def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - R.func_attr({"relax.force_pure": True}) - - B = Expected.relax_subroutine(A) - C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) - - D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) - Expected.tir_main(C, D) - - return D - - @R.function(private=True) - def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - B = R.add(A, R.prim_value(T.float32(1.0))) - return B - - @T.prim_func(s_tir=True) - def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - Expected.tir_subroutine_with_new_name(A.data, B.data) - - @T.prim_func(private=True, s_tir=True) - def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")): - A = T.decl_buffer(16, "float32", data=A_data) - B = T.decl_buffer(16, "float32", data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - tvm.ir.assert_structural_equal(Expected, after) - - -def test_simultaneous_replacements(): - """Multiple replacements may be performed simultaneously""" - - before = _get_before_module() - after = before.replace_global_vars( - { - "relax_main": "relax_main_with_new_name", - "relax_subroutine": "relax_subroutine_with_new_name", - "tir_main": "tir_main_with_new_name", - "tir_subroutine": "tir_subroutine_with_new_name", - } - ) - - @I.ir_module - class Expected: - @R.function - def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): - R.func_attr({"relax.force_pure": True}) - - B = Expected.relax_subroutine_with_new_name(A) - C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32")) - - D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) - Expected.tir_main_with_new_name(C, D) - - return D - - @R.function(private=True) - def relax_subroutine_with_new_name( - A: R.Tensor([16], "float32"), - ) -> R.Tensor([16], "float32"): - B = R.add(A, R.prim_value(T.float32(1.0))) - return B - - @T.prim_func(s_tir=True) - def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - Expected.tir_subroutine_with_new_name(A.data, B.data) - - @T.prim_func(private=True, s_tir=True) - def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")): - A = T.decl_buffer(16, "float32", data=A_data) - B = T.decl_buffer(16, "float32", data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - tvm.ir.assert_structural_equal(Expected, after) - - -if __name__ == "__main__": - tvm.testing.main() From 11e5986983943ff85b5dbcc05f32603f5903d5c7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 17:21:56 +0000 Subject: [PATCH 2/2] [REFACTOR][IR] ReplaceGlobalVarsInModule: direct name-hint compare MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since the loop top already computes new_gvar as replacements.Get(old_gvar).value_or(old_gvar), the kGlobalSymbol rename step can directly compare new_gvar->name_hint vs old_gvar->name_hint in O(1) instead of scanning replacements in O(N) on each iteration. Applied identically to the tirx PrimFunc branch and the relax Function branch — they were duplicated O(N) scans. --- src/relax/transform/attach_global_symbol.cc | 24 ++++++--------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 7b3cd85aec0e..0e8cd722c12d 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -91,15 +91,9 @@ IRModule ReplaceGlobalVarsInModule(IRModule mod, ffi::Map func.CopyOnWrite()->body = new_body; } // Update kGlobalSymbol if the function is externally exposed and being renamed. - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { - auto name = opt.value(); - for (const auto& [before, after] : replacements) { - if (before->name_hint == name) { - if (after->name_hint != name) { - func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint); - } - break; - } + if (func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (new_gvar->name_hint != old_gvar->name_hint) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, new_gvar->name_hint); } } new_func = func; @@ -108,15 +102,9 @@ IRModule ReplaceGlobalVarsInModule(IRModule mod, ffi::Map auto new_relax_func = Downcast(mutator(Downcast(ffi::GetRef(relax_func_node)))); // Update kGlobalSymbol if the function is externally exposed and being renamed. - if (auto opt = new_relax_func->GetAttr(tvm::attr::kGlobalSymbol)) { - auto name = opt.value(); - for (const auto& [before, after] : replacements) { - if (before->name_hint == name) { - if (after->name_hint != name) { - new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, after->name_hint); - } - break; - } + if (new_relax_func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (new_gvar->name_hint != old_gvar->name_hint) { + new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, new_gvar->name_hint); } } new_func = new_relax_func;