Skip to content

Commit

Permalink
[RELAY] Enable registering op with python (#8002)
Browse files Browse the repository at this point in the history
Add a new API register_op

Note: Implementing a op by pure python is still limited:
  1. Custom type relation (add_type_rel()) is still not
     available in python.

  2. Setting number inputs (set_num_inputs()) needs
     plevel > 128 in python.
     (see tests/python/relay/test_ir_op.py)
  • Loading branch information
zackcquic committed May 7, 2021
1 parent 8d9a1df commit 254563a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Expand Up @@ -23,7 +23,7 @@
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .op import Op, register_op_attr, register_intrin_lowering
from .op import Op, register_op, register_op_attr, register_intrin_lowering
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/ir/op.py
Expand Up @@ -86,6 +86,18 @@ def reset_attr(self, attr_name):
_ffi_api.OpResetAttr(self, attr_name)


def register_op(op_name):
"""Register an operator by name
Parameters
----------
op_name : str
The name of new operator
"""

_ffi_api.RegisterOp(op_name)


def register_op_attr(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator by name.
Expand Down
6 changes: 6 additions & 0 deletions src/ir/op.cc
Expand Up @@ -102,6 +102,12 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
reg.reset_attr(attr_name);
});

TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before";
OpRegistry::Global()->RegisterOrGet(op_name).set_name();
});

TVM_REGISTER_GLOBAL("ir.RegisterOpAttr")
.set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
Expand Down
16 changes: 14 additions & 2 deletions tests/python/relay/test_ir_op.py
Expand Up @@ -32,7 +32,7 @@ def test(x):


def test_op_reset_attr():
""" Tests reset_attr functionality. """
"""Tests reset_attr functionality."""

def add1(x):
return x + 1
Expand Down Expand Up @@ -60,7 +60,7 @@ def add2(x):


def test_op_temp_attr():
""" Tests reset_attr functionality. """
"""Tests reset_attr functionality."""

def add1(x):
return x + 1
Expand Down Expand Up @@ -99,9 +99,21 @@ def test_op_level3():
assert y.args[0] == x


def test_op_register():
"""Tests register_op functionality."""
op_name = "custom_op"

tvm.ir.register_op(op_name)
tvm.ir.register_op_attr(op_name, "num_inputs", 2, 256)

assert tvm.ir.Op.get(op_name).name == op_name
assert tvm.ir.Op.get(op_name).num_inputs == 2


if __name__ == "__main__":
test_op_attr()
test_op_reset_attr()
test_op_temp_attr()
test_op_level1()
test_op_level3()
test_op_register()

0 comments on commit 254563a

Please sign in to comment.