Permalink
Browse files

implement get_perspective_transform

  • Loading branch information...
edgarriba committed Oct 8, 2018
1 parent cd0ad41 commit a7db348c5978efc7870c649757947123bfbf78fb
Showing with 307 additions and 2 deletions.
  1. +1 −0 docs/source/index.rst
  2. +188 −0 examples/get_perspective_transform.ipynb
  3. +30 −0 test/test_imgwarp.py
  4. +88 −2 torchgeometry/imgwarp.py
@@ -26,6 +26,7 @@ Geometric Image Transformations
The functions in this section perform various geometrical transformations of 2D images.

.. autofunction:: warp_perspective
.. autofunction:: get_perspective_transform


Pinhole

Large diffs are not rendered by default.

Oops, something went wrong.
@@ -64,6 +64,36 @@ def test_warp_perspective_gradcheck(self):
raise_exception=True)
self.assertTrue(res)

def test_get_perspective_transform(self):
# generate input data
h, w = 64, 32 # height, width
norm = torch.randn(1, 4, 2)
points_src = torch.FloatTensor([[
[0, 0], [h, 0], [0, w], [h, w],
]])
points_dst = points_src + norm

# compute transform from source to target
dst_homo_src = tgm.get_perspective_transform(points_src, points_dst)

res = utils.check_equal_torch(
tgm.transform_points(dst_homo_src, points_src), points_dst)
self.assertTrue(res)

def test_get_perspective_transform_gradcheck(self):
# generate input data
h, w = 64, 32 # height, width
norm = torch.randn(1, 4, 2)
points_src = torch.FloatTensor([[
[0, 0], [h, 0], [0, w], [h, w],
]])
points_dst = points_src + norm
points_src = utils.tensor_to_gradcheck_var(points_src) # to var
points_dst = utils.tensor_to_gradcheck_var(points_dst) # to var

# compute transform from source to target
res = gradcheck(tgm.get_perspective_transform,
(points_src, points_dst,), raise_exception=True)
self.assertTrue(res)
if __name__ == '__main__':
unittest.main()
@@ -6,12 +6,13 @@

__all__ = [
"warp_perspective",
"get_perspective_transform",
]


def center_transform(transform, height, width):
assert len(transform.shape) == 3, transform.shape
# move points to origin
# move points from origin
center_mat_origin = torch.unsqueeze(
torch.eye(
3,
@@ -20,7 +21,7 @@ def center_transform(transform, height, width):
dim=0)
center_mat_origin[..., 0, 2] = float(width) / 2
center_mat_origin[..., 1, 2] = float(height) / 2
# move points from origin
# move points to origin
origin_mat_center = torch.unsqueeze(
torch.eye(
3,
@@ -92,3 +93,88 @@ def warp_perspective(src, M, dsize, flags='bilinear', border_mode=None,
M_new = normalize_transform_to_pix(M_new, height, width)
# warp and return
return homography_warp(src, M_new, dsize)


def get_perspective_transform(src, dst):
"""Calculates a perspective transform from four pairs of the corresponding
points.
The function calculates the matrix of a perspective transform so that:
.. math ::
todo
Args:
src (Tensor): coordinates of quadrangle vertices in the source image.
dst (Tensor): coordinates of the corresponding quadrangle vertices in
the destination image.
Returns:
Tensor: the perspective transformation.
Shape:
- Input: :math:`(B, 4, 2)` and :math:`(B, 4, 2)`
- Output: :math:`(B, 3, 3)`
"""
if not torch.is_tensor(src):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(src)))
if not torch.is_tensor(dst):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(dst)))
if not src.shape[-2:] == (4, 2):
raise ValueError("Inputs must be a Bx4x2 tensor. Got {}"
.format(src.shape))
if not src.shape == dst.shape:
raise ValueError("Inputs must have the same shape. Got {}"
.format(dst.shape))
def ax(p, q):
ones = torch.ones_like(p)[..., 0:1]
zeros = torch.zeros_like(p)[..., 0:1]
return torch.cat(
[ p[:, 0:1], p[:, 1:2], ones, zeros, zeros, zeros,
-p[:, 0:1] * q[:, 0:1], -p[:, 1:2] * q[:, 0:1] ], dim=1)

def ay(p, q):
ones = torch.ones_like(p)[..., 0:1]
zeros = torch.zeros_like(p)[..., 0:1]
return torch.cat(
[ zeros, zeros, zeros, p[:, 0:1], p[:, 1:2], ones,
-p[:, 0:1] * q[:, 1:2], -p[:, 1:2] * q[:, 1:2] ], dim=1)
# we build matrix A by using only 4 point correspondence. The linear
# system is solved with the least square method, so here
# we could even pass more correspondence
p = []
p.append(ax(src[:, 0], dst[:, 0]))
p.append(ay(src[:, 0], dst[:, 0]))

p.append(ax(src[:, 1], dst[:, 1]))
p.append(ay(src[:, 1], dst[:, 1]))

p.append(ax(src[:, 2], dst[:, 2]))
p.append(ay(src[:, 2], dst[:, 2]))

p.append(ax(src[:, 3], dst[:, 3]))
p.append(ay(src[:, 3], dst[:, 3]))

# A is Bx8x8
A = torch.stack(p, dim=1)

# b is a Bx8x1
b = torch.stack([
dst[:, 0:1, 0], dst[:, 0:1, 1],
dst[:, 1:2, 0], dst[:, 1:2, 1],
dst[:, 2:3, 0], dst[:, 2:3, 1],
dst[:, 3:4, 0], dst[:, 3:4, 1],
], dim=1)

# solve the system Ax = b
X, LU = torch.gesv(b, A)

# create variable to return
batch_size = src.shape[0]
M = torch.ones(batch_size, 9, device=src.device, dtype=src.dtype)
M[..., :8] = torch.squeeze(X, dim=-1)
return M.view(-1, 3, 3) # Bx3x3

0 comments on commit a7db348

Please sign in to comment.