Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil),
"CONCATENATION": self.convert_concatenation,
"CONV_2D": functools.partial(self.convert_conv, conv_type="conv2d"),
"CONV_3D": self.convert_conv3d,
"COS": functools.partial(self._convert_unary_elemwise, relax_op=_op.cos),
"CUMSUM": self.convert_cumsum,
"DENSIFY": self.convert_densify,
Expand Down Expand Up @@ -2449,6 +2450,142 @@ def convert_conv(self, op, conv_type):
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out

def convert_conv3d(self, op):
"""3D convolution implementation."""

from tflite.BuiltinOptions import BuiltinOptions
from tflite.Conv3DOptions import Conv3DOptions
from tflite.Padding import Padding
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 2, "input tensors length should be >= 2"

input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
weight_tensor = input_tensors[1]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

assert op.BuiltinOptionsType() == BuiltinOptions.Conv3DOptions
op_options = op.BuiltinOptions()
conv3d_options = Conv3DOptions()
conv3d_options.Init(op_options.Bytes, op_options.Pos)

stride_d = conv3d_options.StrideD()
stride_h = conv3d_options.StrideH()
stride_w = conv3d_options.StrideW()
dilation_d = conv3d_options.DilationDFactor()
dilation_h = conv3d_options.DilationHFactor()
dilation_w = conv3d_options.DilationWFactor()
padding = conv3d_options.Padding()
fused_activation_fn = conv3d_options.FusedActivationFunction()

_, input_d, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor))
# TFLite Conv3D kernel layout is already DHWIO:
# KD KH KW IC OC
kernel_d, kernel_h, kernel_w, in_channels, output_channels = to_int_list(
self.get_tensor_shape(weight_tensor)
)

dilated_kernel_d = dilation_d * (kernel_d - 1) + 1
dilated_kernel_h = dilation_h * (kernel_h - 1) + 1
dilated_kernel_w = dilation_w * (kernel_w - 1) + 1

params = {
"strides": [stride_d, stride_h, stride_w],
"dilation": [dilation_d, dilation_h, dilation_w],
"padding": [0, 0, 0, 0, 0, 0],
"data_layout": "NDHWC",
}

params["kernel_layout"] = "DHWIO"
if input_c != in_channels:
assert input_c % in_channels == 0, (
"Input channels is not divisible by kernel in_channels."
)
params["groups"] = int(input_c / in_channels)

# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (
TensorType.INT8,
TensorType.UINT8,
TensorType.FLOAT32,
)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)

# TFLite Conv3D kernel is already in DHWIO layout, no transpose needed.
if self.has_expr(weight_tensor.tensor_idx):
weight_expr = self.get_expr(weight_tensor.tensor_idx)
else:
if self.is_prefetched(weight_tensor.tensor_idx):
weight_value = self.get_prefetched_node(weight_tensor.tensor_idx)
else:
weight_value = self.get_tensor_value(weight_tensor)

weight_expr = self.exp_tab.new_const(
weight_value, dtype=weight_tensor_type_str,
source_name=weight_tensor.tensor.Name()
)

if padding == Padding.VALID:
pass
elif padding == Padding.SAME:
pad_front, pad_back = get_pad_value(input_d, dilated_kernel_d, stride_d)
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)

do_pad = not (
pad_front == 0 and pad_back == 0
and pad_top == 0 and pad_bottom == 0
and pad_left == 0 and pad_right == 0
)
if do_pad:
params["padding"] = [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right]
else:
raise tvm.error.OpAttributeUnImplemented(
f"Padding format {padding} is not supported for operator Conv3D."
)

if input_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
"Quantized Conv3D is not yet supported in the Relax frontend."
)

out = relax.op.nn.conv3d(in_expr, weight_expr, **params)

# if we have bias
if len(input_tensors) == 3:
bias_tensor = input_tensors[2]
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
else:
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str,
source_name=bias_tensor.tensor.Name(),
)
out = relax.op.add(out, bias_expr)

# Handle fused activation.
if output_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
"Quantized Conv3D is not yet supported in the Relax frontend."
)

out = self.convert_fused_activation_function(out, fused_activation_fn)
return out

def convert_split(self, op):
"""split implementation."""

Expand Down
83 changes: 83 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,89 @@ def main(
verify(Conv2DModule, Expected)


def _make_conv3d_module(data_shape, kernel_shape, strides, padding):
class Conv3DModule(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=data_shape, dtype=tf.float32),
tf.TensorSpec(shape=kernel_shape, dtype=tf.float32),
]
)
def func(self, data, kernel):
return tf.nn.conv3d(
input=data,
filters=kernel,
strides=strides,
padding=padding,
)

return Conv3DModule


def test_conv3d_valid():
Conv3DModule = _make_conv3d_module(
(1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "VALID"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"),
kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"),
) -> R.Tensor((1, 6, 6, 6, 16), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
gv: R.Tensor((1, 6, 6, 6, 16), dtype="float32") = R.nn.conv3d(
data,
kernel,
strides=[1, 1, 1],
padding=[0, 0, 0, 0, 0, 0],
dilation=[1, 1, 1],
groups=1,
data_layout="NDHWC",
kernel_layout="DHWIO",
out_layout="NDHWC",
out_dtype="void",
)
R.output(gv)
return gv

verify(Conv3DModule, Expected)


def test_conv3d_same():
Conv3DModule = _make_conv3d_module(
(1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "SAME"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"),
kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"),
) -> R.Tensor((1, 8, 8, 8, 16), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
gv: R.Tensor((1, 8, 8, 8, 16), dtype="float32") = R.nn.conv3d(
data,
kernel,
strides=[1, 1, 1],
padding=[1, 1, 1, 1, 1, 1],
dilation=[1, 1, 1],
groups=1,
data_layout="NDHWC",
kernel_layout="DHWIO",
out_layout="NDHWC",
out_dtype="void",
)
R.output(gv)
return gv

verify(Conv3DModule, Expected)


def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding):
class Pool2DModule(tf.Module):
@tf.function(
Expand Down
Loading