From e4472ded39ad0908d18a9ad7708345e47e8a908d Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Thu, 26 Mar 2026 06:37:30 +0000 Subject: [PATCH 1/2] [Relax][ONNX][Torch] Add roi_align support and frontend integration --- include/tvm/relax/attrs/vision.h | 25 +++ .../tvm/relax/frontend/onnx/onnx_frontend.py | 67 +++++- .../torch/exported_program_translator.py | 37 ++++ python/tvm/relax/op/__init__.py | 2 +- python/tvm/relax/op/op_attrs.py | 5 + python/tvm/relax/op/vision/__init__.py | 1 + python/tvm/relax/op/vision/roi_align.py | 78 +++++++ .../relax/transform/legalize_ops/vision.py | 15 ++ python/tvm/topi/testing/roi_align_python.py | 15 +- python/tvm/topi/vision/__init__.py | 1 + python/tvm/topi/vision/roi_align.py | 204 +++++++++++++++++ src/relax/op/vision/roi_align.cc | 141 ++++++++++++ src/relax/op/vision/roi_align.h | 42 ++++ .../test_frontend_from_exported_program.py | 91 ++++++++ tests/python/relax/test_frontend_onnx.py | 60 ++++- tests/python/relax/test_op_vision.py | 207 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_vision.py | 32 +++ 17 files changed, 1014 insertions(+), 9 deletions(-) create mode 100644 python/tvm/relax/op/vision/roi_align.py create mode 100644 python/tvm/topi/vision/roi_align.py create mode 100644 src/relax/op/vision/roi_align.cc create mode 100644 src/relax/op/vision/roi_align.h diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 2fd98533b589..59a1dd7314fc 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -48,6 +48,31 @@ struct AllClassNonMaximumSuppressionAttrs AllClassNonMaximumSuppressionAttrs, BaseAttrsNode); }; // struct AllClassNonMaximumSuppressionAttrs +/*! \brief Attributes used in ROIAlign operator */ +struct ROIAlignAttrs : public AttrsNodeReflAdapter { + ffi::Array pooled_size; + double spatial_scale; + int sample_ratio; + bool aligned; + ffi::String layout; + ffi::String mode; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pooled_size", &ROIAlignAttrs::pooled_size, "Output size of roi align.") + .def_ro("spatial_scale", &ROIAlignAttrs::spatial_scale, + "Ratio of input feature map height (or width) to raw image height (or width).") + .def_ro("sample_ratio", &ROIAlignAttrs::sample_ratio, + "Optional sampling ratio of ROI align, using adaptive size by default.") + .def_ro("aligned", &ROIAlignAttrs::aligned, + "Whether to use the aligned ROIAlign semantics without the legacy 1-pixel clamp.") + .def_ro("layout", &ROIAlignAttrs::layout, "Dimension ordering of the input data.") + .def_ro("mode", &ROIAlignAttrs::mode, "Mode for ROI Align. Can be 'avg' or 'max'."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode); +}; // struct ROIAlignAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f08505951de8..f11f6269a93c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2421,6 +2421,71 @@ def _impl_v12(cls, bb, inputs, attr, params): return bb.emit_te(topi.einsum, equation, *inputs) +class RoiAlign(OnnxOpConverter): + """Converts an onnx RoiAlign node into an equivalent Relax expression.""" + + @classmethod + def _impl(cls, bb, inputs, attr, params, default_coordinate_transformation_mode): + if len(inputs) != 3: + raise ValueError("RoiAlign expects exactly 3 inputs") + + data = inputs[0] + rois = inputs[1] + batch_indices = inputs[2] + rois_dtype = rois.struct_info.dtype + + mode = attr.get("mode", b"avg") + if isinstance(mode, bytes): + mode = mode.decode("ascii") + if mode not in ("avg", "max"): + raise NotImplementedError("RoiAlign in Relax only supports avg and max modes") + + output_height = attr.get("output_height", 1) + output_width = attr.get("output_width", 1) + sampling_ratio = attr.get("sampling_ratio", 0) + spatial_scale = attr.get("spatial_scale", 1.0) + coordinate_transformation_mode = attr.get( + "coordinate_transformation_mode", default_coordinate_transformation_mode + ) + if isinstance(coordinate_transformation_mode, bytes): + coordinate_transformation_mode = coordinate_transformation_mode.decode("ascii") + + if coordinate_transformation_mode == "half_pixel": + offset = relax.const([-0.5, -0.5, -0.5, -0.5], rois_dtype) + rois = relax.op.add(rois, offset) + aligned = True + elif coordinate_transformation_mode != "output_half_pixel": + raise NotImplementedError( + "RoiAlign only supports coordinate_transformation_mode " + "'half_pixel' and 'output_half_pixel'" + ) + else: + aligned = False + + batch_indices = relax.op.expand_dims(batch_indices, axis=1) + batch_indices = relax.op.astype(batch_indices, rois_dtype) + rois = relax.op.concat([batch_indices, rois], axis=1) + + return relax.op.vision.roi_align( + data, + rois, + pooled_size=(output_height, output_width), + spatial_scale=spatial_scale, + sample_ratio=sampling_ratio, + aligned=aligned, + layout="NCHW", + mode=mode, + ) + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + return cls._impl(bb, inputs, attr, params, b"output_half_pixel") + + @classmethod + def _impl_v16(cls, bb, inputs, attr, params): + return cls._impl(bb, inputs, attr, params, b"half_pixel") + + class Range(OnnxOpConverter): """Converts an onnx Range node into an equivalent Relax expression.""" @@ -4083,7 +4148,7 @@ def _get_convert_map(): "NonZero": NonZero, # "If": If, # "MaxRoiPool": MaxRoiPool, - # "RoiAlign": RoiAlign, + "RoiAlign": RoiAlign, "NonMaxSuppression": NonMaxSuppression, "AllClassNMS": AllClassNMS, "GridSample": GridSample, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 67e0e45da0a3..47633c69b5f5 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1123,6 +1123,42 @@ def _grid_sampler_2d(self, node: fx.Node) -> relax.Var: ) ) + def _torchvision_roi_align(self, node: fx.Node) -> relax.Var: + """Convert torchvision.ops.roi_align to relax.op.vision.roi_align.""" + args = self.retrieve_args(node) + data = args[0] + rois = args[1] + spatial_scale = args[2] if len(args) > 2 else 1.0 + pooled_height = args[3] if len(args) > 3 else 1 + pooled_width = args[4] if len(args) > 4 else pooled_height + sampling_ratio = args[5] if len(args) > 5 else -1 + aligned = args[6] if len(args) > 6 else False + + if aligned: + batch_indices = self.block_builder.emit( + relax.op.strided_slice(rois, axes=[1], begin=[0], end=[1]) + ) + boxes = self.block_builder.emit( + relax.op.strided_slice(rois, axes=[1], begin=[1], end=[5]) + ) + boxes = self.block_builder.emit( + relax.op.subtract(boxes, relax.const(0.5, rois.struct_info.dtype)) + ) + rois = self.block_builder.emit(relax.op.concat([batch_indices, boxes], axis=1)) + + return self.block_builder.emit( + relax.op.vision.roi_align( + data, + rois, + pooled_size=(pooled_height, pooled_width), + spatial_scale=spatial_scale, + sample_ratio=sampling_ratio, + aligned=aligned, + layout="NCHW", + mode="avg", + ) + ) + def _scalar_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) scalar_value = args[0] @@ -1732,6 +1768,7 @@ def create_convert_map( "zeros.default": self._zeros, "zeros_like.default": self._zeros_like, "grid_sampler_2d.default": self._grid_sampler_2d, + "roi_align.default": self._torchvision_roi_align, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 7c3f75298b6c..0bc3f6578432 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -157,7 +157,7 @@ tanh, trunc, ) -from .vision import all_class_non_max_suppression +from .vision import all_class_non_max_suppression, roi_align def _register_op_make(): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index f07623bd38e3..a3b6544dcc6e 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -246,6 +246,11 @@ class AllClassNonMaximumSuppressionAttrs(Attrs): """Attributes for vision.all_class_non_max_suppression""" +@tvm_ffi.register_object("relax.attrs.ROIAlignAttrs") +class ROIAlignAttrs(Attrs): + """Attributes for vision.roi_align""" + + @tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" diff --git a/python/tvm/relax/op/vision/__init__.py b/python/tvm/relax/op/vision/__init__.py index ea20d2b40000..76d9ea35a11c 100644 --- a/python/tvm/relax/op/vision/__init__.py +++ b/python/tvm/relax/op/vision/__init__.py @@ -18,3 +18,4 @@ """VISION operators.""" from .nms import * +from .roi_align import * diff --git a/python/tvm/relax/op/vision/roi_align.py b/python/tvm/relax/op/vision/roi_align.py new file mode 100644 index 000000000000..8db694c7f2e8 --- /dev/null +++ b/python/tvm/relax/op/vision/roi_align.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ROI Align operator""" + +from ..base import Expr +from . import _ffi_api + + +def roi_align( + data: Expr, + rois: Expr, + pooled_size: int | tuple[int, int] | list[int], + spatial_scale: float, + sample_ratio: int = -1, + aligned: bool = False, + layout: str = "NCHW", + mode: str = "avg", +): + """ROI Align operator. + + Parameters + ---------- + data : relax.Expr + 4-D input tensor. + + rois : relax.Expr + 2-D input tensor with shape `(num_roi, 5)` in + `[batch_idx, x1, y1, x2, y2]` format. + + pooled_size : Union[int, Tuple[int, int], List[int]] + Output pooled size. + + spatial_scale : float + Ratio of input feature map height (or width) to raw image height (or width). + + sample_ratio : int, optional + Sampling ratio for ROI align. Non-positive values use adaptive sampling. + + aligned : bool, optional + Whether to use aligned ROIAlign semantics without the legacy 1-pixel clamp. + + layout : str, optional + Layout of the input data. Supported values are `NCHW` and `NHWC`. + + mode : str, optional + Mode for ROI align. Supported values are `avg` and `max`. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(pooled_size, int): + pooled_size = (pooled_size, pooled_size) + return _ffi_api.roi_align( + data, + rois, + pooled_size, + spatial_scale, + sample_ratio, + aligned, + layout, + mode, + ) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index f95dfa35b642..7a1e305f39f0 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -103,3 +103,18 @@ def compute_end(i): # Return trimmed indices along with num_total_detections for compatibility return relax.Tuple([trimmed_indices, num_total_detections]) + + +@register_legalize("relax.vision.roi_align") +def _roi_align(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.vision.roi_align, + call.args[0], + call.args[1], + pooled_size=call.attrs.pooled_size, + spatial_scale=call.attrs.spatial_scale, + mode=call.attrs.mode, + sample_ratio=call.attrs.sample_ratio, + aligned=call.attrs.aligned, + layout=call.attrs.layout, + ) diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index 9fc72074e4d6..19725a0c5bee 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -59,6 +59,7 @@ def roi_align_common( pooled_size_w, spatial_scale, sample_ratio, + aligned, avg_mode, max_mode, height, @@ -72,8 +73,8 @@ def roi_align_common( roi = rois_np[i] batch_index = int(roi[0]) roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * spatial_scale - roi_h = max(roi_end_h - roi_start_h, 1.0) - roi_w = max(roi_end_w - roi_start_w, 1.0) + roi_h = roi_end_h - roi_start_h if aligned else max(roi_end_h - roi_start_h, 1.0) + roi_w = roi_end_w - roi_start_w if aligned else max(roi_end_w - roi_start_w, 1.0) bin_h = roi_h / pooled_size_h bin_w = roi_w / pooled_size_w @@ -115,7 +116,9 @@ def roi_align_common( return b_np -def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"): +def roi_align_nchw_python( + a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg", aligned=False +): """Roi align NCHW in python""" avg_mode = mode in (b"avg", "avg", 0) max_mode = mode in (b"max", "max", 1) @@ -137,6 +140,7 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati pooled_size_w, spatial_scale, sample_ratio, + aligned, avg_mode, max_mode, height, @@ -145,7 +149,9 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati ) -def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"): +def roi_align_nhwc_python( + a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg", aligned=False +): """Roi align NHWC in python""" avg_mode = mode in (b"avg", "avg", 0) max_mode = mode in (b"max", "max", 1) @@ -169,6 +175,7 @@ def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati pooled_size_w, spatial_scale, sample_ratio, + aligned, avg_mode, max_mode, height, diff --git a/python/tvm/topi/vision/__init__.py b/python/tvm/topi/vision/__init__.py index c637b9cab2aa..75725a8a4bea 100644 --- a/python/tvm/topi/vision/__init__.py +++ b/python/tvm/topi/vision/__init__.py @@ -18,3 +18,4 @@ """Vision operators.""" from .nms import * +from .roi_align import * diff --git a/python/tvm/topi/vision/roi_align.py b/python/tvm/topi/vision/roi_align.py new file mode 100644 index 000000000000..2c2d0faec176 --- /dev/null +++ b/python/tvm/topi/vision/roi_align.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""ROI Align operator""" + +import tvm +from tvm import te + +from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc + + +def _sample_common( + i, + c, + ph, + pw, + rois, + pooled_size_h, + pooled_size_w, + spatial_scale, + sample_ratio, + aligned, + dtype, + avg_mode, + bilinear_func, +): + roi = rois[i] + batch_index = roi[0].astype("int32") + roi_start_w = roi[1] * spatial_scale + roi_start_h = roi[2] * spatial_scale + roi_end_w = roi[3] * spatial_scale + roi_end_h = roi[4] * spatial_scale + + if aligned: + roi_h = roi_end_h - roi_start_h + roi_w = roi_end_w - roi_start_w + else: + roi_h = te.max(roi_end_h - roi_start_h, tvm.tirx.const(1.0, dtype)) + roi_w = te.max(roi_end_w - roi_start_w, tvm.tirx.const(1.0, dtype)) + + pooled_size_h_const = tvm.tirx.const(pooled_size_h, dtype) + pooled_size_w_const = tvm.tirx.const(pooled_size_w, dtype) + bin_h = roi_h / pooled_size_h_const + bin_w = roi_w / pooled_size_w_const + + if sample_ratio > 0: + roi_bin_grid_h = tvm.tirx.const(sample_ratio, "int32") + roi_bin_grid_w = tvm.tirx.const(sample_ratio, "int32") + else: + roi_bin_grid_h = te.ceil(roi_h / pooled_size_h_const).astype("int32") + roi_bin_grid_w = te.ceil(roi_w / pooled_size_w_const).astype("int32") + + count = roi_bin_grid_h * roi_bin_grid_w + rh = te.reduce_axis((0, roi_bin_grid_h), name="rh") + rw = te.reduce_axis((0, roi_bin_grid_w), name="rw") + roi_start_h = roi_start_h + tvm.tirx.Cast(dtype, ph) * bin_h + roi_start_w = roi_start_w + tvm.tirx.Cast(dtype, pw) * bin_w + + def sample_value(rh_idx, rw_idx): + return bilinear_func( + batch_index, + c, + roi_start_h + + (tvm.tirx.Cast(dtype, rh_idx) + tvm.tirx.const(0.5, dtype)) + * bin_h + / tvm.tirx.Cast(dtype, roi_bin_grid_h), + roi_start_w + + (tvm.tirx.Cast(dtype, rw_idx) + tvm.tirx.const(0.5, dtype)) + * bin_w + / tvm.tirx.Cast(dtype, roi_bin_grid_w), + ) + + if avg_mode: + return te.sum( + sample_value(rh, rw) / tvm.tirx.Cast(dtype, count), + axis=[rh, rw], + ) + return te.max(sample_value(rh, rw), axis=[rh, rw]) + + +def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1, aligned=False): + """ROI align operator in NCHW layout.""" + avg_mode = mode in (b"avg", "avg", 0) + max_mode = mode in (b"max", "max", 1) + assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a valid mode." + + _, channel, height, width = data.shape + num_roi, _ = rois.shape + dtype = rois.dtype + + if isinstance(pooled_size, int): + pooled_size_h = pooled_size_w = pooled_size + else: + pooled_size_h, pooled_size_w = pooled_size + + height_f = tvm.tirx.Cast(dtype, height) + width_f = tvm.tirx.Cast(dtype, width) + zero = tvm.tirx.const(0.0, data.dtype) + + def _bilinear(n, c, y, x): + outside = tvm.tirx.any(y < -1.0, x < -1.0, y > height_f, x > width_f) + y = te.min(te.max(y, 0.0), tvm.tirx.Cast(dtype, height - 1)) + x = te.min(te.max(x, 0.0), tvm.tirx.Cast(dtype, width - 1)) + val = bilinear_sample_nchw(data, (n, c, y, x), height - 1, width - 1) + return tvm.tirx.if_then_else(outside, zero, val) + + return te.compute( + (num_roi, channel, pooled_size_h, pooled_size_w), + lambda i, c, ph, pw: _sample_common( + i, + c, + ph, + pw, + rois, + pooled_size_h, + pooled_size_w, + spatial_scale, + sample_ratio, + aligned, + dtype, + avg_mode, + _bilinear, + ), + tag="pool,roi_align_nchw", + ) + + +def roi_align_nhwc(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1, aligned=False): + """ROI align operator in NHWC layout.""" + avg_mode = mode in (b"avg", "avg", 0) + max_mode = mode in (b"max", "max", 1) + assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a valid mode." + + _, height, width, channel = data.shape + num_roi, _ = rois.shape + dtype = rois.dtype + + if isinstance(pooled_size, int): + pooled_size_h = pooled_size_w = pooled_size + else: + pooled_size_h, pooled_size_w = pooled_size + + height_f = tvm.tirx.Cast(dtype, height) + width_f = tvm.tirx.Cast(dtype, width) + zero = tvm.tirx.const(0.0, data.dtype) + + def _bilinear(n, c, y, x): + outside = tvm.tirx.any(y < -1.0, x < -1.0, y > height_f, x > width_f) + y = te.min(te.max(y, 0.0), tvm.tirx.Cast(dtype, height - 1)) + x = te.min(te.max(x, 0.0), tvm.tirx.Cast(dtype, width - 1)) + val = bilinear_sample_nhwc(data, (n, y, x, c), height - 1, width - 1) + return tvm.tirx.if_then_else(outside, zero, val) + + return te.compute( + (num_roi, pooled_size_h, pooled_size_w, channel), + lambda i, ph, pw, c: _sample_common( + i, + c, + ph, + pw, + rois, + pooled_size_h, + pooled_size_w, + spatial_scale, + sample_ratio, + aligned, + dtype, + avg_mode, + _bilinear, + ), + tag="pool,roi_align_nhwc", + ) + + +def roi_align( + data, + rois, + pooled_size, + spatial_scale, + mode="avg", + sample_ratio=-1, + aligned=False, + layout="NCHW", +): + """ROI align operator.""" + if layout == "NCHW": + return roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio, aligned) + if layout == "NHWC": + return roi_align_nhwc(data, rois, pooled_size, spatial_scale, mode, sample_ratio, aligned) + raise ValueError(f"Unsupported layout for roi_align: {layout}") diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc new file mode 100644 index 000000000000..ae5185d6d4fa --- /dev/null +++ b/src/relax/op/vision/roi_align.cc @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file roi_align.cc + * \brief ROI Align operators. + */ + +#include "roi_align.h" + +#include + +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { ROIAlignAttrs::RegisterReflection(); } + +Expr roi_align(Expr data, Expr rois, ffi::Array pooled_size, double spatial_scale, + int sample_ratio, bool aligned, ffi::String layout, ffi::String mode) { + if (pooled_size.size() == 1) { + pooled_size.push_back(pooled_size[0]); + } + TVM_FFI_ICHECK_EQ(pooled_size.size(), 2) + << "The input pooled_size length is expected to be 2. However, the given pooled_size is " + << pooled_size; + + auto attrs = ffi::make_object(); + attrs->pooled_size = std::move(pooled_size); + attrs->spatial_scale = spatial_scale; + attrs->sample_ratio = sample_ratio; + attrs->aligned = aligned; + attrs->layout = layout; + attrs->mode = mode; + + static const Op& op = Op::Get("relax.vision.roi_align"); + return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.roi_align", roi_align); +} + +StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign expects two arguments, while the given number of arguments is " + << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* rois_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign expects the input data to be a Tensor, while the given data is " + << call->args[0]->GetTypeKey()); + } + if (rois_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign expects the rois to be a Tensor, while the given rois is " + << call->args[1]->GetTypeKey()); + } + if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign expects the input data to be 4-D, while the given data has ndim " + << data_sinfo->ndim); + } + if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign expects the rois tensor to be 2-D, while the given rois has ndim " + << rois_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid ROIAlign attrs"; + if (attrs->layout != "NCHW" && attrs->layout != "NHWC") { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign only supports NCHW and NHWC layout, but got " << attrs->layout); + } + if (attrs->mode != "avg" && attrs->mode != "max") { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign only supports avg and max mode, but got " << attrs->mode); + } + + const auto* rois_shape = rois_sinfo->shape.as(); + if (rois_shape != nullptr) { + const auto* last_dim = rois_shape->values[1].as(); + if (last_dim != nullptr && last_dim->value != 5) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIAlign expects rois to have shape (num_roi, 5), but got last " + "dimension " + << last_dim->value); + } + } + + if (data_sinfo->shape.as() == nullptr || rois_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice); + } + + ffi::Array data_shape = data_sinfo->shape.as()->values; + ffi::Array out_shape; + if (attrs->layout == "NCHW") { + out_shape = {rois_shape->values[0], data_shape[1], Integer(attrs->pooled_size[0]), + Integer(attrs->pooled_size[1])}; + } else { + out_shape = {rois_shape->values[0], Integer(attrs->pooled_size[0]), + Integer(attrs->pooled_size[1]), data_shape[3]}; + } + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.vision.roi_align") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", + "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") + .set_attr("FInferStructInfo", InferStructInfoROIAlign) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/vision/roi_align.h b/src/relax/op/vision/roi_align.h new file mode 100644 index 000000000000..e2b861ac64bb --- /dev/null +++ b/src/relax/op/vision/roi_align.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file roi_align.h + * \brief The functions to make Relax ROI Align operator calls. + */ + +#ifndef TVM_RELAX_OP_VISION_ROI_ALIGN_H_ +#define TVM_RELAX_OP_VISION_ROI_ALIGN_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief ROI Align operator. */ +Expr roi_align(Expr data, Expr rois, ffi::Array pooled_size, double spatial_scale, + int sample_ratio, bool aligned, ffi::String layout, ffi::String mode); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_VISION_ROI_ALIGN_H_ diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e1cadb9d0228..5749f8d9423f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch +import torchvision from torch import nn from torch.export import export from torch.nn import Module @@ -8746,6 +8747,96 @@ def main( verify_model(GridSample(), example_args, {}, expected) +def test_torchvision_roi_align(): + class ROIAlign(Module): + def forward(self, input, rois): + return torchvision.ops.roi_align( + input, + rois, + output_size=(3, 3), + spatial_scale=1.0, + sampling_ratio=2, + aligned=False, + ) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8), dtype="float32"), + rois: R.Tensor((2, 5), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 3, 3, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 3, 3, 3), dtype="float32") = R.vision.roi_align( + input_1, + rois, + pooled_size=(3, 3), + spatial_scale=1.0, + sample_ratio=2, + layout="NCHW", + mode="avg", + ) + gv: R.Tuple(R.Tensor((2, 3, 3, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 8, 8, dtype=torch.float32), + torch.tensor([[0.0, 1.0, 1.0, 6.0, 6.0], [0.0, 0.5, 0.5, 7.0, 7.0]], dtype=torch.float32), + ) + verify_model(ROIAlign(), example_args, {}, expected) + + +def test_torchvision_roi_align_aligned(): + class ROIAlign(Module): + def forward(self, input, rois): + return torchvision.ops.roi_align( + input, + rois, + output_size=(1, 1), + spatial_scale=1.0, + sampling_ratio=2, + aligned=True, + ) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 1, 4, 4), dtype="float32"), + rois: R.Tensor((1, 5), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"s86": 2}}) + with R.dataflow(): + lv1: R.Tensor((1, 1), dtype="float32") = R.strided_slice( + rois, axes=[1], begin=[0], end=[1] + ) + lv2: R.Tensor((1, 4), dtype="float32") = R.strided_slice( + rois, axes=[1], begin=[1], end=[5] + ) + lv3: R.Tensor((1, 4), dtype="float32") = R.subtract(lv2, R.const(0.5, "float32")) + lv4: R.Tensor((1, 5), dtype="float32") = R.concat((lv1, lv3), axis=1) + lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.vision.roi_align( + input_1, + lv4, + pooled_size=(1, 1), + spatial_scale=1.0, + sample_ratio=2, + aligned=True, + layout="NCHW", + mode="avg", + ) + gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4), + torch.tensor([[0.0, 1.0, 1.0, 1.2, 1.2]], dtype=torch.float32), + ) + verify_model(ROIAlign(), example_args, {}, expected) + + def test_upsample_nearest2d(): class UpsampleNearest2dScale(Module): def forward(self, input): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8740720205ea..9316a78fd668 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -3979,6 +3979,7 @@ def test_nms_score_threshold(): tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 ) + @pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"]) @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"]) @pytest.mark.parametrize("align_corners", [0, 1]) @@ -4021,6 +4022,7 @@ def test_grid_sample(mode, padding_mode, align_corners): opset=16, ) + def test_grid_sample_linear_mode_translation(): """Test that ONNX mode='linear' is correctly translated to 'bilinear'. @@ -4047,7 +4049,9 @@ def test_grid_sample_linear_mode_translation(): helper.make_tensor_value_info("grid", TensorProto.FLOAT, grid_shape), ], outputs=[ - helper.make_tensor_value_info("Y", TensorProto.FLOAT, [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]]), + helper.make_tensor_value_info( + "Y", TensorProto.FLOAT, [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]] + ), ], ) @@ -4082,7 +4086,9 @@ def test_grid_sample_cubic_mode_translation(): helper.make_tensor_value_info("grid", TensorProto.FLOAT, grid_shape), ], outputs=[ - helper.make_tensor_value_info("Y", TensorProto.FLOAT, [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]]), + helper.make_tensor_value_info( + "Y", TensorProto.FLOAT, [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]] + ), ], ) @@ -4091,6 +4097,54 @@ def test_grid_sample_cubic_mode_translation(): # Verify 'cubic' was translated to 'bicubic' in the Relax IR assert 'method="bicubic"' in str(tvm_model) + +@pytest.mark.parametrize( + ("coordinate_transformation_mode", "rois"), + [ + ( + "output_half_pixel", + np.array([[1.0, 1.0, 6.0, 6.0], [2.0, 0.5, 7.0, 7.0]], dtype="float32"), + ), + ("half_pixel", np.array([[1.0, 1.0, 1.2, 1.2], [2.0, 0.5, 1.1, 1.1]], dtype="float32")), + ], +) +def test_roi_align(coordinate_transformation_mode, rois): + x_shape = [1, 4, 8, 8] + rois_shape = [2, 4] + batch_indices_shape = [2] + out_shape = [2, 4, 3, 3] + + node = helper.make_node( + "RoiAlign", + inputs=["X", "rois", "batch_indices"], + outputs=["Y"], + output_height=3, + output_width=3, + sampling_ratio=2, + spatial_scale=1.0, + mode="avg", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + + graph = helper.make_graph( + [node], + "roi_align_test", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape), + helper.make_tensor_value_info("rois", TensorProto.FLOAT, rois_shape), + helper.make_tensor_value_info("batch_indices", TensorProto.INT64, batch_indices_shape), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name="roi_align_test") + inputs = { + "X": rg.standard_normal(size=x_shape).astype("float32"), + "rois": rois, + "batch_indices": np.array([0, 0], dtype="int64"), + } + check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main() - diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 753ee14140bf..61c380dae5e3 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -15,21 +15,228 @@ # specific language governing permissions and limitations # under the License. +import importlib.util +from pathlib import Path + import numpy as np import pytest import tvm import tvm.testing from tvm import TVMError, relax, tirx +from tvm.ir import Op from tvm.relax.transform import LegalizeOps from tvm.script import relax as R +_ROI_ALIGN_PYTHON_SPEC = importlib.util.spec_from_file_location( + "roi_align_python", + Path(__file__).resolve().parents[3] / "python/tvm/topi/testing/roi_align_python.py", +) +_ROI_ALIGN_PYTHON = importlib.util.module_from_spec(_ROI_ALIGN_PYTHON_SPEC) +assert _ROI_ALIGN_PYTHON_SPEC.loader is not None +_ROI_ALIGN_PYTHON_SPEC.loader.exec_module(_ROI_ALIGN_PYTHON) + def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): ret = bb.normalize(call) tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +def test_roi_align_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((4, 5), "float32")) + assert relax.op.vision.roi_align(x, rois, (7, 7), 1.0).op == Op.get("relax.vision.roi_align") + + +def test_roi_align_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + rois = relax.Var("rois", R.Tensor((5, 5), "float32")) + + _check_inference( + bb, + relax.op.vision.roi_align(x0, rois, (7, 7), 0.25), + relax.TensorStructInfo((5, 3, 7, 7), "float32"), + ) + _check_inference( + bb, + relax.op.vision.roi_align(x1, rois, (5, 7), 1.0, layout="NHWC"), + relax.TensorStructInfo((5, 5, 7, 3), "float32"), + ) + + +def test_roi_align_infer_struct_info_aligned(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((5, 5), "float32")) + + _check_inference( + bb, + relax.op.vision.roi_align(x, rois, (7, 7), 1.0, aligned=True), + relax.TensorStructInfo((5, 3, 7, 7), "float32"), + ) + + +def test_roi_align_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + h = tirx.Var("h", "int64") + w = tirx.Var("w", "int64") + num_roi = tirx.Var("num_roi", "int64") + + x = relax.Var("x", R.Tensor((n, c, h, w), "float32")) + rois = relax.Var("rois", R.Tensor((num_roi, 5), "float32")) + + _check_inference( + bb, + relax.op.vision.roi_align(x, rois, (7, 7), 0.5), + relax.TensorStructInfo((num_roi, c, 7, 7), "float32"), + ) + + +def test_roi_align_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois0 = relax.Var("rois", R.Tensor((4,), "float32")) + rois1 = relax.Var("rois", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_align(x0, rois1, (7, 7), 1.0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_align(x1, rois0, (7, 7), 1.0)) + + +def test_roi_align_wrong_rois_last_dim(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((4, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0)) + + +def test_roi_align_wrong_layout(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0, layout="HWCN")) + + +def test_roi_align_legalize_e2e(): + @tvm.script.ir_module + class ROIAlign: + @R.function + def main( + x: R.Tensor((1, 2, 8, 8), "float32"), + rois: R.Tensor((2, 5), "float32"), + ) -> R.Tensor((2, 2, 3, 3), "float32"): + gv: R.Tensor((2, 2, 3, 3), "float32") = R.vision.roi_align( + x, + rois, + pooled_size=(3, 3), + spatial_scale=1.0, + sample_ratio=2, + layout="NCHW", + mode="avg", + ) + return gv + + mod = LegalizeOps()(ROIAlign) + assert "call_tir" in str(mod) + + x_data = np.arange(1 * 2 * 8 * 8, dtype=np.float32).reshape((1, 2, 8, 8)) + rois_data = np.array( + [ + [0.0, 1.0, 1.0, 6.0, 6.0], + [0.0, 2.0, 2.0, 4.5, 5.5], + ], + dtype=np.float32, + ) + expected = _ROI_ALIGN_PYTHON.roi_align_nchw_python( + x_data, rois_data, pooled_size=(3, 3), spatial_scale=1.0, sample_ratio=2, mode="avg" + ) + + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + result = vm["main"]( + tvm.runtime.tensor(x_data, tvm.cpu()), tvm.runtime.tensor(rois_data, tvm.cpu()) + ) + tvm.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) + + +def test_roi_align_legalize_e2e_aligned(): + @tvm.script.ir_module + class ROIAlign: + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), "float32"), + rois: R.Tensor((1, 5), "float32"), + ) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.vision.roi_align( + x, + rois, + pooled_size=(1, 1), + spatial_scale=1.0, + sample_ratio=2, + aligned=True, + layout="NCHW", + mode="avg", + ) + return gv + + mod = LegalizeOps()(ROIAlign) + x_data = np.arange(16, dtype=np.float32).reshape((1, 1, 4, 4)) + rois_data = np.array([[0.0, 0.5, 0.5, 0.7, 0.7]], dtype=np.float32) + expected = _ROI_ALIGN_PYTHON.roi_align_nchw_python( + x_data, + rois_data, + pooled_size=(1, 1), + spatial_scale=1.0, + sample_ratio=2, + mode="avg", + aligned=True, + ) + + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + result = vm["main"]( + tvm.runtime.tensor(x_data, tvm.cpu()), tvm.runtime.tensor(rois_data, tvm.cpu()) + ) + tvm.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) + + +def test_roi_align_legalize_sample_ratio_zero(): + @tvm.script.ir_module + class ROIAlign: + @R.function + def main( + x: R.Tensor((1, 2, 8, 8), "float32"), + rois: R.Tensor((1, 5), "float32"), + ) -> R.Tensor((1, 2, 2, 2), "float32"): + gv: R.Tensor((1, 2, 2, 2), "float32") = R.vision.roi_align( + x, + rois, + pooled_size=(2, 2), + spatial_scale=1.0, + sample_ratio=0, + layout="NCHW", + mode="avg", + ) + return gv + + mod = LegalizeOps()(ROIAlign) + assert "call_tir" in str(mod) + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo((1, 2, 2, 2), "float32"), + ) + + def test_all_class_non_max_suppression_infer_struct_info(): bb = relax.BlockBuilder() batch_size, num_classes, num_boxes = 10, 8, 5 diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py index 10817c1dc45d..c4e8ff0c9d22 100644 --- a/tests/python/relax/test_tvmscript_parser_op_vision.py +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -75,5 +75,37 @@ def foo( _check(foo, bb.get()["foo"]) +def test_roi_align(): + @R.function + def foo( + x: R.Tensor((1, 2, 8, 8), "float32"), + rois: R.Tensor((2, 5), "float32"), + ) -> R.Tensor((2, 2, 3, 3), "float32"): + gv: R.Tensor((2, 2, 3, 3), "float32") = R.vision.roi_align( + x, + rois, + pooled_size=(3, 3), + spatial_scale=1.0, + sample_ratio=2, + layout="NCHW", + mode="avg", + ) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 8, 8), "float32")) + rois = relax.Var("rois", R.Tensor((2, 5), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [x, rois]): + gv = bb.emit( + relax.op.vision.roi_align( + x, rois, (3, 3), 1.0, sample_ratio=2, layout="NCHW", mode="avg" + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() From 06241799585f209ee9cade699835a8191d881eee Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Thu, 26 Mar 2026 09:09:15 +0000 Subject: [PATCH 2/2] [Relax][ONNX][Torch] Refine roi_align tests and frontend handling --- .../test_frontend_from_exported_program.py | 33 +---------- tests/python/relax/test_op_vision.py | 57 +++---------------- 2 files changed, 10 insertions(+), 80 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 5749f8d9423f..7a3548b4cfd5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -8799,42 +8799,11 @@ def forward(self, input, rois): aligned=True, ) - @tvm.script.ir_module - class expected: - @R.function - def main( - input_1: R.Tensor((1, 1, 4, 4), dtype="float32"), - rois: R.Tensor((1, 5), dtype="float32"), - ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): - R.func_attr({"tir_var_lower_bound": {"s86": 2}}) - with R.dataflow(): - lv1: R.Tensor((1, 1), dtype="float32") = R.strided_slice( - rois, axes=[1], begin=[0], end=[1] - ) - lv2: R.Tensor((1, 4), dtype="float32") = R.strided_slice( - rois, axes=[1], begin=[1], end=[5] - ) - lv3: R.Tensor((1, 4), dtype="float32") = R.subtract(lv2, R.const(0.5, "float32")) - lv4: R.Tensor((1, 5), dtype="float32") = R.concat((lv1, lv3), axis=1) - lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.vision.roi_align( - input_1, - lv4, - pooled_size=(1, 1), - spatial_scale=1.0, - sample_ratio=2, - aligned=True, - layout="NCHW", - mode="avg", - ) - gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv,) - R.output(gv) - return gv - example_args = ( torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4), torch.tensor([[0.0, 1.0, 1.0, 1.2, 1.2]], dtype=torch.float32), ) - verify_model(ROIAlign(), example_args, {}, expected) + verify_model_numerically(ROIAlign(), example_args, rtol=1e-5, atol=1e-5) def test_upsample_nearest2d(): diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 61c380dae5e3..b902518b49bb 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -15,9 +15,6 @@ # specific language governing permissions and limitations # under the License. -import importlib.util -from pathlib import Path - import numpy as np import pytest @@ -28,14 +25,6 @@ from tvm.relax.transform import LegalizeOps from tvm.script import relax as R -_ROI_ALIGN_PYTHON_SPEC = importlib.util.spec_from_file_location( - "roi_align_python", - Path(__file__).resolve().parents[3] / "python/tvm/topi/testing/roi_align_python.py", -) -_ROI_ALIGN_PYTHON = importlib.util.module_from_spec(_ROI_ALIGN_PYTHON_SPEC) -assert _ROI_ALIGN_PYTHON_SPEC.loader is not None -_ROI_ALIGN_PYTHON_SPEC.loader.exec_module(_ROI_ALIGN_PYTHON) - def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): ret = bb.normalize(call) @@ -127,7 +116,7 @@ def test_roi_align_wrong_layout(): bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0, layout="HWCN")) -def test_roi_align_legalize_e2e(): +def test_roi_align_legalize(): @tvm.script.ir_module class ROIAlign: @R.function @@ -148,28 +137,13 @@ def main( mod = LegalizeOps()(ROIAlign) assert "call_tir" in str(mod) - - x_data = np.arange(1 * 2 * 8 * 8, dtype=np.float32).reshape((1, 2, 8, 8)) - rois_data = np.array( - [ - [0.0, 1.0, 1.0, 6.0, 6.0], - [0.0, 2.0, 2.0, 4.5, 5.5], - ], - dtype=np.float32, - ) - expected = _ROI_ALIGN_PYTHON.roi_align_nchw_python( - x_data, rois_data, pooled_size=(3, 3), spatial_scale=1.0, sample_ratio=2, mode="avg" - ) - - exe = tvm.compile(mod, target="llvm") - vm = relax.VirtualMachine(exe, tvm.cpu()) - result = vm["main"]( - tvm.runtime.tensor(x_data, tvm.cpu()), tvm.runtime.tensor(rois_data, tvm.cpu()) + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo((2, 2, 3, 3), "float32"), ) - tvm.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) -def test_roi_align_legalize_e2e_aligned(): +def test_roi_align_legalize_aligned(): @tvm.script.ir_module class ROIAlign: @R.function @@ -190,24 +164,11 @@ def main( return gv mod = LegalizeOps()(ROIAlign) - x_data = np.arange(16, dtype=np.float32).reshape((1, 1, 4, 4)) - rois_data = np.array([[0.0, 0.5, 0.5, 0.7, 0.7]], dtype=np.float32) - expected = _ROI_ALIGN_PYTHON.roi_align_nchw_python( - x_data, - rois_data, - pooled_size=(1, 1), - spatial_scale=1.0, - sample_ratio=2, - mode="avg", - aligned=True, - ) - - exe = tvm.compile(mod, target="llvm") - vm = relax.VirtualMachine(exe, tvm.cpu()) - result = vm["main"]( - tvm.runtime.tensor(x_data, tvm.cpu()), tvm.runtime.tensor(rois_data, tvm.cpu()) + assert "call_tir" in str(mod) + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo((1, 1, 1, 1), "float32"), ) - tvm.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) def test_roi_align_legalize_sample_ratio_zero():