In [None]:
%load_ext autoreload
%autoreload 2


import sys
from pathlib import Path
import os

repo_root = Path(os.getcwd()).parent.resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

import trigrad
import torch
import matplotlib.pyplot as plt
import time
from tests.testcases import *
from mpl_toolkits.axes_grid1 import make_axes_locatable

torch.set_default_dtype(trigrad.precision)
torch.set_default_device("cuda")
print(trigrad.precision)

def aligned_colorbar():
    divider = make_axes_locatable(plt.gca())
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(cax=cax)

def imshow(im: torch.Tensor, depthmap: torch.Tensor = None):
    im_np = im.detach().cpu().numpy().clip(0, 1)
    im_np[...,3] = 1-im_np[...,3]
    if depthmap is not None:
        plt.subplot(1, 2, 2)
        depth_np = depthmap.detach().cpu().numpy()
        plt.imshow(depth_np, extent=[-1, 1, -1, 1], origin="lower", cmap="plasma")
        plt.title("Depth Map")
        aligned_colorbar()
        plt.subplot(1, 2, 1)
        plt.title("Rendered Image")
    plt.imshow(im_np, extent=[-1, 1, -1, 1], origin="lower")

def cu_time(func, *args, **kwargs):
    torch.cuda.synchronize()
    start = time.perf_counter()
    result = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = time.perf_counter()
    return result, end - start


# Render

## Small Triangles

In [None]:
vertices, indices, colors = grid_mesh(500, 500, -0.9, 0.9, -0.9, 0.9)
opacities = torch.ones(vertices.shape[0])
im, t = cu_time(trigrad.render, vertices, indices, colors, opacities)
imshow(*im)
plt.suptitle(f"{vertices.shape[0]} vertices, {indices.shape[0]} triangles\n{t*1000:.2f} ms")
plt.show()

In [None]:
# 10.9ms without per pixel sort
# 11.6ms with per pixel sort
# 11.5ms with better pixel sort
%timeit trigrad.render(vertices, indices, colors, opacities); torch.cuda.synchronize()

## Large Triangles

In [None]:
vertices, indices, colors = overlapping_squares(1000)
opacities = torch.ones(vertices.shape[0]) * 0.01

im, t = cu_time(trigrad.render, vertices, indices, colors, opacities, max_layers=-1)
imshow(*im)
plt.suptitle(f"{vertices.shape[0]} vertices, {indices.shape[0]} triangles\n{t*1000:.2f} ms")
plt.show()

In [None]:
%timeit trigrad.render(vertices, indices, colors, opacities); torch.cuda.synchronize()

# Depth sorting test

In [None]:
N = 10
vertices, indices, colors = overlapping_squares(N)
opacities = torch.ones(vertices.shape[0])
perm = torch.randperm(indices.shape[0])
indices = indices[perm]

im = trigrad.render(vertices, indices, colors, opacities)
imshow(*im)
plt.show()

# Transparency

In [None]:
vertices, indices, colors = overlapping_triangles()
opacities = torch.full_like(vertices[:, 0], 0.5)
im = trigrad.render(vertices, indices, colors, opacities)
plt.suptitle("Red is in front of blue")
imshow(*im)
plt.show()

# Intersecting Triangles

In [None]:
vertices, indices, colors, opacities = depth_overlap()
im = trigrad.render(vertices, indices, colors, opacities)
imshow(*im)
plt.show()

# Clip Test

In [None]:
def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = np.tan((fovY / 2))
    tanHalfFovX = np.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P


import torch


