Skip to content

Commit

Permalink
viewport matrix and some convenience properties added (#751)
Browse files Browse the repository at this point in the history
set height and width with fov invariant

Fix torch dtype incorrect for projection matrix

Tests for new camera additions

Viewport matrix bug fix

MR fixes

retrigger ci

MR fixes + add device and dtype coverage

Fix flakey test

add missing doc



Fix doc build

Signed-off-by: operel <operel@nvidia.com>
Co-authored-by: operel <operel@nvidia.com>
  • Loading branch information
orperel and operel committed Aug 1, 2023
1 parent 89dfbc9 commit 3482a9d
Show file tree
Hide file tree
Showing 8 changed files with 596 additions and 160 deletions.
10 changes: 10 additions & 0 deletions kaolin/render/camera/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,21 @@ def width(self) -> int:
"""Camera image plane width (pixel resolution)"""
return self.intrinsics.width

@width.setter
def width(self, value: int) -> None:
"""Camera image plane width (pixel resolution)"""
self.intrinsics.width = value

@property
def height(self) -> int:
"""Camera image plane height (pixel resolution)"""
return self.intrinsics.height

@height.setter
def height(self, value: int) -> None:
"""Camera image plane height (pixel resolution)"""
self.intrinsics.height = value

@property
def lens_type(self) -> str:
r"""A textual description of the camera lens type. (i.e 'pinhole', 'ortho')
Expand Down
8 changes: 6 additions & 2 deletions kaolin/render/camera/extrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,12 @@ def cam_up(self) -> torch.Tensor:
return self.R.transpose(2, 1) @ self._world_y()

def cam_forward(self) -> torch.Tensor:
"""Returns:
(torch.Tensor): the camera forward axis, in world coordinates"""
r""" Returns the camera forward axis -
See: https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/lookat-function/framing-lookat-function.html
Returns:
(torch.Tensor): the camera forward axis, in world coordinates."""
return self.R.transpose(2, 1) @ self._world_z()

def parameters(self) -> torch.Tensor:
Expand Down
90 changes: 90 additions & 0 deletions kaolin/render/camera/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,96 @@ def projection_matrix(self):
raise NotImplementedError('This projection of this camera type is non-linear in homogeneous coordinates '
'and therefore does not support a projection matrix. Use self.transform() instead.')

def viewport_matrix(self, vl=0, vr=None, vb=0, vt=None, min_depth=0.0, max_depth=1.0) -> torch.Tensor:
r"""Constructs a viewport matrix which transforms coordinates from NDC space to pixel space.
This is the general matrix form of glViewport, familiar from OpenGL.
NDC coordinates are expected to be in:
* [-1, 1] for the (x,y) coordinates.
* [ndc_min, ndc_max] for the (z) coordinate.
Pixel coordinates are in:
* [vl, vr] for the (x) coordinate.
* [vb, vt] for the (y) coordinate.
* [0, 1] for the (z) coordinate (yielding normalized depth).
When used in conjunction with a :func:`projection_matrix()`, a transformation from camera view space to
window space can be obtained.
Note that for the purpose of rendering with OpenGL shaders, this matrix is not required, as viewport
transformation is already applied by the hardware.
By default, this matrix assumes the NDC screen spaces have the y axis pointing up.
Under this assumption, and a [-1, 1] NDC space,
the default values of this method are compatible with OpenGL glViewport.
.. seealso::
glViewport() at https://registry.khronos.org/OpenGL-Refpages/gl4/html/glViewport.xhtml
and https://en.wikibooks.org/wiki/GLSL_Programming/Vertex_Transformations#Viewport_Transformation
projection_matrix() which converts coordinates from camera view space to NDC space.
.. note::
1. This matrix changes form depending on the NDC space used.
2. Returned values are floating points, rather than integers
(thus this method is compatible with antialising ops).
Args:
vl (int): Viewport left (pixel coordinates x) - where the viewport starts. Default is 0.
vr (int): Viewport right (pixel coordinates x) - where the viewport ends. Default is camera width.
vb (int): Viewport bottom (pixel coordinates y) - where the viewport starts. Default is 0.
vt (int): Viewport top (pixel coordinates y) - where the viewport ends. Default is camera height.
min_depth (float): Minimum of output depth range. Default is 0.0.
max (float): Maximum of output depth range. Default is 1.0.
Returns:
(torch.Tensor): the viewport matrix, of shape :math:`(1, 4, 4)`.
"""
if vr is None:
vr = self.width
if vt is None:
vt = self.height
vl = float(vl)
vr = float(vr)
vb = float(vb)
vt = float(vt)

# From NDC space
ndc_min_x = -1.0
ndc_min_y = -1.0
ndc_min_z = self.ndc_min
ndc_max_x = 1.0
ndc_max_y = 1.0
ndc_max_z = self.ndc_max
ndc_width = ndc_max_x - ndc_min_x # All ndc spaces assume x clip coordinates in [-1, 1]
ndc_height = ndc_max_y - ndc_min_y # All ndc spaces assume y clip coordinates in [-1, 1]
ndc_depth = ndc_max_z - ndc_min_z # NDC depth range, this is NDC space dependent

# To screen space
vw = vr - vl # Viewport width
vh = vt - vb # Viewport height
out_depth_range = max_depth - min_depth # By default, normalized depth is assumed [0, 1]

# Recall that for OpenGL NDC space and full screen viewport, the following matrix is given,
# where vw, vh stand for screen width and height:
# [vw/2, 0.0, 0.0, vw/2] @ [ x ] = .. perspective = [(x/w + 1) * (vw/2)]
# [0.0, vh/2, 0.0, vh/2] [ y ] division [(y/w + 1) * (vh/2)]
# [0.0, 0.0, 1/2, 1/2] [ z ] ------------> [(z/w + 1) / 2]
# [0.0, 0.0, 0.0, 1.0] [ w ] (/w) [ 1.0 ]

# The matrix is non differentiable, as viewport coordinates are a fixed standard set by the graphics api
ndc_mat = self.params.new_tensor([
[vw / ndc_width, 0.0, 0.0, -(ndc_min_x / ndc_width) * vw + vl],
[0.0, vh / ndc_height, 0.0, -(ndc_min_y / ndc_height) * vh + vb],
[0.0, 0.0, out_depth_range / ndc_depth, -(ndc_min_z / ndc_depth) * out_depth_range + min_depth],
[0.0, 0.0, 0.0, 1.0]
])

# Add batch dim, to allow broadcasting
return ndc_mat.unsqueeze(0)

@abstractmethod
def transform(self, vectors: torch.Tensor) -> torch.Tensor:
r"""Projects the vectors from view space / camera space to NDC (normalized device coordinates) space.
Expand Down
83 changes: 74 additions & 9 deletions kaolin/render/camera/intrinsics_pinhole.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def ndc_matrix(self, left, right, bottom, top, near, far) -> torch.Tensor:
ty = -(top + bottom) / (top - bottom)
# tz = -(far + near) / (far - near) # Not used explicitly here, but makes easier to follow derivations

# Some examples of U,V choices to control the NDC space obtained:
# Some examples of U,V choices to control the depth of the NDC space obtained:
# ------------------------------------------------------------------------------------------------------
# | NDC in [-1, 1] | U = -2.0 * near * far / (far - near) | i.e. OpenGL NDC space
# | | V = -(far + near) / (far - near) |
Expand Down Expand Up @@ -451,7 +451,7 @@ def ndc_matrix(self, left, right, bottom, top, near, far) -> torch.Tensor:
[0.0, 2.0 / (top - bottom), 0.0, -ty ],
[0.0, 0.0, U, V ],
[0.0, 0.0, 0.0, -1.0]
])
], dtype=self.dtype)

