Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] warp_perspective does not return the input if used with identity matrix #747

Closed
pmeier opened this issue Oct 26, 2020 · 20 comments · Fixed by #896
Closed

[Bug] warp_perspective does not return the input if used with identity matrix #747

pmeier opened this issue Oct 26, 2020 · 20 comments · Fixed by #896
Assignees
Labels
1 Priority 1 🚨 High priority bug 🐛 Something isn't working help wanted Extra attention is needed module: geometry

Comments

@pmeier
Copy link
Contributor

pmeier commented Oct 26, 2020

🐛 Bug

The output of kornia.geometry.warp_perspective does not equal the input if the identity matrix is used.

To Reproduce

Steps to reproduce the behavior:

import torch
from kornia.geometry import warp_perspective

torch.manual_seed(0)

dsize = (32, 16)
src = torch.rand(1, 3, *dsize)
M = torch.eye(3).unsqueeze(0)

dst = warp_perspective(src, M, dsize)

mae = torch.mean(torch.abs(dst - src))
print(mae.item())
0.14952071011066437

Expected behavior

0.0

Environment

kornia==0.4.1

@edgarriba
Copy link
Member

@pmeier I can confirm that happens the same in 0.4.0, 0.3.2, 0.3.1 but no in 0.3.0

@edgarriba edgarriba changed the title warp_perspective does not return the input if used with identity matrix [Bug] warp_perspective does not return the input if used with identity matrix Oct 26, 2020
@edgarriba edgarriba added bug 🐛 Something isn't working help wanted Extra attention is needed 1 Priority 1 🚨 High priority module: geometry labels Oct 26, 2020
@edgarriba
Copy link
Member

@pmeier align_corners should be set to True - this change was introduced in #574

See: https://github.com/kornia/kornia/blob/master/kornia/geometry/transform/imgwarp.py#L48 it's now set False by default.
Not sure how much will affect other components.

This definitely needs better testing: https://github.com/kornia/kornia/blob/master/test/geometry/transform/test_imgwarp.py#L12

@stale
Copy link

stale bot commented Dec 26, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions, and happy coding day 😎

@stale stale bot added the wontfix This will not be worked on label Dec 26, 2020
@pmeier
Copy link
Contributor Author

pmeier commented Dec 26, 2020

@edgarriba This should not be closed by the bot, since the bug won't go away if no one has posted anything new.

@stale
Copy link

stale bot commented Feb 24, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions, and happy coding day 😎

@stale stale bot added the wontfix This will not be worked on label Feb 24, 2021
@edgarriba edgarriba removed the wontfix This will not be worked on label Feb 25, 2021
@edgarriba edgarriba linked a pull request Mar 7, 2021 that will close this issue
22 tasks
@ducha-aiki
Copy link
Member

ducha-aiki commented Mar 10, 2021

@pmeier
Hi, that is indeed a complex issue. The problem is the following:

When transform is "non-descructive", i.e. results in sampling ONLY integer locations from original image, like identity, rotation by 90 degrees, shift by integer number of pixels, the optimal and correct way is to use "align_corners=True", so the data is not changed.
However, for the "destructive" transforms, which do require interpolation, align_corners=False is a better choice, see the image from here:
https://discuss.pytorch.org/t/what-we-should-use-align-corners-false/22663/9?u=ducha-aiki

image

The probably best solution, would be to detect which kind of transform we get by checking the warp matrix, but it is non-trivial to implement. We will be glad to receive help here :)

In the mean time, I would recommend you to send "align_corners" flag manually.

@pmeier
Copy link
Contributor Author

pmeier commented Mar 10, 2021

@ducha-aiki Thanks for the explanation. Given that this leads to a lot of confusion I suggest that you add explanation probably with more details / examples to the documentation so that users can actually make an informed choice on how to set align_corners.

@shijianjian
Copy link
Member

I agree with @pmeier that we might add it in the docs to let the users more aware of the differences of the parameter choice. This pic is a nice demonstration.

@edgarriba
Copy link
Member

Probably a separate page as a side note in the docs would fit and link functions to that. @ducha-aiki or @shijianjian could you go fo this?

@shijianjian
Copy link
Member

Probably adding to the homepage of the geometry module. https://kornia.readthedocs.io/en/latest/geometry.html#

@anibali
Copy link
Contributor

anibali commented Sep 30, 2021

I don't see how this issue has been resolved---doesn't the change of defaults to align_corners=True just hide the fact that there's a problem in warp_perspective with align_corners=False?

