diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a1a407287d20..b48b6ad313f7 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1458,6 +1458,15 @@ def _impl(inputs, attr, params, mod): return ret + def _dyn(): + for d in data_shape: + if not isinstance(d, int): + return True + return False + + if _dyn(): + return _op.strided_slice(inputs[0], begin, end, stride) + def _transform_mask(stride_dim, ellipsis_mask): """Handle mask inputs to create new begin, end, stride and output shape""" m_begin = [0] * data_dim diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b1c2d8b23373..8dbe49cb30ef 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2146,7 +2146,18 @@ Array StridedSliceCompute(const Attrs& attrs, const Array(); CHECK(param != nullptr); - if (param->begin && param->end && param->strides) { + + bool dyn = false; + for (auto& v : out_type.as()->shape) { + if (const tir::VarNode* var_node = v.as()) { + if (var_node->name_hint == "any_dim") { + dyn = true; + break; + } + } + } + + if (param->begin && param->end && param->strides && !dyn) { Array begin, end, strides; begin = param->begin.value(); end = param->end.value(); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 182c2d72447a..bd7fc3ece6ad 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -179,7 +179,7 @@ def run_tf_graph(sess, input_data, input_node, output_node): def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False, opt_level=3, mode='graph_runtime', - cuda_layout="NCHW"): + cuda_layout="NCHW", ignore_in_shape=False): """Generic function to generate and compare tensorflow and TVM output""" def name_without_num(name): return name.split(':')[0] if ":" in name else name @@ -208,7 +208,7 @@ def name_without_num(name): tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device, out_names=out_name, num_output=len(out_name), opt_level=opt_level, mode=mode, - cuda_layout=cuda_layout) + cuda_layout=cuda_layout, ignore_in_shape=ignore_in_shape) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared for i in range(len(tf_output)): @@ -1354,19 +1354,30 @@ def test_forward_batch_matmul(): def _test_stridedslice(ip_shape, begin, end, stride, dtype, begin_mask=0, end_mask=0, new_axis_mask=0, - shrink_axis_mask=0, ellipsis_mask=0): + shrink_axis_mask=0, ellipsis_mask=0, dynamic_input=False): """ One iteration of a Stridedslice """ + var_shape = ip_shape + if dynamic_input: + # Generate a dynamic shape + assert isinstance(var_shape, tuple) + var_shape = list(var_shape) + var_shape[0] = None + tf.reset_default_graph() with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, ip_shape, name="in_data") + in_data = tf.placeholder(dtype, var_shape, name="in_data") tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask, end_mask=end_mask, new_axis_mask=new_axis_mask, shrink_axis_mask=shrink_axis_mask, ellipsis_mask=ellipsis_mask, name="strided_slice") np_data = np.random.uniform(size=ip_shape).astype(dtype) - compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0') + if dynamic_input: + compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0', mode="vm", + ignore_in_shape=True) + else: + compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0') def test_forward_stridedslice(): @@ -1421,6 +1432,8 @@ def test_forward_stridedslice(): _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1], 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=8) + _test_stridedslice((2, 1), [0, 0], [1, 1], [1, 1], 'float32', dynamic_input=True) + ####################################################################### # FloorDiv, RealDiv diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6d940a563566..47a1c369cdec 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -657,7 +657,6 @@ def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, mod = tvm.IRModule() data = relay.var('data', shape=data_shape, dtype='float32') if const_attrs: - data = relay.var('data', shape=data_np_shape, dtype='float32') begin = relay.const(np_begin) end = relay.const(np_end) strides = relay.const(np_strides)