From 9e525b43691d76e6ee880432e33c2d086522d04b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 28 Nov 2019 00:24:44 +0000 Subject: [PATCH 1/4] [relay][op] Add shape func to tile --- python/tvm/relay/op/_transform.py | 32 +++++++++++++++++++++ src/relay/op/tensor/transform.cc | 47 +++++++++++++++++++++---------- tests/python/relay/test_any.py | 20 +++++++++++++ 3 files changed, 84 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 13f41fc87001..4e724da366da 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -495,3 +495,35 @@ def reshape_like_shape_func(attrs, inputs, _): Shape function for reshape_like op. """ return [_reshape_like_shape_func(inputs[1])] + +@script +def _tile_shape_func(data, reps, ndim, tndim, rndim): + out = output_tensor((tndim,), "int64") + + if ndim == rndim: + for i in const_range(tndim): + out[i] = data[i] * int64(reps[i]) + elif ndim > rndim: + ngap = ndim - rndim + for i in const_range(ndim): + if i < ngap: + out[i] = data[i] + else: + out[i] = data[i] * int64(reps[i - ngap]) + else: + rgap = rndim - ndim + for i in const_range(rndim): + if i < rgap: + out[i] = int64(reps[i]) + else: + out[i] = int64(reps[i]) * data[i - rgap] + return out + +@_reg.register_shape_func("tile", False) +def tile_shape_func(attrs, inputs, _): + reps = get_const_tuple(attrs.reps) + ndim = inputs[0].shape[0].value + rndim = len(reps) + tndim = ndim if ndim > rndim else rndim + return [_tile_shape_func(inputs[0], convert(reps), convert(ndim), + convert(tndim), convert(rndim))] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 203a0411d3c4..f4f22e7dc8e5 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1393,28 +1393,45 @@ bool TileRel(const Array& types, reps_shape.reserve(tndim); if (ndim == rndim) { for (size_t i = 0; i < tndim; ++i) { - data_shape.emplace_back(data->shape[i]); - reps_shape.emplace_back(reps[i]); + data_shape.emplace_back(data->shape[i]); + reps_shape.emplace_back(reps[i]); } } else if (ndim > rndim) { - for (size_t i = 0; i < ndim; ++i) - data_shape.emplace_back(data->shape[i]); - for (size_t i = 0; i < (ndim - rndim); ++i) - reps_shape.emplace_back(1); - for (size_t i = 0; i < rndim; ++i) - reps_shape.emplace_back(reps[i]); + for (size_t i = 0; i < ndim; ++i) { + data_shape.emplace_back(data->shape[i]); + } + for (size_t i = 0; i < (ndim - rndim); ++i) { + reps_shape.emplace_back(1); + } + for (size_t i = 0; i < rndim; ++i) { + reps_shape.emplace_back(reps[i]); + } } else { - for (size_t i = 0; i < rndim; ++i) - reps_shape.emplace_back(reps[i]); - for (size_t i = 0; i < (rndim - ndim); ++i) - data_shape.emplace_back(1); - for (size_t i = 0; i < ndim; ++i) - data_shape.emplace_back(data->shape[i]); + for (size_t i = 0; i < rndim; ++i) { + reps_shape.emplace_back(reps[i]); + } + for (size_t i = 0; i < (rndim - ndim); ++i) { + data_shape.emplace_back(1); + } + for (size_t i = 0; i < ndim; ++i) { + data_shape.emplace_back(data->shape[i]); + } } std::vector oshape; oshape.reserve(tndim); + bool is_dynamic_shape = false; + for (size_t i = 0; i < data->shape.size(); ++i) { + if (!data->shape[i].as()) { + is_dynamic_shape = true; + break; + } + } for (size_t i = 0; i < tndim; ++i) { - oshape.emplace_back(data_shape[i] * reps_shape[i]); + if (is_dynamic_shape) { + oshape.emplace_back(Any::make()); + } else { + oshape.emplace_back(data_shape[i] * reps_shape[i]); + } } reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); return true; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 75be88cbcb19..8001c8ee6b34 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -166,6 +166,25 @@ def test_any_take(): verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4)) verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5)) +def verify_any_tile(dshape, reps, np_dshape, np_reps): + mod = relay.Module() + x = relay.var("x", shape=dshape, dtype="float32") + y = relay.tile(x, reps=reps) + mod["main"] = relay.Function([x], y) + x_data = np.random.uniform(size=np_dshape).astype("float32") + ref_res = np.tile(x_data, reps=np_reps) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + res = ex.evaluate()(x_data) + tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5) + +def test_any_tile(): + verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1)) + verify_any_tile(any_dims(3), (1, 2), (2, 3, 4), (1, 2)) + verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1)) + verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,)) + def test_any_shape_of(): x = relay.var('x', shape=any_dims(2), dtype='float32') y = relay.shape_of(x) @@ -558,6 +577,7 @@ def _body(i, st): test_any_concat() test_any_reshape() test_any_take() + test_any_tile() test_any_shape_of() test_any_reduce() test_any_layout_transform() From 72a65c6f171d4bef00f5a33c2c8d1a4a73433a75 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 28 Nov 2019 01:54:08 +0000 Subject: [PATCH 2/4] retrigger ci From 99382369913db5dbfa14b4f951f0540fea6c26b4 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 2 Dec 2019 22:05:16 +0000 Subject: [PATCH 3/4] check dynamic axes --- src/relay/op/tensor/transform.cc | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f4f22e7dc8e5..944d0dbe9852 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1419,15 +1419,9 @@ bool TileRel(const Array& types, } std::vector oshape; oshape.reserve(tndim); - bool is_dynamic_shape = false; - for (size_t i = 0; i < data->shape.size(); ++i) { - if (!data->shape[i].as()) { - is_dynamic_shape = true; - break; - } - } for (size_t i = 0; i < tndim; ++i) { - if (is_dynamic_shape) { + // Save Any if it is dynamic shape + if (!data_shape[i].as()) { oshape.emplace_back(Any::make()); } else { oshape.emplace_back(data_shape[i] * reps_shape[i]); From 2597b9ded4de43b11e72162eeafb9431524113e8 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 3 Dec 2019 01:38:03 +0000 Subject: [PATCH 4/4] retrigger ci