def split_clipped(clip_vertices, indices, vertex_colors, vertex_opacities, near_plane_w=1e-5):
    """
    Clips triangles against the near plane (w > epsilon).

    Args:
        clip_vertices: Tensor (N, 4) - Homogeneous coords (x, y, z, w)
        indices: Tensor (M, 3) - Vertex indices per triangle
        vertex_colors: Tensor (N, C) - RGB(A) colors
        vertex_opacities: Tensor (N, 1) - Opacity
        near_plane_w: float - The w-value to clip against (singularity protection)

    Returns:
        new_vertices, new_indices, new_colors, new_opacities
        (Returned as a "Triangle Soup" - indices will just be 0,1,2, 3,4,5...)
    """

    # Output lists
    out_verts = []
    out_colors = []
    out_opacities = []

    # Helper to interpolate between two vertices
    def interpolate(idx_a, idx_b, t):
        # Position
        va = clip_vertices[idx_a]
        vb = clip_vertices[idx_b]
        v_new = va + t * (vb - va)

        # Color
        ca = vertex_colors[idx_a]
        cb = vertex_colors[idx_b]
        c_new = ca + t * (cb - ca)

        # Opacity
        oa = vertex_opacities[idx_a]
        ob = vertex_opacities[idx_b]
        o_new = oa + t * (ob - oa)

        return v_new, c_new, o_new

    # Loop over all triangles
    # (Note: For massive meshes, this should be vectorized or done in a CUDA kernel,
    # but for Python prototyping, a loop is clearest)
    for tri in indices:
        idx0, idx1, idx2 = tri[0], tri[1], tri[2]

        v0 = clip_vertices[idx0]
        v1 = clip_vertices[idx1]
        v2 = clip_vertices[idx2]

        # Calculate signed distance to the plane (w - epsilon)
        # Positive = Inside (Keep), Negative = Outside (Clip)
        d0 = v0[2] + v0[3]  # z + w
        d1 = v1[2] + v1[3]
        d2 = v2[2] + v2[3]

        inside0 = d0 >= 0
        inside1 = d1 >= 0
        inside2 = d2 >= 0

        inside_count = int(inside0) + int(inside1) + int(inside2)

        if inside_count == 3:
            # All inside: Keep original triangle
            out_verts.extend([v0, v1, v2])
            out_colors.extend([vertex_colors[idx0], vertex_colors[idx1], vertex_colors[idx2]])
            out_opacities.extend([vertex_opacities[idx0], vertex_opacities[idx1], vertex_opacities[idx2]])

        elif inside_count == 0:
            # All outside: Discard
            continue

        elif inside_count == 1:
            # 1 Inside, 2 Outside -> Triangle becomes smaller (1 new triangle)
            # We keep the inside vertex, and compute 2 new ones on the edges.

            # Identify which vertex is inside
            if inside0:
                in_idx, out1_idx, out2_idx = idx0, idx1, idx2
                d_in, d_out1, d_out2 = d0, d1, d2
            elif inside1:
                in_idx, out1_idx, out2_idx = idx1, idx2, idx0
                d_in, d_out1, d_out2 = d1, d2, d0
            else:
                in_idx, out1_idx, out2_idx = idx2, idx0, idx1
                d_in, d_out1, d_out2 = d2, d0, d1

            # Compute t values for interpolation
            # t = distance_in / (distance_in - distance_out)
            t1 = d_in / (d_in - d_out1)
            t2 = d_in / (d_in - d_out2)

            # Generate new vertices
            v_new1, c_new1, o_new1 = interpolate(in_idx, out1_idx, t1)
            v_new2, c_new2, o_new2 = interpolate(in_idx, out2_idx, t2)

            # Add the single triangle
            # Preserve Winding Order: In -> New1 -> New2
            out_verts.extend([clip_vertices[in_idx], v_new1, v_new2])
            out_colors.extend([vertex_colors[in_idx], c_new1, c_new2])
            out_opacities.extend([vertex_opacities[in_idx], o_new1, o_new2])

        elif inside_count == 2:
            # 2 Inside, 1 Outside -> Triangle becomes a Quad -> Split into 2 Triangles

            # Identify which vertex is OUTSIDE
            if not inside0:
                out_idx, in1_idx, in2_idx = idx0, idx1, idx2
                d_out, d_in1, d_in2 = d0, d1, d2
            elif not inside1:
                out_idx, in1_idx, in2_idx = idx1, idx2, idx0
                d_out, d_in1, d_in2 = d1, d2, d0
            else:
                out_idx, in1_idx, in2_idx = idx2, idx0, idx1
                d_out, d_in1, d_in2 = d2, d0, d1

            # Interpolate on the two edges connected to the outside vertex
            # Note: We interpolate FROM the Inside vertex TO the Outside vertex
            # to maintain consistent direction, or just be careful with t.
            # Here: Intersection is between In and Out.

            t1 = d_in1 / (d_in1 - d_out)
            t2 = d_in2 / (d_in2 - d_out)

            v_new1, c_new1, o_new1 = interpolate(in1_idx, out_idx, t1)
            v_new2, c_new2, o_new2 = interpolate(in2_idx, out_idx, t2)

            # We now have a Quad: In1, In2, New2, New1
            # Split into 2 triangles:
            # T1: In1, In2, New1
            # T2: New1, In2, New2
            # (Winding order must be preserved carefully)

            # Triangle 1
            out_verts.extend([clip_vertices[in1_idx], clip_vertices[in2_idx], v_new1])
            out_colors.extend([vertex_colors[in1_idx], vertex_colors[in2_idx], c_new1])
            out_opacities.extend([vertex_opacities[in1_idx], vertex_opacities[in2_idx], o_new1])

            # Triangle 2
            out_verts.extend([v_new1, clip_vertices[in2_idx], v_new2])
            out_colors.extend([c_new1, vertex_colors[in2_idx], c_new2])
            out_opacities.extend([o_new1, vertex_opacities[in2_idx], o_new2])

    # Convert lists back to tensors
    if len(out_verts) == 0:
        return (torch.empty(0, 4), torch.empty(0, 3, dtype=torch.int32), torch.empty(0, vertex_colors.shape[1]), torch.empty(0, 1))

    out_verts_t = torch.stack(out_verts)
    out_colors_t = torch.stack(out_colors)
    out_opacities_t = torch.stack(out_opacities)

    # Generate simple sequential indices (0, 1, 2), (3, 4, 5)...
    num_tris = len(out_verts) // 3
    out_indices_t = torch.arange(len(out_verts), dtype=torch.int32).view(num_tris, 3)

    return out_verts_t, out_indices_t, out_colors_t, out_opacities_t


