In [78]:
import torch

In [105]:
def closest_point_on_triangle(p, a, b, c):
    """
    p: (N, 3) query points
    a, b, c: (N, 3) triangle vertices
    Returns: (N, 3) closest points on each triangle
    """
    ab = b - a
    ac = c - a
    ap = p - a

    d1 = torch.sum(ab * ap, dim=1, keepdim=True)
    d2 = torch.sum(ac * ap, dim=1, keepdim=True)
    d3 = torch.sum(ab * ab, dim=1, keepdim=True)
    d4 = torch.sum(ab * ac, dim=1, keepdim=True)
    d5 = torch.sum(ac * ac, dim=1, keepdim=True)

    denom = d3 * d5 - d4 * d4
    v = (d5 * d1 - d4 * d2) / (denom + 1e-10)
    w = (d3 * d2 - d4 * d1) / (denom + 1e-10)
    u = 1 - v - w

    # Clamp barycentric coordinates to the triangle
    v_clamped = torch.clamp(v, 0, 1)
    w_clamped = torch.clamp(w, 0, 1)
    u_clamped = torch.clamp(1 - v_clamped - w_clamped, 0, 1)

    # Recompute v and w to make sure they sum to ≤ 1
    sum_clamped = u_clamped + v_clamped + w_clamped
    u = u_clamped / sum_clamped
    v = v_clamped / sum_clamped
    w = w_clamped / sum_clamped

    return a * u + b * v + c * w

In [106]:
query_point = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32)

In [107]:
a = torch.tensor([[1, 0, 0]], dtype=torch.float32)

In [108]:
b = torch.tensor([[0, 1, 0]], dtype=torch.float32)

In [109]:
c = torch.tensor([[0, 0, 1]], dtype=torch.float32)

In [110]:
closest_point_on_triangle(query_point, a, b, c)

tensor([[0.3333, 0.3333, 0.3333]])

In [112]:
def test_closest_point_on_triangle():
    a = torch.tensor([[0., 0., 0.]])
    b = torch.tensor([[1., 0., 0.]])
    c = torch.tensor([[0., 1., 0.]])

    # Test points
    test_points = torch.tensor([
        [0.25, 0.25, 0.0],   # inside triangle
        [-1.0, -1.0, 0.0],   # closest to vertex A
        [0.5, -0.5, 0.0],    # closest to edge AB
        [-0.5, 0.5, 0.0],    # closest to edge AC
        [0.5, 0.5, 0.0],     # closest to edge BC
        [0.25, 0.25, 1.0],   # above triangle, project to face
    ])

    expected = torch.tensor([
        [0.25, 0.25, 0.0],  # same (inside)
        [0.0, 0.0, 0.0],    # vertex A
        [0.5, 0.0, 0.0],    # projected to AB
        [0.0, 0.5, 0.0],    # projected to AC
        [0.5, 0.5, 0.0],    # projected to BC
        [0.25, 0.25, 0.0],  # projected to face
    ])

    n = len(test_points)
    a = a.expand(n, -1)
    b = b.expand(n, -1)
    c = c.expand(n, -1)

    output = closest_point_on_triangle(test_points, a, b, c)

    # Compare outputs
    if not torch.allclose(output, expected, atol=1e-5):
        print("❌ Test failed")
        print("Expected:\n", expected)
        print("Got:\n", output)
    else:
        print("✅ Test passed")

✅ Test passed


In [113]:
test_closest_point_on_triangle()

✅ Test passed
