Skip to content

Commit

Permalink
Add deftet losses (#496)
Browse files Browse the repository at this point in the history
* Implementation of loss functions and operations in DefTet with pytests

Signed-off-by: Michael Li <michaeli@nvidia.com>

* fix naming and doc

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

* fix tests

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

* fix small indent issue

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

* remove torch.abs from tetrahedrons_volume

Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>

Co-authored-by: Michael Li <michaeli@nvidia.com>
  • Loading branch information
Caenorst and Michael Li committed Dec 16, 2021
1 parent dd98c85 commit 0236fbc
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

This guide is for developers who write API documentation. To build the documentation, run

`pip install -r tools/doc_requirements.txt` to install the dependencies for documentation.

Then, run

```make html``` on Linux

```make.bat html``` on Windows
Expand Down
3 changes: 2 additions & 1 deletion docs/kaolin_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -36,6 +36,7 @@ def run_apidoc(_):
"kaolin/ops/conversions/tetmesh.py",
"kaolin/ops/mesh/check_sign.py",
"kaolin/ops/mesh/mesh.py",
"kaolin/ops/mesh/tetmesh.py",
"kaolin/ops/mesh/trianglemesh.py",
"kaolin/ops/spc/spc.py",
"kaolin/ops/spc/convolution.py",
Expand Down
2 changes: 2 additions & 0 deletions docs/modules/kaolin.metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Metrics are differentiable operators that can be used to compute loss or accurac

We currently provide an IoU for voxelgrid, sided distance based metrics such as chamfer distance,
point_to_mesh_distance and other simple regularization such as uniform_laplacian_smoothing.
For tetrahedral mesh, we support the equivolume and AMIPS losses.

.. toctree::
:maxdepth: 2
Expand All @@ -16,3 +17,4 @@ point_to_mesh_distance and other simple regularization such as uniform_laplacian
kaolin.metrics.render
kaolin.metrics.trianglemesh
kaolin.metrics.voxelgrid
kaolin.metrics.tetmesh
12 changes: 12 additions & 0 deletions docs/modules/kaolin.metrics.tetmesh.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _kaolin.metrics.tetmesh:

kaolin.metrics.tetmesh
======================

API
---

.. automodule:: kaolin.metrics.tetmesh
:members:
:undoc-members:
:show-inheritance:
28 changes: 27 additions & 1 deletion docs/modules/kaolin.ops.mesh.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
.. _kaolin.ops.mesh:

kaolin.ops.mesh
===============
***********************

A mesh is a 3D object representation consisting of a collection of vertices and polygons.

Triangular meshes
==================

Triangular meshes comprise of a set of triangles that are connected by their common edges or corners. In Kaolin, they are usually represented as a set of two tensors:

* ``vertices``: A :class:`torch.Tensor`, of shape :math:`(\text{batch_size}, \text{num_vertices}, 3)`, contains the vertices coordinates.

* ``faces``: A :class:`torch.LongTensor`, of shape :math:`(\text{batch_size}, \text{num_faces}, 3)`, contains the mesh topology, by listing the vertices index for each face.

Both tensors can be combined using :func:`kaolin.ops.mesh.index_vertices_by_faces`, to form ``face_vertices``, of shape :math:`(\text{batch_size}, \text{num_faces}, 3, 3)`, listing the vertices coordinate for each face.


Tetrahedral meshes
==================

A tetrahedron or triangular pyramid is a polyhedron composed of four triangular faces, six straight edges, and four vertex corners. Tetrahedral meshes inside Kaolin are composed of two tensors:

* ``vertices``: A :class:`torch.Tensor`, of shape :math:`(\text{batch_size}, \text{num_vertices}, 3)`, contains the vertices coordinates.

* ``tet``: A :class:`torch.LongTensor`, of shape :math:`(\text{batch_size}, \text{num_tet}, 4)`, contains the tetrahedral mesh topology, by listing the vertices index for each tetrahedron.

Both tensors can be combined, to form ``tet_vertices``, of shape :math:`(\text{batch_size}, \text{num_tet}, 4, 3)`, listing the tetrahedrons vertices coordinates for each face.


API
---
Expand Down
1 change: 1 addition & 0 deletions kaolin/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from . import trianglemesh
from . import pointcloud
from . import render
from . import tetmesh
192 changes: 192 additions & 0 deletions kaolin/metrics/tetmesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from kaolin.ops.mesh.tetmesh import _validate_tet_vertices


def tetrahedron_volume(tet_vertices):
r"""Compute the volume of tetrahedrons.
Args:
tet_vertices (torch.Tensor):
Batched tetrahedrons, of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
Returns:
(torch.Tensor):
volume of each tetrahedron in each mesh, of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons})`.
Example:
>>> tet_vertices = torch.tensor([[[[0.5000, 0.5000, 0.4500],
... [0.4500, 0.5000, 0.5000],
... [0.4750, 0.4500, 0.4500],
... [0.5000, 0.5000, 0.5000]]]])
>>> tetrahedron_volume(tet_vertices)
tensor([[2.0833e-05]])
"""
_validate_tet_vertices(tet_vertices)

# split the tensor
A, B, C, D = [split.squeeze(2) for split in
torch.split(tet_vertices, split_size_or_sections=1, dim=2)]

# compute the volume of each tetrahedron directly by using V = |(a - d) * ((b - d) x (c - d))| / 6
volumes = torch.div(
((A - D) * torch.cross(input=(B - D), other=(C - D), dim=2)).sum(dim=2), 6)

return volumes

def equivolume(tet_vertices, tetrahedrons_mean=None, pow=4):
r"""Compute the EquiVolume loss as devised by *Gao et al.* in `Learning Deformable Tetrahedral Meshes for 3D
Reconstruction <https://nv-tlabs.github.io/DefTet/>`_ NeurIPS 2020.
See `supplementary material <https://nv-tlabs.github.io/DefTet/files/supplement.pdf>`_ for the definition of the loss function.
Args:
tet_vertices (torch.Tensor):
Batched tetrahedrons, of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
tetrahedrons_mean (torch.Tensor):
Mean volume of all tetrahedrons in a grid, of shape :math:`(1, 1)`.
pow (int):
Power for the equivolume loss.
Increasing power puts more emphasis on the larger tetrahedron deformation.
Default: 4.
Returns:
(torch.Tensor):
EquiVolume loss for each mesh, of shape :math:`(\\text{batch_size})`.
Example:
>>> tet_vertices = torch.tensor([[[[0.5000, 0.5000, 0.7500],
... [0.4500, 0.8000, 0.6000],
... [0.4750, 0.4500, 0.2500],
... [0.5000, 0.3000, 0.3000]],
... [[0.4750, 0.4500, 0.2500],
... [0.5000, 0.9000, 0.3000],
... [0.4500, 0.4000, 0.9000],
... [0.4500, 0.4500, 0.7000]]],
... [[[0.7000, 0.3000, 0.4500],
... [0.4800, 0.2000, 0.3000],
... [0.9000, 0.4500, 0.4500],
... [0.2000, 0.5000, 0.1000]],
... [[0.3750, 0.4500, 0.2500],
... [0.9000, 0.8000, 0.7000],
... [0.6000, 0.9000, 0.3000],
... [0.5500, 0.3500, 0.9000]]]])
>>> equivolume(tet_vertices, pow=4)
tensor([[2.2898e-15],
[2.9661e-10]])
"""
_validate_tet_vertices(tet_vertices)

# compute the volume of each tetrahedron
volumes = tetrahedron_volume(tet_vertices)

if tetrahedrons_mean is None:
# finding the mean volume of all tetrahedrons in the tetrahedron grid
tetrahedrons_mean = torch.mean(volumes, dim=-1, keepdim=True)

# compute EquiVolume loss
equivolume_loss = torch.mean(torch.pow(
torch.abs(volumes - tetrahedrons_mean), exponent=pow),
dim=-1, keepdim=True)

return equivolume_loss


def amips(tet_vertices, inverse_offset_matrix):
r"""Compute the AMIPS (Advanced MIPS) loss as devised by *Fu et al.* in
`Computing Locally Injective Mappings by Advanced MIPS. \
<https://www.microsoft.com/en-us/research/publication/computing-locally-injective-mappings-advanced-mips/>`_
ACM Transactions on Graphics (TOG) - Proceedings of ACM SIGGRAPH 2015.
The Jacobian can be derived as: :math:`J = (g(x) - g(x_0)) / (x - x_0)`
Only components where the determinant of the Jacobian is positive, are included in the calculation of AMIPS.
This is because the AMIPS Loss is only defined for tetrahedrons whose determinant of the Jacobian is positive.
Args:
tet_vertices (torch.Tensor):
Batched tetrahedrons, of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
inverse_offset_matrix (torch.LongTensor): The inverse of the offset matrix is of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 3, 3)`.
Refer to :func:`kaolin.ops.mesh.tetmesh.inverse_vertices_offset`.
Returns:
(torch.Tensor):
AMIPS loss for each mesh, of shape :math:`(\\text{batch_size})`.
Example:
>>> tet_vertices = torch.tensor([[[[1.7000, 2.3000, 4.4500],
... [3.4800, 0.2000, 5.3000],
... [4.9000, 9.4500, 6.4500],
... [6.2000, 8.5000, 7.1000]],
... [[-1.3750, 1.4500, 3.2500],
... [4.9000, 1.8000, 2.7000],
... [3.6000, 1.9000, 2.3000],
... [1.5500, 1.3500, 2.9000]]],
... [[[1.7000, 2.3000, 4.4500],
... [3.4800, 0.2000, 5.3000],
... [4.9000, 9.4500, 6.4500],
... [6.2000, 8.5000, 7.1000]],
... [[-1.3750, 1.4500, 3.2500],
... [4.9000, 1.8000, 2.7000],
... [3.6000, 1.9000, 2.3000],
... [1.5500, 1.3500, 2.9000]]]])
>>> inverse_offset_matrix = torch.tensor([[[[ -1.1561, -1.1512, -1.9049],
... [1.5138, 1.0108, 3.4302],
... [1.6538, 1.0346, 4.2223]],
... [[ 2.9020, -1.0995, -1.8744],
... [ 1.1554, 1.1519, 1.7780],
... [-0.0766, 1.6350, 1.1064]]],
... [[[-0.9969, 1.4321, -0.3075],
... [-1.3414, 1.5795, -1.6571],
... [-0.1775, -0.4349, 1.1772]],
... [[-1.1077, -1.2441, 1.8037],
... [-0.5722, 0.1755, -2.4364],
... [-0.5263, 1.5765, 1.5607]]]])
>>> amips(tet_vertices, inverse_offset_matrix)
tensor([[13042.3408],
[ 2376.2517]])
"""
_validate_tet_vertices(tet_vertices)

# split the tensor
A, B, C, D = torch.split(tet_vertices, split_size_or_sections=1, dim=2)

# compute the offset matrix of the tetrahedrons w.r.t. vertex A.
offset_matrix = torch.cat([B - A, C - A, D - A], dim=2)

# compute the Jacobian for each tetrahedron - the Jacobian represents the unique 3D deformation that transforms the
# tetrahedron t into a regular tetrahedron.
jacobian = torch.matmul(offset_matrix, inverse_offset_matrix)

# compute determinant of Jacobian
j_det = torch.det(jacobian)

# compute the trace of J * J.T
jacobian_squared = torch.matmul(jacobian, torch.transpose(jacobian, -2, -1))
trace = torch.diagonal(jacobian_squared, dim1=-2, dim2=-1).sum(-1)

# compute the determinant of the Jacobian to the 2/3
EPS = 1e-10
denominator = torch.pow(torch.pow(j_det, 2) + EPS, 1 / 3)

# compute amips energy for positive tetrahedrons whose determinant of their Jacobian is positive
amips_energy = torch.mean(torch.div(trace, denominator) * (j_det >= 0).float(),
dim=1, keepdim=True)

return amips_energy
1 change: 1 addition & 0 deletions kaolin/ops/mesh/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .mesh import *
from .trianglemesh import *
from .check_sign import check_sign
from .tetmesh import *

__all__ = [k for k in locals().keys() if not k.startswith('__')]
76 changes: 76 additions & 0 deletions kaolin/ops/mesh/tetmesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

def _validate_tet_vertices(tet_vertices):
"""Helper method to validate the dimensions of the batched tetrahedrons tensor.
Args:
tet_vertices (torch.Tensor):
Batched tetrahedrons, of shape
:math:`(\text{batch_size}, \text{num_tetrahedrons}, 4, 3)`.
"""
assert tet_vertices.ndim == 4, \
f"tetrahedrons has {tetrahedrons.ndim} but must have 4 dimensions."
assert tet_vertices.shape[2] == 4, \
f"The third dimension of the tetrahedrons must be 4 " \
f"but the input has {tetrahedrons.shape[2]}. Each tetrahedron has 4 vertices."
assert tet_vertices.shape[3] == 3, \
f"The fourth dimension of the tetrahedrons must be 3 " \
f"but the input has {tetrahedrons.shape[3]}. Each vertex must have 3 dimensions."


def inverse_vertices_offset(tet_vertices):
r"""Given tetrahedrons with 4 vertices A, B, C, D. Compute the inverse of the offset matrix w.r.t. vertex A for each
tetrahedron. The offset matrix is obtained by the concatenation of `B - A`, `C - A` and `D - A`. The resulting shape
of the offset matrix is :math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 3, 3)`. The inverse of the offset matrix
is computed by this function.
Args:
tet_vertices (torch.Tensor):
Batched tetrahedrons, of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 4, 3)`.
Returns:
(torch.Tensor):
Batched inverse offset matrix, of shape
:math:`(\\text{batch_size}, \\text{num_tetrahedrons}, 3, 3)`.
Each offset matrix is of shape :math:`(3, 3)`,
hence its inverse is also of shape :math:`(3, 3)`.
Example:
>>> tet_vertices = torch.tensor([[[[-0.0500, 0.0000, 0.0500],
... [-0.0250, -0.0500, 0.0000],
... [ 0.0000, 0.0000, 0.0500],
... [0.5000, 0.5000, 0.4500]]]])
>>> inverse_vertices_offset(tet_vertices)
tensor([[[[ 0.0000, 20.0000, 0.0000],
[ 79.9999, -149.9999, 10.0000],
[ -99.9999, 159.9998, -10.0000]]]])
"""
_validate_tet_vertices(tet_vertices)

# split the tensor
A, B, C, D = torch.split(tet_vertices, split_size_or_sections=1, dim=2)

# compute the offset matrix w.r.t. vertex A
offset_matrix = torch.cat([B - A, C - A, D - A], dim=2)

# compute the inverse of the offset matrix
inverse_offset_matrix = torch.inverse(offset_matrix)

return inverse_offset_matrix
2 changes: 1 addition & 1 deletion kaolin/render/mesh/deftet.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def forward(ctx, pixel_coords, render_ranges, face_vertices_z,
sorted_w2 = (sorted_face_idx != -1).float() - (sorted_w0 + sorted_w1)
_idx = sorted_face_idx + 1
_idx = _idx.reshape(batch_size, -1, 1, 1).expand(
batch_size, pixel_num * knum, 3, feat_dim)
batch_size, pixel_num * knum, 3, feat_dim)
selected_features = torch.gather(
torch.nn.functional.pad(face_features, (0, 0, 0, 0, 1, 0), value=0.), 1, _idx).reshape(
batch_size, pixel_num, knum, 3, feat_dim)
Expand Down

0 comments on commit 0236fbc

Please sign in to comment.