diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 8d112b91d642..e66dff8356c8 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3280,10 +3280,27 @@ def convert_batch_to_space_nd(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - block_shape = list(self.get_tensor_value(input_tensors[1])) - crops = self.get_tensor_value(input_tensors[2]).tolist() + block_shape = to_int_list(self.get_tensor_value(input_tensors[1])) + crops = self.get_tensor_value(input_tensors[2]) + crop_begin = to_int_list(crops[:, 0]) + crop_end = to_int_list(crops[:, 1]) - out = relax.op.nn.batch_to_space_nd(in_expr, block_shape, crops) + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + output_shape = to_int_list(self.get_tensor_shape(output_tensor)) + output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) + + out = relax.op.call_dps_packed( + "topi.nn.batch_to_space_nd", + ( + in_expr, + relax.ShapeExpr(block_shape), + relax.ShapeExpr(crop_begin), + relax.ShapeExpr(crop_end), + ), + out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), + ) return out @@ -3389,10 +3406,28 @@ def convert_space_to_batch_nd(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - block_shape = list(self.get_tensor_value(input_tensors[1])) - paddings = self.get_tensor_value(input_tensors[2]).tolist() + block_shape = to_int_list(self.get_tensor_value(input_tensors[1])) + paddings = self.get_tensor_value(input_tensors[2]) + pad_before = to_int_list(paddings[:, 0]) + pad_after = to_int_list(paddings[:, 1]) - out = relax.op.nn.space_to_batch_nd(in_expr, block_shape, paddings) + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + output_shape = to_int_list(self.get_tensor_shape(output_tensor)) + output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) + + out = relax.op.call_dps_packed( + "topi.nn.space_to_batch_nd", + ( + in_expr, + relax.ShapeExpr(block_shape), + relax.ShapeExpr(pad_before), + relax.ShapeExpr(pad_after), + 0.0, + ), + out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), + ) return out diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index a2d2612232c0..69e9b290fd32 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3099,6 +3099,74 @@ def main( verify(SpaceToDepth, Expected) +@pytest.mark.parametrize( + "input_shape, block_shape, paddings, expected_out_shape", + [ + ((1, 2, 2, 1), [2, 2], [[0, 0], [0, 0]], (4, 1, 1, 1)), + ((1, 2, 3, 1), [2, 2], [[0, 0], [1, 0]], (4, 1, 2, 1)), + ], +) +def test_space_to_batch_nd(input_shape, block_shape, paddings, expected_out_shape): + """SPACE_TO_BATCH_ND imports to Relax and preserves expected output shape.""" + + class SpaceToBatchND(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) + def func(self, x): + return tf.space_to_batch_nd( + x, + tf.constant(block_shape, dtype=tf.int32), + tf.constant(paddings, dtype=tf.int32), + ) + + cf = SpaceToBatchND().func.get_concrete_function() + mod = _get_mod_from_cfunc(cf) + ir = mod.script() + + assert "space_to_batch_nd" in ir + assert len(mod["main"].params) == 1 + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo(expected_out_shape, "float32"), + ) + + if "CI_ENV_NIGHTLY" in os.environ: + verify(SpaceToBatchND) + + +@pytest.mark.parametrize( + "input_shape, block_shape, crops, expected_out_shape", + [ + ((4, 1, 1, 1), [2, 2], [[0, 0], [0, 0]], (1, 2, 2, 1)), + ((4, 1, 2, 1), [2, 2], [[0, 0], [1, 0]], (1, 2, 3, 1)), + ], +) +def test_batch_to_space_nd(input_shape, block_shape, crops, expected_out_shape): + """BATCH_TO_SPACE_ND imports to Relax and preserves expected output shape.""" + + class BatchToSpaceND(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) + def func(self, x): + return tf.raw_ops.BatchToSpaceND( + input=x, + block_shape=tf.constant(block_shape, dtype=tf.int32), + crops=tf.constant(crops, dtype=tf.int32), + ) + + cf = BatchToSpaceND().func.get_concrete_function() + mod = _get_mod_from_cfunc(cf) + ir = mod.script() + + assert "batch_to_space_nd" in ir + assert len(mod["main"].params) == 1 + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo(expected_out_shape, "float32"), + ) + + if "CI_ENV_NIGHTLY" in os.environ: + verify(BatchToSpaceND) + + def test_leaky_relu(): class LeakyReLU(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])