Skip to content

Commit

Permalink
Implement printing/parsing of dynamic tuple indices
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Nov 6, 2023
1 parent 4e32c73 commit 1b3b5bf
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 5 deletions.
5 changes: 1 addition & 4 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,7 @@ def __getitem__(self, index: Union[Expr, PrimExpr, int]) -> "ExprWithOp":
result: ExprWithOp
The result expression.
"""
if not isinstance(index, Expr):
index = PrimValue(index)

return _ffi_api.tuple_get_item(self, index)
return tvm.relax.op.tuple_get_item(self, index)


@tvm._ffi.register_object("relax.expr.Call")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from .set import unique
from .statistical import cumsum, max, mean, min, prod, std, sum, variance
from .ternary import ewise_fma
from .tuple import tuple_get_item, tuple_get_item_dyn
from .unary import (
abs,
acos,
Expand Down
93 changes: 93 additions & 0 deletions python/tvm/relax/op/tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuple operators."""
from typing import Union

import tvm
from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr, PrimValue


def tuple_get_item(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr:
"""Perform tuple access
Use of this method is recommended, rather than constructing a
`relax.TupleGetItem` directly.
1. May resolve to the tuple's contents, avoiding the intermediate
`TupleGetItem`.
2. Handles access of a tuple at a dynamic index, where
`TupleGetItem` requires a statically-known index.
Parameters
----------
tuple: Expr
The tuple to be accessed. The tuple is not required to be an
in-line `relax.Tuple`, but must have `TupleStructInfo`
index: Union[int, PrimExpr, Expr]
The index at which the tuple is accessed. The index may be
static or dynamic.
Returns
-------
Expr
An expression representing the item in the tuple.
"""

if not isinstance(index, Expr):
index = PrimValue(index)

return _ffi_api.tuple_get_item(tuple, index) # type: ignore


def tuple_get_item_dyn(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr:
"""Explicitly generate a call to tuple_get_item_dyn
This method is not recommended for general use, and is provided to
ensure round-trip consistency in TVMScript. In most cases, the
`tuple_get_item` method should be used, which will delegate to the
dynamic builtin for cases where the index is dynamic.
Parameters
----------
tuple: Expr
The tuple to be accessed. The tuple is not required to be an
in-line `relax.Tuple`, but must have `TupleStructInfo`
index: Union[int, PrimExpr, Expr]
The index at which the tuple is accessed. The index may be
static or dynamic.
Returns
-------
Expr
An expression representing the item in the tuple.
"""
if not isinstance(index, Expr):
index = PrimValue(index)
return tvm.relax.Call(tvm.ir.Op.get("relax.tuple_get_item_dyn"), [tuple, index])
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@
tile,
tril,
triu,
tuple_get_item,
tuple_get_item_dyn,
unique,
vm,
where,
Expand Down Expand Up @@ -775,6 +777,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"tril",
"triu",
"tuple",
"tuple_get_item",
"tuple_get_item_dyn",
"unique",
"variance",
"vm",
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
return annotation
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")
except Exception as err:
raise
self.report_error(node, str(err))
raise err

Expand All @@ -108,6 +109,7 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St
try:
return eval_struct_info_proxy(self, node).as_struct_info(var_table)
except Exception as err:
raise
self.report_error(node, str(err))
raise err

Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/tuple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Expr tuple_get_item(Expr tuple, Expr index) {
}
}

TVM_REGISTER_GLOBAL("relax.tuple_get_item").set_body_typed(tuple_get_item);
TVM_REGISTER_GLOBAL("relax.op.tuple_get_item").set_body_typed(tuple_get_item);

} // namespace relax
} // namespace tvm
22 changes: 22 additions & 0 deletions src/script/printer/relax/call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ Optional<ExprDoc> PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p,
return Relax(d, "print")->Call(args, {"format"}, {first_arg});
}

Optional<ExprDoc> PrintTupleGetItem(const relax::Call& call, const ObjectPath& path,
const IRDocsifier& doc) {
static const Op& print_op = Op::Get("relax.tuple_get_item_dyn");
if (!call->op.same_as(print_op)) {
return NullOpt;
}

if (!doc->cfg->syntax_sugar) {
// Fall back to the default printing for builtins as `R.tuple_get_item_dyn`
return NullOpt;
}

ICHECK_EQ(call->args.size(), 2);
ExprDoc tuple = doc->AsDoc<ExprDoc>(call->args[0], path->Attr("args")->ArrayIndex(0));
ExprDoc index = doc->AsDoc<ExprDoc>(call->args[1], path->Attr("args")->ArrayIndex(1));
return tuple[{index}];
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Call>( //
"", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc {
Expand All @@ -269,6 +287,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (Optional<ExprDoc> doc = PrintRelaxPrint(n, n_p, d)) {
return doc.value();
}
// Special case: tuple_get_item_dyn
if (Optional<ExprDoc> doc = PrintTupleGetItem(n, n_p, d)) {
return doc.value();
}
ExprDoc prefix{nullptr};
Array<ExprDoc> args;
Array<String> kwargs_keys;
Expand Down
34 changes: 34 additions & 0 deletions tests/python/relax/test_tuple_get_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from tvm import relax
from tvm.script import relax as R, tir as T

import pytest

exec_mode = tvm.testing.parameter("bytecode", "compiled")

tuple_type_annotation = tvm.testing.parameter(
Expand All @@ -33,6 +35,8 @@

tuple_index_type = tvm.testing.parameter("static", "dynamic")

syntax_sugar = tvm.testing.parameter(by_dict={"sugared": True, "unsugared": False})


def test_vm_tuple_get_item(exec_mode, tuple_type_annotation, tuple_index_type):
def access_tuple(tuple_obj, dyn_index):
Expand Down Expand Up @@ -61,5 +65,35 @@ def func(arg: tuple_type_annotation, index_param: R.Prim(value="index_var")):
assert res == 17


def test_dynamic_index_printing(syntax_sugar: bool):
"""Check syntax-sugar for dynamic tuple indices
The "relax.tuple_get_item_dyn" operator should be printed as
`my_tuple[my_index]` by default, which will regenerate the
original operator when parsed. If syntax sugar is disabled, it
should display the `R.tuple_get_item_dyn` directly.
"""

@R.function(private=True)
def func(
arg_tuple: R.Tuple([R.Prim("int64"), R.Prim("float32")]),
arg_index: R.Prim(value="index_var"),
):
return arg_tuple[arg_index]

script = func.script(syntax_sugar=syntax_sugar)

if syntax_sugar:
assert "arg_tuple[arg_index]" in script
assert "tuple_get_item_dyn" not in script
else:
assert "arg_tuple[arg_index]" not in script
assert "tuple_get_item_dyn" in script

roundtrip = tvm.script.from_source(script)

tvm.ir.assert_structural_equal(func, roundtrip)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1b3b5bf

Please sign in to comment.