# Add batch dim, to allow broadcasting
return ndc_mat.unsqueeze(0)
Expand Down Expand Up @@ -483,13 +483,15 @@ def projection_matrix(self) -> torch.Tensor:
proj = ndc @ persp_matrix
return proj

def transform(self, vectors: torch.Tensor) -> torch.Tensor:
def project(self, vectors: torch.Tensor) -> torch.Tensor:
r"""
Applies perspective projection to actual NDC coordinates (this function also performs perspective division).
Applies perspective projection to obtain Clip Coordinates
(this function does not perform perspective division the actual Normalized Device Coordinates).
Assumptions:
* Camera is looking down the negative z axis (that is: Z axis points outwards from screen, OpenGL compatible).
* Camera is looking down the negative "z" axis
(that is: camera forward axis points outwards from screen, OpenGL compatible).
* Practitioners are advised to keep near-far gap as narrow as possible,
to avoid inherent depth precision errors.
Expand All @@ -502,23 +504,50 @@ def transform(self, vectors: torch.Tensor) -> torch.Tensor:
or :math:`(\text{num_cameras}, \text{num_vectors}, 3)`
Returns:
(torch.Tensor): the transformed vectors, of same shape than ``vectors`` but last dim 3
(torch.Tensor): the transformed vectors, of same shape as ``vectors`` but, with homogeneous coordinates,
where the last dim is 4
"""
proj = self.projection_matrix()

