Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete pytorch grid_sample #10504

Merged
merged 1 commit into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,23 +276,44 @@ struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
String method;
String layout;
String padding_mode;
bool align_corners;

TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
TVM_ATTR_FIELD(method)
.set_default("bilinear")
.describe(
"Specify the mode to use for scaling."
"bilinear - Bilinear Interpolation");
"nearest - 2D or 3D Nearest Interpolation."
"bilinear - '2D Bilinear' or '3D Trilinear' Interpolation."
"bicubic - 2D Bicubic Interpolation.");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
"Dimension ordering of input data. Can be 'NCHW', 'NCDHW', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively."
"2D Resize is applied on the 'H' and 'W' dimensions."
"3D Resize is applied on the 'D' and 'H' and 'W' dimensions.");
TVM_ATTR_FIELD(padding_mode)
.set_default("zeros")
.describe(
"Specify the padding mode to use."
"zeros, border etc.");
"If :attr:'grid' has values outside the range of '[-1, 1]', the corresponding"
"outputs are handled as defined by padding_mode. Options are"
"padding_mode='zeros': use '0' for out-of-bound grid locations,"
"padding_mode='border': use border values for out-of-bound grid locations"
"padding_mode='reflection': use values at locations reflected by"
"the border for out-of-bound grid locations. For location far away"
"from the border, it will keep being reflected until becoming in bound,"
"e.g., (normalized) pixel location 'x = -3.5' reflects by border '-1'"
"and becomes 'x' = 1.5, then reflects by border '1' and becomes"
"'x' = -0.5");
TVM_ATTR_FIELD(align_corners)
.set_default(true)
.describe(
"Geometrically, we consider the pixels of the"
"input as squares rather than points."
"If set to True, the extrema (-1 and 1) are considered as referring"
"to the center points of the input's corner pixels. If set to False, they"
"are instead considered as referring to the corner points of the input's corner"
"pixels, making the sampling more resolution agnostic.");
}
};

Expand Down
51 changes: 37 additions & 14 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2931,23 +2931,46 @@ def mv(self, inputs, _):
return _op.transform.squeeze(dense_result)

def grid_sampler(self, inputs, input_types):
if inputs[2] == 0:
mode = "bilinear"
interpolate_mode = inputs[2]
padding_mode = inputs[3]
align_corners = inputs[4]
data_shape = self.infer_shape_with_prelude(inputs[0])

if len(data_shape) == 4:
layout = "NCHW"
axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
elif len(data_shape) == 5:
layout = "NCDHW"
axes = [0, 4, 1, 2, 3]
grid = _op.transform.transpose(inputs[1], axes)
else:
msg = "Only bilinear mode is supported in grid_sampler"
raise NotImplementedError(msg)

if inputs[3] == 0:
padding_mode = "zeros"
elif inputs[3] == 1:
padding_mode = "border"
msg = f"only 4D and 5D are supported."
raise ValueError(msg)

if interpolate_mode == 0:
interpolate_str = "bilinear"
elif interpolate_mode == 1:
interpolate_str = "nearest"
elif interpolate_mode == 2:
interpolate_str = "bicubic"
else:
msg = "Only zeros and border padding mode are supported in grid_sampler"
raise NotImplementedError(msg)
msg = f"interpolation method {interpolate_mode} is not supported"
raise ValueError(msg)

if padding_mode == 0:
padding_mode_str = "zeros"
elif padding_mode == 1:
padding_mode_str = "border"
elif padding_mode == 2:
padding_mode_str = "reflection"
else:
msg = f"padding_mode {padding_mode} is not supported"
raise ValueError(msg)

axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode)
return _op.image.grid_sample(
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
)

# Operator mappings
def create_convert_map(self):
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,17 @@ def compute_grid_sample(attrs, inputs, out_dtype):
method = attrs.method
layout = attrs.layout
padding_mode = attrs.padding_mode
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)]
align_corners = attrs.align_corners
return [
topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode, align_corners)
]


reg.register_injective_schedule("image.grid_sample")


@script
def _grid_sample_func(data, grid):
def _grid_sample_func_nchw(data, grid):
out = output_tensor((4,), "int64")
out[0] = int64(data[0])
out[1] = int64(data[1])
Expand All @@ -382,9 +385,27 @@ def _grid_sample_func(data, grid):
return out


@script
def _grid_sample_func_ncdhw(data, grid):
out = output_tensor((5,), "int64")
out[0] = int64(data[0])
out[1] = int64(data[1])
out[2] = int64(grid[2])
out[3] = int64(grid[3])
out[4] = int64(grid[4])
return out


@reg.register_shape_func("image.grid_sample", False)
def grid_sample_func(attrs, inputs, _):
"""
Shape function for grid_sample op.
"""
return [_grid_sample_func(inputs[0], inputs[1])]
if attrs.layout == "NCHW":
script_func = _grid_sample_func_nchw
elif attrs.layout == "NCDHW":
script_func = _grid_sample_func_ncdhw
else:
msg = f"layout {attrs.layout} is not supported"
raise ValueError(msg)
return [script_func(inputs[0], inputs[1])]
46 changes: 34 additions & 12 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,22 +455,33 @@ def affine_grid(data, target_shape=None):
return _make.affine_grid(data, target_shape)


def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
"""Applies bilinear sampling to input feature map.
def grid_sample(
data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True
):
"""Applies grid sampling to input feature map.

Given :math:`data` and :math:`grid`, then the output is computed by
Given :math:`data` and :math:`grid`, then for 4-D the output is computed by

.. math::

x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src})
output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}])

:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
:math:`G()` denotes the interpolation function.
The out-boundary points will be padded with zeros if padding_mode is "zeros".

The out-boundary points will be padded with zeros if padding_mode is "zeros", or
border pixel value if padding_mode is "border", or
inner pixel value if padding_mode is "reflection".

The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to
(0, 0) and (h - 1, w - 1) of data if align_corners is "True", or
(-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is "False".

The shape of the output will be
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or
5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]).

The operator assumes that :math:`grid` has been normalized to [-1, 1].

Expand All @@ -479,23 +490,34 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
4-D with shape [batch, in_channel, in_height, in_width], or
5-D with shape [batch, in_channel, in_depth, in_height, in_width]

grid : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
4-D with shape [batch, 2, out_height, out_width], or
5-D with shape [batch, 3, out_depth, out_height, out_width]

method : str
The interpolation method. Only 'bilinear' is supported.
The interpolation method, 4-D "nearest", "bilinear", "bicubic" and
5-D "nearest", "bilinear"("trilinear") are supported.

layout : str
The layout of input data and the output.

padding_mode : str
The padding mode for outside grid values.
The padding mode for outside grid values, "zeros", "border", "reflection" are supported.

align_corners: bool
Geometrically, we consider the pixels of the input as squares rather than points.
If set to "True", the extrema ("-1" and "1") are considered as referring
to the center points of the input corner pixels. If set to "False", they
are instead considered as referring to the corner points of the input corner
pixels, making the sampling more resolution agnostic.

Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
4-D with shape [batch, in_channel, out_height, out_width], or
5-D with shape [batch, in_channel, out_depth, out_height, out_width]
"""
return _make.grid_sample(data, grid, method, layout, padding_mode)
return _make.grid_sample(data, grid, method, layout, padding_mode, align_corners)
Loading