Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Allow StructuralEqual/Hash via Var.vid #6424

Merged
merged 4 commits into from Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/base.h
Expand Up @@ -97,7 +97,15 @@ class IdNode : public Object {

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); }

bool SEqualReduce(const IdNode* other, SEqualReducer equal) const {
return equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); }

static constexpr const char* _type_key = "relay.Id";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object);
};

Expand Down
6 changes: 4 additions & 2 deletions include/tvm/relay/expr.h
Expand Up @@ -182,12 +182,14 @@ class VarNode : public ExprNode {
}

bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other);
equal->MarkGraphNode();
return equal(type_annotation, other->type_annotation) && equal(vid, other->vid);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(type_annotation);
hash_reduce.FreeVarHashImpl(this);
hash_reduce(vid);
}

static constexpr const char* _type_key = "relay.Var";
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Expand Up @@ -49,7 +49,7 @@
// Always inline macro only use in template
// expansion cases where we know inline is important.
#ifdef _MSC_VER
#define TVM_ALWAYS_INLINE __forceinline inline
#define TVM_ALWAYS_INLINE __forceinline
#else
#define TVM_ALWAYS_INLINE inline __attribute__((always_inline))
#endif
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_ir_structural_equal_hash.py
Expand Up @@ -705,7 +705,19 @@ def test_fn_attribute():
assert not consistent_equal(add_fn, add_1_fn)


def test_fn_vid_map():
def get_fn(with_vid):
x = relay.var("x", shape=(10,), dtype="float32")
f = relay.Function([x], x).with_attr(
"dict", {x.vid: 1} if with_vid else {x : 1})
return f

assert consistent_equal(get_fn(True), get_fn(True))
assert consistent_equal(get_fn(False), get_fn(False))


if __name__ == "__main__":
test_fn_vid_map()
test_tensor_type_sequal()
test_incomplete_type_sequal()
test_constant_sequal()
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level5.py
Expand Up @@ -112,7 +112,7 @@ def verify_resize(dshape, scale, method, layout):
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
for method in ["trilinear", "nearest_neighbor"]:
for layout in ["NDHWC", "NCDHW"]:
verify_resize((1, 4, 4, 4, 4), 2, method, layout)
Expand Down