Skip to content

Commit

Permalink
[TVMScript] Better Type Annotation for TIR OP (#17107)
Browse files Browse the repository at this point in the history
Enable ParamType for TIR op, so that we can have better experience when
writing TVMScript in Python with tools.

However, ParamType is introduced in Python 3.10, so we only enable it
when Python version is 3.10 or above.
  • Loading branch information
Hzfengsy committed Jun 20, 2024
1 parent 269a4f7 commit 36b9535
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import inspect
from numbers import Integral
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# isort: off
Expand Down Expand Up @@ -1764,14 +1765,31 @@ def f():
# pylint: disable=invalid-name


def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
kwargs.pop("dtype")
return func(*args, **kwargs)
if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeVar # pylint: disable=import-error

return wrapped
T = TypeVar("T")
P = ParamSpec("P")

def _op_wrapper(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def wrapped(*args, **kwargs) -> T:
if "dtype" in kwargs:
kwargs.pop("dtype")
return func(*args, **kwargs)

return wrapped

else:

def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
kwargs.pop("dtype")
return func(*args, **kwargs)

return wrapped


abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
Expand Down

0 comments on commit 36b9535

Please sign in to comment.