Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/contrib/hexagon/meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tvm
from tvm.driver import build as tvm_build
from tvm.ir.module import IRModule
from tvm.ir.utils import derived_object
from tvm.runtime import Module, Tensor
from tvm.s_tir.meta_schedule.builder import LocalBuilder
from tvm.s_tir.meta_schedule.runner import (
Expand All @@ -36,7 +37,7 @@
default_alloc_argument,
default_run_evaluator,
)
from tvm.s_tir.meta_schedule.utils import cpu_count, derived_object
from tvm.s_tir.meta_schedule.utils import cpu_count
from tvm.s_tir.transform import RemoveWeightLayoutRewriteBlock
from tvm.support.popen_pool import PopenPoolExecutor
from tvm.target import Target
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/runtime/support.py → python/tvm/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Runtime support infra of TVM."""
"""Utilities shared across TVM IR packages."""

from typing import TypeVar

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import tvm_ffi

from tvm.ir import Op
from tvm.ir.utils import derived_object
from tvm.runtime import Object
from tvm.runtime.support import derived_object

from ..ir.module import IRModule
from . import _ffi_api
Expand Down
1 change: 0 additions & 1 deletion python/tvm/s_tir/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,4 @@
from .tir_integration import tune_tir
from .tune import tune_tasks
from .tune_context import TuneContext
from .utils import derived_object
from .post_optimization import post_opt
3 changes: 2 additions & 1 deletion python/tvm/s_tir/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from tvm_ffi import register_global_func

from tvm.ir import IRModule
from tvm.ir.utils import derived_object
from tvm.runtime import Module, Tensor, load_param_dict, save_param_dict
from tvm.support.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from tvm.target import Target

from ..logging import get_logger
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from ..utils import cpu_count, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder

logger = get_logger(__name__) # pylint: disable=invalid-name
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch # type: ignore

import tvm
from tvm.ir.utils import derived_object
from tvm.support.tar import tar, untar

from ....runtime import Tensor
Expand All @@ -43,7 +44,7 @@
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import derived_object, shash2hex
from ..utils import shash2hex

logger = get_logger("mlp_model") # pylint: disable=invalid-name

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/s_tir/meta_schedule/cost_model/random_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
Random cost model
"""

from tvm.ir.utils import derived_object

from ..cost_model import PyCostModel
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import derived_object # type: ignore


@derived_object
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np # type: ignore

from tvm.ir.utils import derived_object
from tvm.support.tar import tar, untar

from ....runtime import Tensor
Expand All @@ -33,7 +34,7 @@
from ..logging import get_logger
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..utils import cpu_count, derived_object, shash2hex
from ..utils import cpu_count, shash2hex
from .metric import max_curve

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import numpy as np # type: ignore

import tvm.runtime
from tvm.ir.utils import derived_object

from ..feature_extractor import PyFeatureExtractor
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import derived_object


@derived_object
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/s_tir/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from contextlib import contextmanager

import tvm
from tvm.ir.utils import derived_object
from tvm.support.popen_pool import PopenPoolExecutor

from ....runtime import Device, Module
from ..logging import get_logger
from ..profiler import Profiler
from ..utils import derived_object, get_global_func_with_default_on_worker
from ..utils import get_global_func_with_default_on_worker
from .config import EvaluatorConfig
from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult
from .utils import (
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/s_tir/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from collections.abc import Callable
from contextlib import contextmanager

from tvm.ir.utils import derived_object
from tvm.rpc import RPCSession
from tvm.runtime import Device, Module
from tvm.support.popen_pool import PopenPoolExecutor

from ..logging import get_logger
from ..profiler import Profiler
from ..utils import (
derived_object,
get_global_func_on_rpc_session,
get_global_func_with_default_on_worker,
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/s_tir/meta_schedule/testing/dummy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

import random

from tvm.ir.utils import derived_object
from tvm.s_tir.schedule import Trace

from ..builder import BuilderInput, BuilderResult, PyBuilder
from ..mutator import PyMutator
from ..runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult
from ..tune_context import TuneContext # pylint: disable=unused-import
from ..utils import derived_object


@derived_object
Expand Down
140 changes: 0 additions & 140 deletions python/tvm/s_tir/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,146 +32,6 @@
from tvm.tirx import FloatImm, IntImm


def derived_object(cls: type) -> type:
"""A decorator to register derived subclasses for TVM objects.

Parameters
----------
cls : type
The derived class to be registered.

Returns
-------
cls : type
The decorated TVM object.

Example
-------
.. code-block:: python

@register_object("s_tir.meta_schedule.PyRunner")
class _PyRunner(meta_schedule.Runner):
def __init__(self, f_run: Callable = None):
self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, f_run)

class PyRunner:
_tvm_metadata = {
"cls": _PyRunner,
"methods": ["run"]
}
def run(self, runner_inputs):
raise NotImplementedError

@derived_object
class LocalRunner(PyRunner):
def run(self, runner_inputs):
...
"""

import functools # pylint: disable=import-outside-toplevel
import weakref # pylint: disable=import-outside-toplevel

def _extract(inst: type, name: str):
"""Extract function from intrinsic class."""

def method(*args, **kwargs):
return getattr(inst, name)(*args, **kwargs)

for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]):
# extract functions that differ from the base class
if not hasattr(base_cls, name):
continue
if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__":
continue
return method

