diff --git a/docs/dev/relay_op_strategy.rst b/docs/dev/relay_op_strategy.rst index c40251d22433..e4fcc1a7ee07 100644 --- a/docs/dev/relay_op_strategy.rst +++ b/docs/dev/relay_op_strategy.rst @@ -253,6 +253,15 @@ so. You can find more examples in ``vta/python/vta/top/op.py``. def conv2d_strategy_mytarget(attrs, inputs, out_type, target): ... +Additionally, you can extend a native (C++) strategy. + +.. code:: python + + @tvm.target.extend_native_generic_func("add_strategy", "mytarget") + def add_strategy_mytarget(attrs, inputs, out_type, target): + # Adds a specialization of the addition strategy for 'mytarget' + ... + Select Implementation from Op Strategy -------------------------------------- diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index a9a12bb158c5..e4eadbf0f9fd 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -58,7 +58,12 @@ from .target import cuda, rocm, mali, intel_graphics, arm_cpu, rasp, vta, bifrost, hexagon from .tag import list_tags from .generic_func import GenericFunc -from .generic_func import generic_func, get_native_generic_func, override_native_generic_func +from .generic_func import ( + generic_func, + get_native_generic_func, + override_native_generic_func, + extend_native_generic_func, +) from . import datatype from . import codegen from .intrin import register_intrin_rule diff --git a/python/tvm/target/generic_func.py b/python/tvm/target/generic_func.py index 932eaa47f112..e545e9aad60f 100644 --- a/python/tvm/target/generic_func.py +++ b/python/tvm/target/generic_func.py @@ -195,6 +195,67 @@ def dispatch_func(func, *args, **kwargs): return fdecorate +def extend_native_generic_func(func_name, key, override=True): + """Extend a generic function defined in C++ + + Generic function allows registration of further functions + that can be dispatched on current target context. + + Parameters + ---------- + func_name : string + The name of the generic func to be extended. + + key : str or list of str + The key to be registered. + + override : bool, optional + Whether to override existing registration. + + Returns + ------- + fregister : function + A decorator function for registering the decorated function as a specialization + for the provided generic function + + Example + ------- + .. code-block:: python + + import tvm + # register a specialization of a native generic function "my_func" + @tvm.target.extend_native_generic_func("my_func", "cuda") + def my_func_cuda(a): + return a + 1 + # retrieve generic func + my_func = tvm.target.get_native_generic_func("my_func") + # displays result of the native generic function (possibly registered in C++) + print(my_func(2)) + # displays 3, because my_func_cuda is called + with tvm.target.cuda(): + print(my_func(2)) + """ + generic_func_node = get_native_generic_func(func_name) + + def fregister(func): + """Register function as a specialization of a native generic function + + Parameters + ---------- + func : function + The function to be registered. + + Returns + ------- + func : function + The provided function. + """ + generic_func_node.register(func, key, override) + return func + + return fregister + + def generic_func(fdefault): """Wrap a target generic function. @@ -266,7 +327,7 @@ def _do_reg(myf): return _do_reg def dispatch_func(func, *args, **kwargs): - """The wrapped dispath function""" + """The wrapped dispatch function""" target = Target.current() if target is None: return func(*args, **kwargs) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 643043f13663..e6512357680b 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -17,7 +17,19 @@ import json import tvm from tvm import te -from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon +from tvm.target import ( + cuda, + rocm, + mali, + intel_graphics, + arm_cpu, + vta, + bifrost, + hexagon, + get_native_generic_func, + override_native_generic_func, + extend_native_generic_func, +) @tvm.target.generic_func @@ -60,6 +72,29 @@ def test_target_dispatch(): assert tvm.target.Target.current() is None +@override_native_generic_func("tests.my_native_generic") +def my_native_generic(data): + # default generic function + return data + 1 + + +@extend_native_generic_func("tests.my_native_generic", "cuda") +def my_native_generic_cuda(data): + return data + 2 + + +def test_target_dispatch_native_generic_function(): + func = get_native_generic_func("tests.my_native_generic") + + with tvm.target.cuda(): + assert func(1) == 3 + + with tvm.target.arm_cpu(): + assert func(1) == 2 + + assert tvm.target.Target.current() is None + + def test_target_string_parse(): target = tvm.target.Target("cuda -model=unknown -libs=cublas,cudnn") @@ -133,6 +168,7 @@ def test_composite_target(): if __name__ == "__main__": test_target_dispatch() + test_target_dispatch_native_generic_function() test_target_string_parse() test_target_create() test_target_config()