Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/dev/relay_op_strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ so. You can find more examples in ``vta/python/vta/top/op.py``.
def conv2d_strategy_mytarget(attrs, inputs, out_type, target):
...

Additionally, you can extend a native (C++) strategy.

.. code:: python

@tvm.target.extend_native_generic_func("add_strategy", "mytarget")
def add_strategy_mytarget(attrs, inputs, out_type, target):
# Adds a specialization of the addition strategy for 'mytarget'
...


Select Implementation from Op Strategy
--------------------------------------
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@
from .target import cuda, rocm, mali, intel_graphics, arm_cpu, rasp, vta, bifrost, hexagon
from .tag import list_tags
from .generic_func import GenericFunc
from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
from .generic_func import (
generic_func,
get_native_generic_func,
override_native_generic_func,
extend_native_generic_func,
)
from . import datatype
from . import codegen
from .intrin import register_intrin_rule
63 changes: 62 additions & 1 deletion python/tvm/target/generic_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,67 @@ def dispatch_func(func, *args, **kwargs):
return fdecorate


def extend_native_generic_func(func_name, key, override=True):
"""Extend a generic function defined in C++

Generic function allows registration of further functions
that can be dispatched on current target context.

Parameters
----------
func_name : string
The name of the generic func to be extended.

key : str or list of str
The key to be registered.

override : bool, optional
Whether to override existing registration.

Returns
-------
fregister : function
A decorator function for registering the decorated function as a specialization
for the provided generic function

Example
-------
.. code-block:: python

import tvm
# register a specialization of a native generic function "my_func"
@tvm.target.extend_native_generic_func("my_func", "cuda")
def my_func_cuda(a):
return a + 1
# retrieve generic func
my_func = tvm.target.get_native_generic_func("my_func")
# displays result of the native generic function (possibly registered in C++)
print(my_func(2))
# displays 3, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
generic_func_node = get_native_generic_func(func_name)

def fregister(func):
"""Register function as a specialization of a native generic function

Parameters
----------
func : function
The function to be registered.

Returns
-------
func : function
The provided function.
"""
generic_func_node.register(func, key, override)
return func

return fregister


def generic_func(fdefault):
"""Wrap a target generic function.

Expand Down Expand Up @@ -266,7 +327,7 @@ def _do_reg(myf):
return _do_reg

def dispatch_func(func, *args, **kwargs):
"""The wrapped dispath function"""
"""The wrapped dispatch function"""
target = Target.current()
if target is None:
return func(*args, **kwargs)
Expand Down
38 changes: 37 additions & 1 deletion tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@
import json
import tvm
from tvm import te
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon
from tvm.target import (
cuda,
rocm,
mali,
intel_graphics,
arm_cpu,
vta,
bifrost,
hexagon,
get_native_generic_func,
override_native_generic_func,
extend_native_generic_func,
)


@tvm.target.generic_func
Expand Down Expand Up @@ -60,6 +72,29 @@ def test_target_dispatch():
assert tvm.target.Target.current() is None


@override_native_generic_func("tests.my_native_generic")
def my_native_generic(data):
# default generic function
return data + 1


@extend_native_generic_func("tests.my_native_generic", "cuda")
def my_native_generic_cuda(data):
return data + 2


def test_target_dispatch_native_generic_function():
func = get_native_generic_func("tests.my_native_generic")

with tvm.target.cuda():
assert func(1) == 3

with tvm.target.arm_cpu():
assert func(1) == 2

assert tvm.target.Target.current() is None


def test_target_string_parse():
target = tvm.target.Target("cuda -model=unknown -libs=cublas,cudnn")

Expand Down Expand Up @@ -133,6 +168,7 @@ def test_composite_target():

if __name__ == "__main__":
test_target_dispatch()
test_target_dispatch_native_generic_function()
test_target_string_parse()
test_target_create()
test_target_config()
Expand Down