proj = getProjectionMatrix(0.1, 100.0, np.radians(60.0), np.radians(60.0))
vertices = torch.tensor(
    [
        [-1, -1, 2],
        [1.0, -1, 2.0],
        [0.0, 0, 1.0],
    ]
)
indices = torch.tensor([[0, 1, 2]], dtype=torch.int32)
colors = torch.tensor(
    [
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, 0.0, 1.0],
    ]
)


opacities = torch.tensor([1.0, 1.0, 1.0])
verts_h = torch.cat([vertices, torch.ones(vertices.shape[0], 1, device=vertices.device)], dim=1)
verts_2d = verts_h @ proj.T
# print(verts_2d)
verts_2d, indices, colors, opacities = split_clipped(verts_2d, indices, colors, opacities)
# print(verts_2d)
ndc = torch.ones_like(verts_2d)

ndc[..., :3] = verts_2d[..., :3] / (verts_2d[..., 3:4])
ndc[..., 3] = 1 / (verts_2d[..., 3])
# print(ndc)
im = trigrad.render(ndc, indices, colors, opacities)
imshow(*im)
plt.show()

In [None]:
_, d = im
d = d.cpu().numpy()
plt.subplot(1, 2, 1)
plt.imshow(d, origin="lower")
plt.colorbar()
plt.subplot(1, 2, 2)
# ds = np.nan_to_num(, posinf=0)
ds = d[20:250, 200]
plt.plot(ds)
plt.tight_layout()

# Test early stopping

In [None]:
N = 10000
opacity = 0.1
th = 1 / 256
vertices = torch.zeros((N * 4, 4))
vertices[:, 3] = 1.0
vertices[::4, 0] = -0.9
vertices[::4, 1] = -0.9
vertices[1::4, 0] = 0.9
vertices[1::4, 1] = -0.9
vertices[2::4, 0] = 0.9
vertices[2::4, 1] = 0.9
vertices[3::4, 0] = -0.9
vertices[3::4, 1] = 0.9
indices = torch.zeros((N * 2, 3), dtype=torch.int32)
indices[::2, 0] = torch.arange(0, N * 4, 4)
indices[::2, 1] = torch.arange(1, N * 4, 4)
indices[::2, 2] = torch.arange(2, N * 4, 4)
indices[1::2, 0] = torch.arange(0, N * 4, 4)
indices[1::2, 1] = torch.arange(2, N * 4, 4)
indices[1::2, 2] = torch.arange(3, N * 4, 4)
vertices[indices[::2], 2] = torch.linspace(0, 1, N).unsqueeze(1)
vertices[indices[1::2], 2] = torch.linspace(0, 1, N).unsqueeze(1)
colors = torch.rand((N * 4, 3))
opacities = torch.full((N * 4,), opacity)


plt.figure(figsize=(15, 5))
plt.suptitle(f"{N} overlapping squares with $\\sigma$ = {opacity}")
plt.subplot(1, 3, 1)
im1, t1 = cu_time(trigrad.render, vertices, indices, colors, opacities, early_stopping_threshold=0, per_pixel_sort=False, max_layers=-1)
torch.cuda.synchronize()
end = time.perf_counter()
plt.title(f"{(t1)*1000:.3f} ms without early stopping")
imshow(im1[0])
plt.subplot(1, 3, 2)
im2, t2 = cu_time(trigrad.render, vertices, indices, colors, opacities, early_stopping_threshold=th, per_pixel_sort=False, max_layers=-1)
n_early = int(np.ceil((np.log(th) / np.log(1 - opacity))))
plt.title(f"{(t2)*1000:.3f} ms with early stopping at $\\alpha$ = {th:.4f}\nrenders {n_early}/{N}")
imshow(im2[0])
plt.subplot(1, 3, 3)
diff = im1[0] - im2[0]
plt.title(f"Difference\nmax {diff.max():.3f}")
imshow(diff)

