In [3]:
import torch
import numpy as np

In [4]:
torch.cuda.is_available()

True

In [20]:
def closest_point_on_triangle(p, triangle):
    """
        p: (N, 3) query points
        triangle: (N, 3, 3) triangle vertices
        Returns: (N, 3) closest points on each triangle
    """
    ab = triangle[:, 1] - triangle[:, 0]
    ac = triangle[:, 2] - triangle[:, 0]
    ap = p - triangle[:, 0]

    d1 = torch.sum(ab * ap, dim=1, keepdim=True)
    d2 = torch.sum(ac * ap, dim=1, keepdim=True)
    d3 = torch.sum(ab * ab, dim=1, keepdim=True)
    d4 = torch.sum(ab * ac, dim=1, keepdim=True)
    d5 = torch.sum(ac * ac, dim=1, keepdim=True)

    denom = d3 * d5 - d4 * d4
    v = (d5 * d1 - d4 * d2) / (denom + 1e-10)
    w = (d3 * d2 - d4 * d1) / (denom + 1e-10)
    u = 1 - v - w

    v_clamped = torch.clamp(v, 0, 1)
    w_clamped = torch.clamp(w, 0, 1)
    u_clamped = torch.clamp(1 - v_clamped - w_clamped, 0, 1)

    sum_clamped = u_clamped + v_clamped + w_clamped
    u = u_clamped / sum_clamped
    v = v_clamped / sum_clamped
    w = w_clamped / sum_clamped

    return triangle[:, 0] * u + triangle[:, 1] * v + triangle[:, 2] * w

In [21]:
query_point = torch.tensor([[0.5, 0.5, 0.5]], dtype=torch.float32)

In [22]:
triangle = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float32)

In [23]:
closest_point_on_triangle(query_point, triangle)

tensor([[0.3333, 0.3333, 0.3333]])

In [26]:
def test_closest_point_on_triangle():
    triangle = torch.tensor([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=torch.float32)

    # Test points
    test_points = torch.tensor([
        [0.25, 0.25, 0.0],   # inside triangle
        [-1.0, -1.0, 0.0],   # closest to vertex A
        [0.5, -0.5, 0.0],    # closest to edge AB
        [-0.5, 0.5, 0.0],    # closest to edge AC
        [0.5, 0.5, 0.0],     # closest to edge BC
        [0.25, 0.25, 1.0],   # above triangle, project to face
    ])

    expected = torch.tensor([
        [0.25, 0.25, 0.0],  # same (inside)
        [0.0, 0.0, 0.0],    # vertex A
        [0.5, 0.0, 0.0],    # projected to AB
        [0.0, 0.5, 0.0],    # projected to AC
        [0.5, 0.5, 0.0],    # projected to BC
        [0.25, 0.25, 0.0],  # projected to face
    ])

    n = len(test_points)
    output = closest_point_on_triangle(test_points, triangle)

    # Compare outputs
    if not torch.allclose(output, expected, atol=1e-5):
        print("❌ Test failed")
        print("Expected:\n", expected)
        print("Got:\n", output)
    else:
        print("✅ Test passed")

In [27]:
test_closest_point_on_triangle()

✅ Test passed


## Profile using real triangle data

In [28]:
def load_numpy(filename: str) -> np.ndarray:
    return np.load(f'./numpy/{filename}')

In [29]:
vertices = torch.tensor(load_numpy('piece_vertices.npy')).cuda()

In [30]:
triangles = torch.tensor(load_numpy('triangle_vertices.npy')).cuda()

In [251]:
repeat_triangles = triangles.repeat(3205, 1, 1)

In [274]:
%timeit triangles.repeat(3205, 1, 1)

1.32 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [256]:
repeat_vertices = vertices.repeat_interleave(1968, dim=0)

In [254]:
%timeit vertices.repeat_interleave(1968, dim=0)

419 µs ± 7.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [68]:
%timeit closest_point_on_triangle(repeat_vertices, repeat_triangles)

32 ms ± 46.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [280]:
closest_points = closest_point_on_triangle(repeat_vertices, repeat_triangles)

In [286]:
distances = torch.norm(repeat_vertices - closest_pts, dim=1)

In [288]:
%timeit torch.norm(repeat_vertices - closest_pts, dim=1)

2.02 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


All together about 36 ms

## Double check a single value

In [289]:
distances = distances.view(3205, 1968); distances

tensor([[0.0000, 0.0305, 0.0610,  ..., 0.4243, 0.3152, 0.4412],
        [0.0102, 0.0203, 0.0508,  ..., 0.4146, 0.3050, 0.4321],
        [0.0203, 0.0102, 0.0407,  ..., 0.4048, 0.2948, 0.4231],
        ...,
        [0.8206, 0.8052, 0.7907,  ..., 0.5809, 0.7107, 0.5091],
        [0.8259, 0.8102, 0.7955,  ..., 0.5812, 0.7123, 0.5098],
        [0.8313, 0.8154, 0.8003,  ..., 0.5816, 0.7141, 0.5107]],
       device='cuda:0')

In [267]:
p = closest_point_on_triangle(vertices[2].unsqueeze(0), triangles[1].unsqueeze(0))

In [268]:
v = vertices[2].unsqueeze(0) - p

In [269]:
v **= 2

In [270]:
v.sum()

tensor(0.0102, device='cuda:0')

In [293]:
closest_triangle_ind = torch.argmin(distances, dim=1)

In [294]:
%timeit torch.argmin(distances, dim=1)

91.2 µs ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