# Expand input vectors to 4D homogeneous coordinates if needed
homogeneous_vecs = up_to_homogeneous(vectors)

num_cameras = len(self) # C - number of cameras
num_cameras = len(self) # C - number of cameras
batch_size = vectors.shape[-2] # B - number of vectors

v = homogeneous_vecs.expand(num_cameras, batch_size, 4)[..., None] # Expand as (C, B, 4, 1)
proj = proj[:, None].expand(num_cameras, batch_size, 4, 4) # Expand as (C, B, 4, 4)
proj = proj[:, None].expand(num_cameras, batch_size, 4, 4) # Expand as (C, B, 4, 4)

transformed_v = proj @ v
transformed_v = transformed_v.squeeze(-1) # Reshape: (C, B, 4)
normalized_v = down_from_homogeneous(transformed_v)

return transformed_v # Return shape: (C, B, 4)

def transform(self, vectors: torch.Tensor) -> torch.Tensor:
r"""
Applies perspective projection to obtain Normalized Device Coordinates
(this function also performs perspective division).
Assumptions:
* Camera is looking down the negative z axis (that is: Z axis points outwards from screen, OpenGL compatible).
* Practitioners are advised to keep near-far gap as narrow as possible,
to avoid inherent depth precision errors.
Args:
vectors (torch.Tensor):
the vectors to be transformed,
can homogeneous of shape :math:`(\text{num_vectors}, 4)`
or :math:`(\text{num_cameras}, \text{num_vectors}, 4)`
or non-homogeneous of shape :math:`(\text{num_vectors}, 3)`
or :math:`(\text{num_cameras}, \text{num_vectors}, 3)`
Returns:
(torch.Tensor): the transformed vectors, of same shape as ``vectors`` but with non-homogeneous coords,
e.g. the last dim 3
"""
transformed_v = self.project(vectors) # Project with homogeneous coords to shape (C, B, 4)
normalized_v = down_from_homogeneous(transformed_v) # Perspective divide to shape: (C, B, 3)
return normalized_v # Return shape: (C, B, 3)

def normalize_depth(self, depth: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -549,6 +578,26 @@ def normalize_depth(self, depth: torch.Tensor) -> torch.Tensor:
normalized_depth = torch.clamp(normalized_depth, min=0.0, max=1.0)
return normalized_depth

@CameraIntrinsics.width.setter
def width(self, value: int) -> None:
""" Updates the width of the image plane.
The fov will remain invariant, and the focal length may change instead.
"""
# Keep the fov invariant and change focal length instead
fov = self.fov_x
self._shared_fields['width'] = value
self.fov_x = fov

@CameraIntrinsics.height.setter
def height(self, value: int) -> None:
""" Updates the hieght of the image plane.
The fov will remain invariant, and the focal length may change instead.
"""
# Keep the fov invariant and change focal length instead
fov = self.fov_y
self._shared_fields['height'] = value
self.fov_y = fov

@property
def x0(self) -> torch.FloatTensor:
"""The horizontal offset from the NDC origin in image space
Expand All @@ -571,6 +620,22 @@ def y0(self) -> torch.FloatTensor:
def y0(self, val: Union[float, torch.Tensor]) -> None:
self._set_param(val, PinholeParamsDefEnum.y0)

@property
def cx(self) -> torch.FloatTensor:
"""The principal point X coordinate.
Note: By default, the principal point is canvas center (kaolin defines the NDC origin at the canvas center).
"""
# Assumes the NDC x origin is at the center of the canvas
return self.width / 2.0 + self.params[:, PinholeParamsDefEnum.x0]

@property
def cy(self) -> torch.FloatTensor:
"""The principal point Y coordinate.
Note: By default, the principal point is canvas center (kaolin defines the NDC origin at the canvas center).
"""
# Assumes the NDC y origin is at the center of the canvas
return self.height / 2.0 + self.params[:, PinholeParamsDefEnum.y0]

@property
def focal_x(self) -> torch.FloatTensor:
return self.params[:, PinholeParamsDefEnum.focal_x]
Expand Down

0 comments on commit 3482a9d

Please sign in to comment.