Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/tvm/relax/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ struct AllClassNonMaximumSuppressionAttrs
AllClassNonMaximumSuppressionAttrs, BaseAttrsNode);
}; // struct AllClassNonMaximumSuppressionAttrs

/*! \brief Attributes used in ROIAlign operator */
struct ROIAlignAttrs : public AttrsNodeReflAdapter<ROIAlignAttrs> {
ffi::Array<int64_t> pooled_size;
double spatial_scale;
int sample_ratio;
bool aligned;
ffi::String layout;
ffi::String mode;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ROIAlignAttrs>()
.def_ro("pooled_size", &ROIAlignAttrs::pooled_size, "Output size of roi align.")
.def_ro("spatial_scale", &ROIAlignAttrs::spatial_scale,
"Ratio of input feature map height (or width) to raw image height (or width).")
.def_ro("sample_ratio", &ROIAlignAttrs::sample_ratio,
"Optional sampling ratio of ROI align, using adaptive size by default.")
.def_ro("aligned", &ROIAlignAttrs::aligned,
"Whether to use the aligned ROIAlign semantics without the legacy 1-pixel clamp.")
.def_ro("layout", &ROIAlignAttrs::layout, "Dimension ordering of the input data.")
.def_ro("mode", &ROIAlignAttrs::mode, "Mode for ROI Align. Can be 'avg' or 'max'.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode);
}; // struct ROIAlignAttrs

} // namespace relax
} // namespace tvm

Expand Down
67 changes: 66 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,71 @@ def _impl_v12(cls, bb, inputs, attr, params):
return bb.emit_te(topi.einsum, equation, *inputs)


class RoiAlign(OnnxOpConverter):
"""Converts an onnx RoiAlign node into an equivalent Relax expression."""

@classmethod
def _impl(cls, bb, inputs, attr, params, default_coordinate_transformation_mode):
if len(inputs) != 3:
raise ValueError("RoiAlign expects exactly 3 inputs")

data = inputs[0]
rois = inputs[1]
batch_indices = inputs[2]
rois_dtype = rois.struct_info.dtype

mode = attr.get("mode", b"avg")
if isinstance(mode, bytes):
mode = mode.decode("ascii")
if mode not in ("avg", "max"):
raise NotImplementedError("RoiAlign in Relax only supports avg and max modes")

output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)
sampling_ratio = attr.get("sampling_ratio", 0)
spatial_scale = attr.get("spatial_scale", 1.0)
coordinate_transformation_mode = attr.get(
"coordinate_transformation_mode", default_coordinate_transformation_mode
)
if isinstance(coordinate_transformation_mode, bytes):
coordinate_transformation_mode = coordinate_transformation_mode.decode("ascii")

if coordinate_transformation_mode == "half_pixel":
offset = relax.const([-0.5, -0.5, -0.5, -0.5], rois_dtype)
rois = relax.op.add(rois, offset)
aligned = True
elif coordinate_transformation_mode != "output_half_pixel":
raise NotImplementedError(
"RoiAlign only supports coordinate_transformation_mode "
"'half_pixel' and 'output_half_pixel'"
)
Comment on lines +2457 to +2461
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The NotImplementedError message is slightly misleading. It states that RoiAlign only supports 'half_pixel' and 'output_half_pixel', but the else branch implicitly handles 'output_half_pixel' by setting aligned = False. The error should only be raised if coordinate_transformation_mode is neither 'half_pixel' nor 'output_half_pixel'.

Suggested change
elif coordinate_transformation_mode != "output_half_pixel":
raise NotImplementedError(
"RoiAlign only supports coordinate_transformation_mode "
"'half_pixel' and 'output_half_pixel'"
)
elif coordinate_transformation_mode == "output_half_pixel":
aligned = False
else:
raise NotImplementedError(
"RoiAlign only supports coordinate_transformation_mode "
"'half_pixel' and 'output_half_pixel'"
)

else:
aligned = False

batch_indices = relax.op.expand_dims(batch_indices, axis=1)
batch_indices = relax.op.astype(batch_indices, rois_dtype)
rois = relax.op.concat([batch_indices, rois], axis=1)

return relax.op.vision.roi_align(
data,
rois,
pooled_size=(output_height, output_width),
spatial_scale=spatial_scale,
sample_ratio=sampling_ratio,
aligned=aligned,
layout="NCHW",
mode=mode,
)

@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
return cls._impl(bb, inputs, attr, params, b"output_half_pixel")

@classmethod
def _impl_v16(cls, bb, inputs, attr, params):
return cls._impl(bb, inputs, attr, params, b"half_pixel")


class Range(OnnxOpConverter):
"""Converts an onnx Range node into an equivalent Relax expression."""

Expand Down Expand Up @@ -4083,7 +4148,7 @@ def _get_convert_map():
"NonZero": NonZero,
# "If": If,
# "MaxRoiPool": MaxRoiPool,
# "RoiAlign": RoiAlign,
"RoiAlign": RoiAlign,
"NonMaxSuppression": NonMaxSuppression,
"AllClassNMS": AllClassNMS,
"GridSample": GridSample,
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,42 @@ def _grid_sampler_2d(self, node: fx.Node) -> relax.Var:
)
)

def _torchvision_roi_align(self, node: fx.Node) -> relax.Var:
"""Convert torchvision.ops.roi_align to relax.op.vision.roi_align."""
args = self.retrieve_args(node)
data = args[0]
rois = args[1]
spatial_scale = args[2] if len(args) > 2 else 1.0
pooled_height = args[3] if len(args) > 3 else 1
pooled_width = args[4] if len(args) > 4 else pooled_height
sampling_ratio = args[5] if len(args) > 5 else -1
aligned = args[6] if len(args) > 6 else False
Comment on lines +1134 to +1135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sampling_ratio default value is set to -1 in the roi_align function signature, but here it's 0 in the attr.get call. It's better to be consistent with the Relax op's default value.

