-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add intepolte_v2 for Paddle2.0 #26520
Conversation
Thanks for your contribution! |
input : (N,C,H_in,W_in) | ||
output: (N,C,H_out,W_out) where: | ||
H_out = round(H_{in} * scale_{factor}) | ||
W_out = round(W_{in} * scale_{factor}) | ||
|
||
Bilinear interpolation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove :alias
from docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable
-> Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update code samples to 2.0 API
@@ -327,7 +322,10 @@ def interpolate(input, | |||
|
|||
if align_mode != 0 and align_mode != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add fast path for dygraph mode
@@ -327,7 +322,10 @@ def interpolate(input, | |||
|
|||
if align_mode != 0 and align_mode != 1: | |||
raise ValueError("align_mode can only be 0 or 1") | |||
|
|||
if align_corners != 0 and resample == 'NEAREST': | |||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also enforce in v2 op.
else: | ||
raise TypeError( | ||
"Attr(scale)'s type should be float, int or Variable.") | ||
|
||
"Attr(scale)'s type should be float, int, list or Variable.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable
-> Tensor
https://en.wikipedia.org/wiki/Trilinear_interpolation. | ||
|
||
Parameters: | ||
input (Variable): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable
.. code-block:: python | ||
import paddle | ||
import numpy as np | ||
import paddle.fluid.dygraph as dg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update to 2.0 API
python/paddle/nn/layer/common.py
Outdated
This op upsamples a batch of images, using nearest neighbours' pixel values. | ||
The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w), | ||
and the upsampling only applies on the two dimensions(height and width). | ||
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
python/paddle/nn/layer/common.py
Outdated
# [2L, 3L, 12L, 12L] | ||
""" | ||
|
||
def __init__(self, size=None, scale_factor=None, data_format='NCHW'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name
python/paddle/nn/layer/common.py
Outdated
self.mode = mode.lower() | ||
self.data_format = data_format | ||
|
||
def forward(self, input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input
-> x
inputs=inputs, | ||
outputs={"Out": out}, | ||
attrs=attrs) | ||
return out | ||
|
||
|
||
def upsample(input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input
-> x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above problems has been fixed, please review again
|
||
input_data = np.random.rand(2,3,6,10).astype("float32") | ||
x = paddle.to_tensor(input_data) | ||
output = F.interpolate(x=x, size=[12,12]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否需要多举个例子说明不同mode的区别
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done~ thx!
9c80d4f
to
25da67a
Compare
interpolating functions of three variables (e.g. D-direction, | ||
H-direction and W-direction in this op) on a rectilinear 3D grid. | ||
The linear interpolation is performed on three directions. | ||
Align_corners and align_mode are optional parameters,the calculation method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align -> align
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@XiaoguangHu01 @lanxianghit 因为这个OP不符合精度规范,请问可以豁免么? |
这个问题应该是因为 numpy 和 c++ 在非整数计算插值的时候精度达不到要求。这次没有改旧的OP,之前的版本一直就在白名单里,新版本也需要加上才能过单测。对比过pytorch和paddle的输出结果,精度误差在1e-6可以对齐。 |
确认可以先加入白名单。跟之前的op保持一致。@luotao1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增补一个pr注释说明一下新增v2 op的考量。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
APIs
Describe
Update interpolate API
新增interpolate_v2 原因:
修改了scale_factor参数的类型,没有办法和之前的OP做到兼容。旧的模型参数无法被新的OP加载,因此增加一个V2 版本。