Skip to content

Commit

Permalink
Implement closest_point function (#173)
Browse files Browse the repository at this point in the history
Background:

Training SDF + texture model was not available because closest_point
function wasn't implemented. (see #162)
To traing the SDF + texture model, I implement closest_point function.

Changes:

1. **Change all float variables** in
wisp/csrc/external/mesh2sdf_kernel.cu **into double variables**.
Previously, when using float variables, there were accumulated numerical
errors on SDF and closest points' coordinates. (the numerical SDF errors
even reaches 1.xxx ...)
2. Add mesh2sdf_triangle_gpu_fast_nopre function in
wisp/csrc/external/mesh2sdf_kernel.cu which can **return both SDF and
closest triangle**.
3. Implement closest_point function in wisp/ops/mesh/closest_point.py.
closest points are calculated from closest triangles returned from
mesh2sdf_triangle_gpu_fast_nopre function.

Tests:

Unfortunately, I'm not good at pytest code. I just tested the code by
printing results on terminal.

Test process:

1. Add print() in closest_tex() function (wisp/ops/mesh/closest_tex.py) 
```python
import torch
import numpy as np
from .barycentric_coordinates import barycentric_coordinates
from .closest_point import closest_point
from .sample_tex import sample_tex

def closest_tex(
    V : torch.Tensor, 
    F : torch.Tensor,
    TV : torch.Tensor,
    TF : torch.Tensor,
    materials,
    points : torch.Tensor):
    """Returns the closest texture for a set of points.

        V (torch.FloatTensor): mesh vertices of shape [V, 3] 
        F (torch.LongTensor): mesh face indices of shape [F, 3]
        TV (torch.FloatTensor): 
        TF (torch.FloatTensor):
        materials:
        points (torch.FloatTensor): sample locations of shape [N, 3]

    Returns:
        (torch.FloatTensor): texture samples of shape [N, 3]
    """

    TV = TV.to(V.device)
    TF = TF.to(V.device)
    points = points.to(V.device)

    dist, hit_pts, hit_tidx = closest_point(V, F, points)

    ##### !!!test by printing!!! #####
    ######################### 
    print(dist)  # should be same with below result
    print(((hit_pts - points) ** 2).sum(dim=1).sqrt())  # distance between hit_pts and points should be same with dist
    ######################### 
    ######################### 

    hit_F = F[hit_tidx]
    hit_V = V[hit_F]
    BC = barycentric_coordinates(hit_pts.cuda(), hit_V[:,0], hit_V[:,1], hit_V[:,2])

    hit_TF = TF[hit_tidx]
    hit_TM = hit_TF[...,3]
    hit_TF = hit_TF[...,:3]

    if TV.shape[0] > 0:
        hit_TV = TV[hit_TF]
        hit_Tp = (hit_TV * BC.unsqueeze(-1)).sum(1)
    else:
        hit_Tp = BC
    
    rgb = sample_tex(hit_Tp, hit_TM, materials)

    return rgb, hit_pts, dist

```
2. Run following code
```
from wisp.datasets.formats import MeshSampledSDFDataset


dataset = MeshSampledSDFDataset(mesh_path="data/obj/mesh.obj", split="train", sample_mode=["trace"], num_samples=5000, sample_tex=True, mode_norm="sphere")
```

Results:

I also trained the SDF + texture model on some .obj files after
implementing closest_point function. (However, I don't include training
code on this PR)

First image is .obj model in blender and second image is the SDF +
texture model result. (wrong eye coloring is due to wrong .obj texture
data, so you can ignore it)

![Untitled
(9)](https://github.com/NVIDIAGameWorks/kaolin-wisp/assets/98357201/f58e8554-6bf5-4a18-89e7-4ce71b2aa8ed)
![Untitled
(10)](https://github.com/NVIDIAGameWorks/kaolin-wisp/assets/98357201/2bccf84e-19a6-470a-94d9-2f6ab24870f2)
  • Loading branch information
jskim-research committed Aug 3, 2023
1 parent 07539e4 commit 246b1ec
Show file tree
Hide file tree
Showing 5 changed files with 886 additions and 356 deletions.
1 change: 1 addition & 0 deletions wisp/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
render.def("find_depth_bound_cuda", &find_depth_bound_cuda);
py::module external = m.def_submodule("external");
external.def("mesh_to_sdf_cuda", &mesh_to_sdf_cuda);
external.def("mesh_to_sdf_triangle_cuda", &mesh_to_sdf_triangle_cuda);
py::module ops = m.def_submodule("ops");
ops.def("hashgrid_interpolate_cuda", &hashgrid_interpolate_cuda);
ops.def("hashgrid_interpolate_backward_cuda", &hashgrid_interpolate_backward_cuda);
Expand Down
Loading

0 comments on commit 246b1ec

Please sign in to comment.