Consider the original scenario of warping with the identity matrix and align_corners=False. This works as intended for warp_affine (returning the input unchanged) but not for warp_perspective. Digging into the source code I think the issue is to do with the grid creation.

For warp_affine, the grid is generated with align_corners taken into consideration:

grid = F.affine_grid(src_norm_trans_dst_norm[:, :2, :], [B, C, dsize[0], dsize[1]], align_corners=align_corners)

For warp_perspective, align_corners is not considered at all:

# this piece of code substitutes F.affine_grid since it does not support 3x3
grid = (
create_meshgrid(h_out, w_out, normalized_coordinates=True, device=src.device).to(src.dtype).repeat(B, 1, 1, 1)
)
grid = transform_points(src_norm_trans_dst_norm[:, None, None], grid)

Basically Kornia is not applying the "correction" for when align_corners=False, as can be seen here in the PyTorch source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp#L12-L14
I could be mistaken, but isn't this the problem?

@edgarriba
Copy link
Member

@anibali possibly yes - what would you suggest here ?

@Multihuntr
Copy link

I was investigating how the different libraries handle the sampling issues of interpolating, and I wrote a test script that compared PyTorch, OpenCV and Kornia on a few different settings. There seems to be a few inconsistencies between Kornia and the others. I apologise but I don't have the solution to it. It may be an issue in the PyTorch functions you call?

  • With identity matrix:
    • source image == cv2.warp_affine == cv2.warp_perspective == cv2.resize == torch.nn.Upsample (regardless of align_corners)
    • With align_corners=False:
      • kornia.geometry.warp_affine == source ✔️
      • kornia.geometry.warp_perspective cuts off bottom right and uses border constant of 0. ❗
    • With align_corners=True:
      • kornia.geometry.warp_affine testing a 3x3 stretches top-left 2x2 into the space of 3x3. I suspect an off-by-one in indexing. ❗
      • kornia.geometry.warp_perspective == source ✔️
  • With affine transformation scaling by factor of 2:
    • cv2.warp_affine == cv2.warp_perspective
    • With align_corners=False:
      • cv2.resize == torch.nn.Upsample which is not quite the same as the diagram with blue dots implies;
      • kornia.geometry.warp_affine != kornia.geometry.warp_perspective != cv2.warp_affine:
        • kornia.geometry.warp_affine cuts off bottom right and uses border constant of 0 ❗
        • kornia.geometry.warp_perspective cuts off bottom right, ends up with a border of 0s ❗
    • With align_corners=True:
      • torch.nn.Upsample as per the diagram with the blue dots.
      • kornia.geometry.warp_affine == kornia.geometry.warp_perspective == cv2.warpAffine == cv2.warpPerspective ✔️
        • What's the justification for calling this "align_corners=True"? It's not using a fixed spacing between sampled points like in torch.nn.Upsample.
Here's my code (click to reveal)
import torch
import torch.nn
import torch.nn.functional as F

import kornia.geometry
import cv2
import numpy as np

print('Kornia version', kornia.__version__)

src = torch.tensor([
    [ 0,  2,  4,],
    [ 2,  4,  6,],
    [ 4,  6,  8,],
], dtype=torch.float32)
src_th = src[None, None]
src_cv2 = src.numpy()[..., None]
size = np.array(src.shape)

M1 = torch.eye(3).unsqueeze(0)
S = 2
M2 = torch.tensor([
    [S, 0, 0],
    [0, S, 0],
    [0, 0, 1],
], dtype=torch.float32).unsqueeze(0)


kornia_aff_1F = kornia.geometry.warp_affine(src_th, M1[:, :2], size, align_corners=False)
kornia_aff_1T = kornia.geometry.warp_affine(src_th, M2[:, :2], size, align_corners=True)
kornia_aff_2F = kornia.geometry.warp_affine(src_th, M2[:, :2], size*S, align_corners=False)
kornia_aff_2T = kornia.geometry.warp_affine(src_th, M2[:, :2], size*S, align_corners=True)
kornia_per_1F = kornia.geometry.warp_perspective(src_th, M1, size, align_corners=False)
kornia_per_1T = kornia.geometry.warp_perspective(src_th, M1, size, align_corners=True)
kornia_per_2F = kornia.geometry.warp_perspective(src_th, M2, size*S, align_corners=False)
kornia_per_2T = kornia.geometry.warp_perspective(src_th, M2, size*S, align_corners=True)

