Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1540,9 +1540,18 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExp

inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
PrimExpr num_elem = tvm::cast(
tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
Array<PrimExpr> shape;
PrimExpr num_elem;
if (start.dtype().is_int() && stop.dtype().is_int() && step.dtype().is_int()) {
// fast path for integer arange
num_elem = tvm::floordiv((stop - start + step - 1), step);
} else {
num_elem = tvm::cast(DefaultIndexType(),
tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
}

arith::Analyzer analyzer;
num_elem = analyzer.Simplify(num_elem);

return compute(
{num_elem},
[&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def function(
ret: FunctionScope
A FunctionScope for building a Relax function node.
"""
if not params:
params = None
elif isinstance(params, rx.Var):
if isinstance(params, rx.Var):
params = [params]
elif isinstance(params, (list, tuple)):
for param in params:
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def _eq(self, node: fx.node.Node) -> relax.Expr:

def _arange(self, node: fx.node.Node) -> relax.Var:
import torch
import numpy as np

start_end_step = [None, None, None]
if "start" in node.kwargs:
Expand Down Expand Up @@ -288,8 +287,10 @@ def _arange(self, node: fx.node.Node) -> relax.Var:
dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype())
else:
dtype = "int64"

return relax.const(np.arange(*start_end_step, dtype=dtype))
start_end_step = [
self.env[x] if isinstance(x, torch.fx.node.Node) else x for x in start_end_step
]
return relax.op.arange(*start_end_step, dtype=dtype)

def _empty(self, node: fx.node.Node) -> relax.Var:
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
Expand Down
56 changes: 54 additions & 2 deletions python/tvm/relax/op/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
"""Creation operators."""
from typing import Optional, Tuple, Union

from tvm import DataType
from tvm import DataType, DataTypeCode
from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr, ShapeExpr
from ..expr import Expr, PrimValue, ShapeExpr

PrimExprLike = Union[int, PrimExpr]

Expand Down Expand Up @@ -163,6 +163,58 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
return _ffi_api.zeros_like(x, dtype) # type: ignore


def arange(
start: Union[PrimExprLike, PrimValue],
end: Optional[Union[PrimExprLike, PrimValue]] = None,
step: Union[PrimExprLike, PrimValue] = 1,
dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Construct a tensor with evenly spaced elements.

Parameters
----------
start : Union[PrimExprLike,PrimValue]
The start of the interval.

end : Optional[Union[PrimExprLike,PrimValue]]
The end of the interval. If not given, it will be set to start,
and start will be set to 0.

step : Union[PrimExprLike,PrimValue]
The step size.

dtype : Optional[Union[str, DataType]]
The data type of the created tensor.

Returns
-------
result : relax.Expr
The result tensor.
"""
if end is None:
end = start
start = 0

def is_int(expr):
if isinstance(expr, int):
return True
if isinstance(expr, PrimValue):
expr = expr.value
return (
isinstance(expr, PrimExpr) and DataType(expr.dtype).type_code == DataTypeCode.INT
) # type: ignore

if dtype is None:
args = (start, end, step)
integer_args = all(is_int(arg) for arg in args)
dtype = "int64" if integer_args else "float32"

start = start if isinstance(start, PrimValue) else PrimValue(start)
end = end if isinstance(end, PrimValue) else PrimValue(end)
step = step if isinstance(step, PrimValue) else PrimValue(step)
return _ffi_api.arange(start, end, step, dtype) # type: ignore


def tril(x: Expr, k: int = 0) -> Expr:
"""Return the lower triangular part of a matrix or a batch of matrices.

Expand Down
25 changes: 22 additions & 3 deletions python/tvm/relax/transform/legalize_ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
"""Default legalization function for creation operators."""
from typing import Optional

from tvm import topi, tir
import numpy as np

from tvm import tir, topi

from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from .common import LegalizeFunc, register_legalize, _try_convert_to_scalar_const
from ...expr import Call, Expr, PrimValue, const
from .common import LegalizeFunc, _try_convert_to_scalar_const, register_legalize


def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) -> LegalizeFunc:
Expand Down Expand Up @@ -64,3 +67,19 @@ def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.zeros_like", _full(is_like=True, fill_value=0.0, primfunc_name="zeros"))
register_legalize("relax.tril", _tril_triu(is_upper=False, primfunc_name="tril"))
register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu"))


@register_legalize("relax.arange")
def _arange(bb: BlockBuilder, call: Call) -> Expr:
assert len(call.args) == 3
assert all([isinstance(x, PrimValue) for x in call.args])
start, end, step = [x.value for x in call.args]
dtype = call.attrs.dtype

def is_const_scalar(x: PrimValue):
return isinstance(x.value, (tir.IntImm, tir.FloatImm))

if all([is_const_scalar(x) for x in call.args]):
return const(np.arange(start.value, end.value, step.value, dtype=dtype), dtype=dtype)
else:
return bb.call_te(topi.arange, start, end, step, dtype)
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
atan,
atanh,
add,
arange,
argmax,
argmin,
assert_op,
Expand Down Expand Up @@ -543,6 +544,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"atan",
"atanh",
"add",
"arange",
"arg",
"argmax",
"argmin",
Expand Down
54 changes: 54 additions & 0 deletions src/relax/op/tensor/create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

#include "create.h"

#include <tvm/arith/analyzer.h>

#include <string>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -219,6 +222,57 @@ TVM_REGISTER_OP("relax.zeros_like")
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesLikeZerosLike);

/* relax.arange */
Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) {
ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("relax.arange");
return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange);

StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 3) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "Arange should have 3 arguments, which are `start`, `end` and `step`, but got "
<< call->args.size() << " arguments");
}
// TODO(Siyuan): Support indirect prim_values
auto get_prim_value = [&ctx](const Expr& expr, std::string key) {
if (!expr->IsInstance<PrimValueNode>()) {
ctx->ReportFatal(Diagnostic::Error(expr)
<< "Arange expects the `" << key << "` to be a PrimValue, but got "
<< expr->GetTypeKey());
}
return expr.as<PrimValueNode>()->value;
};
PrimExpr start = get_prim_value(call->args[0], "start");
PrimExpr end = get_prim_value(call->args[1], "end");
PrimExpr step = get_prim_value(call->args[2], "step");
DataType dtype = call->attrs.as<InitAttrs>()->dtype;
PrimExpr num_elem;
if (start.dtype().is_int() && end.dtype().is_int() && step.dtype().is_int()) {
num_elem = tvm::floordiv((end - start + step - 1), step);
} else {
num_elem = tvm::cast(tvm::DataType::Int(64),
tvm::ceil(tvm::cast(tvm::DataType::Float(32), end - start) / step));
}
arith::Analyzer analyzer;
num_elem = analyzer.Simplify(num_elem);
return TensorStructInfo(ShapeExpr({num_elem}), dtype);
}

TVM_REGISTER_OP("relax.arange")
.set_attrs_type<InitAttrs>()
.set_num_inputs(3)
.add_argument("start", "PrimValue", "The starting value for the set of points.")
.add_argument("end", "PrimValue", "The ending value for the set of points.")
.add_argument("step", "PrimValue", "The gap between each pair of adjacent points.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoArange)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);

/* relax.tril & relax.triu */
TVM_REGISTER_NODE_TYPE(TriluAttrs);

Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/tensor/create.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ Expr zeros(Expr shape, DataType dtype);
/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */
Expr zeros_like(Expr x, DataType dtype);

/*! \brief Construct a tensor with evenly spaced elements. */
Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype);

/*! \brief Return the lower triangular part of a matrix or a batch of matrices. */
Expr tril(Expr x, int k);

Expand Down
73 changes: 71 additions & 2 deletions tests/python/relax/test_op_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
from tvm import TVMError, relax, tir
from tvm.ir import Op
from tvm.script import relax as R
from tvm.script import tir as T


def test_op_correctness():
Expand All @@ -32,6 +33,7 @@ def test_op_correctness():
assert relax.op.ones_like(x).op == Op.get("relax.ones_like")
assert relax.op.zeros((2, 3), "float32").op == Op.get("relax.zeros")
assert relax.op.zeros_like(x).op == Op.get("relax.zeros_like")
assert relax.op.arange(3, 4, 1, "float32").op == Op.get("relax.arange")
assert relax.op.tril(x).op == Op.get("relax.tril")
assert relax.op.triu(x).op == Op.get("relax.triu")

Expand Down Expand Up @@ -534,6 +536,73 @@ def test_ones_like_zeros_like_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.zeros_like(x1))


def test_arange_infer_struct_info():
bb = relax.BlockBuilder()

_check_inference(bb, relax.op.arange(10), relax.TensorStructInfo((10,), "int64"))
_check_inference(bb, relax.op.arange(1, 10), relax.TensorStructInfo((9,), "int64"))
_check_inference(bb, relax.op.arange(0, 10, 2), relax.TensorStructInfo((5,), "int64"))
_check_inference(bb, relax.op.arange(1, 10, 2), relax.TensorStructInfo((5,), "int64"))

_check_inference(bb, relax.op.arange(10.0), relax.TensorStructInfo((10,), "float32"))
_check_inference(bb, relax.op.arange(1.0, 10), relax.TensorStructInfo((9,), "float32"))
_check_inference(bb, relax.op.arange(0, 20, 2.5), relax.TensorStructInfo((8,), "float32"))
_check_inference(bb, relax.op.arange(1, 10, 2.3), relax.TensorStructInfo((4,), "float32"))


def test_arange_infer_struct_info_shape_var():
bb = relax.BlockBuilder()
start = tir.Var("start", "int64")
stop = tir.Var("stop", "int64")
step = tir.Var("step", "int64")

_check_inference(bb, relax.op.arange(stop), relax.TensorStructInfo((stop,), "int64"))
_check_inference(bb, relax.op.arange(1, stop), relax.TensorStructInfo((stop - 1,), "int64"))
_check_inference(
bb, relax.op.arange(start, stop), relax.TensorStructInfo((stop - start,), "int64")
)
_check_inference(
bb,
relax.op.arange(start, stop, 2),
relax.TensorStructInfo(((stop + 1 - start) // 2,), "int64"),
)
_check_inference(
bb,
relax.op.arange(start, stop, step),
relax.TensorStructInfo(((stop + step - start - 1) // step,), "int64"),
)

start = tir.Var("start", "float32")
stop = tir.Var("stop", "float32")
step = tir.Var("step", "float32")

_check_inference(
bb,
relax.op.arange(stop),
relax.TensorStructInfo((T.cast(T.ceil(stop), "int64"),), "float32"),
)
_check_inference(
bb,
relax.op.arange(1, stop),
relax.TensorStructInfo((T.cast(T.ceil(stop - 1.0), "int64"),), "float32"),
)
_check_inference(
bb,
relax.op.arange(start, stop),
relax.TensorStructInfo((T.cast(T.ceil(stop - start), "int64"),), "float32"),
)
_check_inference(
bb,
relax.op.arange(start, stop, 2),
relax.TensorStructInfo((T.cast(T.ceil((stop - start) * 0.5), "int64"),), "float32"),
)
_check_inference(
bb,
relax.op.arange(start, stop, step),
relax.TensorStructInfo((T.cast(T.ceil((stop - start) / step), "int64"),), "float32"),
)


def test_tril_triu_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
Expand Down
Loading