Skip to content

Commit

Permalink
Complete pytorch grid_sample
Browse files Browse the repository at this point in the history
Pytorch's grid_sample() supports various interpolation options:
(1) data dimension: 2D / 3D
(2) interpolation method: nearest / bilinear / bicubic
(3) padding_mode: zeros / border / reflection
(4) align_corners: True / False

However, TVM only supports a part of above options:
(1) data dimension: 2D
(2) interpolation method: bilinear
(3) padding_mode: zeros / border
(4) align_corners: True

This commit completes the options not supported by TVM, and keeps existing
grid_sample of onnx/pytorch uninfluenced.

Co-authored-by:  shukun.net
  • Loading branch information
Ziqiang XU committed Apr 5, 2022
1 parent ae285c6 commit 435f094
Show file tree
Hide file tree
Showing 11 changed files with 1,094 additions and 177 deletions.
35 changes: 28 additions & 7 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,23 +276,44 @@ struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
String method;
String layout;
String padding_mode;
bool align_corners;

TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
TVM_ATTR_FIELD(method)
.set_default("bilinear")
.describe(
"Specify the mode to use for scaling."
"bilinear - Bilinear Interpolation");
"nearest - 2D or 3D Nearest Interpolation."
"bilinear - '2D Bilinear' or '3D Trilinear' Interpolation."
"bicubic - 2D Bicubic Interpolation.");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
"Dimension ordering of input data. Can be 'NCHW', 'NCDHW', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively."
"2D Resize is applied on the 'H' and 'W' dimensions."
"3D Resize is applied on the 'D' and 'H' and 'W' dimensions.");
TVM_ATTR_FIELD(padding_mode)
.set_default("zeros")
.describe(
"Specify the padding mode to use."
"zeros, border etc.");
"If :attr:'grid' has values outside the range of '[-1, 1]', the corresponding"
"outputs are handled as defined by padding_mode. Options are"
"padding_mode='zeros': use '0' for out-of-bound grid locations,"
"padding_mode='border': use border values for out-of-bound grid locations"
"padding_mode='reflection': use values at locations reflected by"
"the border for out-of-bound grid locations. For location far away"
"from the border, it will keep being reflected until becoming in bound,"
"e.g., (normalized) pixel location 'x = -3.5' reflects by border '-1'"
"and becomes 'x' = 1.5, then reflects by border '1' and becomes"
"'x' = -0.5");
TVM_ATTR_FIELD(align_corners)
.set_default(true)
.describe(
"Geometrically, we consider the pixels of the"
"input as squares rather than points."
"If set to True, the extrema (-1 and 1) are considered as referring"
"to the center points of the input's corner pixels. If set to False, they"
"are instead considered as referring to the corner points of the input's corner"
"pixels, making the sampling more resolution agnostic.");
}
};

Expand Down
51 changes: 37 additions & 14 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2931,23 +2931,46 @@ def mv(self, inputs, _):
return _op.transform.squeeze(dense_result)

def grid_sampler(self, inputs, input_types):
if inputs[2] == 0:
mode = "bilinear"
interpolate_mode = inputs[2]
padding_mode = inputs[3]
align_corners = inputs[4]
data_shape = self.infer_shape_with_prelude(inputs[0])

if len(data_shape) == 4:
layout = "NCHW"
axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
elif len(data_shape) == 5:
layout = "NCDHW"
axes = [0, 4, 1, 2, 3]
grid = _op.transform.transpose(inputs[1], axes)
else:
msg = "Only bilinear mode is supported in grid_sampler"
raise NotImplementedError(msg)

if inputs[3] == 0:
padding_mode = "zeros"
elif inputs[3] == 1:
padding_mode = "border"
msg = f"only 4D and 5D are supported."
raise ValueError(msg)

if interpolate_mode == 0:
interpolate_str = "bilinear"
elif interpolate_mode == 1:
interpolate_str = "nearest"
elif interpolate_mode == 2:
interpolate_str = "bicubic"
else:
msg = "Only zeros and border padding mode are supported in grid_sampler"
raise NotImplementedError(msg)
msg = f"interpolation method {interpolate_mode} is not supported"
raise ValueError(msg)