th_ups_1F = torch.nn.Upsample([*size], mode='bilinear', align_corners=False)(src_th)
th_ups_1T = torch.nn.Upsample([*size], mode='bilinear', align_corners=True)(src_th)
th_ups_2F = torch.nn.Upsample([*(size*S)], mode='bilinear', align_corners=False)(src_th)
th_ups_2T = torch.nn.Upsample([*(size*S)], mode='bilinear', align_corners=True)(src_th)

cv2_aff_1 = cv2.warpAffine(src_cv2, M1[0, :2].numpy(), size, flags=cv2.INTER_LINEAR)
cv2_aff_2 = cv2.warpAffine(src_cv2, M2[0, :2].numpy(), size*S, flags=cv2.INTER_LINEAR)
cv2_res_1 = cv2.resize(src_cv2, size)
cv2_res_2 = cv2.resize(src_cv2, size*S)
cv2_per_1 = cv2.warpPerspective(src_cv2, M1[0].numpy(), size, flags=cv2.INTER_LINEAR)
cv2_per_2 = cv2.warpPerspective(src_cv2, M2[0].numpy(), size*S, flags=cv2.INTER_LINEAR)

print()
print('=== Test 1: Identity ===')
print('Source tensor:')
print(src)
print()
print('  = Resize =')
print('   - OpenCV')
print(cv2_res_1)
print()
print('  = Upsample =')
print('   - Torch  (align_corners=False):')
print(th_ups_1F)
print()
print('   - Torch  (align_corners=True):')
print(th_ups_1T)
print()
print('  = Warp Affine =')
print('   - Kornia (align_corners=False):')
print(kornia_aff_1F)
print()
print('   - Kornia (align_corners=True):')
print(kornia_aff_1T)
print()
print('   - OpenCV:')
print(cv2_aff_1)
print()
print('  = Warp Perspective =')
print('   - Kornia (align_corners=False):')
print(kornia_per_1F)
print()
print('   - Kornia (align_corners=True):')
print(kornia_per_1T)
print()
print('   - OpenCV:')
print(cv2_per_1)

print()
print()

print()
print(f'=== Test 2: Scale {S:d}x ===')
print('Source tensor:')
print(src)
print()
print('  = Resize =')
print('   - OpenCV')
print(cv2_res_2)
print()
print('  = Upsample =')
print('   - Torch  (align_corners=False):')
print(th_ups_2F)
print()
print('   - Torch  (align_corners=True):')
print(th_ups_2T)
print()
print('  = Warp Affine =')
print('   - Kornia (align_corners=False):')
print(kornia_aff_2F)
print()
print('   - Kornia (align_corners=True):')
print(kornia_aff_2T)
print()
print('   - OpenCV:')
print(cv2_aff_2)
print()
print('  = Warp Perspective =')
print('   - Kornia (align_corners=False):')
print(kornia_per_2F)
print()
print('   - Kornia (align_corners=True):')
print(kornia_per_2T)
print()
print('   - OpenCV:')
print(cv2_per_2)
And here's the output (click to reveal)
Kornia version 0.5.11

=== Test 1: Identity ===
Source tensor:
tensor([[0., 2., 4.],
        [2., 4., 6.],
        [4., 6., 8.]])

  = Resize =
   - OpenCV
[[0. 2. 4.]
 [2. 4. 6.]
 [4. 6. 8.]]

  = Upsample =
   - Torch  (align_corners=False):
tensor([[[[0., 2., 4.],
          [2., 4., 6.],
          [4., 6., 8.]]]])

   - Torch  (align_corners=True):
tensor([[[[0., 2., 4.],
          [2., 4., 6.],
          [4., 6., 8.]]]])

  = Warp Affine =
   - Kornia (align_corners=False):
tensor([[[[0., 2., 4.],
          [2., 4., 6.],
          [4., 6., 8.]]]])

   - Kornia (align_corners=True):
tensor([[[[0., 1., 2.],
          [1., 2., 3.],
          [2., 3., 4.]]]])

   - OpenCV:
[[0. 2. 4.]
 [2. 4. 6.]
 [4. 6. 8.]]

  = Warp Perspective =
   - Kornia (align_corners=False):
tensor([[[[0., 1., 1.],
          [1., 4., 3.],
          [1., 3., 2.]]]])

   - Kornia (align_corners=True):
tensor([[[[0., 2., 4.],
          [2., 4., 6.],
          [4., 6., 8.]]]])

   - OpenCV:
