From 02e4019c4211bf2a189a10498da6af304263c476 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 22:47:53 +0000 Subject: [PATCH] [REFACTOR][PYTHON] Consolidate derived_object into tvm.ir.utils derived_object was duplicated byte-for-byte across python/tvm/runtime/support.py and python/tvm/s_tir/meta_schedule/utils.py. The function is not a runtime feature and is used outside meta_schedule (tvm.relax, tvm.tirx). Move the single canonical definition into python/tvm/ir/utils.py. tvm.ir loads before tvm.tirx and tvm.s_tir, so eager top-level imports work from every consumer without load-order shenanigans. Rewrite all 25 caller imports (runtime/support importers, meta_schedule production sites, tests). Delete python/tvm/runtime/support.py and remove the duplicate from python/tvm/s_tir/meta_schedule/utils.py. --- python/tvm/contrib/hexagon/meta_schedule.py | 3 +- .../tvm/{runtime/support.py => ir/utils.py} | 3 +- python/tvm/relax/expr_functor.py | 2 +- python/tvm/s_tir/meta_schedule/__init__.py | 1 - .../meta_schedule/builder/local_builder.py | 3 +- .../meta_schedule/cost_model/mlp_model.py | 3 +- .../meta_schedule/cost_model/random_model.py | 3 +- .../meta_schedule/cost_model/xgb_model.py | 3 +- .../random_feature_extractor.py | 2 +- .../meta_schedule/runner/local_runner.py | 3 +- .../s_tir/meta_schedule/runner/rpc_runner.py | 2 +- .../meta_schedule/testing/dummy_object.py | 2 +- python/tvm/s_tir/meta_schedule/utils.py | 140 ------------------ python/tvm/tirx/functor.py | 35 +---- .../test_meta_schedule_cost_model.py | 2 +- .../test_meta_schedule_database.py | 5 +- .../test_meta_schedule_feature_extractor.py | 2 +- .../test_meta_schedule_measure_callback.py | 15 +- .../test_meta_schedule_post_order_apply.py | 2 +- .../test_meta_schedule_runner.py | 2 +- .../test_meta_schedule_search_strategy.py | 2 +- .../test_meta_schedule_space_generator.py | 2 +- .../test_meta_schedule_task_scheduler.py | 7 +- .../test_meta_schedule_tune_tir.py | 3 +- 24 files changed, 43 insertions(+), 204 deletions(-) rename python/tvm/{runtime/support.py => ir/utils.py} (99%) diff --git a/python/tvm/contrib/hexagon/meta_schedule.py b/python/tvm/contrib/hexagon/meta_schedule.py index 5582f697464e..0084d1da7f56 100644 --- a/python/tvm/contrib/hexagon/meta_schedule.py +++ b/python/tvm/contrib/hexagon/meta_schedule.py @@ -23,6 +23,7 @@ import tvm from tvm.driver import build as tvm_build from tvm.ir.module import IRModule +from tvm.ir.utils import derived_object from tvm.runtime import Module, Tensor from tvm.s_tir.meta_schedule.builder import LocalBuilder from tvm.s_tir.meta_schedule.runner import ( @@ -36,7 +37,7 @@ default_alloc_argument, default_run_evaluator, ) -from tvm.s_tir.meta_schedule.utils import cpu_count, derived_object +from tvm.s_tir.meta_schedule.utils import cpu_count from tvm.s_tir.transform import RemoveWeightLayoutRewriteBlock from tvm.support.popen_pool import PopenPoolExecutor from tvm.target import Target diff --git a/python/tvm/runtime/support.py b/python/tvm/ir/utils.py similarity index 99% rename from python/tvm/runtime/support.py rename to python/tvm/ir/utils.py index d9762ef57116..a2068d47598f 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/ir/utils.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -"""Runtime support infra of TVM.""" +"""Utilities shared across TVM IR packages.""" from typing import TypeVar diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index 5ac77da3c04d..c9ea88d11100 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -23,8 +23,8 @@ import tvm_ffi from tvm.ir import Op +from tvm.ir.utils import derived_object from tvm.runtime import Object -from tvm.runtime.support import derived_object from ..ir.module import IRModule from . import _ffi_api diff --git a/python/tvm/s_tir/meta_schedule/__init__.py b/python/tvm/s_tir/meta_schedule/__init__.py index f3601f6e6df2..3fbdd37859d2 100644 --- a/python/tvm/s_tir/meta_schedule/__init__.py +++ b/python/tvm/s_tir/meta_schedule/__init__.py @@ -53,5 +53,4 @@ from .tir_integration import tune_tir from .tune import tune_tasks from .tune_context import TuneContext -from .utils import derived_object from .post_optimization import post_opt diff --git a/python/tvm/s_tir/meta_schedule/builder/local_builder.py b/python/tvm/s_tir/meta_schedule/builder/local_builder.py index 2a88c0167be3..aa563294e210 100644 --- a/python/tvm/s_tir/meta_schedule/builder/local_builder.py +++ b/python/tvm/s_tir/meta_schedule/builder/local_builder.py @@ -25,12 +25,13 @@ from tvm_ffi import register_global_func from tvm.ir import IRModule +from tvm.ir.utils import derived_object from tvm.runtime import Module, Tensor, load_param_dict, save_param_dict from tvm.support.popen_pool import MapResult, PopenPoolExecutor, StatusKind from tvm.target import Target from ..logging import get_logger -from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker +from ..utils import cpu_count, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder logger = get_logger(__name__) # pylint: disable=invalid-name diff --git a/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py b/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py index 162110371ffb..a9bb7c784d32 100644 --- a/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py +++ b/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py @@ -32,6 +32,7 @@ import torch # type: ignore import tvm +from tvm.ir.utils import derived_object from tvm.support.tar import tar, untar from ....runtime import Tensor @@ -43,7 +44,7 @@ from ..runner import RunnerResult from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext -from ..utils import derived_object, shash2hex +from ..utils import shash2hex logger = get_logger("mlp_model") # pylint: disable=invalid-name diff --git a/python/tvm/s_tir/meta_schedule/cost_model/random_model.py b/python/tvm/s_tir/meta_schedule/cost_model/random_model.py index 292fd4a96417..86df91d58dbf 100644 --- a/python/tvm/s_tir/meta_schedule/cost_model/random_model.py +++ b/python/tvm/s_tir/meta_schedule/cost_model/random_model.py @@ -18,11 +18,12 @@ Random cost model """ +from tvm.ir.utils import derived_object + from ..cost_model import PyCostModel from ..runner import RunnerResult from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext -from ..utils import derived_object # type: ignore @derived_object diff --git a/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py b/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py index 3bc0b4d769bb..8d6aa49b10e7 100644 --- a/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py @@ -25,6 +25,7 @@ import numpy as np # type: ignore +from tvm.ir.utils import derived_object from tvm.support.tar import tar, untar from ....runtime import Tensor @@ -33,7 +34,7 @@ from ..logging import get_logger from ..runner import RunnerResult from ..search_strategy import MeasureCandidate -from ..utils import cpu_count, derived_object, shash2hex +from ..utils import cpu_count, shash2hex from .metric import max_curve if TYPE_CHECKING: diff --git a/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py index 8cf7e2f2bfc3..fc42b36604a5 100644 --- a/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py @@ -19,11 +19,11 @@ import numpy as np # type: ignore import tvm.runtime +from tvm.ir.utils import derived_object from ..feature_extractor import PyFeatureExtractor from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext -from ..utils import derived_object @derived_object diff --git a/python/tvm/s_tir/meta_schedule/runner/local_runner.py b/python/tvm/s_tir/meta_schedule/runner/local_runner.py index c55925fd0bef..b56cd6613cf9 100644 --- a/python/tvm/s_tir/meta_schedule/runner/local_runner.py +++ b/python/tvm/s_tir/meta_schedule/runner/local_runner.py @@ -22,12 +22,13 @@ from contextlib import contextmanager import tvm +from tvm.ir.utils import derived_object from tvm.support.popen_pool import PopenPoolExecutor from ....runtime import Device, Module from ..logging import get_logger from ..profiler import Profiler -from ..utils import derived_object, get_global_func_with_default_on_worker +from ..utils import get_global_func_with_default_on_worker from .config import EvaluatorConfig from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult from .utils import ( diff --git a/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py b/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py index 435cfd8b4d3b..27ab71e66917 100644 --- a/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py @@ -21,6 +21,7 @@ from collections.abc import Callable from contextlib import contextmanager +from tvm.ir.utils import derived_object from tvm.rpc import RPCSession from tvm.runtime import Device, Module from tvm.support.popen_pool import PopenPoolExecutor @@ -28,7 +29,6 @@ from ..logging import get_logger from ..profiler import Profiler from ..utils import ( - derived_object, get_global_func_on_rpc_session, get_global_func_with_default_on_worker, ) diff --git a/python/tvm/s_tir/meta_schedule/testing/dummy_object.py b/python/tvm/s_tir/meta_schedule/testing/dummy_object.py index 007de8a9de0a..d3e0d55a936e 100644 --- a/python/tvm/s_tir/meta_schedule/testing/dummy_object.py +++ b/python/tvm/s_tir/meta_schedule/testing/dummy_object.py @@ -18,13 +18,13 @@ import random +from tvm.ir.utils import derived_object from tvm.s_tir.schedule import Trace from ..builder import BuilderInput, BuilderResult, PyBuilder from ..mutator import PyMutator from ..runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult from ..tune_context import TuneContext # pylint: disable=unused-import -from ..utils import derived_object @derived_object diff --git a/python/tvm/s_tir/meta_schedule/utils.py b/python/tvm/s_tir/meta_schedule/utils.py index 42f52a6c1e4b..775054c4ce07 100644 --- a/python/tvm/s_tir/meta_schedule/utils.py +++ b/python/tvm/s_tir/meta_schedule/utils.py @@ -32,146 +32,6 @@ from tvm.tirx import FloatImm, IntImm -def derived_object(cls: type) -> type: - """A decorator to register derived subclasses for TVM objects. - - Parameters - ---------- - cls : type - The derived class to be registered. - - Returns - ------- - cls : type - The decorated TVM object. - - Example - ------- - .. code-block:: python - - @register_object("s_tir.meta_schedule.PyRunner") - class _PyRunner(meta_schedule.Runner): - def __init__(self, f_run: Callable = None): - self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, f_run) - - class PyRunner: - _tvm_metadata = { - "cls": _PyRunner, - "methods": ["run"] - } - def run(self, runner_inputs): - raise NotImplementedError - - @derived_object - class LocalRunner(PyRunner): - def run(self, runner_inputs): - ... - """ - - import functools # pylint: disable=import-outside-toplevel - import weakref # pylint: disable=import-outside-toplevel - - def _extract(inst: type, name: str): - """Extract function from intrinsic class.""" - - def method(*args, **kwargs): - return getattr(inst, name)(*args, **kwargs) - - for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): - # extract functions that differ from the base class - if not hasattr(base_cls, name): - continue - if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": - continue - return method - - # for task scheduler return None means calling default function - # otherwise it will trigger a TVMError of method not implemented - # on the c++ side when you call the method, __str__ not required - return None - - assert isinstance(cls.__base__, type) - if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore - raise TypeError( - f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " - f"Please inherit from `{cls.__name__}._cls`." - ) - assert hasattr(cls, "_tvm_metadata"), ( - "Please use the user-facing method overriding class, i.e., PyRunner." - ) - - base = cls.__base__ - metadata = getattr(base, "_tvm_metadata") - fields = metadata.get("fields", []) - methods = metadata.get("methods", []) - - base_cls = metadata["cls"] - slots = [] - if getattr(base_cls, "__dictoffset__", 0) == 0: - slots.append("__dict__") - if getattr(base_cls, "__weakrefoffset__", 0) == 0: - slots.append("__weakref__") - - class TVMDerivedObject(base_cls): # type: ignore - """The derived object to avoid cyclic dependency.""" - - __slots__ = tuple(slots) - - _cls = cls - _type = "TVMDerivedObject" - - def __init__(self, *args, **kwargs): - """Constructor.""" - self._inst = cls(*args, **kwargs) - - super().__init__( - # the constructor's parameters, builder, runner, etc. - *[getattr(self._inst, name) for name in fields], - # the function methods, init_with_tune_context, build, run, etc. - *[_extract(self._inst, name) for name in methods], - ) - - # for task scheduler hybrid funcs in c++ & python side - # using weakref to avoid cyclic dependency - self._inst._outer = weakref.ref(self) - - def __getattr__(self, name): - import inspect # pylint: disable=import-outside-toplevel - - try: - # fall back to instance attribute if there is not any - # return self._inst.__getattribute__(name) - result = self._inst.__getattribute__(name) - except AttributeError: - result = super().__getattr__(name) - - if inspect.ismethod(result): - - def method(*args, **kwargs): - return result(*args, **kwargs) - - # set __own__ to aviod implicit deconstruction - setattr(method, "__own__", self) - return method - - return result - - def __setattr__(self, name, value): - if name not in ["_inst", "key", "handle"]: - self._inst.__setattr__(name, value) - else: - super().__setattr__(name, value) - - functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type: ignore - TVMDerivedObject.__name__ = cls.__name__ - TVMDerivedObject.__doc__ = cls.__doc__ - TVMDerivedObject.__module__ = cls.__module__ - for key, value in cls.__dict__.items(): - if isinstance(value, classmethod | staticmethod): - setattr(TVMDerivedObject, key, value) - return TVMDerivedObject - - @register_global_func("s_tir.meta_schedule.cpu_count") def _cpu_count_impl(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py index 4619c0b51fbb..ab2af06a7912 100644 --- a/python/tvm/tirx/functor.py +++ b/python/tvm/tirx/functor.py @@ -23,7 +23,7 @@ import tvm_ffi from tvm.ir import PrimExpr -from tvm.runtime.support import derived_object +from tvm.ir.utils import derived_object from . import _ffi_api from .expr import ( @@ -78,39 +78,10 @@ While, ) +# visitor and mutator are aliases for derived_object visitor = derived_object -""" -A decorator to wrap user-customized PyStmtExprVisitor as TVM object _PyStmtExprVisitor. - -Parameters ----------- -visitor_cls : PyStmtExprVisitor - The user-customized PyStmtExprVisitor. - -Returns -------- -cls : _PyStmtExprVisitor - The decorated TVM object _PyStmtExprVisitor(StmtExprVisitor on the C++ side). - -Example -------- -.. code-block:: python - - @tirx.functor.stmt_expr_visitor - class MyStmtExprVisitor(PyStmtExprVisitor): - # customize visit function - def visit_call_(self, op: Call) -> None: - # just for demo purposes - ... - # myvisitor is now a special visitor that visit every Call with - # user-customized visit_call_ - myvisitor = MyStmtExprVisitor() - # apply myvisitor to PrimExpr and Stmt - myvisitor.visit_expr(expr) - myvisitor.visit_stmt(stmt) -""" - mutator = derived_object + """ A decorator to wrap user-customized PyStmtExprMutator as TVM object _PyStmtExprMutator. diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py index 4ba49ebe2402..b2385597ab92 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py @@ -27,13 +27,13 @@ import tvm import tvm.testing +from tvm.ir.utils import derived_object from tvm.s_tir.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel from tvm.s_tir.meta_schedule.cost_model.xgb_model import PackSum, _get_custom_call_back from tvm.s_tir.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.s_tir.meta_schedule.runner import RunnerResult from tvm.s_tir.meta_schedule.search_strategy import MeasureCandidate from tvm.s_tir.meta_schedule.tune_context import TuneContext -from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule.schedule import Schedule from tvm.script import tirx as T diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py index 9314dedf578d..ffe4945f6883 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py @@ -29,6 +29,7 @@ import tvm.testing from tvm import tirx from tvm.ir.module import IRModule +from tvm.ir.utils import derived_object from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.database import TuningRecord, Workload @@ -113,7 +114,7 @@ def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord): assert str(arg0.as_json()) == str(arg1.as_json()) -@ms.utils.derived_object +@derived_object class PyMemoryDatabaseDefault(ms.database.PyDatabase): def __init__(self): super().__init__() @@ -156,7 +157,7 @@ def __len__(self) -> int: return len(self.tuning_records_) -@ms.utils.derived_object +@derived_object class PyMemoryDatabaseOverride(ms.database.PyDatabase): def __init__(self): super().__init__() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py index 1d336d9b5aa0..91723c539c50 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py @@ -20,10 +20,10 @@ import numpy as np import tvm.runtime +from tvm.ir.utils import derived_object from tvm.s_tir.meta_schedule import TuneContext from tvm.s_tir.meta_schedule.feature_extractor import PyFeatureExtractor from tvm.s_tir.meta_schedule.search_strategy import MeasureCandidate -from tvm.s_tir.meta_schedule.utils import derived_object def test_meta_schedule_feature_extractor(): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py index 2d6182920309..b9f2bcab7a6e 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py @@ -21,6 +21,7 @@ import pytest import tvm +from tvm.ir.utils import derived_object from tvm.s_tir import meta_schedule as ms from tvm.s_tir.schedule import Schedule from tvm.script import tirx as T @@ -48,7 +49,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: def test_meta_schedule_measure_callback(): - @ms.derived_object + @derived_object class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, @@ -82,7 +83,7 @@ def apply( def test_meta_schedule_measure_callback_fail(): - @ms.derived_object + @derived_object class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, @@ -106,7 +107,7 @@ def apply( def test_meta_schedule_measure_callback_as_string(): - @ms.derived_object + @derived_object class NotSoFancyMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, @@ -125,7 +126,7 @@ def apply( @pytest.mark.skip("Tuning test - launches runner") def test_meta_schedule_measure_callback_update_cost_model_with_zero(): - @ms.derived_object + @derived_object class AllZeroRunnerFuture(ms.runner.PyRunnerFuture): def done(self) -> bool: return True @@ -133,7 +134,7 @@ def done(self) -> bool: def result(self) -> ms.runner.RunnerResult: return ms.runner.RunnerResult([0.0, 0.0], None) - @ms.derived_object + @derived_object class AllZeroRunner(ms.runner.PyRunner): def run(self, runner_inputs: list[ms.runner.RunnerInput]) -> list[ms.runner.RunnerResult]: return [AllZeroRunnerFuture() for _ in runner_inputs] @@ -151,7 +152,7 @@ def run(self, runner_inputs: list[ms.runner.RunnerInput]) -> list[ms.runner.Runn @pytest.mark.skip("Tuning test - launches runner") def test_meta_schedule_measure_callback_update_cost_model_with_runtime_error(): - @ms.derived_object + @derived_object class EmptyRunnerFuture(ms.runner.PyRunnerFuture): def done(self) -> bool: return True @@ -159,7 +160,7 @@ def done(self) -> bool: def result(self) -> ms.runner.RunnerResult: return ms.runner.RunnerResult(None, "error") - @ms.derived_object + @derived_object class EmptyRunner(ms.runner.PyRunner): def run(self, runner_inputs: list[ms.runner.RunnerInput]) -> list[ms.runner.RunnerResult]: return [EmptyRunnerFuture() for _ in runner_inputs] diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py index ee9b74d92d6c..46d71ca6e745 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py @@ -28,10 +28,10 @@ from tvm import te from tvm.error import TVMError from tvm.ir.module import IRModule +from tvm.ir.utils import derived_object from tvm.s_tir.meta_schedule import TuneContext from tvm.s_tir.meta_schedule.schedule_rule import PyScheduleRule from tvm.s_tir.meta_schedule.space_generator import PostOrderApply -from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule import SBlockRV, Schedule from tvm.script import tirx as T from tvm.target import Target diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py index 9c267a69c6e4..b23c603a4b39 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py @@ -28,6 +28,7 @@ import tvm import tvm.testing +from tvm.ir.utils import derived_object from tvm.rpc import RPCSession from tvm.runtime import Device, Module from tvm.s_tir.meta_schedule.arg_info import TensorInfo @@ -53,7 +54,6 @@ ) from tvm.s_tir.meta_schedule.testing.local_rpc import LocalRPC from tvm.s_tir.meta_schedule.utils import ( - derived_object, get_global_func_with_default_on_worker, ) from tvm.script import tirx as T diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py index 0f393e23abd4..002741c6bf9e 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py @@ -23,9 +23,9 @@ import tvm import tvm.testing +from tvm.ir.utils import derived_object from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.testing.dummy_object import DummyMutator -from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule import Schedule, Trace from tvm.script import tirx as T diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py index a783cf587214..5515d66f9a02 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py @@ -24,13 +24,13 @@ import tvm import tvm.testing from tvm.base import TVMError +from tvm.ir.utils import derived_object from tvm.s_tir.meta_schedule.space_generator import ( PySpaceGenerator, ScheduleFn, SpaceGeneratorUnion, ) from tvm.s_tir.meta_schedule.tune_context import TuneContext -from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule import Schedule from tvm.script import tirx as T diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py index 1ffedc30cae9..61f5583c2a83 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py @@ -24,6 +24,7 @@ import tvm import tvm.testing +from tvm.ir.utils import derived_object from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner @@ -119,7 +120,7 @@ def _schedule_batch_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1) -@ms.derived_object +@derived_object class MyTaskScheduler(ms.task_scheduler.PyTaskScheduler): done: set = set() @@ -233,7 +234,7 @@ def test_meta_schedule_task_scheduler_multiple(): def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name - @ms.derived_object + @derived_object class NIETaskScheduler(ms.task_scheduler.PyTaskScheduler): pass @@ -360,7 +361,7 @@ def test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy() the scheduler should continue working as normal for other tasks """ - @ms.derived_object + @derived_object class NullSearchStrategy(ms.search_strategy.PySearchStrategy): def __init__(self, rounds_with_empty_candidates): self.rounds_with_empty_candidates = rounds_with_empty_candidates diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py index 97f803fc4848..8430072223bc 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py @@ -23,6 +23,7 @@ import tvm import tvm.testing +from tvm.ir.utils import derived_object from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.s_tir.meta_schedule.testing.local_rpc import LocalRPC @@ -147,7 +148,7 @@ def f_timer(rt_mod, dev, input_data): @pytest.mark.skip("Integration test") def test_tune_block_cpu(): - @ms.derived_object + @derived_object class RemoveBlock(ms.schedule_rule.PyScheduleRule): def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: pass