if padding_mode == 0:
padding_mode_str = "zeros"
elif padding_mode == 1:
padding_mode_str = "border"
elif padding_mode == 2:
padding_mode_str = "reflection"
else:
msg = f"padding_mode {padding_mode} is not supported"
raise ValueError(msg)

axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode)
return _op.image.grid_sample(
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
)

# Operator mappings
def create_convert_map(self):
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,17 @@ def compute_grid_sample(attrs, inputs, out_dtype):
method = attrs.method
layout = attrs.layout
padding_mode = attrs.padding_mode
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)]
align_corners = attrs.align_corners
return [
topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode, align_corners)
]


reg.register_injective_schedule("image.grid_sample")


@script
def _grid_sample_func(data, grid):
def _grid_sample_func_nchw(data, grid):
out = output_tensor((4,), "int64")
out[0] = int64(data[0])
out[1] = int64(data[1])
Expand All @@ -382,9 +385,27 @@ def _grid_sample_func(data, grid):
return out


@script
def _grid_sample_func_ncdhw(data, grid):
out = output_tensor((5,), "int64")
out[0] = int64(data[0])
out[1] = int64(data[1])
out[2] = int64(grid[2])
out[3] = int64(grid[3])
out[4] = int64(grid[4])
return out


@reg.register_shape_func("image.grid_sample", False)
def grid_sample_func(attrs, inputs, _):
"""
Shape function for grid_sample op.
"""
return [_grid_sample_func(inputs[0], inputs[1])]
if attrs.layout == "NCHW":
script_func = _grid_sample_func_nchw
elif attrs.layout == "NCDHW":
script_func = _grid_sample_func_ncdhw
else:
msg = f"layout {attrs.layout} is not supported"
raise ValueError(msg)
return [script_func(inputs[0], inputs[1])]
54 changes: 43 additions & 11 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,22 +455,43 @@ def affine_grid(data, target_shape=None):
return _make.affine_grid(data, target_shape)


def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
"""Applies bilinear sampling to input feature map.
def grid_sample(
data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True
):
"""Applies grid sampling to input feature map.
Given :math:`data` and :math:`grid`, then the output is computed by
Given :math:`data` and :math:`grid`, then for 4-D the output is computed by
.. math::
x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src})
For 5-D, the output is computed by
.. math::
x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\
y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\
z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\
output[batch, channel, z_{src}, y_{dst}, x_{dst}]
= G(data[batch, channel, z_{src}, y_{src}, x_{src})
:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
:math:`G()` denotes the interpolation function.
The out-boundary points will be padded with zeros if padding_mode is "zeros".
The out-boundary points will be padded with zeros if padding_mode is `zeros`, or
border pixel value if padding_mode is `border`, or
inner pixel value if padding_mode is `reflection`.
The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to
(0, 0) and (h - 1, w - 1) of data if align_corners is `True`, or
(-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is `False`.
The shape of the output will be
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or
5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]).
The operator assumes that :math:`grid` has been normalized to [-1, 1].
Expand All @@ -479,23 +500,34 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
4-D with shape [batch, in_channel, in_height, in_width], or
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
grid : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
4-D with shape [batch, 2, out_height, out_width], or
5-D with shape [batch, 3, out_depth, out_height, out_width]
method : str
The interpolation method. Only 'bilinear' is supported.
The interpolation method, 4-D `nearest`, `bilinear`, `bicubic` and
5-D `nearest`, `bilinear`(trilinear) are supported.
layout : str
The layout of input data and the output.
padding_mode : str
The padding mode for outside grid values.
The padding mode for outside grid values, `zeros`, `border`, `reflection` are supported.
align_corners: bool
Geometrically, we consider the pixels of the input as squares rather than points.
If set to `True`, the extrema (`-1` and `1`) are considered as referring
to the center points of the input's corner pixels. If set to `False`, they
are instead considered as referring to the corner points of the input's corner
pixels, making the sampling more resolution agnostic.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
4-D with shape [batch, in_channel, out_height, out_width], or
5-D with shape [batch, in_channel, out_depth, out_height, out_width]
"""
return _make.grid_sample(data, grid, method, layout, padding_mode)
return _make.grid_sample(data, grid, method, layout, padding_mode, align_corners)
Loading

0 comments on commit 435f094

Please sign in to comment.