In [1]:
import torch
import numpy as np

In [2]:
torch.cuda.is_available()

True

In [3]:
def closest_point_on_triangle(p, triangle):
    """
        p: (N, 3) query points
        triangle: (N, 3, 3) triangle vertices
        Returns: (N, 3) closest points on each triangle
    """
    ab = triangle[:, 1] - triangle[:, 0]
    ac = triangle[:, 2] - triangle[:, 0]
    ap = p - triangle[:, 0]

    d1 = torch.sum(ab * ap, dim=1, keepdim=True)
    d2 = torch.sum(ac * ap, dim=1, keepdim=True)
    d3 = torch.sum(ab * ab, dim=1, keepdim=True)
    d4 = torch.sum(ab * ac, dim=1, keepdim=True)
    d5 = torch.sum(ac * ac, dim=1, keepdim=True)

    denom = d3 * d5 - d4 * d4
    v = (d5 * d1 - d4 * d2) / (denom + 1e-10)
    w = (d3 * d2 - d4 * d1) / (denom + 1e-10)
    u = 1 - v - w

    v_clamped = torch.clamp(v, 0, 1)
    w_clamped = torch.clamp(w, 0, 1)
    u_clamped = torch.clamp(1 - v_clamped - w_clamped, 0, 1)

    sum_clamped = u_clamped + v_clamped + w_clamped
    u = u_clamped / sum_clamped
    v = v_clamped / sum_clamped
    w = w_clamped / sum_clamped

    return triangle[:, 0] * u + triangle[:, 1] * v + triangle[:, 2] * w

In [4]:
query_point = torch.tensor([[0.5, 0.5, 0.5]], dtype=torch.float32)

In [5]:
triangle = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float32)

In [6]:
closest_point_on_triangle(query_point, triangle)

tensor([[0.3333, 0.3333, 0.3333]])

In [7]:
def test_closest_point_on_triangle():
    triangle = torch.tensor([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=torch.float32)
    triangles = triangle.repeat(6, 1, 1)

    # Test points
    test_points = torch.tensor([
        [0.25, 0.25, 0.0],   # inside triangle
        [-1.0, -1.0, 0.0],   # closest to vertex A
        [0.5, -0.5, 0.0],    # closest to edge AB
        [-0.5, 0.5, 0.0],    # closest to edge AC
        [0.5, 0.5, 0.0],     # closest to edge BC
        [0.25, 0.25, 1.0],   # above triangle, project to face
    ])

    expected = torch.tensor([
        [0.25, 0.25, 0.0],  # same (inside)
        [0.0, 0.0, 0.0],    # vertex A
        [0.5, 0.0, 0.0],    # projected to AB
        [0.0, 0.5, 0.0],    # projected to AC
        [0.5, 0.5, 0.0],    # projected to BC
        [0.25, 0.25, 0.0],  # projected to face
    ])

    n = len(test_points)
    output = closest_point_on_triangle(test_points, triangles)

    # Compare outputs
    if not torch.allclose(output, expected, atol=1e-5):
        print("❌ Test failed")
        print("Expected:\n", expected)
        print("Got:\n", output)
    else:
        print("✅ Test passed")

In [8]:
test_closest_point_on_triangle()

✅ Test passed


## Profile using real triangle data

In [9]:
def load_numpy(filename: str) -> np.ndarray:
    return np.load(f'./numpy/{filename}')

In [10]:
vertices = torch.tensor(load_numpy('piece_vertices.npy')).cuda()

In [11]:
triangles = torch.tensor(load_numpy('triangle_vertices.npy')).cuda()

In [12]:
normals = torch.tensor(load_numpy('triangle_normals.npy')).cuda()

In [13]:
repeat_triangles = triangles.repeat(3205, 1, 1)

In [14]:
%timeit triangles.repeat(3205, 1, 1)

1.37 ms ± 13.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
repeat_vertices = vertices.repeat_interleave(1968, dim=0)

In [16]:
%timeit vertices.repeat_interleave(1968, dim=0)

461 µs ± 9.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
%timeit closest_point_on_triangle(repeat_vertices, repeat_triangles)

6.56 ms ± 1.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
closest_points = closest_point_on_triangle(repeat_vertices, repeat_triangles)

In [19]:
distances = torch.norm(repeat_vertices - closest_points, dim=1)

In [20]:
%timeit torch.norm(repeat_vertices - closest_points, dim=1)

2.01 ms ± 4.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


All together about 10 ms for the bulk of the work.

## Double check a single value

In [21]:
distances = distances.view(3205, 1968); distances

tensor([[0.4446, 0.7661, 0.5039,  ..., 0.7991, 0.7570, 0.7704],
        [0.4463, 0.7588, 0.5065,  ..., 0.7900, 0.7487, 0.7617],
        [0.4483, 0.7515, 0.5094,  ..., 0.7810, 0.7406, 0.7531],
        ...,
        [0.5759, 0.4122, 0.6016,  ..., 0.6302, 0.5080, 0.5612],
        [0.5846, 0.4092, 0.6103,  ..., 0.6257, 0.5043, 0.5612],
        [0.5933, 0.4064, 0.6189,  ..., 0.6213, 0.5009, 0.5602]],
       device='cuda:0')

In [22]:
p = closest_point_on_triangle(vertices[2].unsqueeze(0), triangles[1].unsqueeze(0))

In [23]:
v = vertices[2].unsqueeze(0) - p

In [24]:
v **= 2

In [25]:
v.sum().sqrt()

tensor(0.7515, device='cuda:0')

In [26]:
closest_triangle_inds = torch.argmin(distances, dim=1)

In [27]:
%timeit torch.argmin(distances, dim=1)

121 µs ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [28]:
closest_points = closest_points.view(3205, 1968, 3)

In [29]:
%timeit closest_points[torch.arange(3205), closest_triangle_inds]

269 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [30]:
closest_points = closest_points[torch.arange(3205), closest_triangle_inds]

In [31]:
point_vectors = vertices - closest_points

In [32]:
distance_to_closest = torch.norm(point_vectors, dim=1)

In [33]:
distance_to_closest

tensor([0.1372, 0.1283, 0.1196,  ..., 0.0790, 0.0835, 0.0886], device='cuda:0')

In [34]:
distance_to_closest.max()

tensor(0.2128, device='cuda:0')

In [35]:
vertices[-1]

tensor([-0.1458,  1.8576,  0.1198], device='cuda:0')

In [36]:
normal_vectors = normals[closest_triangle_inds]

In [37]:
triangles[:, :, 1].max()

tensor(2.1796, device='cuda:0')

In [38]:
inside_closest_mesh = torch.sum(point_vectors * normal_vectors, dim=1)

In [39]:
(inside_closest_mesh < 0.).sum()

tensor(352, device='cuda:0')

In [44]:
torch.where(inside_closest_mesh < -1e-7)

(tensor([1680, 1683, 1741, 1798, 2245, 2298, 2299, 2300, 2301, 2302, 2303, 2345,
         2346, 2347, 2348, 2349, 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357,
         2358, 2359, 2360, 2361, 2362, 2406, 2407, 2408, 2456, 2457, 2458, 2459,
         2498, 2499, 2501, 2549, 2600, 2609, 2651, 2655, 2690, 2701, 2741, 2742,
         2793, 2945], device='cuda:0'),)

In [46]:
len(torch.where(inside_closest_mesh < -1e-7)[0])

50

In [40]:
%timeit torch.sum(point_vectors * normal_vectors, dim=1)

204 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
