diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index be207a2d0593..e0ee6dc748c2 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -276,23 +276,44 @@ struct GridSampleAttrs : public tvm::AttrsNode { String method; String layout; String padding_mode; + bool align_corners; TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") { TVM_ATTR_FIELD(method) .set_default("bilinear") .describe( "Specify the mode to use for scaling." - "bilinear - Bilinear Interpolation"); + "nearest - 2D or 3D Nearest Interpolation." + "bilinear - '2D Bilinear' or '3D Trilinear' Interpolation." + "bicubic - 2D Bicubic Interpolation."); TVM_ATTR_FIELD(layout).set_default("NCHW").describe( - "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); + "Dimension ordering of input data. Can be 'NCHW', 'NCDHW', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively." + "2D Resize is applied on the 'H' and 'W' dimensions." + "3D Resize is applied on the 'D' and 'H' and 'W' dimensions."); TVM_ATTR_FIELD(padding_mode) .set_default("zeros") .describe( - "Specify the padding mode to use." - "zeros, border etc."); + "If :attr:'grid' has values outside the range of '[-1, 1]', the corresponding" + "outputs are handled as defined by padding_mode. Options are" + "padding_mode='zeros': use '0' for out-of-bound grid locations," + "padding_mode='border': use border values for out-of-bound grid locations" + "padding_mode='reflection': use values at locations reflected by" + "the border for out-of-bound grid locations. For location far away" + "from the border, it will keep being reflected until becoming in bound," + "e.g., (normalized) pixel location 'x = -3.5' reflects by border '-1'" + "and becomes 'x' = 1.5, then reflects by border '1' and becomes" + "'x' = -0.5"); + TVM_ATTR_FIELD(align_corners) + .set_default(true) + .describe( + "Geometrically, we consider the pixels of the" + "input as squares rather than points." + "If set to True, the extrema (-1 and 1) are considered as referring" + "to the center points of the input's corner pixels. If set to False, they" + "are instead considered as referring to the corner points of the input's corner" + "pixels, making the sampling more resolution agnostic."); } }; diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9d3980d0f151..f4c09447db3b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2931,23 +2931,46 @@ def mv(self, inputs, _): return _op.transform.squeeze(dense_result) def grid_sampler(self, inputs, input_types): - if inputs[2] == 0: - mode = "bilinear" + interpolate_mode = inputs[2] + padding_mode = inputs[3] + align_corners = inputs[4] + data_shape = self.infer_shape_with_prelude(inputs[0]) + + if len(data_shape) == 4: + layout = "NCHW" + axes = [0, 3, 1, 2] + grid = _op.transform.transpose(inputs[1], axes) + elif len(data_shape) == 5: + layout = "NCDHW" + axes = [0, 4, 1, 2, 3] + grid = _op.transform.transpose(inputs[1], axes) else: - msg = "Only bilinear mode is supported in grid_sampler" - raise NotImplementedError(msg) - - if inputs[3] == 0: - padding_mode = "zeros" - elif inputs[3] == 1: - padding_mode = "border" + msg = f"only 4D and 5D are supported." + raise ValueError(msg) + + if interpolate_mode == 0: + interpolate_str = "bilinear" + elif interpolate_mode == 1: + interpolate_str = "nearest" + elif interpolate_mode == 2: + interpolate_str = "bicubic" else: - msg = "Only zeros and border padding mode are supported in grid_sampler" - raise NotImplementedError(msg) + msg = f"interpolation method {interpolate_mode} is not supported" + raise ValueError(msg) + + if padding_mode == 0: + padding_mode_str = "zeros" + elif padding_mode == 1: + padding_mode_str = "border" + elif padding_mode == 2: + padding_mode_str = "reflection" + else: + msg = f"padding_mode {padding_mode} is not supported" + raise ValueError(msg) - axes = [0, 3, 1, 2] - grid = _op.transform.transpose(inputs[1], axes) - return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode) + return _op.image.grid_sample( + inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners + ) # Operator mappings def create_convert_map(self): diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index ec25198adf68..f46a04bd0592 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -366,14 +366,17 @@ def compute_grid_sample(attrs, inputs, out_dtype): method = attrs.method layout = attrs.layout padding_mode = attrs.padding_mode - return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)] + align_corners = attrs.align_corners + return [ + topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode, align_corners) + ] reg.register_injective_schedule("image.grid_sample") @script -def _grid_sample_func(data, grid): +def _grid_sample_func_nchw(data, grid): out = output_tensor((4,), "int64") out[0] = int64(data[0]) out[1] = int64(data[1]) @@ -382,9 +385,27 @@ def _grid_sample_func(data, grid): return out +@script +def _grid_sample_func_ncdhw(data, grid): + out = output_tensor((5,), "int64") + out[0] = int64(data[0]) + out[1] = int64(data[1]) + out[2] = int64(grid[2]) + out[3] = int64(grid[3]) + out[4] = int64(grid[4]) + return out + + @reg.register_shape_func("image.grid_sample", False) def grid_sample_func(attrs, inputs, _): """ Shape function for grid_sample op. """ - return [_grid_sample_func(inputs[0], inputs[1])] + if attrs.layout == "NCHW": + script_func = _grid_sample_func_nchw + elif attrs.layout == "NCDHW": + script_func = _grid_sample_func_ncdhw + else: + msg = f"layout {attrs.layout} is not supported" + raise ValueError(msg) + return [script_func(inputs[0], inputs[1])] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index eb6c316402c6..b5886300cbed 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -455,22 +455,33 @@ def affine_grid(data, target_shape=None): return _make.affine_grid(data, target_shape) -def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"): - """Applies bilinear sampling to input feature map. +def grid_sample( + data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True +): + """Applies grid sampling to input feature map. - Given :math:`data` and :math:`grid`, then the output is computed by + Given :math:`data` and :math:`grid`, then for 4-D the output is computed by .. math:: x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ - output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and :math:`G()` denotes the interpolation function. - The out-boundary points will be padded with zeros if padding_mode is "zeros". + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to + (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + The shape of the output will be - (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + 4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or + 5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). The operator assumes that :math:`grid` has been normalized to [-1, 1]. @@ -479,23 +490,34 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero Parameters ---------- data : tvm.Tensor - 4-D with shape [batch, in_channel, in_height, in_width] + 4-D with shape [batch, in_channel, in_height, in_width], or + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] grid : tvm.Tensor - 4-D with shape [batch, 2, out_height, out_width] + 4-D with shape [batch, 2, out_height, out_width], or + 5-D with shape [batch, 3, out_depth, out_height, out_width] method : str - The interpolation method. Only 'bilinear' is supported. + The interpolation method, 4-D "nearest", "bilinear", "bicubic" and + 5-D "nearest", "bilinear"("trilinear") are supported. layout : str The layout of input data and the output. padding_mode : str - The padding mode for outside grid values. + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. Returns ------- Output : tvm.Tensor - 4-D with shape [batch, 2, out_height, out_width] + 4-D with shape [batch, in_channel, out_height, out_width], or + 5-D with shape [batch, in_channel, out_depth, out_height, out_width] """ - return _make.grid_sample(data, grid, method, layout, padding_mode) + return _make.grid_sample(data, grid, method, layout, padding_mode, align_corners) diff --git a/python/tvm/topi/image/grid_sample.py b/python/tvm/topi/image/grid_sample.py index e3a6dd80405a..705df8db7b54 100644 --- a/python/tvm/topi/image/grid_sample.py +++ b/python/tvm/topi/image/grid_sample.py @@ -59,10 +59,12 @@ def _compute(n, dim, i, j): return te.compute(oshape, _compute, tag="affine_grid") -def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"): - """Applies bilinear sampling to input feature map. +def _grid_sample_2d( + data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True +): + """Applies bilinear/nearest/bicubic sampling to input feature map. - Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by + Given :math:`data` and :math:`grid` assuming NCHW layout, then the output is computed by .. math:: @@ -72,9 +74,16 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and :math:`G()` denotes the interpolation method. - The out-boundary points will be padded with zeros if the padding_mode is "zeros". - The shape of the output will be - (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to + (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + + The shape of the output will be (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). The operator assumes that :math:`grid` has been normalized to [-1, 1]. @@ -89,44 +98,99 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero 4-D with shape [batch, 2, out_height, out_width] method : str - The interpolation method. Only 'bilinear' is supported. + The interpolation method "nearest", "bilinear", "bicubic" are supported. layout : str The layout of input data and the output. + padding_mode : str + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. + Returns ------- Output : tvm.Tensor 4-D with shape [batch, in_channel, out_height, out_width] """ + + assert method in ("bilinear", "nearest", "bicubic"), f"{method} is not supported" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert layout == "NCHW", f"{layout} is not supported" + batch, in_channel, in_height, in_width = data.shape out_height, out_width = grid.shape[2:] - assert method == "bilinear", "Only bilinear is supported" - assert layout == "NCHW", "Only NCHW is supported" def _get_pixel_value(n, c, h, w): - if padding_mode == "zeros": - return te.if_then_else( - te.all(h >= 0, w >= 0, h < in_height, w < in_width), - data[n, c, h, w], - tir.const(0.0, dtype=data.dtype), + return te.if_then_else( + te.all(h >= 0, w >= 0, h < in_height, w < in_width), + data[n, c, h, w], + tir.const(0.0, dtype=data.dtype), + ) + + def _unnormalize(h, w): + if align_corners: + y = (h + 1) * (in_height - 1) / 2 + x = (w + 1) * (in_width - 1) / 2 + else: + y = -0.5 + (h + 1) * in_height / 2 + x = -0.5 + (w + 1) * in_width / 2 + return (y, x) + + def _clip_coordinates(x, size): + return te.min(te.max(x, 0), size - 1) + + def _compute_source_index(n, h, w): + y = grid[n, 1, h, w] + x = grid[n, 0, h, w] + y, x = _unnormalize(y, x) + + if padding_mode == "reflection": + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + + return (y, x) + + def _reflect_coordinates(x, size): + def __refelection(x, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = te.abs(corner_start - index) + size_times = te.truncdiv(index_align_corner.astype("int32"), size).astype("int32") + t = tir.Mod(size_times, 2) + extra = index_align_corner - size_times * size + return tir.if_then_else( + tir.EQ(t, 0), extra + corner_start, size - extra + corner_start + ) + + return tir.if_then_else( + tir.all(x >= corner_start, x <= size + corner_start), + x, + __reflect(x, size, corner_start), ) - if padding_mode == "border": - h_b = te.max(te.min(h, in_height - 1), 0) - w_b = te.max(te.min(w, in_width - 1), 0) - return data[n, c, h_b, w_b] - raise AssertionError("unsupported padding_mode") + if align_corners: + new_x = __refelection(x, size - 1, 0) + else: + new_x = __refelection(x, size, -0.5) + return new_x def _bilinear_sample(n, c, h, w): - x = grid[n, 0, h, w] - y = grid[n, 1, h, w] - y = (y + 1) * (in_height - 1) / 2 - x = (x + 1) * (in_width - 1) / 2 - x0 = te.floor(x).astype("int32") + y, x = _compute_source_index(n, h, w) y0 = te.floor(y).astype("int32") - x1 = x0 + tir.const(1, "int32") + x0 = te.floor(x).astype("int32") y1 = y0 + tir.const(1, "int32") + x1 = x0 + tir.const(1, "int32") + return ( _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) @@ -134,6 +198,332 @@ def _bilinear_sample(n, c, h, w): + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0) ) + def _nearest_sample(n, c, h, w): + y, x = _compute_source_index(n, h, w) + y_new = te.round(y).astype("int32") + x_new = te.round(x).astype("int32") + + return _get_pixel_value(n, c, y_new, x_new) + + def _bicubic_sample(n, c, h, w): + A = -0.75 # 0.75 is used in pytorch, it maybe different in other frameworks + + def cubic_weight_1(fraction): + return ((A + 2) * fraction - (A + 3)) * fraction * fraction + 1 + + def cubic_weight_2(fraction): + return ((A * fraction - 5 * A) * fraction + 8 * A) * fraction - 4 * A + + def cubic_interp_1d(pixel_0, pixel_1, pixel_2, pixel_3, fraction): + weights = [0] * 4 + weights[0] = cubic_weight_2(fraction + 1) + weights[1] = cubic_weight_1(fraction) + weights[2] = cubic_weight_1(1 - fraction) + weights[3] = cubic_weight_2(2 - fraction) + return ( + pixel_0 * weights[0] + + pixel_1 * weights[1] + + pixel_2 * weights[2] + + pixel_3 * weights[3] + ) + + y = grid[n, 1, h, w] + x = grid[n, 0, h, w] + y, x = _unnormalize(y, x) + y_floor = te.floor(y).astype("int32") + x_floor = te.floor(x).astype("int32") + y_fraction = y - y_floor + x_fraction = x - x_floor + + coefficients = [0] * 4 + + for i in range(4): + y_ = y_floor - 1 + i + x_0 = x_floor - 1 + x_1 = x_floor + 0 + x_2 = x_floor + 1 + x_3 = x_floor + 2 + + if padding_mode == "border": + y_ = _clip_coordinates(y_, in_height).astype("int32") + x_0 = _clip_coordinates(x_0, in_width).astype("int32") + x_1 = _clip_coordinates(x_1, in_width).astype("int32") + x_2 = _clip_coordinates(x_2, in_width).astype("int32") + x_3 = _clip_coordinates(x_3, in_width).astype("int32") + + elif padding_mode == "reflection": + y_ = _reflect_coordinates(y_, in_height) + x_0 = _reflect_coordinates(x_0, in_width) + x_1 = _reflect_coordinates(x_1, in_width) + x_2 = _reflect_coordinates(x_2, in_width) + x_3 = _reflect_coordinates(x_3, in_width) + + y_ = _clip_coordinates(y_, in_height).astype("int32") + x_0 = _clip_coordinates(x_0, in_width).astype("int32") + x_1 = _clip_coordinates(x_1, in_width).astype("int32") + x_2 = _clip_coordinates(x_2, in_width).astype("int32") + x_3 = _clip_coordinates(x_3, in_width).astype("int32") + + coefficients[i] = cubic_interp_1d( + _get_pixel_value(n, c, y_, x_0), + _get_pixel_value(n, c, y_, x_1), + _get_pixel_value(n, c, y_, x_2), + _get_pixel_value(n, c, y_, x_3), + x_fraction, + ) + + return cubic_interp_1d( + coefficients[0], coefficients[1], coefficients[2], coefficients[3], y_fraction + ) + + if method == "bilinear": + interpolation = _bilinear_sample + elif method == "nearest": + interpolation = _nearest_sample + else: # method == "bicubic" + interpolation = _bicubic_sample + + return te.compute((batch, in_channel, out_height, out_width), interpolation, tag="grid_sample") + + +def _grid_sample_3d( + data, grid, method="bilinear", layout="NCDHW", padding_mode="zeros", align_corners=True +): + """Applies bilinear/nearest sampling to input feature map. + + Given :math:`data` and :math:`grid` assuming NCDHW layout, then the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\ + z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\ + output[batch, channel, z_{src}, y_{dst}, x_{dst}] + = G(data[batch, channel, z_{src}, y_{src}, x_{src}) + + :math:`x_{dst}`, :math:`y_{dst}`, :math:`z_{dst}` enumerate all spatial locations + in :math:`output`, and :math:`G()` denotes the interpolation method. + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1, -1) and right-bottom corner (1, 1, 1) in grid will be map to + (0, 0, 0) and (d - 1, h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5, -0.5) and (d + 0.5, h + 0.5, w + 0.5) of data if align_corners is "False". + + The shape of the output will be + (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + grid : tvm.Tensor + 5-D with shape [batch, 3, out_depth, out_height, out_width] + + method : str + The interpolation method "nearest", "bilinear"("trilinear") are supported. + + layout : str + The layout of input data and the output. + + padding_mode : str + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. + + Returns + ------- + Output : tvm.Tensor + 5-D with shape [batch, in_channel, out_depth, out_height, out_width] + """ + + assert method in ("bilinear", "nearest"), f"{method} is not supported" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert layout == "NCDHW", f"{layout} is not supported" + + batch, in_channel, in_depth, in_height, in_width = data.shape + out_depth, out_height, out_width = grid.shape[2:] + + def _get_pixel_value(n, c, d, h, w): + return te.if_then_else( + te.all(d >= 0, h >= 0, w >= 0, d < in_depth, h < in_height, w < in_width), + data[n, c, d, h, w], + tir.const(0.0, dtype=data.dtype), + ) + + def _compute_source_index(n, d, h, w): + z = grid[n, 2, d, h, w] + y = grid[n, 1, d, h, w] + x = grid[n, 0, d, h, w] + + if align_corners: + z = (z + 1) * (in_depth - 1) / 2 + y = (y + 1) * (in_height - 1) / 2 + x = (x + 1) * (in_width - 1) / 2 + else: + z = -0.5 + (z + 1) * in_depth / 2 + y = -0.5 + (y + 1) * in_height / 2 + x = -0.5 + (x + 1) * in_width / 2 + + if padding_mode == "reflection": + z = _reflect_coordinates(z, in_depth) + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + + return (z, y, x) + + def _clip_coordinates(x, size): + return te.min(te.max(x, 0), size - 1) + + def _reflect_coordinates(x, size): + def __refelection(x, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = te.abs(corner_start - index) + size_times = te.truncdiv(index_align_corner.astype("int32"), size).astype("int32") + t = tir.Mod(size_times, 2) + extra = index_align_corner - size_times * size + return tir.if_then_else( + tir.EQ(t, 0), extra + corner_start, size - extra + corner_start + ) + + return tir.if_then_else( + tir.all(x >= corner_start, x <= size + corner_start), + x, + __reflect(x, size, corner_start), + ) + + if align_corners: + return __refelection(x, size - 1, 0) + return __refelection(x, size, -0.5) + + def _trilinear_sample(n, c, d, h, w): + z, y, x = _compute_source_index(n, d, h, w) + z0 = te.floor(z).astype("int32") + y0 = te.floor(y).astype("int32") + x0 = te.floor(x).astype("int32") + z1 = z0 + tir.const(1, "int32") + y1 = y0 + tir.const(1, "int32") + x1 = x0 + tir.const(1, "int32") + + return ( + _get_pixel_value(n, c, z0, y0, x0) * (1 - (x - x0)) * (1 - (y - y0)) * (1 - (z - z0)) + + _get_pixel_value(n, c, z0, y0, x1) * (x - x0) * (1 - (y - y0)) * (1 - (z - z0)) + + _get_pixel_value(n, c, z1, y1, x0) * (1 - (x - x0)) * (y - y0) * (z - z0) + + _get_pixel_value(n, c, z1, y1, x1) * (x - x0) * (y - y0) * (z - z0) + + _get_pixel_value(n, c, z0, y1, x0) * (1 - (x - x0)) * (y - y0) * (1 - (z - z0)) + + _get_pixel_value(n, c, z1, y0, x1) * (x - x0) * (1 - (y - y0)) * (z - z0) + + _get_pixel_value(n, c, z1, y0, x0) * (1 - (x - x0)) * (1 - (y - y0)) * (z - z0) + + _get_pixel_value(n, c, z0, y1, x1) * (x - x0) * (y - y0) * (1 - (z - z0)) + ) + + def _nearest_sample(n, c, d, h, w): + z, y, x = _compute_source_index(n, d, h, w) + z_new = te.round(z).astype("int32") + y_new = te.round(y).astype("int32") + x_new = te.round(x).astype("int32") + return _get_pixel_value(n, c, z_new, y_new, x_new) + + if method == "bilinear": + interpolation = _trilinear_sample + else: # method == "nearest" + interpolation = _nearest_sample + return te.compute( - (batch, in_channel, out_height, out_width), _bilinear_sample, tag="grid_sample" + (batch, in_channel, out_depth, out_height, out_width), interpolation, tag="grid_sample" ) + + +def grid_sample( + data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True +): + """Applies grid sampling to input feature map. + + Given :math:`data` and :math:`grid`, then for 4-D the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) + + :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and + :math:`G()` denotes the interpolation function. + + The out-boundary points will be padded with zeros if padding_mode is "zeros", or + border pixel value if padding_mode is "border", or + inner pixel value if padding_mode is "reflection". + + The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to + (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or + (-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + + The shape of the output will be + 4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or + 5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width], or + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + grid : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width], or + 5-D with shape [batch, 3, out_depth, out_height, out_width] + + method : str + The interpolation method, 4-D "nearest", "bilinear", "bicubic" and + 5-D "nearest", "bilinear"("trilinear") are supported. + + layout : str + The layout of input data and the output. + + padding_mode : str + The padding mode for outside grid values, "zeros", "border", "reflection" are supported. + + align_corners: bool + Geometrically, we consider the pixels of the input as squares rather than points. + If set to "True", the extrema ("-1" and "1") are considered as referring + to the center points of the input corner pixels. If set to "False", they + are instead considered as referring to the corner points of the input corner + pixels, making the sampling more resolution agnostic. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, in_channel, out_height, out_width], or + 5-D with shape [batch, in_channel, out_depth, out_height, out_width] + """ + + if len(layout) == 4: + compute = _grid_sample_2d + elif len(layout) == 5: + compute = _grid_sample_3d + else: + msg = f"layout {layout} is not supported" + raise ValueError(msg) + + return compute(data, grid, method, layout, padding_mode, align_corners) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index c3d222cfd120..21ddf6fc5536 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -68,7 +68,7 @@ dispatch, ) from .adaptive_pool_python import adaptive_pool -from .grid_sample_python import affine_grid_python, grid_sample_nchw_python +from .grid_sample_python import affine_grid_python, grid_sample_python from .matrix_set_diag import matrix_set_diag from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python diff --git a/python/tvm/topi/testing/grid_sample_python.py b/python/tvm/topi/testing/grid_sample_python.py index e6b0bef38685..07a7c10d8db2 100644 --- a/python/tvm/topi/testing/grid_sample_python.py +++ b/python/tvm/topi/testing/grid_sample_python.py @@ -29,71 +29,368 @@ def affine_grid_python(data, target_shape): return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape) -def _bilinear_sample_nchw_python(data, grid, padding_mode): - batch, in_channel, in_height, in_width = data.shape - _, _, out_height, out_width = grid.shape - out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype) - - def _within_bound(y, x): - return 0 <= y < in_height and 0 <= x < in_width - - def compute_padding_mode_zeros(): - for n in range(0, batch): - for h in range(0, out_height): - for w in range(0, out_width): - x, y = grid[n, :, h, w] - y = (y + 1) * (in_height - 1) / 2 - x = (x + 1) * (in_width - 1) / 2 - y0 = int(math.floor(y)) - x0 = int(math.floor(x)) - y1 = y0 + 1 - x1 = x0 + 1 - if _within_bound(y0, x0): - out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0)) - if _within_bound(y0, x1): - out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0) - if _within_bound(y1, x0): - out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0)) - if _within_bound(y1, x1): - out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0) - - return out - - def get_pixel_value(x, x_max): - return max(min(x, x_max - 1), 0) - - def compute_padding_mode_border(): - for n in range(0, batch): - for h in range(0, out_height): - for w in range(0, out_width): - x, y = grid[n, :, h, w] - y = (y + 1) * (in_height - 1) / 2 - x = (x + 1) * (in_width - 1) / 2 - y0 = int(math.floor(y)) - x0 = int(math.floor(x)) - y1 = y0 + 1 - x1 = x0 + 1 - y0 = get_pixel_value(y0, in_height) - y1 = get_pixel_value(y1, in_height) - x0 = get_pixel_value(x0, in_width) - x1 = get_pixel_value(x1, in_width) - out[n, :, h, w] = data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0)) - out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0) - out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0)) - out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0) - - return out - - if padding_mode == "zeros": - return compute_padding_mode_zeros() - if padding_mode == "border": - return compute_padding_mode_border() - - raise ValueError("invalid padding_mode") - - -def grid_sample_nchw_python(data, grid, method="bilinear", padding_mode="zeros"): +def grid_sample_2d( + data: np.ndarray, + grid: np.ndarray, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, +): + r"""grid_sample_2d for NCHW layout""" + + assert method in ("bilinear", "nearest", "bicubic"), f"{method} is not supported" + assert layout == "NCHW" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert len(data.shape) == len(grid.shape) == 4 + + batch, channel = data.shape[:2] + in_height, in_width = data.shape[2:] + out_height, out_width = grid.shape[2:] + out_shape = [batch, channel, out_height, out_width] + out = np.zeros(out_shape) + + def _get_pixel(b, c, h, w): + if 0 <= h <= in_height - 1 and 0 <= w <= in_width - 1: + return data[b, c, h, w] + return 0 + + def _unnormalize(h, w): + if align_corners: + new_h = (h + 1) * (in_height - 1) / 2 + new_w = (w + 1) * (in_width - 1) / 2 + else: + new_h = -0.5 + (h + 1) * in_height / 2 + new_w = -0.5 + (w + 1) * in_width / 2 + return (new_h, new_w) + + def _clip_coordinates(x, size): + return min(max(x, 0), size - 1) + + def _reflect_coordinates(i, size): + def __refelection(i, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = abs(corner_start - index) + size_times = index_align_corner // size + even = size_times % 2 == 0 + extra = index_align_corner - size_times * size + return extra + corner_start if even else size - extra + corner_start + + if corner_start <= i <= size + corner_start: + new_i = i + else: + new_i = __reflect(i, size, corner_start) + return new_i + + if align_corners: + x = __refelection(i, size - 1, 0) + else: + x = __refelection(i, size, -0.5) + return x + + def _compute_source_index(b, h, w): + y = grid[b, 1, h, w] + x = grid[b, 0, h, w] + y, x = _unnormalize(y, x) + + if padding_mode == "reflection": + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + + return (y, x) + + def _nearest_sample(): + for _b in range(batch): + for _c in range(channel): + for _h in range(out_height): + for _w in range(out_width): + y, x = _compute_source_index(_b, _h, _w) + # python round is not used here, + # beacause it is done toward the even choice + new_y = int(y + 0.5) if y > 0 else int(y - 0.5) + new_x = int(x + 0.5) if x > 0 else int(x - 0.5) + out[_b, _c, _h, _w] = _get_pixel(_b, _c, new_y, new_x) + + def _bilinear_sample(): + for _b in range(batch): + for _c in range(channel): + for _h in range(out_height): + for _w in range(out_width): + y, x = _compute_source_index(_b, _h, _w) + y0 = int(math.floor(y)) + x0 = int(math.floor(x)) + y1 = y0 + 1 + x1 = x0 + 1 + + out[_b, _c, _h, _w] = ( + _get_pixel(_b, _c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) + + _get_pixel(_b, _c, y0, x1) * (1.0 - (y - y0)) * (x - x0) + + _get_pixel(_b, _c, y1, x0) * (y - y0) * (1.0 - (x - x0)) + + _get_pixel(_b, _c, y1, x1) * (y - y0) * (x - x0) + ) + + def _bicubic_sample(): + A = -0.75 + + def cubic_weight_1(x_fraction): + return ((A + 2) * x_fraction - (A + 3)) * x_fraction * x_fraction + 1 + + def cubic_weight_2(x_fraction): + return ((A * x_fraction - 5 * A) * x_fraction + 8 * A) * x_fraction - 4 * A + + def cubic_interp_1d(pixel_0, pixel_1, pixel_2, pixel_3, x_fraction): + weights = [0] * 4 + weights[0] = cubic_weight_2(x_fraction + 1) + weights[1] = cubic_weight_1(x_fraction) + weights[2] = cubic_weight_1(1 - x_fraction) + weights[3] = cubic_weight_2(2 - x_fraction) + + return ( + pixel_0 * weights[0] + + pixel_1 * weights[1] + + pixel_2 * weights[2] + + pixel_3 * weights[3] + ) + + def coefficients_along_x(x_floor, y_floor, x_fraction): + coefficients = [0] * 4 + + for i in range(4): + y_ = y_floor - 1 + i + x_0 = x_floor - 1 + x_1 = x_floor + 0 + x_2 = x_floor + 1 + x_3 = x_floor + 2 + + if padding_mode == "border": + y_ = _clip_coordinates(y_, in_height) + x_0 = _clip_coordinates(x_0, in_width) + x_1 = _clip_coordinates(x_1, in_width) + x_2 = _clip_coordinates(x_2, in_width) + x_3 = _clip_coordinates(x_3, in_width) + + elif padding_mode == "reflection": + y_ = _reflect_coordinates(y_, in_height) + x_0 = _reflect_coordinates(x_0, in_width) + x_1 = _reflect_coordinates(x_1, in_width) + x_2 = _reflect_coordinates(x_2, in_width) + x_3 = _reflect_coordinates(x_3, in_width) + + y_ = int(_clip_coordinates(y_, in_height)) + x_0 = int(_clip_coordinates(x_0, in_width)) + x_1 = int(_clip_coordinates(x_1, in_width)) + x_2 = int(_clip_coordinates(x_2, in_width)) + x_3 = int(_clip_coordinates(x_3, in_width)) + + coefficients[i] = cubic_interp_1d( + _get_pixel(_b, _c, y_, x_0), + _get_pixel(_b, _c, y_, x_1), + _get_pixel(_b, _c, y_, x_2), + _get_pixel(_b, _c, y_, x_3), + x_fraction, + ) + return coefficients + + for _b in range(batch): + for _c in range(channel): + for _h in range(out_height): + for _w in range(out_width): + y = grid[_b, 1, _h, _w] + x = grid[_b, 0, _h, _w] + y, x = _unnormalize(y, x) + y_floor = int(math.floor(y)) + x_floor = int(math.floor(x)) + y_fraction = y - y_floor + x_fraction = x - x_floor + + coefficients = coefficients_along_x(x_floor, y_floor, x_fraction) + + out[_b, _c, _h, _w] = cubic_interp_1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + y_fraction, + ) + if method == "bilinear": - return _bilinear_sample_nchw_python(data, grid, padding_mode) + _bilinear_sample() + elif method == "nearest": + _nearest_sample() + else: # mode == "bicubic": + _bicubic_sample() + + return out + + +def grid_sample_3d( + data: np.ndarray, + grid: np.ndarray, + method="bilinear", + layout="NCDHW", + padding_mode="zeros", + align_corners=True, +): + r"""grid_sample_3d for NCDHW layout""" + + assert method in ("bilinear", "nearest"), f"{method} is not supported" + assert layout == "NCDHW" + assert padding_mode in ("zeros", "border", "reflection"), f"{padding_mode} is not supported" + assert len(data.shape) == len(grid.shape) == 5 + + batch, channel = data.shape[:2] + in_depth, in_height, in_width = data.shape[2:] + out_depth, out_height, out_width = grid.shape[2:] + out_shape = [batch, channel, out_depth, out_height, out_width] + out = np.zeros(out_shape) + + def _get_pixel(b, c, d, h, w): + if 0 <= d <= in_depth - 1 and 0 <= h <= in_height - 1 and 0 <= w <= in_width - 1: + return data[b, c, d, h, w] + return 0 + + def _unnormalize(d, h, w): + if align_corners: + new_d = (d + 1) * (in_depth - 1) / 2 + new_h = (h + 1) * (in_height - 1) / 2 + new_w = (w + 1) * (in_width - 1) / 2 + else: + new_d = -0.5 + (d + 1) * in_depth / 2 + new_h = -0.5 + (h + 1) * in_height / 2 + new_w = -0.5 + (w + 1) * in_width / 2 + return (new_d, new_h, new_w) + + def _clip_coordinates(x, size): + return min(max(x, 0), size - 1) + + def _reflect_coordinates(i, size): + def __refelection(i, size, corner_start): + def __reflect(index, size, corner_start): + index_align_corner = abs(corner_start - index) + size_times = index_align_corner // size + even = size_times % 2 == 0 + extra = index_align_corner - size_times * size + return extra + corner_start if even else size - extra + corner_start + + if corner_start <= i <= size + corner_start: + new_i = i + else: + new_i = __reflect(i, size, corner_start) + return new_i + + if align_corners: + x = __refelection(i, size - 1, 0) + else: + x = __refelection(i, size, -0.5) + return x + + def _compute_source_index(b, d, h, w): + z = grid[b, 2, d, h, w] + y = grid[b, 1, d, h, w] + x = grid[b, 0, d, h, w] + z, y, x = _unnormalize(z, y, x) + + if padding_mode == "reflection": + z = _reflect_coordinates(z, in_depth) + y = _reflect_coordinates(y, in_height) + x = _reflect_coordinates(x, in_width) + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + elif padding_mode == "border": + z = _clip_coordinates(z, in_depth) + y = _clip_coordinates(y, in_height) + x = _clip_coordinates(x, in_width) + return (z, y, x) + + def _nearest_sample(): + for _b in range(batch): + for _c in range(channel): + for _d in range(out_depth): + for _h in range(out_height): + for _w in range(out_width): + z, y, x = _compute_source_index(_b, _d, _h, _w) + # python round is not used here, + # beacause it is done toward the even choice + new_z = int(z + 0.5) if z > 0 else int(z - 0.5) + new_y = int(y + 0.5) if y > 0 else int(y - 0.5) + new_x = int(x + 0.5) if x > 0 else int(x - 0.5) + out[_b, _c, _d, _h, _w] = _get_pixel(_b, _c, new_z, new_y, new_x) + + def _triilinear_sample(): + for _b in range(batch): + for _c in range(channel): + for _d in range(out_depth): + for _h in range(out_height): + for _w in range(out_width): + z, y, x = _compute_source_index(_b, _d, _h, _w) + z0 = int(math.floor(z)) + y0 = int(math.floor(y)) + x0 = int(math.floor(x)) + z1 = z0 + 1 + y1 = y0 + 1 + x1 = x0 + 1 + + out[_b, _c, _d, _h, _w] = ( + _get_pixel(_b, _c, z0, y0, x0) + * (1 - (x - x0)) + * (1 - (y - y0)) + * (1 - (z - z0)) + + _get_pixel(_b, _c, z0, y0, x1) + * (x - x0) + * (1 - (y - y0)) + * (1 - (z - z0)) + + _get_pixel(_b, _c, z1, y1, x0) + * (1 - (x - x0)) + * (y - y0) + * (z - z0) + + _get_pixel(_b, _c, z1, y1, x1) * (x - x0) * (y - y0) * (z - z0) + + _get_pixel(_b, _c, z0, y1, x0) + * (1 - (x - x0)) + * (y - y0) + * (1 - (z - z0)) + + _get_pixel(_b, _c, z1, y0, x1) + * (x - x0) + * (1 - (y - y0)) + * (z - z0) + + _get_pixel(_b, _c, z1, y0, x0) + * (1 - (x - x0)) + * (1 - (y - y0)) + * (z - z0) + + _get_pixel(_b, _c, z0, y1, x1) + * (x - x0) + * (y - y0) + * (1 - (z - z0)) + ) + + if method == "bilinear": + _triilinear_sample() + else: # method == "nearest": + _nearest_sample() + + return out + + +def grid_sample_python( + data: np.ndarray, + grid: np.ndarray, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, +): + r"""grid_sample_3d for NCDHW layout or grid_sample_2d for NCHW layout""" + + if len(data.shape) == 4: + grid_sample = grid_sample_2d + elif len(data.shape) == 5: + grid_sample = grid_sample_3d + else: + raise ValueError("invalid shape") - raise ValueError("invalid method") + return grid_sample(data, grid, method, layout, padding_mode, align_corners) diff --git a/src/relay/op/image/grid_sample.cc b/src/relay/op/image/grid_sample.cc index e0282cc2e8c7..689a71ebc53b 100644 --- a/src/relay/op/image/grid_sample.cc +++ b/src/relay/op/image/grid_sample.cc @@ -103,24 +103,44 @@ bool GridSampleRel(const Array& types, int num_inputs, const Attrs& attrs, if (!data || !grid) return false; const auto* param = attrs.as(); ICHECK(param); - static const Layout kNCHW("NCHW"); const Layout in_layout(param->layout); - auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); - auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, grid->shape[2]); - oshape.Set(3, grid->shape[3]); - // assign output type - reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); - return true; + + if (data->shape.size() == 4) { + static const Layout kNCHW("NCHW"); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, grid->shape[2]); + oshape.Set(3, grid->shape[3]); + + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + return true; + } else if (data->shape.size() == 5) { + static const Layout kNDCHW("NCDHW"); + auto layout_converter = tir::BijectiveLayout(in_layout, kNDCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, grid->shape[2]); + oshape.Set(3, grid->shape[3]); + oshape.Set(4, grid->shape[4]); + + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + return true; + } + + return false; } // Positional relay function to create affine_grid operator // used by frontend FFI. -Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode) { +Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode, + bool align_corners) { auto attrs = make_object(); attrs->method = std::move(method); attrs->layout = std::move(layout); attrs->padding_mode = std::move(padding_mode); + attrs->align_corners = std::move(align_corners); + static const Op& op = Op::Get("image.grid_sample"); return Call(op, {data, grid}, Attrs(attrs), {}); } @@ -133,29 +153,51 @@ RELAY_REGISTER_OP("image.grid_sample") Given :math:`data` and :math:`grid`, then the output is computed by .. math:: + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ - output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) + +For 5-D, the output is computed by + +.. math:: + + x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\ + z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\ + output[batch, channel, z_{src}, y_{dst}, x_{dst}] + = G(data[batch, channel, z_{src}, y_{src}, x_{src}]) :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and :math:`G()` denotes the interpolation function. -The out-boundary points will be padded with zeros. The shape of the output will be -(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). -The operator assumes that :math:`data` has 'NCHW' layout and :math:`grid` has been normalized to [-1, 1]. +The out-boundary points will be padded with zeros if padding_mode is "zeros", or +border pixel value if padding_mode is "border", or +inner pixel value if padding_mode is "reflection". + +The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to +(0, 0) and (h - 1, w - 1) of data if align_corners is "True", or +(-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False". + +The shape of the output will be +4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or +5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). + +The operator assumes that :math:`data` and :math:`grid` has been normalized to [-1, 1]. grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. -- **data**: data is 4D array of shape - (batch_size, channels, in_height, in_width) for NCHW - (batch_size, in_height, in_width, channels) for NHWC +- **data**: data is of 4-D shape (batch_size, channels, in_height, in_width), or + of 5-D shape (batch_size, channels, in_depth, in_height, in_width) -- **grid**: grid is 4D array of shape [batch, 2, out_height, out_width], where each vector - :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` +- **grid**: grid is of 4-D shape [batch, 2, out_height, out_width] + where each vector :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`, + or of 5-D of shape [batch, 3, out_depth, out_height, out_width] + where each vector :math:`out[b, :, d, h, w]` represents the coordinate + :math:`(x, y, z)` -- **out**: out is 4D array of shape - (batch, in_channel, out_height, out_width) for NCHW - (batch_size, in_height, in_width, channels) for NHWC +- **out**: out is of 4-D shape (batch, in_channel, out_height, out_width), or + of 5-D shape [batch, channel, out_depth, out_height, out_width] )code" TVM_ADD_FILELINE) .set_num_inputs(2) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 14705209b464..4cdf55ba34d0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=import-self, invalid-name, unused-argument """Unit tests for various models and operators""" +from contextlib import suppress import os import sys from time import time @@ -4171,22 +4172,44 @@ def test_fn(m, v): def test_grid_sample(): - class Grid_sample_zeros(Module): - def forward(self, x, y): - return torch.nn.functional.grid_sample( - input=x, grid=y, mode="bilinear", padding_mode="zeros", align_corners=True - ) + class Grid_sample(Module): + def __init__(self, method, padding_mode, align_corners): + super().__init__() + self._method = method + self._padding_mode = padding_mode + self._align_corners = align_corners - class Grid_sample_border(Module): def forward(self, x, y): return torch.nn.functional.grid_sample( - input=x, grid=y, mode="bilinear", padding_mode="border", align_corners=True + input=x, + grid=y, + mode=self._method, + padding_mode=self._padding_mode, + align_corners=self._align_corners, ) - data = torch.rand([4, 4, 16, 32]).float() - grid = torch.rand([4, 8, 8, 2]).float() - verify_model(Grid_sample_zeros(), input_data=[data, grid]) - verify_model(Grid_sample_border(), input_data=[data, grid]) + methods = ["nearest", "bilinear", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [True, False] + + data_2D = torch.rand([4, 4, 8, 8]).float() + grid_2D = torch.rand([4, 16, 16, 2]).float() + data_3D = torch.rand([4, 4, 8, 8, 8]).float() + grid_3D = torch.rand([4, 16, 16, 16, 3]).float() + + for _method in methods: + for _padding in padding_modes: + for _align in align_corners: + # ATTENTION: + # "nearest" + "reflection" result may be different with pytorch on cpu device, + # because pytorch's cpu result is different with gpu result, + # and gpu result used here as baseline in tvm topi.image.grid_sample. + model = Grid_sample(_method, _padding, _align) + verify_model(model, input_data=[data_2D, grid_2D]) + + # 3D "bicubic"(tricubic) is not supported in pytorch + if _method != "bicubic": + verify_model(model, input_data=[data_3D, grid_3D]) def test_list_tuple(): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index f162917974a8..10cd91415724 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1397,23 +1397,42 @@ def verify_affine_grid(num_batch, target_shape): @tvm.testing.uses_gpu def test_grid_sample(): - def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"): + def verify_grid_sample( + data_shape, grid_shape, method="bilinear", padding_mode="zeros", align_corners=True + ): dtype = "float32" - batch, channel, _, _ = data_shape - _, _, out_height, out_width = grid_shape data = relay.var("data", relay.ty.TensorType(data_shape, dtype)) grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype)) + + if len(data_shape) == 4: + layout = "NCHW" + batch, channel, _, _ = data_shape + _, _, out_height, out_width = grid_shape + tensor_type = relay.TensorType((batch, channel, out_height, out_width), dtype) + else: # len(data_shape) == 5: + layout = "NCDHW" + batch, channel, _, _, _ = data_shape + _, _, out_depth, out_height, out_width = grid_shape + tensor_type = relay.TensorType( + (batch, channel, out_depth, out_height, out_width), dtype + ) + y = relay.image.grid_sample( - data, grid, method="bilinear", layout="NCHW", padding_mode=padding_mode + data, + grid, + method=method, + layout=layout, + padding_mode=padding_mode, + align_corners=align_corners, ) yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype) + assert yy.checked_type == tensor_type func = relay.Function([data, grid], y) data_np = np.random.uniform(size=data_shape).astype(dtype) grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) - ref_res = tvm.topi.testing.grid_sample_nchw_python( - data_np, grid_np, method="bilinear", padding_mode=padding_mode + ref_res = tvm.topi.testing.grid_sample_python( + data_np, grid_np, method, layout, padding_mode, align_corners ) for target, dev in tvm.testing.enabled_targets(): @@ -1423,10 +1442,23 @@ def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"): ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) - verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border") - verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border") + methods = ["nearest", "bilinear", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [True, False] + + data_2D_shape = (4, 4, 8, 8) + grid_2D_shape = (4, 2, 16, 16) + data_3D_shape = (4, 4, 8, 8, 8) + grid_3D_shape = (4, 3, 16, 16, 16) + + for _method in methods: + for _padding in padding_modes: + for _align in align_corners: + verify_grid_sample(data_2D_shape, grid_2D_shape, _method, _padding, _align) + + # 3D "bicubic"(tricubic) is not supported in pytorch + if _method != "bicubic": + verify_grid_sample(data_3D_shape, grid_3D_shape, _method, _padding, _align) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 9f4b67354075..3aedc8ef4399 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -274,19 +274,26 @@ def check_target(target, dev): @tvm.testing.uses_gpu def test_grid_sample(): - def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"): + def verify_grid_sample( + data_shape, + grid_shape, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, + ): dtype = "float32" data = te.placeholder(data_shape, dtype=dtype) grid = te.placeholder(grid_shape, dtype=dtype) - out = topi.image.grid_sample(data, grid, "bilinear", padding_mode=padding_mode) + out = topi.image.grid_sample(data, grid, method, layout, padding_mode, align_corners) @memoize("topi.tests.test_grid_sample.verify_grid_sample") def get_ref_data(): data_np = np.random.uniform(size=data_shape).astype(dtype) # allow grid values to be out-of-bound grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) - out_np = tvm.topi.testing.grid_sample_nchw_python( - data_np, grid_np, "bilinear", padding_mode + out_np = tvm.topi.testing.grid_sample_python( + data_np, grid_np, method, layout, padding_mode, align_corners ) return data_np, grid_np, out_np @@ -307,9 +314,28 @@ def check_target(target, dev): for target, dev in tvm.testing.enabled_targets(): check_target(target, dev) - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) - verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border") - verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border") + methods = ["nearest", "bilinear", "bicubic"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [True, False] + data_2D_shape = (4, 4, 8, 8) + grid_2D_shape = (4, 2, 16, 16) + layout_2D = "NCHW" + data_3D_shape = (4, 4, 8, 8, 8) + grid_3D_shape = (4, 3, 16, 16, 16) + layout_3D = "NCDHW" + + for _method in methods: + for _padding in padding_modes: + for _align in align_corners: + verify_grid_sample( + data_2D_shape, grid_2D_shape, _method, layout_2D, _padding, _align + ) + + # 3D "bicubic"(tricubic) is not supported in pytorch + if _method != "bicubic": + verify_grid_sample( + data_3D_shape, grid_3D_shape, _method, layout_3D, _padding, _align + ) if __name__ == "__main__":