Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/tvm/relax/frontend/nn/_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,22 @@ def __radd__(self, other):
other = _convert_scalar(other, self)
return _op().add(self, other)

def __sub__(self, other):
other = _convert_scalar(other, self)
return _op().subtract(self, other)

def __rsub__(self, other):
other = _convert_scalar(other, self)
return _op().subtract(other, self)

def __mul__(self, other):
other = _convert_scalar(other, self)
return _op().multiply(self, other)

def __rmul__(self, other):
other = _convert_scalar(other, self)
return _op().multiply(self, other)

def __truediv__(self, other):
other = _convert_scalar(other, self)
return _op().divide(self, other)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def __init__(

def forward(self, x: Tensor) -> Tensor:
"""
Forward method for convtranspose1d layer.
Forward method for conv transpose 1d layer.

Parameters
----------
Expand All @@ -321,7 +321,7 @@ def forward(self, x: Tensor) -> Tensor:
Returns
-------
ret : Tensor
The output tensor for the convtranspose1d layer.
The output tensor for the conv transpose 1d layer.
"""
return op.conv1d_transpose(
x,
Expand Down
76 changes: 75 additions & 1 deletion python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,13 +1461,87 @@ def _convert(arg):
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])


def tensor_ir_op(
func: _tir.PrimFunc,
name_hint: str,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s a bit of complication here: if the PrimFunc provided is a public function (has “global_symbol” field in its attrs), Relax is not allowed to rename it, and in this case, it’s not a name hint but a name instead. Therefore, we will have to check symbol duplication and potentially throw an error if it happens.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably leave this logic to future work, but let’s rename name_hint to name to better reflect this point

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and thanks for pointing it out. However, the current Python interface AddFunction also treats it as name_hint, which may be renamed if conflicts exist.

It would be an independent problem out of the scope of this PR.

args: Union[Tensor, Sequence[Union[Tensor, _tir.Var]]],
out: OutType,
) -> OutType:
"""Create a `call_tir` binding with given PrimFunc

Parameters
----------
func : _tir.PrimFunc
The PrimFunc to call.

name_hint : str
Name hint.

args : Union[Tensor, Sequence[Union[Tensor, _tir.Var]]]
The arguments to pass to the PrimFunc.

out : Union[Tensor, List[Tensor]]
The output tensors.

Returns
-------
result : Tensor
The result tensor
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel

call_tir_args, tir_vars = [], []
if not isinstance(args, (tuple, list)):
args = [args]

for arg in args:
if isinstance(arg, Tensor):
call_tir_args.append(arg._expr)
elif isinstance(arg, _tir.Var):
tir_vars.append(arg)
else:
raise TypeError(
f"Unsupported type: tensor_ir_op args expect Tensor or tir.Var, but got {type(arg)}"
)

if isinstance(out, Tensor):
out_sinfo = [out._expr.struct_info]
else:
out_sinfo = [x._expr.struct_info for x in out]

bb = BlockBuilder.current()
global_var = bb.add_func(func, name_hint)

return wrap_nested(
bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)),
name=name_hint,
)


def extern(
name: str,
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
out: OutType,
) -> OutType:
"""Invoke an extern function during runtime. The extern function must be registered with the "
TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python)."""
TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python).

Parameters
----------
name : str
The name of the extern function to call.

args : Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]]
The arguments to pass to the extern function.

out : Union[Tensor, List[Tensor]]
The output tensors, only

Returns
-------
result : Tensor
The result
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel

def _convert(arg, name: str):
Expand Down
130 changes: 130 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring, invalid-name
import tvm
import tvm.testing
from tvm import tir
Expand Down Expand Up @@ -508,5 +509,134 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_tensor_ir_op():
num_q_heads, num_kv_heads, head_dim = 8, 8, 16
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unittest is a bit more complicated than I expected :)) in the simplest case, we could probably just supply a “B = A + 1”-style TIR

fused_heads = num_q_heads + num_kv_heads * 2
dtype = "float16"

@T.prim_func(private=True)
def fused_rope( # pylint: disable=too-many-locals
var_qkv: T.handle,
offset: T.int64,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
):
batch_size = T.int64()
seq_len = T.int64()
qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads, head_dim), dtype)
q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads, head_dim), dtype)
k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, head_dim), dtype)
v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, head_dim), dtype)
T.evaluate(offset)

class Model(Module):
def test(self, qkv: Tensor, offset: tir.Var):
tensor_expr_op_out = op.tensor_ir_op(
fused_rope,
"llama_fused_rope",
args=[qkv, offset],
out=[
Tensor.placeholder((1, 1, num_q_heads, head_dim), dtype),
Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype),
Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype),
],
)
return tensor_expr_op_out

# fmt: off
@I.ir_module
class Expected:
@T.prim_func(private=True)
def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle):
batch_size, seq_len = T.int64(), T.int64()
qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), "float16")
q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16")
k = T.match_buffer(var_k, (batch_size, seq_len, 8, 16), "float16")
v = T.match_buffer(var_v, (batch_size, seq_len, 8, 16), "float16")
T.evaluate(offset)

@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv

@R.function
def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offset_1"]), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")), R.Tuple(R.Object)):
offset_1 = T.int64()
R.func_attr({"num_input": 3})
cls = Expected
with R.dataflow():
lv1 = R.call_tir(cls.llama_fused_rope, (qkv,), out_sinfo=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")], tir_vars=R.shape([offset_1]))
llama_fused_rope_0: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[0]
llama_fused_rope_1: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[1]
llama_fused_rope_2: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[2]
gv1: R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")), R.Tuple(R.Object)) = (llama_fused_rope_0, llama_fused_rope_1, llama_fused_rope_2), (_io,)
R.output(gv1)
return gv1
# fmt: on

m = Model()
irmodule, _ = m.export_tvm(
spec={
"test": {"qkv": spec.Tensor([1, 1, fused_heads, head_dim], "float16"), "offset": int}
},
debug=True,
)
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_extern():
class Model(Module):
def test(self, q: Tensor, k: Tensor, v: Tensor):
b, s, h_q, d = q.shape
tensor_expr_op_out = op.extern(
name="flashinfer.single_decode",
args=[q, k, v, 0, 0, 1.0, 10000.0],
out=Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)
return tensor_expr_op_out

# fmt: off
@I.ir_module
class Expected:
@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv

@R.function
def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), dtype="float32"), v: R.Tensor((64, 16, 8), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)):
R.func_attr({"num_input": 4})
with R.dataflow():
flashinfer_single_decode = R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0), R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))), out_sinfo=R.Tensor((1, 1, 128), dtype="float16"))
gv1: R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)) = flashinfer_single_decode, (_io,)
R.output(gv1)
return gv1
# fmt: on

batch, seq, t, d, h_q, h_kv = 1, 1, 64, 8, 16, 16
m = Model()
irmodule, _ = m.export_tvm(
spec={
"test": {
"q": spec.Tensor([batch, seq, h_q, d], "float32"),
"k": spec.Tensor([t, h_kv, d], "float32"),
"v": spec.Tensor([t, h_kv, d], "float32"),
}
},
debug=True,
)
tvm.ir.assert_structural_equal(irmodule, Expected)


if __name__ == "__main__":
tvm.testing.main()