Skip to content

Commit

Permalink
Add tetmesh subdivision (#551)
Browse files Browse the repository at this point in the history
* add tetmesh subdivision

Signed-off-by: Frank Shen <frshen@nvidia.com>

* fix docstring

Signed-off-by: Frank Shen <frshen@nvidia.com>

* fix docstring

Signed-off-by: Frank Shen <frshen@nvidia.com>

* add indentation

Signed-off-by: Frank Shen <frshen@nvidia.com>

Co-authored-by: Frank Shen <frshen@nvidia.com>
  • Loading branch information
frankshen07 and frankshen07 committed Apr 20, 2022
1 parent ab85440 commit 7464f76
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
103 changes: 103 additions & 0 deletions kaolin/ops/mesh/tetmesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@


import torch
from kaolin.ops.conversions.tetmesh import _sort_edges

base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long)


def _validate_tet_vertices(tet_vertices):
r"""Helper method to validate the dimensions of the batched tetrahedrons tensor.
Expand Down Expand Up @@ -76,3 +80,102 @@ def inverse_vertices_offset(tet_vertices):
inverse_offset_matrix = torch.inverse(offset_matrix)

return inverse_offset_matrix


def subdivide_tetmesh(vertices, tetrahedrons, features=None):
r"""Subdivide each tetrahedron in tetmesh into 8 smaller tetrahedrons
by adding midpoints. If per-vertex features (e.g. SDF value) are given, the features
of the new vertices are computed by averaging the features of vertices on the edge.
For more details and example usage in learning, see
`Deep Marching Tetrahedra\: a Hybrid Representation for High-Resolution 3D Shape Synthesis`_ NeurIPS 2021.
Args:
vertices (torch.Tensor): batched vertices of tetrahedral meshes, of shape
:math:`(\text{batch_size}, \text{num_vertices}, 3)`.
tetrahedrons (torch.LongTensor): unbatched tetrahedral mesh topology, of shape
:math:`(\text{num_tetrahedrons}, 4)`.
features (optional, torch.Tensor): batched per-vertex feature vectors, of shape
:math:`(\text{batch_size}, \text{num_vertices}, \text{feature_dim})`.
Returns:
(torch.Tensor, torch.LongTensor, (optional) torch.Tensor):
- batched vertices of subdivided tetrahedral meshes, of shape
:math:`(\text{batch_size}, \text{new_num_vertices}, 3)`
- unbatched tetrahedral mesh topology, of shape
:math:`(\text{num_tetrahedrons} * 8, 4)`.
- batched per-vertex feature vectors of subdivided tetrahedral meshes, of shape
:math:`(\text{batch_size}, \text{new_num_vertices}, \text{feature_dim})`.
Example:
>>> vertices = torch.tensor([[[0, 0, 0],
... [1, 0, 0],
... [0, 1, 0],
... [0, 0, 1]]], dtype=torch.float)
>>> tetrahedrons = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
>>> sdf = torch.tensor([[[-1.], [-1.], [0.5], [0.5]]], dtype=torch.float)
>>> new_vertices, new_tetrahedrons, new_sdf = subdivide_tetmesh(vertices, tetrahedrons, sdf)
>>> new_vertices
tensor([[[0.0000, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000],
[0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 0.5000, 0.0000],
[0.5000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.5000]]])
>>> new_tetrahedrons
tensor([[0, 4, 5, 6],
[1, 7, 4, 8],
[2, 5, 7, 9],
[3, 6, 9, 8],
[4, 5, 6, 8],
[4, 5, 8, 7],
[9, 5, 8, 6],
[9, 5, 7, 8]])
>>> new_sdf
tensor([[[-1.0000],
[-1.0000],
[ 0.5000],
[ 0.5000],
[-1.0000],
[-0.2500],
[-0.2500],
[-0.2500],
[-0.2500],
[ 0.5000]]])
.. _Deep Marching Tetrahedra\: a Hybrid Representation for High-Resolution 3D Shape Synthesis:
https://arxiv.org/abs/2111.04276
"""

device = vertices.device
all_edges = tetrahedrons[:, base_tet_edges].reshape(-1, 2)
all_edges = _sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
idx_map = idx_map + vertices.shape[1]

pos_feature = torch.cat([vertices, features], -1) if (features is not None) else vertices

mid_pos_feature = pos_feature[:, unique_edges.reshape(-1)].reshape(
pos_feature.shape[0], -1, 2, pos_feature.shape[-1]).mean(2)
new_pos_feature = torch.cat([pos_feature, mid_pos_feature], 1)
new_pos, new_features = new_pos_feature[..., :3], new_pos_feature[..., 3:]

idx_a, idx_b, idx_c, idx_d = torch.split(tetrahedrons, 1, -1)
idx_ab, idx_ac, idx_ad, idx_bc, idx_bd, idx_cd = idx_map.reshape(-1, 6).split(1, -1)

tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)

new_tetrahedrons = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0).squeeze(-1)

return (new_pos, new_tetrahedrons) if features is None else (new_pos, new_tetrahedrons, new_features)
96 changes: 96 additions & 0 deletions tests/python/kaolin/ops/mesh/test_tetmesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,99 @@ def test_inverse_vertices_offset(self):
[79.9999, -149.9999, 10.0000],
[-99.9999, 159.9998, -10.0000]]]])
torch.allclose(tetmesh.inverse_vertices_offset(tetrahedrons), oracle)


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
class TestSubdivideTetmesh:

@pytest.fixture(autouse=True)
def vertices_single_tet(self, device):
return torch.tensor([[[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float, device=device)

@pytest.fixture(autouse=True)
def faces_single_tet(self, device):
return torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device)

@pytest.fixture(autouse=True)
def expected_vertices_single_tet(self, device):
return torch.tensor([[[0.0000, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000],
[0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 0.5000, 0.0000],
[0.5000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.5000]]], dtype=torch.float, device=device)

@pytest.fixture(autouse=True)
def expected_faces_single_tet(self, device):
return torch.tensor([[0, 4, 5, 6],
[1, 7, 4, 8],
[2, 5, 7, 9],
[3, 6, 9, 8],
[4, 5, 6, 8],
[4, 5, 8, 7],
[9, 5, 8, 6],
[9, 5, 7, 8]], dtype=torch.long, device=device)

@pytest.fixture(autouse=True)
def faces_two_tets(self, device):
return torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long, device=device)

@pytest.fixture(autouse=True)
def expected_faces_two_tets(self, device):
return torch.tensor([[0, 4, 5, 6],
[0, 4, 5, 6],
[1, 7, 4, 8],
[1, 7, 4, 8],
[2, 5, 7, 9],
[2, 5, 7, 9],
[3, 6, 9, 8],
[3, 6, 9, 8],
[4, 5, 6, 8],
[4, 5, 6, 8],
[4, 5, 8, 7],
[4, 5, 8, 7],
[9, 5, 8, 6],
[9, 5, 8, 6],
[9, 5, 7, 8],
[9, 5, 7, 8]], dtype=torch.long, device=device)

@pytest.fixture(autouse=True)
def features_single_tet(self, device):
return torch.tensor([[[-1, 2], [-1, 4], [0.5, -2], [0.5, -3]]], dtype=torch.float, device=device)

@pytest.fixture(autouse=True)
def expected_features_single_tet(self, device):
return torch.tensor([[[-1.0000, 2.0000],
[-1.0000, 4.0000],
[0.5000, -2.0000],
[0.5000, -3.0000],
[-1.0000, 3.0000],
[-0.2500, 0.0000],
[-0.2500, -0.5000],
[-0.2500, 1.0000],
[-0.2500, 0.5000],
[0.5000, -2.5000]]], dtype=torch.float, device=device)

def test_subdivide_tetmesh_no_features(self, vertices_single_tet, faces_single_tet, expected_vertices_single_tet, expected_faces_single_tet):
new_vertices, new_faces = tetmesh.subdivide_tetmesh(vertices_single_tet, faces_single_tet)
assert torch.equal(new_vertices, expected_vertices_single_tet)
assert torch.equal(new_faces, expected_faces_single_tet)

def test_subdivide_tetmesh_no_features(self, vertices_single_tet, faces_single_tet, expected_vertices_single_tet, expected_faces_single_tet, features_single_tet, expected_features_single_tet):
new_vertices, new_faces, new_features = tetmesh.subdivide_tetmesh(
vertices_single_tet, faces_single_tet, features_single_tet)
assert torch.equal(new_vertices, expected_vertices_single_tet)
assert torch.equal(new_faces, expected_faces_single_tet)
assert torch.equal(new_features, expected_features_single_tet)

def test_subdivide_tetmesh_shared_verts(self, vertices_single_tet, faces_two_tets, expected_vertices_single_tet, expected_faces_two_tets, features_single_tet, expected_features_single_tet):
# check if redundant vertices are generated
new_vertices, new_faces, new_features = tetmesh.subdivide_tetmesh(
vertices_single_tet, faces_two_tets, features_single_tet)
assert torch.equal(new_vertices, expected_vertices_single_tet)
assert torch.equal(new_faces, expected_faces_two_tets)
assert torch.equal(new_features, expected_features_single_tet)

0 comments on commit 7464f76

Please sign in to comment.