Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3280,10 +3280,27 @@ def convert_batch_to_space_nd(self, op):
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)

block_shape = list(self.get_tensor_value(input_tensors[1]))
crops = self.get_tensor_value(input_tensors[2]).tolist()
block_shape = to_int_list(self.get_tensor_value(input_tensors[1]))
crops = self.get_tensor_value(input_tensors[2])
crop_begin = to_int_list(crops[:, 0])
crop_end = to_int_list(crops[:, 1])

out = relax.op.nn.batch_to_space_nd(in_expr, block_shape, crops)
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_shape = to_int_list(self.get_tensor_shape(output_tensor))
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())

out = relax.op.call_dps_packed(
"topi.nn.batch_to_space_nd",
(
in_expr,
relax.ShapeExpr(block_shape),
relax.ShapeExpr(crop_begin),
relax.ShapeExpr(crop_end),
),
out_sinfo=relax.TensorStructInfo(output_shape, output_dtype),
)

return out

Expand Down Expand Up @@ -3389,10 +3406,28 @@ def convert_space_to_batch_nd(self, op):
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)

block_shape = list(self.get_tensor_value(input_tensors[1]))
paddings = self.get_tensor_value(input_tensors[2]).tolist()
block_shape = to_int_list(self.get_tensor_value(input_tensors[1]))
paddings = self.get_tensor_value(input_tensors[2])
pad_before = to_int_list(paddings[:, 0])
pad_after = to_int_list(paddings[:, 1])

out = relax.op.nn.space_to_batch_nd(in_expr, block_shape, paddings)
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_shape = to_int_list(self.get_tensor_shape(output_tensor))
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())

out = relax.op.call_dps_packed(
"topi.nn.space_to_batch_nd",
(
in_expr,
relax.ShapeExpr(block_shape),
relax.ShapeExpr(pad_before),
relax.ShapeExpr(pad_after),
0.0,
),
out_sinfo=relax.TensorStructInfo(output_shape, output_dtype),
)

return out

Expand Down
68 changes: 68 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3099,6 +3099,74 @@ def main(
verify(SpaceToDepth, Expected)


@pytest.mark.parametrize(
"input_shape, block_shape, paddings, expected_out_shape",
[
((1, 2, 2, 1), [2, 2], [[0, 0], [0, 0]], (4, 1, 1, 1)),
((1, 2, 3, 1), [2, 2], [[0, 0], [1, 0]], (4, 1, 2, 1)),
],
)
def test_space_to_batch_nd(input_shape, block_shape, paddings, expected_out_shape):
"""SPACE_TO_BATCH_ND imports to Relax and preserves expected output shape."""

class SpaceToBatchND(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
def func(self, x):
return tf.space_to_batch_nd(
x,
tf.constant(block_shape, dtype=tf.int32),
tf.constant(paddings, dtype=tf.int32),
)

cf = SpaceToBatchND().func.get_concrete_function()
mod = _get_mod_from_cfunc(cf)
ir = mod.script()

assert "space_to_batch_nd" in ir
assert len(mod["main"].params) == 1
tvm.ir.assert_structural_equal(
mod["main"].ret_struct_info,
relax.TensorStructInfo(expected_out_shape, "float32"),
)

if "CI_ENV_NIGHTLY" in os.environ:
verify(SpaceToBatchND)


@pytest.mark.parametrize(
"input_shape, block_shape, crops, expected_out_shape",
[
((4, 1, 1, 1), [2, 2], [[0, 0], [0, 0]], (1, 2, 2, 1)),
((4, 1, 2, 1), [2, 2], [[0, 0], [1, 0]], (1, 2, 3, 1)),
],
)
def test_batch_to_space_nd(input_shape, block_shape, crops, expected_out_shape):
"""BATCH_TO_SPACE_ND imports to Relax and preserves expected output shape."""

class BatchToSpaceND(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
def func(self, x):
return tf.raw_ops.BatchToSpaceND(
input=x,
block_shape=tf.constant(block_shape, dtype=tf.int32),
crops=tf.constant(crops, dtype=tf.int32),
)

cf = BatchToSpaceND().func.get_concrete_function()
mod = _get_mod_from_cfunc(cf)
ir = mod.script()

assert "batch_to_space_nd" in ir
assert len(mod["main"].params) == 1
tvm.ir.assert_structural_equal(
mod["main"].ret_struct_info,
relax.TensorStructInfo(expected_out_shape, "float32"),
)

if "CI_ENV_NIGHTLY" in os.environ:
verify(BatchToSpaceND)


def test_leaky_relu():
class LeakyReLU(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])
Expand Down
Loading