[[0. 2. 4.]
 [2. 4. 6.]
 [4. 6. 8.]]



=== Test 2: Scale 2x ===
Source tensor:
tensor([[0., 2., 4.],
        [2., 4., 6.],
        [4., 6., 8.]])

  = Resize =
   - OpenCV
[[0.  0.5 1.5 2.5 3.5 4. ]
 [0.5 1.  2.  3.  4.  4.5]
 [1.5 2.  3.  4.  5.  5.5]
 [2.5 3.  4.  5.  6.  6.5]
 [3.5 4.  5.  6.  7.  7.5]
 [4.  4.5 5.5 6.5 7.5 8. ]]

  = Upsample =
   - Torch  (align_corners=False):
tensor([[[[0.0000, 0.5000, 1.5000, 2.5000, 3.5000, 4.0000],
          [0.5000, 1.0000, 2.0000, 3.0000, 4.0000, 4.5000],
          [1.5000, 2.0000, 3.0000, 4.0000, 5.0000, 5.5000],
          [2.5000, 3.0000, 4.0000, 5.0000, 6.0000, 6.5000],
          [3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 7.5000],
          [4.0000, 4.5000, 5.5000, 6.5000, 7.5000, 8.0000]]]])

   - Torch  (align_corners=True):
tensor([[[[0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000],
          [0.8000, 1.6000, 2.4000, 3.2000, 4.0000, 4.8000],
          [1.6000, 2.4000, 3.2000, 4.0000, 4.8000, 5.6000],
          [2.4000, 3.2000, 4.0000, 4.8000, 5.6000, 6.4000],
          [3.2000, 4.0000, 4.8000, 5.6000, 6.4000, 7.2000],
          [4.0000, 4.8000, 5.6000, 6.4000, 7.2000, 8.0000]]]])

  = Warp Affine =
   - Kornia (align_corners=False):
tensor([[[[0.0000, 0.7109, 1.7266, 2.7422, 2.2344, 0.2031],
          [0.7109, 1.7500, 3.0000, 4.2500, 3.3516, 0.3047],
          [1.7266, 3.0000, 4.2500, 5.5000, 4.2109, 0.3828],
          [2.7422, 4.2500, 5.5000, 6.7500, 5.0703, 0.4609],
          [2.2344, 3.3516, 4.2109, 5.0703, 3.7812, 0.3438],
          [0.2031, 0.3047, 0.3828, 0.4609, 0.3438, 0.0313]]]])

   - Kornia (align_corners=True):
tensor([[[[0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 2.0000],
          [1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 2.5000],
          [2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 3.0000],
          [3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000],
          [4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 4.0000],
          [2.0000, 2.5000, 3.0000, 3.5000, 4.0000, 2.0000]]]])

   - OpenCV:
[[0.  1.  2.  3.  4.  2. ]
 [1.  2.  3.  4.  5.  2.5]
 [2.  3.  4.  5.  6.  3. ]
 [3.  4.  5.  6.  7.  3.5]
 [4.  5.  6.  7.  8.  4. ]
 [2.  2.5 3.  3.5 4.  2. ]]

  = Warp Perspective =
   - Kornia (align_corners=False):
