Skip to content

Commit

Permalink
Update MatchShape AST Node (apache#11)
Browse files Browse the repository at this point in the history
* Update MatchShape AST Node.

* Update.

* Update.
  • Loading branch information
ZihengJiang authored and junrushao committed Feb 9, 2023
1 parent f04dbcb commit 1c1e214
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 33 deletions.
15 changes: 10 additions & 5 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,26 @@ class Binding : public ObjectRef {
class MatchShape;
class MatchShapeNode : public BindingNode {
public:
Array<PrimExpr> pattern;
Expr value;
Array<PrimExpr> pattern;
Var var;

void VisitAttrs(AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("value", &value);
v->Visit("pattern", &pattern);
v->Visit("var", &var);
v->Visit("span", &span);
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
return equal(pattern, other->pattern) && equal(value, other->value);
return equal(value, other->value) && equal(pattern, other->pattern)
&& equal(var, other->var);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(pattern);
hash_reduce(value);
hash_reduce(pattern);
hash_reduce(var);
}

static constexpr const char* _type_key = "relax.expr.MatchShape";
Expand All @@ -222,7 +226,8 @@ class MatchShapeNode : public BindingNode {

class MatchShape : public Binding {
public:
TVM_DLL explicit MatchShape(Array<PrimExpr> pattern, Expr value, Span span = Span());
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern,
Var var, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
};

Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ def __init__(self, span: Span = None) -> None:

@tvm._ffi.register_object("relax.expr.MatchShape")
class MatchShape(Binding):
pattern: List[PrimExpr]
value: Expr
pattern: List[PrimExpr]
var: Var

def __init__(self, pattern: List[PrimExpr], value: Expr, span: Span = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value, span)
def __init__(self, value: Expr, pattern: List[PrimExpr], var: Var, span: Span = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.MatchShape, value, pattern, var, span)


@tvm._ffi.register_object("relax.expr.VarBinding")
Expand Down
38 changes: 21 additions & 17 deletions src/relax/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ Var::Var(Id vid, Optional<Expr> shape_annotation, Optional<Type> type_annotation
}

TVM_REGISTER_GLOBAL("relax.Var")
.set_body_typed([](String name_hint, Optional<Expr> shape_annotation,
Optional<Type> type_annotation, Span span) {
return Var(name_hint, shape_annotation, type_annotation, span);
});
.set_body_typed([](String name_hint, Optional<Expr> shape_annotation,
Optional<Type> type_annotation, Span span) {
return Var(name_hint, shape_annotation, type_annotation, span);
});

TVM_REGISTER_NODE_TYPE(DataflowVarNode);

Expand All @@ -83,10 +83,10 @@ DataflowVar::DataflowVar(Id vid, Optional<Expr> shape_annotation, Optional<Type>
}

TVM_REGISTER_GLOBAL("relax.DataflowVar")
.set_body_typed([](String name_hint, Optional<Expr> shape_annotation,
Optional<Type> type_annotation, Span span) {
return DataflowVar(name_hint, shape_annotation, type_annotation, span);
});
.set_body_typed([](String name_hint, Optional<Expr> shape_annotation,
Optional<Type> type_annotation, Span span) {
return DataflowVar(name_hint, shape_annotation, type_annotation, span);
});

Binding::Binding(Span span) {
ObjectPtr<BindingNode> n = make_object<BindingNode>();
Expand All @@ -96,22 +96,25 @@ Binding::Binding(Span span) {

TVM_REGISTER_NODE_TYPE(BindingNode);

TVM_REGISTER_GLOBAL("relax.Binding").set_body_typed([](Span span) { return Binding(span); });
TVM_REGISTER_GLOBAL("relax.Binding").set_body_typed([](Span span) {
return Binding(span);
});

TVM_REGISTER_NODE_TYPE(MatchShapeNode);

MatchShape::MatchShape(Array<PrimExpr> pattern, Expr value, Span span) {
MatchShape::MatchShape(Expr value, Array<PrimExpr> pattern, Var var, Span span) {
ObjectPtr<MatchShapeNode> n = make_object<MatchShapeNode>();
n->pattern = std::move(pattern);
n->value = std::move(value);
n->pattern = std::move(pattern);
n->var = std::move(var);
n->span = span;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.MatchShape")
.set_body_typed([](Array<PrimExpr> pattern, Expr value, Span span) {
return MatchShape(pattern, value, span);
});
.set_body_typed([](Expr value, Array<PrimExpr> pattern, Var var, Span span) {
return MatchShape(value, pattern, var, span);
});

TVM_REGISTER_NODE_TYPE(VarBindingNode);

Expand Down Expand Up @@ -182,9 +185,10 @@ Function::Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr bo
}

TVM_REGISTER_GLOBAL("relax.Function")
.set_body_typed([](runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
Type ret_type,
Span span) { return Function(name, params, body, ret_type, span); });
.set_body_typed([](runtime::Optional<GlobalVar> name, Array<Var> params,
Expr body, Type ret_type, Span span) {
return Function(name, params, body, ret_type, span);
});

TVM_REGISTER_NODE_TYPE(ExternFuncNode);

Expand Down
36 changes: 28 additions & 8 deletions tests/python/relax/test_expr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import tvm
from tvm import tir
from tvm import relax as rx
from tvm.ir import TensorType
import numpy as np


Expand All @@ -11,7 +10,7 @@ def test_var() -> None:
assert v0.shape_ is None
assert v0.type_annotation is None
shape_anno = [54, 96]
type_anno = TensorType(shape_anno, "float32")
type_anno = rx.DynTensorType(2, "float32")
v1 = rx.Var("v1", shape_anno, type_anno)
assert v1.name_hint == "v1"
for s0, s1 in zip(v1.shape_, shape_anno):
Expand All @@ -25,7 +24,7 @@ def test_dataflow_var() -> None:
assert v0.shape_ is None
assert v0.type_annotation is None
shape_anno = [54, 96]
type_anno = TensorType(shape_anno, "float16")
type_anno = rx.DynTensorType(2, "float16")
v1 = rx.DataflowVar("v1", shape_anno, type_anno)
assert v1.name_hint == "v1"
for s0, s1 in zip(v1.shape_, shape_anno):
Expand All @@ -35,13 +34,34 @@ def test_dataflow_var() -> None:


def test_match_shape() -> None:
# match_shape([16, 8], [m, n])
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchShape([m, n], shape)
var = rx.Var("v0", type_annotation=rx.ShapeType())
b0 = rx.MatchShape(shape, [m, n], var)
assert b0.value == shape
assert b0.pattern[0] == m
assert b0.pattern[1] == n
assert b0.value == shape
assert b0.var is not None
assert b0.var.checked_type_ == rx.ShapeType()

# var1: Tensor[(m, n), "float32"] =
# match_shape(var0: Tensor[_, "float32"], [m, n])
type_anno0 = rx.DynTensorType(-1, "float32")
value = rx.Var("value", type_annotation=type_anno0)

shape_anno = [m, n]
type_anno = rx.DynTensorType(2, "float32")
var = rx.Var("v1", shape_anno, type_anno)
b1 = rx.MatchShape(value, [m, n], var)
assert b1.value == value
assert b1.pattern[0] == m
assert b1.pattern[1] == n
assert b1.var is not None
for s0, s1 in zip(b1.var.shape, [m, n]):
assert s0 == s1
assert b1.var.checked_type_ == rx.DynTensorType(2, "float32")


def test_var_binding() -> None:
Expand All @@ -56,7 +76,7 @@ def test_binding_block() -> None:
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchShape([m, n], shape)
b0 = rx.MatchShape(shape, [m, n], rx.Var("v0"))

v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
Expand All @@ -71,7 +91,7 @@ def test_dataflow_block() -> None:
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchShape([m, n], shape)
b0 = rx.MatchShape(shape, [m, n], rx.Var("v0"))

v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
Expand Down Expand Up @@ -105,7 +125,7 @@ def test_func():
bindings = [rx.VarBinding(x, rx.const(1))]
blocks = [rx.BindingBlock(bindings)]
seqe = rx.SeqExpr(blocks, x)
ret_type = TensorType(None, "float32")
ret_type = rx.DynTensorType(-1, "float32")
func = rx.Function([x], seqe, ret_type, rx.GlobalVar("func"))
assert func.params[0] == x
assert func.body == seqe
Expand Down

0 comments on commit 1c1e214

Please sign in to comment.