Skip to content

Commit

Permalink
call_dps -> call_tir (apache#60)
Browse files Browse the repository at this point in the history
* Rename call_dps to call_tir

* Rename call_dps_rewrite.cc
  • Loading branch information
electriclilies authored and YuchenJin committed Nov 17, 2022
1 parent a55c4b8 commit 7f8cb36
Show file tree
Hide file tree
Showing 15 changed files with 72 additions and 72 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ TVM_DLL Pass FMARewrite();
TVM_DLL Pass ToNonDataflow();

/*!
* \brief Perform explicit tensor allocation for call_dps.
* \brief Perform explicit tensor allocation for call_tir.
*
* \return The Pass.
*/
TVM_DLL Pass CallDPSRewrite();
TVM_DLL Pass CallTIRRewrite();

/*!
* \brief Transform Relax IR to A-normal form.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
load_exec_from_file = vm.load_exec_from_file

# Operator
from .op.base import call_dps
from .op.base import call_tir
from .op.op_attrs import AllocStorageAttrs, AllocTensorAttrs

# IRBuilder
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import relax as rx
from tvm import tir
from .expr import *
from .op.base import call_dps
from .op.base import call_tir
from tvm._ffi.base import _LIB, check_call
from . import _ffi_api

Expand Down Expand Up @@ -334,7 +334,7 @@ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_comp
@R.function
def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tensor:
# block 0
gv = relax.call_dps((128, 128), "te_func", (x, y))
gv = relax.call_tir((128, 128), "te_func", (x, y))
return gv
Example
Expand Down Expand Up @@ -380,7 +380,7 @@ def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> N
@R.function
def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Tensor[_, "float32"]:
# block 0
gv: Tensor[((n + 1),), "float32"] = relax.call_dps(((n + 1),), te_func, (y,), (n,))
gv: Tensor[((n + 1),), "float32"] = relax.call_tir(((n + 1),), te_func, (y,), (n,))
return gv
"""
new_args, te_arg_list = self._convert_te_arg(args)
Expand All @@ -404,9 +404,9 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten
call_args = [x.op.value for x in inputs[:-1]]
# add arguments for extra parameters from unbound var
if (len(unbound_tir_vars) > 0):
call = call_dps(inputs[-1].shape, gvar, call_args, tir_vars=ShapeExpr(unbound_tir_vars))
call = call_tir(inputs[-1].shape, gvar, call_args, tir_vars=ShapeExpr(unbound_tir_vars))
else:
call = call_dps(inputs[-1].shape, gvar, call_args)
call = call_tir(inputs[-1].shape, gvar, call_args)
return _ffi_api.BlockBuilderEmit(self, call)


Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Union, List


def call_dps(
def call_tir(
shape: Union[ShapeExpr, List[int]], func: Expr, args: Union[Tuple, List[Expr]],
tir_vars: ShapeExpr = None
) -> Call:
Expand All @@ -43,10 +43,10 @@ def call_dps(
Returns
-------
ret: Call
A call node for the call_dps operator.
A call node for the call_tir operator.
"""
if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)
if isinstance(args, (list, tuple)):
args = Tuple(args)
return _ffi_api.call_dps(shape, func, args, tir_vars)
return _ffi_api.call_tir(shape, func, args, tir_vars)
6 changes: 3 additions & 3 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def ToNonDataflow() -> tvm.ir.transform.Pass:
return _ffi_api.ToNonDataflow()


def CallDPSRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_dps.
def CallTIRRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_tir.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CallDPSRewrite()
return _ffi_api.CallTIRRewrite()


def VMMemoryLower() -> tvm.ir.transform.Pass:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
ex, lib = relax.vm.build(mod, target)
"""
passes = [relax.transform.ToNonDataflow()]
passes.append(relax.transform.CallDPSRewrite())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.VMMemoryLower())
passes.append(relax.transform.VMShapeLower())
seq = tvm.transform.Sequential(passes)
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _is_registered(op_name: str, op_set=None) -> bool:
return op_name in op_set


# NOTE: call_dps is an actual registered operator
# NOTE: call_tir is an actual registered operator
class SpecialOp(Enum):
"""Relax operators that have special semantics handled by the parser."""

Expand Down Expand Up @@ -845,7 +845,7 @@ def parse_attr(self, expr: ast.Attr) -> relax.Expr:
relay.op.get("relax.shape_of"), [obj], span=self.to_tvm_span(expr.span)
)
else:
# assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps)
# assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_tir)
op_name = self._parse_attrs_to_str(expr)
# NOTE: at least for now, all special operators are namespaced
try:
Expand Down Expand Up @@ -890,8 +890,8 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
elif isinstance(op, tvm.ir.Op):
args = [self.transform_expr(arg) for arg in expr.params]
# check call arity eagerly
if op.name == "relax.call_dps":
# call_dps is special case because last argument is optional
if op.name == "relax.call_tir":
# call_tir is special case because last argument is optional
if len(args) != 3 and len(args) != 4:
self.report_error(
f"{op.name} expects {op.num_inputs} arguments but got {len(args)}", expr.span
Expand All @@ -900,7 +900,7 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
self.report_error(
f"{op.name} expects {op.num_inputs} arguments but got {len(args)}", expr.span
)
if op.name == "relax.call_dps" and isinstance(args[1], str):
if op.name == "relax.call_tir" and isinstance(args[1], str):
# extern function call case: rewrite identifier to an ExternFunc
args[1] = relax.ExternFunc(args[1], self.to_tvm_span(expr.params[1].span))

Expand Down
12 changes: 6 additions & 6 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) {
return false;
}

// call_dps
// call_tir

RELAY_REGISTER_OP("relax.call_dps")
RELAY_REGISTER_OP("relax.call_tir")
.set_num_inputs(4)
.add_argument("shape", "Expr", "The output shape.")
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.")
.add_argument("packed_ints", "Expr",
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from args if unused");

Expr MakeCallDPS(Expr shape, Expr func, Tuple args, Optional<Expr> packed_ints) {
static const Op& op = Op::Get("relax.call_dps");
Expr MakeCallTIR(Expr shape, Expr func, Tuple args, Optional<Expr> packed_ints) {
static const Op& op = Op::Get("relax.call_tir");
Call call;
if (!packed_ints) {
// don't use additional optional argument
Expand All @@ -73,8 +73,8 @@ Expr MakeCallDPS(Expr shape, Expr func, Tuple args, Optional<Expr> packed_ints)
return call;
}

TVM_REGISTER_GLOBAL("relax.op.call_dps")
.set_body_typed(MakeCallDPS);
TVM_REGISTER_GLOBAL("relax.op.call_tir")
.set_body_typed(MakeCallTIR);

// shape_of

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* under the License.
*/
/*!
* \file src/relax/transform/call_dps_rewrite.cc
* \brief Perform explicit tensor allocation for call_dps.
* \file src/relax/transform/call_tir_rewrite.cc
* \brief Perform explicit tensor allocation for call_tir.
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
Expand All @@ -32,26 +32,26 @@ namespace tvm {
namespace relax {

// ==================
// CallDPSMutator
// Perform explicit tensor allocation for call_dps.
// CallTIRMutator
// Perform explicit tensor allocation for call_tir.
// Example:
// lv0: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// lv0: Tensor[n, m] = rx.call_tir((n, m), op.identity, (x))
// -->
// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, gv0)

class CallDPSMutator : public ExprMutator {
class CallTIRMutator : public ExprMutator {
public:
Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expr expr = VisitExprPostOrder_(call);
call = expr.as<CallNode>();

static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");

if (call->op == call_dps_op) {
if (call->op == call_tir_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
Var tensor = builder_->Emit(Call(alloc_tensor_op, {output_shape}), "alloc");
Array<Expr> args;
Expand All @@ -76,17 +76,17 @@ class CallDPSMutator : public ExprMutator {
}
};

Expr CallDPSRewrite(const Expr& e) { return CallDPSMutator().VisitExpr(e); }
Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }

namespace transform {

Pass CallDPSRewrite() {
Pass CallTIRRewrite() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(CallDPSRewrite(f)); };
return CreateFunctionPass(pass_func, 0, "CallDPSRewrite", {});
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(CallTIRRewrite(f)); };
return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
}

TVM_REGISTER_GLOBAL("relax.transform.CallDPSRewrite").set_body_typed(CallDPSRewrite);
TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);

} // namespace transform

Expand Down
8 changes: 4 additions & 4 deletions tests/python/relax/test_blockbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def get_tir_func():
# check TIR structure matches expected
assert_structural_equal(mod["te_func"].body, get_tir_func().body)

# check Relax function calls TIR function with call_dps call
# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert rx_func.params[2] == z
Expand All @@ -311,7 +311,7 @@ def get_tir_func():
assert len(rx_func.body.blocks) == 1
assert len(rx_func.body.blocks[0].bindings) == 1
assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call)
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_tir")
assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x
Expand Down Expand Up @@ -354,12 +354,12 @@ def test_emit_te_extern():
mod = bb.get()
rx_func = mod["rx_cblas_matmul"]

# check Relax function calls TIR function with call_dps call
# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert len(rx_func.body.blocks) == 1
assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call)
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_tir")
assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "matmul"
assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relax/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def identity_tir(a: T.handle, b: T.handle) -> None:
B[vi, vj] = A[vi, vj]


def test_call_dps() -> None:
def test_call_tir() -> None:
shape_anno = [54, 96]
type_anno = rx.DynTensorType(2, "float32")
v0 = rx.Var("v0", shape_anno, type_anno)
v1 = rx.call_dps([54, 96], rx.extern("test.op.identity"), [v0])
v1 = rx.call_dps([54, 96], identity_tir, [v0])
v1 = rx.call_tir([54, 96], rx.extern("test.op.identity"), [v0])
v1 = rx.call_tir([54, 96], identity_tir, [v0])


if __name__ == "__main__":
test_call_dps()
test_call_tir()
12 changes: 6 additions & 6 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vj, vk]

z = relax.call_dps((B, 128), my_matmul, (x, y))
z = relax.call_tir((B, 128), my_matmul, (x, y))
return z

x, y = f.params
Expand All @@ -433,7 +433,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:

check_call(
z_bind.value,
"relax.call_dps",
"relax.call_tir",
[relax.ShapeExpr([B, tir.IntImm("int64", 128)]), mm_bind.var, relax.Tuple([x, y])],
)

Expand Down Expand Up @@ -475,18 +475,18 @@ def f(x: Tensor[(n, m), "float32"]):
assert_structural_equal(sh_bind.value.values, [tir.Add(n, m), tir.FloorDiv(n, m)])


def test_call_dps_extern():
def test_call_tir_extern():
@R.function
def f(x: Tensor):
z = relax.call_dps((10,), "my_extern", (x,))
z = relax.call_tir((10,), "my_extern", (x,))
return z

x = f.params[0]
(z_bind,) = f.body.blocks[0].bindings

check_call(
z_bind.value,
"relax.call_dps",
"relax.call_tir",
[
relax.ShapeExpr([tir.IntImm("int64", 10)]),
relax.ExternFunc("my_extern"),
Expand Down Expand Up @@ -517,7 +517,7 @@ def f(x: Tensor[(n, n), _]) -> Tensor:

@R.function
def g(y: Tensor[(n, n), _]) -> Tensor:
return relax.call_dps((n, n), my_matmul, (y, y))
return relax.call_tir((n, n), my_matmul, (y, y))

@R.function
def h(x, y, z):
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vj, vk]

z = relax.call_dps((B, 128), my_matmul, (x, y))
z = relax.call_tir((B, 128), my_matmul, (x, y))
return z

check_roundtrip(foo)
Expand Down Expand Up @@ -171,10 +171,10 @@ def foo(x: Tensor[(n, m), "float32"]):
check_roundtrip(foo)


def test_call_dps_extern():
def test_call_tir_extern():
@R.function
def foo(x: Tensor):
z = relax.call_dps((10,), "my_extern", (x,))
z = relax.call_tir((10,), "my_extern", (x,))
return z

check_roundtrip(foo)
Expand Down Expand Up @@ -202,7 +202,7 @@ def f(x: Tensor[(n, n), _]) -> Tensor:

@R.function
def g(y: Tensor[(n, n), _]) -> Tensor:
return relax.call_dps((n, n), my_matmul, (y, y))
return relax.call_tir((n, n), my_matmul, (y, y))

@R.function
def h(x, y, z):
Expand Down
Loading

0 comments on commit 7f8cb36

Please sign in to comment.