plt.tight_layout()
plt.show()

# Colors Backwards Test

In [None]:
target_color = torch.rand((4,))
target_color[3] = 1.0
vertices, indices, colors = overlapping_squares(10)
opacities = torch.ones(vertices.shape[0])
im = trigrad.render(vertices, indices, colors, opacities)
colors.requires_grad_()
im, d = trigrad.render(vertices, indices, colors, opacities)
plt.figure(figsize=(15, 5))
plt.suptitle(f"Optimize towards {target_color.cpu().numpy()}", color=target_color.cpu().numpy())
plt.subplot(1, 3, 1)
plt.title("Initial Image")
imshow(im)
optim = torch.optim.Adam([colors], lr=0.1)
for i in range(5):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, opacities)
    loss = torch.mean((im - target_color) ** 2)
    loss.backward()
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, opacities)
plt.subplot(1, 3, 2)
plt.title("After 5 iterations")
imshow(im)
for i in range(95):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, opacities)
    loss = torch.mean((im - target_color) ** 2)
    loss.backward()
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, opacities)
plt.subplot(1, 3, 3)
plt.title("After 100 iterations")
imshow(im)
plt.show()

# Transparency Backwards Test

In [None]:
N = 10
target_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
vertices, indices, colors = overlapping_squares(10)
opacities = torch.ones(vertices.shape[0])
opacities.requires_grad_()
im, d = trigrad.render(vertices, indices, colors, torch.sigmoid(opacities))

plt.figure(figsize=(15, 5))
plt.suptitle(f"Optimize towards {target_color.cpu().numpy()}")
plt.subplot(1, 3, 1)
plt.title("Initial Image")
imshow(im)
optim = torch.optim.Adam([opacities], lr=0.1)
for i in range(15):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, torch.sigmoid(opacities))
    loss = torch.mean((im - target_color) ** 2)
    loss.backward()
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, torch.sigmoid(opacities))
plt.subplot(1, 3, 2)
plt.title("After 15 iterations")
imshow(im)
for i in range(85):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, torch.sigmoid(opacities))
    loss = torch.mean((im - target_color) ** 2)
    loss.backward()
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, torch.sigmoid(opacities))
plt.subplot(1, 3, 3)
plt.title("After 100 iterations")
imshow(im)
plt.show()

# Vertices Backwards Test

In [None]:
vertices, indices, colors = test_square()
opacities = torch.ones(vertices.shape[0])
im, d = trigrad.render(vertices, indices, colors, opacities)
target = im.detach().clip(0, 1.0)

plt.figure(figsize=(20, 5))
plt.subplot(1, 4, 1)
plt.title("Target")
imshow(target)
plt.subplot(1, 4, 2)
plt.title("initial config")
vertices[..., 0] = vertices[..., 0] / 2
vertices[..., 1] = vertices[..., 1] / 2
im, d = trigrad.render(vertices, indices, colors, opacities)
imshow(im)
plt.subplot(1, 4, 3)
plt.title("After 45 iterations")
vertices.requires_grad_()
optim = torch.optim.Adam([vertices], lr=0.002)
for i in range(45):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, opacities)
    loss = torch.mean((im - target) ** 2)
    loss.backward()
    vertices.grad[:, 3] = 0

    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, opacities)
imshow(im)
plt.subplot(1, 4, 4)
plt.title("After 200 iterations")
for i in range(155):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, opacities)
    loss = torch.mean((im - target) ** 2)
    loss.backward()
    vertices.grad[:, 3] = 0
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, opacities)
imshow(im)
plt.show()

# Depthmap Backwards Test

In [None]:
vertices, indices, colors, opacities = depth_overlap()
im, d = trigrad.render(vertices, indices, colors, opacities)

plt.suptitle("initial config")
im, d = trigrad.render(vertices, indices, colors, opacities)
imshow(im, d)
plt.show()
plt.suptitle("After 45 iterations")
vertices.requires_grad_()
optim = torch.optim.Adam([vertices], lr=1e-3)
for i in range(45):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, opacities)
    mask = d != 0
    loss = torch.mean(torch.abs(d[mask] - 2.5))
    loss.backward()
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, opacities)
imshow(im, d)
plt.show()
plt.suptitle("After 200 iterations")
for i in range(155):
    optim.zero_grad()
    im, d = trigrad.render(vertices, indices, colors, opacities)
    mask = d != 0
    loss = torch.mean(torch.abs(d[mask] - 2.5))
    loss.backward()
    optim.step()
with torch.no_grad():
    im, d = trigrad.render(vertices, indices, colors, opacities)
imshow(im, d)
plt.show()