# for task scheduler return None means calling default function
# otherwise it will trigger a TVMError of method not implemented
# on the c++ side when you call the method, __str__ not required
return None

assert isinstance(cls.__base__, type)
if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore
raise TypeError(
f"Inheritance from a decorated object `{cls.__name__}` is not allowed. "
f"Please inherit from `{cls.__name__}._cls`."
)
assert hasattr(cls, "_tvm_metadata"), (
"Please use the user-facing method overriding class, i.e., PyRunner."
)

base = cls.__base__
metadata = getattr(base, "_tvm_metadata")
fields = metadata.get("fields", [])
methods = metadata.get("methods", [])

base_cls = metadata["cls"]
slots = []
if getattr(base_cls, "__dictoffset__", 0) == 0:
slots.append("__dict__")
if getattr(base_cls, "__weakrefoffset__", 0) == 0:
slots.append("__weakref__")

class TVMDerivedObject(base_cls): # type: ignore
"""The derived object to avoid cyclic dependency."""

__slots__ = tuple(slots)

_cls = cls
_type = "TVMDerivedObject"

def __init__(self, *args, **kwargs):
"""Constructor."""
self._inst = cls(*args, **kwargs)

super().__init__(
# the constructor's parameters, builder, runner, etc.
*[getattr(self._inst, name) for name in fields],
# the function methods, init_with_tune_context, build, run, etc.
*[_extract(self._inst, name) for name in methods],
)

# for task scheduler hybrid funcs in c++ & python side
# using weakref to avoid cyclic dependency
self._inst._outer = weakref.ref(self)

def __getattr__(self, name):
import inspect # pylint: disable=import-outside-toplevel

try:
# fall back to instance attribute if there is not any
# return self._inst.__getattribute__(name)
result = self._inst.__getattribute__(name)
except AttributeError:
result = super().__getattr__(name)

if inspect.ismethod(result):

def method(*args, **kwargs):
return result(*args, **kwargs)

# set __own__ to aviod implicit deconstruction
setattr(method, "__own__", self)
return method

return result

def __setattr__(self, name, value):
if name not in ["_inst", "key", "handle"]:
self._inst.__setattr__(name, value)
else:
super().__setattr__(name, value)

functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type: ignore
TVMDerivedObject.__name__ = cls.__name__
TVMDerivedObject.__doc__ = cls.__doc__
TVMDerivedObject.__module__ = cls.__module__
for key, value in cls.__dict__.items():
if isinstance(value, classmethod | staticmethod):
setattr(TVMDerivedObject, key, value)
return TVMDerivedObject


@register_global_func("s_tir.meta_schedule.cpu_count")
def _cpu_count_impl(logical: bool = True) -> int:
"""Return the number of logical or physical CPUs in the system
Expand Down
35 changes: 3 additions & 32 deletions python/tvm/tirx/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tvm_ffi

from tvm.ir import PrimExpr
from tvm.runtime.support import derived_object
from tvm.ir.utils import derived_object

from . import _ffi_api
Comment thread
tqchen marked this conversation as resolved.
from .expr import (
Expand Down Expand Up @@ -78,39 +78,10 @@
While,
)

# visitor and mutator are aliases for derived_object
visitor = derived_object
"""
A decorator to wrap user-customized PyStmtExprVisitor as TVM object _PyStmtExprVisitor.

Parameters
----------
visitor_cls : PyStmtExprVisitor
The user-customized PyStmtExprVisitor.

Returns
-------
cls : _PyStmtExprVisitor
The decorated TVM object _PyStmtExprVisitor(StmtExprVisitor on the C++ side).

Example
-------
.. code-block:: python

@tirx.functor.stmt_expr_visitor
class MyStmtExprVisitor(PyStmtExprVisitor):
# customize visit function
def visit_call_(self, op: Call) -> None:
# just for demo purposes
...
# myvisitor is now a special visitor that visit every Call with
# user-customized visit_call_
myvisitor = MyStmtExprVisitor()
# apply myvisitor to PrimExpr and Stmt
myvisitor.visit_expr(expr)
myvisitor.visit_stmt(stmt)
"""

mutator = derived_object

"""
A decorator to wrap user-customized PyStmtExprMutator as TVM object _PyStmtExprMutator.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

import tvm
import tvm.testing
from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel
from tvm.s_tir.meta_schedule.cost_model.xgb_model import PackSum, _get_custom_call_back
from tvm.s_tir.meta_schedule.feature_extractor import RandomFeatureExtractor
from tvm.s_tir.meta_schedule.runner import RunnerResult
from tvm.s_tir.meta_schedule.search_strategy import MeasureCandidate
from tvm.s_tir.meta_schedule.tune_context import TuneContext
from tvm.s_tir.meta_schedule.utils import derived_object
from tvm.s_tir.schedule.schedule import Schedule
from tvm.script import tirx as T

Expand Down
Loading
Loading