diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py b/python/tvm/relax/frontend/nn/_tensor_op.py index a653c9fa2955..627b8b626c6e 100644 --- a/python/tvm/relax/frontend/nn/_tensor_op.py +++ b/python/tvm/relax/frontend/nn/_tensor_op.py @@ -47,10 +47,22 @@ def __radd__(self, other): other = _convert_scalar(other, self) return _op().add(self, other) + def __sub__(self, other): + other = _convert_scalar(other, self) + return _op().subtract(self, other) + + def __rsub__(self, other): + other = _convert_scalar(other, self) + return _op().subtract(other, self) + def __mul__(self, other): other = _convert_scalar(other, self) return _op().multiply(self, other) + def __rmul__(self, other): + other = _convert_scalar(other, self) + return _op().multiply(self, other) + def __truediv__(self, other): other = _convert_scalar(other, self) return _op().divide(self, other) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index b2c97a567ab8..03d6a06994a1 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -311,7 +311,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """ - Forward method for convtranspose1d layer. + Forward method for conv transpose 1d layer. Parameters ---------- @@ -321,7 +321,7 @@ def forward(self, x: Tensor) -> Tensor: Returns ------- ret : Tensor - The output tensor for the convtranspose1d layer. + The output tensor for the conv transpose 1d layer. """ return op.conv1d_transpose( x, diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 2369451ac98b..3197145289ef 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1461,13 +1461,87 @@ def _convert(arg): OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]]) +def tensor_ir_op( + func: _tir.PrimFunc, + name_hint: str, + args: Union[Tensor, Sequence[Union[Tensor, _tir.Var]]], + out: OutType, +) -> OutType: + """Create a `call_tir` binding with given PrimFunc + + Parameters + ---------- + func : _tir.PrimFunc + The PrimFunc to call. + + name_hint : str + Name hint. + + args : Union[Tensor, Sequence[Union[Tensor, _tir.Var]]] + The arguments to pass to the PrimFunc. + + out : Union[Tensor, List[Tensor]] + The output tensors. + + Returns + ------- + result : Tensor + The result tensor + """ + from tvm import relax as rx # pylint: disable=import-outside-toplevel + + call_tir_args, tir_vars = [], [] + if not isinstance(args, (tuple, list)): + args = [args] + + for arg in args: + if isinstance(arg, Tensor): + call_tir_args.append(arg._expr) + elif isinstance(arg, _tir.Var): + tir_vars.append(arg) + else: + raise TypeError( + f"Unsupported type: tensor_ir_op args expect Tensor or tir.Var, but got {type(arg)}" + ) + + if isinstance(out, Tensor): + out_sinfo = [out._expr.struct_info] + else: + out_sinfo = [x._expr.struct_info for x in out] + + bb = BlockBuilder.current() + global_var = bb.add_func(func, name_hint) + + return wrap_nested( + bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)), + name=name_hint, + ) + + def extern( name: str, args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]], out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " - TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python).""" + TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python). + + Parameters + ---------- + name : str + The name of the extern function to call. + + args : Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]] + The arguments to pass to the extern function. + + out : Union[Tensor, List[Tensor]] + The output tensors, only + + Returns + ------- + result : Tensor + The result + """ from tvm import relax as rx # pylint: disable=import-outside-toplevel def _convert(arg, name: str): diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index ddaec7234b9a..55870426e485 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-docstring, invalid-name import tvm import tvm.testing from tvm import tir @@ -508,5 +509,134 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten tvm.ir.assert_structural_equal(irmodule, Expected) +def test_tensor_ir_op(): + num_q_heads, num_kv_heads, head_dim = 8, 8, 16 + fused_heads = num_q_heads + num_kv_heads * 2 + dtype = "float16" + + @T.prim_func(private=True) + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + offset: T.int64, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + ): + batch_size = T.int64() + seq_len = T.int64() + qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, head_dim), dtype) + T.evaluate(offset) + + class Model(Module): + def test(self, qkv: Tensor, offset: tir.Var): + tensor_expr_op_out = op.tensor_ir_op( + fused_rope, + "llama_fused_rope", + args=[qkv, offset], + out=[ + Tensor.placeholder((1, 1, num_q_heads, head_dim), dtype), + Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype), + Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype), + ], + ) + return tensor_expr_op_out + + # fmt: off + @I.ir_module + class Expected: + @T.prim_func(private=True) + def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle): + batch_size, seq_len = T.int64(), T.int64() + qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), "float16") + q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16") + k = T.match_buffer(var_k, (batch_size, seq_len, 8, 16), "float16") + v = T.match_buffer(var_v, (batch_size, seq_len, 8, 16), "float16") + T.evaluate(offset) + + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offset_1"]), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")), R.Tuple(R.Object)): + offset_1 = T.int64() + R.func_attr({"num_input": 3}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir(cls.llama_fused_rope, (qkv,), out_sinfo=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")], tir_vars=R.shape([offset_1])) + llama_fused_rope_0: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[0] + llama_fused_rope_1: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[1] + llama_fused_rope_2: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[2] + gv1: R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")), R.Tuple(R.Object)) = (llama_fused_rope_0, llama_fused_rope_1, llama_fused_rope_2), (_io,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + irmodule, _ = m.export_tvm( + spec={ + "test": {"qkv": spec.Tensor([1, 1, fused_heads, head_dim], "float16"), "offset": int} + }, + debug=True, + ) + tvm.ir.assert_structural_equal(irmodule, Expected) + + +def test_extern(): + class Model(Module): + def test(self, q: Tensor, k: Tensor, v: Tensor): + b, s, h_q, d = q.shape + tensor_expr_op_out = op.extern( + name="flashinfer.single_decode", + args=[q, k, v, 0, 0, 1.0, 10000.0], + out=Tensor.placeholder((b, s, h_q * d), dtype="float16"), + ) + return tensor_expr_op_out + + # fmt: off + @I.ir_module + class Expected: + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), dtype="float32"), v: R.Tensor((64, 16, 8), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)): + R.func_attr({"num_input": 4}) + with R.dataflow(): + flashinfer_single_decode = R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0), R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))), out_sinfo=R.Tensor((1, 1, 128), dtype="float16")) + gv1: R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)) = flashinfer_single_decode, (_io,) + R.output(gv1) + return gv1 + # fmt: on + + batch, seq, t, d, h_q, h_kv = 1, 1, 64, 8, 16, 16 + m = Model() + irmodule, _ = m.export_tvm( + spec={ + "test": { + "q": spec.Tensor([batch, seq, h_q, d], "float32"), + "k": spec.Tensor([t, h_kv, d], "float32"), + "v": spec.Tensor([t, h_kv, d], "float32"), + } + }, + debug=True, + ) + tvm.ir.assert_structural_equal(irmodule, Expected) + + if __name__ == "__main__": tvm.testing.main()