Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF] Fix some shape mismatches between TF and Relay #6166

Merged
merged 1 commit into from Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/relay/op/tensor/transform.cc
Expand Up @@ -2740,9 +2740,6 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
if (oshape.size() == 0) {
oshape.push_back(tir::make_const(DataType::Int(32), 1));
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/unary.cc
Expand Up @@ -462,7 +462,7 @@ bool NdarraySizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
CHECK(tt != nullptr);
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[1], TensorType({1}, param->dtype));
reporter->Assign(types[1], TensorType({}, param->dtype));
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Expand Up @@ -288,7 +288,7 @@ class ConstantFolder : public ExprMutator {
ctx.device_id = 0;
runtime::NDArray value;
DLDataType cdtype = DataType::Int(32);
value = runtime::NDArray::Empty({1}, cdtype, ctx);
value = runtime::NDArray::Empty({}, cdtype, ctx);
int32_t* data = static_cast<int32_t*>(value->data);
if (ishape.size() == 0) {
*data = 0;
Expand Down
4 changes: 3 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Expand Up @@ -73,7 +73,7 @@ def convert_to_list(x):

def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
return [o.asnumpy()]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
Expand Down Expand Up @@ -211,6 +211,8 @@ def name_without_num(name):
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
if not isinstance(tf_output[i], np.ndarray):
assert len(tvm_output[i].shape) == 0
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_fold_constant.py
Expand Up @@ -175,7 +175,7 @@ def before(dtype):
def expected(dtype):
x = relay.var("x", shape=c_shape, dtype="float32")
y = relay.var("y", shape=c_shape, dtype="float32")
z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype)
z = relay.const(np.size(np.zeros(c_shape)), dtype=dtype)
func = relay.Function([x, y], z)
return func

Expand Down
5 changes: 1 addition & 4 deletions topi/include/topi/transform.h
Expand Up @@ -1126,9 +1126,6 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
for (size_t i = indices_dim0; i < ndim_d; ++i) {
out_shape.push_back(data->shape[i]);
}
if (out_shape.size() == 0) {
out_shape.push_back(make_const(DataType::Int(32), 1));
}
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Expand Down Expand Up @@ -1401,7 +1398,7 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
const std::string& name = "ndarray_size",
const std::string& tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
Array<PrimExpr> out_ndarray_size = {1};
Array<PrimExpr> out_ndarray_size = {};
return compute(
out_ndarray_size,
[&](const Array<Var>& indices) {
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_transform.py
Expand Up @@ -1029,7 +1029,7 @@ def check_device(device):
print("Skip because %s is not enabled" % device)
return
tvm_input = tvm.nd.array(input, ctx=ctx)
tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype)
tvm_output = tvm.nd.empty((), ctx=ctx, dtype=B.dtype)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.testing.get_injective_schedule(device)(B)
Expand Down