Suggested change
sampling_ratio = args[5] if len(args) > 5 else -1
aligned = args[6] if len(args) > 6 else False
sampling_ratio = args[5] if len(args) > 5 else -1


if aligned:
batch_indices = self.block_builder.emit(
relax.op.strided_slice(rois, axes=[1], begin=[0], end=[1])
)
boxes = self.block_builder.emit(
relax.op.strided_slice(rois, axes=[1], begin=[1], end=[5])
)
boxes = self.block_builder.emit(
relax.op.subtract(boxes, relax.const(0.5, rois.struct_info.dtype))
)
rois = self.block_builder.emit(relax.op.concat([batch_indices, boxes], axis=1))

return self.block_builder.emit(
relax.op.vision.roi_align(
data,
rois,
pooled_size=(pooled_height, pooled_width),
spatial_scale=spatial_scale,
sample_ratio=sampling_ratio,
aligned=aligned,
layout="NCHW",
mode="avg",
)
)

def _scalar_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
scalar_value = args[0]
Expand Down Expand Up @@ -1732,6 +1768,7 @@ def create_convert_map(
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
"grid_sampler_2d.default": self._grid_sampler_2d,
"roi_align.default": self._torchvision_roi_align,
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
tanh,
trunc,
)
from .vision import all_class_non_max_suppression
from .vision import all_class_non_max_suppression, roi_align


def _register_op_make():
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ class AllClassNonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.all_class_non_max_suppression"""


@tvm_ffi.register_object("relax.attrs.ROIAlignAttrs")
class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""


@tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
class Conv1DAttrs(Attrs):
"""Attributes for nn.conv1d"""
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
"""VISION operators."""

from .nms import *
from .roi_align import *
78 changes: 78 additions & 0 deletions python/tvm/relax/op/vision/roi_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""ROI Align operator"""

from ..base import Expr
from . import _ffi_api


def roi_align(
data: Expr,
rois: Expr,
pooled_size: int | tuple[int, int] | list[int],
spatial_scale: float,
sample_ratio: int = -1,
aligned: bool = False,
layout: str = "NCHW",
mode: str = "avg",
):
"""ROI Align operator.

Parameters
----------
data : relax.Expr
4-D input tensor.

rois : relax.Expr
2-D input tensor with shape `(num_roi, 5)` in
`[batch_idx, x1, y1, x2, y2]` format.

pooled_size : Union[int, Tuple[int, int], List[int]]
Output pooled size.

spatial_scale : float
Ratio of input feature map height (or width) to raw image height (or width).

sample_ratio : int, optional
Sampling ratio for ROI align. Non-positive values use adaptive sampling.

aligned : bool, optional
Whether to use aligned ROIAlign semantics without the legacy 1-pixel clamp.

layout : str, optional
Layout of the input data. Supported values are `NCHW` and `NHWC`.

mode : str, optional
Mode for ROI align. Supported values are `avg` and `max`.

Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(pooled_size, int):
pooled_size = (pooled_size, pooled_size)
return _ffi_api.roi_align(
data,
rois,
pooled_size,
spatial_scale,
sample_ratio,
aligned,
layout,
mode,
)
15 changes: 15 additions & 0 deletions python/tvm/relax/transform/legalize_ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,18 @@ def compute_end(i):

# Return trimmed indices along with num_total_detections for compatibility
return relax.Tuple([trimmed_indices, num_total_detections])


@register_legalize("relax.vision.roi_align")
def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.vision.roi_align,
call.args[0],
call.args[1],
pooled_size=call.attrs.pooled_size,
spatial_scale=call.attrs.spatial_scale,
mode=call.attrs.mode,
sample_ratio=call.attrs.sample_ratio,
aligned=call.attrs.aligned,
layout=call.attrs.layout,
)
15 changes: 11 additions & 4 deletions python/tvm/topi/testing/roi_align_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def roi_align_common(
pooled_size_w,
spatial_scale,
sample_ratio,
aligned,
avg_mode,
max_mode,
height,
Expand All @@ -72,8 +73,8 @@ def roi_align_common(
roi = rois_np[i]
batch_index = int(roi[0])
roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * spatial_scale
roi_h = max(roi_end_h - roi_start_h, 1.0)
roi_w = max(roi_end_w - roi_start_w, 1.0)
roi_h = roi_end_h - roi_start_h if aligned else max(roi_end_h - roi_start_h, 1.0)
roi_w = roi_end_w - roi_start_w if aligned else max(roi_end_w - roi_start_w, 1.0)

bin_h = roi_h / pooled_size_h
bin_w = roi_w / pooled_size_w
Expand Down Expand Up @@ -115,7 +116,9 @@ def roi_align_common(
return b_np


def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"):
def roi_align_nchw_python(
a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg", aligned=False
):
"""Roi align NCHW in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
Expand All @@ -137,6 +140,7 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
pooled_size_w,
spatial_scale,
sample_ratio,
aligned,
avg_mode,
max_mode,
height,
Expand All @@ -145,7 +149,9 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
)


def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"):
def roi_align_nhwc_python(
a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg", aligned=False
):
"""Roi align NHWC in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
Expand All @@ -169,6 +175,7 @@ def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
pooled_size_w,
spatial_scale,
sample_ratio,
aligned,
avg_mode,
max_mode,
height,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
"""Vision operators."""

from .nms import *
from .roi_align import *
Loading
Loading