tensor([[[[0.0000, 0.2500, 1.0000, 1.7500, 1.0000, 0.0000],
          [0.2500, 1.0000, 2.5000, 4.0000, 2.2500, 0.0000],
          [1.0000, 2.5000, 4.0000, 5.5000, 3.0000, 0.0000],
          [1.7500, 4.0000, 5.5000, 7.0000, 3.7500, 0.0000],
          [1.0000, 2.2500, 3.0000, 3.7500, 2.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])

   - Kornia (align_corners=True):
tensor([[[[0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 2.0000],
          [1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 2.5000],
          [2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 3.0000],
          [3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000],
          [4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 4.0000],
          [2.0000, 2.5000, 3.0000, 3.5000, 4.0000, 2.0000]]]])

   - OpenCV:
[[0.  1.  2.  3.  4.  2. ]
 [1.  2.  3.  4.  5.  2.5]
 [2.  3.  4.  5.  6.  3. ]
 [3.  4.  5.  6.  7.  3.5]
 [4.  5.  6.  7.  8.  4. ]
 [2.  2.5 3.  3.5 4.  2. ]]

On the blue dots diagram (click to reveal)

Specifically when resizing, OpenCV and PyTorch don't extrapolate beyond the source image boundaries (in contrast to warping, where OpenCV does extrapolate). Instead the interpolation indices get clipped to the image boundary, so when resizing in PyTorch and OpenCV we actually end up with data like this:
better-blue
This is to explain my earlier comment that it didn't match the blue dots diagram reference earlier in this issue. The original blue dots diagram is more appropriate for discussions on warping, since you choose how to pad the original image for warping, and thus can extrapolate beyond the original image boundaries.

@edgarriba
Copy link
Member

Our initial intention here was to mimic opencv behaviour. In the early stages, pytorch had default align_corners==True which as showed here is what we mostly support and assume by default. Was later that align_córners was introduced to generate the grid and overcomplicated our logic. Possibly the question is that I from Kornia should be really expose align_corners at all and if it makes sense for our users. /cc @ducha-aiki

This is possibly one of the most core functionalities in the lib so I’d love to hear opinions and use cases on that end.

@Multihuntr
Copy link

Multihuntr commented Oct 4, 2021

Because I can't help myself I dug into the code and figured out what was going on. The issue is that your code incorrectly uses F.grid_sample(..., align_corners=False). More specifically, your meshgrid coordinates are always in [-1, 1], but saying align_corners=False tells F.grid_sample that your indices into the original image are in the range [-(n-1)/n, (n-1)/n], which it never is. e.g.

src = torch.tensor([
    [10., 20, 30],
    [40, 50, 60],
    [70, 80, 90],
])[None, None]
points_1 = torch.tensor([[-1., -1.], [1., 1.]])[None, None]
points_2 = torch.tensor([[-2/3, -2/3], [2/3, 2/3]])[None, None]
sample_1 = F.grid_sample(src, points_1, align_corners=True)
sample_2 = F.grid_sample(src, points_2, align_corners=False)
print(sample_1[0, 0, 0, :]) # tensor([10., 90])
print(sample_2[0, 0, 0, :]) # tensor([10., 90])

Another way to put it: the align_corners argument of F.grid_sample is asking "did you already align_corners?", not, as your code currently assumes "should I align_corners?" That is, because your meshgrid is always in the range [-1, 1], the answer should always be "yes".

tl;dr kornia.geometry.warp_perspective should always use F.grid_sample(..., align_corners=True).

Does the concept of align_corners even make sense for warp_perspective?

I don't think it does. We can calculate both torch.nn.Upsample(..., align_corners=True) and torch.nn.Upsample(..., align_corners=False) using kornia.geometry.warp_perspective using the right arguments. So, including align_corners as an argument for warp_perspective doesn't seem to make sense.

Here's a code snippet showing this
src = torch.tensor([
    [ 0,  2,  4,],
    [ 2,  4,  6,],
    [ 4,  6,  8,],
], dtype=torch.float32)
src_th = src[None, None]
size = np.array(src.shape)
h, w = size

# as torch.nn.Upsample(..., align_corners=False)
M_F = torch.tensor([
    [2, 0, 0.5],
    [0, 2, 0.5],
    [0, 0, 1],
], dtype=torch.float32).unsqueeze(0)

th_upsample_F = torch.nn.Upsample([*(size*2)], mode='bilinear', align_corners=False)(src_th)
kornia_as_th_F = kornia.geometry.warp_perspective(
    src_th, M_F, size*2,
    padding_mode='border',
    align_corners=True)


# as torch.nn.Upsample(..., align_corners=True)
R = (2*3-1)/2 # I don't know the true formula for this...
M_T = torch.tensor([
    [R, 0, 0],
    [0, R, 0],
    [0, 0, 1],
], dtype=torch.float32).unsqueeze(0)

th_upsample_T = torch.nn.Upsample([*(size*2)], mode='bilinear', align_corners=True)(src_th)
kornia_as_th_T = kornia.geometry.warp_perspective(
    src_th, M_T, size*2,
    padding_mode='reflection',
    align_corners=True)

print('  ==  align_corners=False  ==')
print('   - Torch upsample')
print(th_upsample_F[0, 0])
print('   - Kornia warp_perspective')
print(kornia_as_th_F[0, 0])
print('  ==  align_corners=True   ==')
print('   - Torch upsample')
print(th_upsample_T[0, 0])
print('   - Kornia warp_perspective')
print(kornia_as_th_T[0, 0])
  ==  align_corners=False  ==
   - Torch upsample
tensor([[0.0000, 0.5000, 1.5000, 2.5000, 3.5000, 4.0000],
        [0.5000, 1.0000, 2.0000, 3.0000, 4.0000, 4.5000],
        [1.5000, 2.0000, 3.0000, 4.0000, 5.0000, 5.5000],
        [2.5000, 3.0000, 4.0000, 5.0000, 6.0000, 6.5000],
        [3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 7.5000],
        [4.0000, 4.5000, 5.5000, 6.5000, 7.5000, 8.0000]])
   - Kornia warp_perspective
tensor([[0.0000, 0.5000, 1.5000, 2.5000, 3.5000, 4.0000],
        [0.5000, 1.0000, 2.0000, 3.0000, 4.0000, 4.5000],
        [1.5000, 2.0000, 3.0000, 4.0000, 5.0000, 5.5000],
        [2.5000, 3.0000, 4.0000, 5.0000, 6.0000, 6.5000],
        [3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 7.5000],
        [4.0000, 4.5000, 5.5000, 6.5000, 7.5000, 8.0000]])
  ==  align_corners=True   ==
   - Torch upsample
tensor([[0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000],
        [0.8000, 1.6000, 2.4000, 3.2000, 4.0000, 4.8000],
        [1.6000, 2.4000, 3.2000, 4.0000, 4.8000, 5.6000],
        [2.4000, 3.2000, 4.0000, 4.8000, 5.6000, 6.4000],
        [3.2000, 4.0000, 4.8000, 5.6000, 6.4000, 7.2000],
        [4.0000, 4.8000, 5.6000, 6.4000, 7.2000, 8.0000]])
   - Kornia warp_perspective
tensor([[0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000],
        [0.8000, 1.6000, 2.4000, 3.2000, 4.0000, 4.8000],
        [1.6000, 2.4000, 3.2000, 4.0000, 4.8000, 5.6000],
        [2.4000, 3.2000, 4.0000, 4.8000, 5.6000, 6.4000],
        [3.2000, 4.0000, 4.8000, 5.6000, 6.4000, 7.2000],
        [4.0000, 4.8000, 5.6000, 6.4000, 7.2000, 8.0000]])

The purpose of align_corners for resizing is to choose between alternative resolutions to sampling problems when resizing images (see this blog post). These resolutions are needed because of a lack of flexibility in the input options of the resize functions. But we already have enough flexibility in the other arguments to warp_perspective to do anything we would want.

@edgarriba
Copy link
Member

@Multihuntr thanks for long detailed explanation. One possible thing I see here would be allowing create_meshgrid to receive align_corners to compensate. But as you said, there are too many possible combinations vs use cases.

@Multihuntr
Copy link

Multihuntr commented Oct 6, 2021

@edgarriba I apologise for the long detailed explanation being overly long and detailed and possibly hard to read. I'm still straightening this out in my own head. 😅

Conceptually torch.nn.Upsample is a special case of warp_affine and has an align_corners argument because they recognised two possible affine transformations for the same scaling factor. So my opinion is that align_corners is a nonsense argument in warp_affine and warp_perspective and the "solution" is to get rid of it.

Note 1: I disagree with ducha-aiki. The blue dots diagram they shared is for torch.nn.Upsample, and doesn't apply to warp_affine/warp_perspective in the sense that these functions should not aim to replicate the sampling shown in the diagram via this "align_corners" parameter, instead, that's up to the user to decide by providing whatever transform they actually want.

Note 2: There are more than two possible interpretations of "upsample my image by 2", but all possible interpretations can be expressed using the other existing arguments (M, dsize, mode, and padding_mode).

@edgarriba
Copy link
Member

@Multihuntr assuming that possibly makes sense to get rid of align_corners from warp_affine, warp_perspective. This might requires some work to verify that we don't break any other component since we highly rely on this operator.

@pmeier @anibali @ducha-aiki any opinions about this deprecation ? I think that we introduced support to align_corners because pytorch did, but we never discussed the real need for us.

@anibali
Copy link
Contributor

anibali commented Oct 7, 2021

I agree with the deprecation. It is further supported by the fact that no other library seems to have align_corners or equivalent for image-transform-by-matrix functions:

In my opinion the addition of an align_corners option is confusing, of limited practical use, and increases the likelihood of doing bad things like transforming images and point annotations in subtly different ways despite using the same transformation matrix.

@edgarriba
Copy link
Member

thanks so much - I'll open a separated issue to track the deprecation of align_corners. This might requires some extra work - so not sure if we can fit it in the next 0.6 release coming next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1 Priority 1 🚨 High priority bug 🐛 Something isn't working help wanted Extra attention is needed module: geometry
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants