Skip to content

Commit

Permalink
fix and refactor test_warp_perspective
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Apr 3, 2019
1 parent 944620a commit d19121e
Showing 1 changed file with 126 additions and 80 deletions.
206 changes: 126 additions & 80 deletions test/test_imgwarp.py
Expand Up @@ -3,13 +3,13 @@
import torch
import torchgeometry as tgm
from torch.autograd import gradcheck
from torch.testing import assert_allclose

import utils # test utils
from common import device_type


@pytest.mark.parametrize("batch_shape",
[(1, 1, 7, 32), (2, 3, 16, 31)])
@pytest.mark.parametrize("batch_shape", [(1, 1, 7, 32), (2, 3, 16, 31)])
def test_warp_perspective_rotation(batch_shape, device_type):
# generate input data
batch_size, channels, height, width = batch_shape
Expand All @@ -29,23 +29,28 @@ def test_warp_perspective_rotation(batch_shape, device_type):
# apply transformation and inverse
_, _, h, w = patch.shape
patch_warped = tgm.warp_perspective(patch, M, dsize=(height, width))
patch_warped_inv = tgm.warp_perspective(patch_warped, torch.inverse(M),
dsize=(height, width))
patch_warped_inv = tgm.warp_perspective(
patch_warped, torch.inverse(M), dsize=(height, width))

# generate mask to compute error
mask = torch.ones_like(patch)
mask_warped_inv = tgm.warp_perspective(
tgm.warp_perspective(patch, M, dsize=(height, width)),
torch.inverse(M), dsize=(height, width))
torch.inverse(M),
dsize=(height, width))

assert utils.check_equal_torch(mask_warped_inv * patch,
mask_warped_inv * patch_warped_inv)

# evaluate function gradient
patch = utils.tensor_to_gradcheck_var(patch) # to var
M = utils.tensor_to_gradcheck_var(M, requires_grad=False) # to var
assert gradcheck(tgm.warp_perspective, (patch, M, (height, width,)),
raise_exception=True)
assert gradcheck(
tgm.warp_perspective, (patch, M, (
height,
width,
)),
raise_exception=True)


@pytest.mark.parametrize("batch_size", [1, 2, 5])
Expand Down Expand Up @@ -74,8 +79,12 @@ def test_get_perspective_transform(batch_size, device_type):
# compute gradient check
points_src = utils.tensor_to_gradcheck_var(points_src) # to var
points_dst = utils.tensor_to_gradcheck_var(points_dst) # to var
assert gradcheck(tgm.get_perspective_transform,
(points_src, points_dst,), raise_exception=True)
assert gradcheck(
tgm.get_perspective_transform, (
points_src,
points_dst,
),
raise_exception=True)


@pytest.mark.parametrize("batch_size", [1, 2, 5])
Expand Down Expand Up @@ -126,75 +135,107 @@ def test_rotation_matrix2d(batch_size, device_type):
center = utils.tensor_to_gradcheck_var(center) # to var
angle = utils.tensor_to_gradcheck_var(angle) # to var
scale = utils.tensor_to_gradcheck_var(scale) # to var
assert gradcheck(tgm.get_rotation_matrix2d, (center, angle, scale),
raise_exception=True)


@pytest.mark.parametrize("batch_size", [1, 2, 5])
@pytest.mark.parametrize("channels", [1, 5])
def test_warp_perspective_crop(batch_size, device_type, channels):
# generate input data
src_h, src_w = 3, 4
dst_h, dst_w = 3, 2
device = torch.device(device_type)

# [x, y] origin
# top-left, top-right, bottom-right, bottom-left
points_src = torch.rand(batch_size, 4, 2).to(device)
points_src[:, :, 0] *= dst_h
points_src[:, :, 1] *= dst_w

# [x, y] destination
# top-left, top-right, bottom-right, bottom-left
points_dst = torch.zeros_like(points_src)
points_dst[:, 1, 0] = dst_w - 1
points_dst[:, 2, 0] = dst_w - 1
points_dst[:, 2, 1] = dst_h - 1
points_dst[:, 3, 1] = dst_h - 1

# compute transformation between points
dst_pix_trans_src_pix = tgm.get_perspective_transform(
points_src, points_dst)

# create points grid in normalized coordinates
grid_src_norm = tgm.utils.create_meshgrid(
src_h, src_w, normalized_coordinates=True)
grid_src_norm = grid_src_norm.repeat(batch_size, 1, 1, 1).to(device)

# create points grid in pixel coordinates
grid_src_pix = tgm.utils.create_meshgrid(
src_h, src_w, normalized_coordinates=False)
grid_src_pix = grid_src_pix.repeat(batch_size, 1, 1, 1).to(device)

src_norm_trans_src_pix = tgm.normal_transform_pixel(src_h, src_w).repeat(
batch_size, 1, 1).to(device)
src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix)

dst_norm_trans_dst_pix = tgm.normal_transform_pixel(dst_h, dst_w).repeat(
batch_size, 1, 1).to(device)

# transform pixel grid
grid_dst_pix = tgm.transform_points(
dst_pix_trans_src_pix.unsqueeze(1), grid_src_pix)
grid_dst_norm = tgm.transform_points(
dst_norm_trans_dst_pix.unsqueeze(1), grid_dst_pix)

# transform norm grid
dst_norm_trans_src_norm = torch.matmul(
dst_norm_trans_dst_pix, torch.matmul(
dst_pix_trans_src_pix, src_pix_trans_src_norm))
grid_dst_norm2 = tgm.transform_points(
dst_norm_trans_src_norm.unsqueeze(1), grid_src_norm)

# grids should be equal
# TODO: investage why precision is that low
assert utils.check_equal_torch(grid_dst_norm, grid_dst_norm2, 1e-2)

# warp tensor
patch = torch.rand(batch_size, channels, src_h, src_w)
patch_warped = tgm.warp_perspective(
patch, dst_pix_trans_src_pix, (dst_h, dst_w))
assert patch_warped.shape == (batch_size, channels, dst_h, dst_w)
assert gradcheck(
tgm.get_rotation_matrix2d, (center, angle, scale),
raise_exception=True)


class TestWarpPerspective:
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("channels", [1, 5])
def test_crop(self, device_type, batch_size, channels):
# generate input data
src_h, src_w = 3, 3
dst_h, dst_w = 3, 3
device = torch.device(device_type)

# [x, y] origin
# top-left, top-right, bottom-right, bottom-left
points_src = torch.FloatTensor([[
[0, 0],
[0, src_w - 1],
[src_h - 1, src_w - 1],
[src_h - 1, 0],
]])

# [x, y] destination
# top-left, top-right, bottom-right, bottom-left
points_dst = torch.FloatTensor([[
[0, 0],
[0, dst_w - 1],
[dst_h - 1, dst_w - 1],
[dst_h - 1, 0],
]])

# compute transformation between points
dst_trans_src = tgm.get_perspective_transform(points_src,
points_dst).expand(
batch_size, -1, -1)

# warp tensor
patch = torch.FloatTensor([[[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]]]).expand(batch_size, channels, -1, -1)

expected = torch.FloatTensor([[[
[1, 2, 3],
[5, 6, 7],
[9, 10, 11],
]]])

# warp and assert
patch_warped = tgm.warp_perspective(patch, dst_trans_src,
(dst_h, dst_w))
assert_allclose(patch_warped, expected)

def test_crop_center_resize(self, device_type):
# generate input data
dst_h, dst_w = 4, 4
device = torch.device(device_type)

# [x, y] origin
# top-left, top-right, bottom-right, bottom-left
points_src = torch.FloatTensor([[
[1, 1],
[1, 2],
[2, 2],
[2, 1],
]])

# [x, y] destination
# top-left, top-right, bottom-right, bottom-left
points_dst = torch.FloatTensor([[
[0, 0],
[0, dst_w - 1],
[dst_h - 1, dst_w - 1],
[dst_h - 1, 0],
]])

# compute transformation between points
dst_trans_src = tgm.get_perspective_transform(points_src, points_dst)

# warp tensor
patch = torch.FloatTensor([[[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]]])

expected = torch.FloatTensor([[[
[6.000, 6.333, 6.666, 7.000],
[7.333, 7.666, 8.000, 8.333],
[8.666, 9.000, 9.333, 9.666],
[10.000, 10.333, 10.666, 11.000],
]]])

# warp and assert
patch_warped = tgm.warp_perspective(patch, dst_trans_src,
(dst_h, dst_w))
assert_allclose(patch_warped, expected)


class TestWarpAffine:
Expand Down Expand Up @@ -223,5 +264,10 @@ def test_gradcheck(self):
aff_ab = utils.tensor_to_gradcheck_var(
aff_ab, requires_grad=False) # to var
img_b = utils.tensor_to_gradcheck_var(img_b) # to var
assert gradcheck(tgm.warp_affine, (img_b, aff_ab, (height, width),),
raise_exception=True)
assert gradcheck(
tgm.warp_affine, (
img_b,
aff_ab,
(height, width),
),
raise_exception=True)

0 comments on commit d19121e

Please sign in to comment.