In [7]:
import trimesh
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from factok.model import FaceTokenConfig, FaceTokenModel
from factok.utils.mesh import argsort_face_vertices, normalize_mesh
from factok.utils.plot import plot_pointcloud

rng = np.random.default_rng()


def normalize_face_vertices_batch(face_vertices):
    """
    Normalize a batch of face vertices to be centered at origin and bounded within [-1, 1].

    Args:
        face_vertices: numpy array of shape (N, 3, 3) where:
                       - N: number of faces
                       - 3: vertices per face
                       - 3: xyz coordinates

    Returns:
        Normalized face vertices with same shape as input
    """
    # Calculate face centroids (mean of each face's vertices)
    centroids = np.mean(face_vertices, axis=1, keepdims=True)
    
    # Center faces at origin
    centered = face_vertices - centroids
    
    # Find maximum absolute extent for each face
    face_max_extent = np.max(np.abs(centered), axis=(1, 2), keepdims=True)

    # Normalize to [-1, 1] range
    # Adding small epsilon to prevent division by zero for degenerate faces
    normalized = centered / (face_max_extent + 1e-12)

    return normalized, centroids, (face_max_extent + 1e-12)


mesh, _, _ = normalize_mesh(trimesh.load("../examples/dolphin.obj"), 10)
data = mesh.vertices[
    np.take_along_axis(
        mesh.faces,
        argsort_face_vertices(mesh.vertices[mesh.faces]),
        axis=1
    )
]
data = rng.permutation(data, axis=0)
data, centroids, scales = normalize_face_vertices_batch(data)
data = torch.from_numpy(data).to("cuda:0", dtype=torch.float32)
centroids = torch.from_numpy(centroids.squeeze()).to("cuda:0", dtype=torch.float32)
centroids_norm = centroids.norm(dim=-1, keepdim=True)
centroids_unit = centroids / centroids_norm
scales = torch.from_numpy(scales).view(-1, 1).to("cuda:0", dtype=torch.float32)
data = torch.cat([data, centroids_unit.view(-1, 1, 3)], dim=1)
data = data.view(-1, 12)

In [8]:
config = FaceTokenConfig(
    size_in=12,
    size_out=12,
    size_hidden=512,
    size_intermediate=512,
    num_encoder_layers=4,
    num_decoder_layers=4,
    codebook_dim=128,
    codebook_size=1024,
    codebook_heads=4,
)
model = FaceTokenModel(config).to("cuda:0")
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)

In [9]:
batch_size = 32
losses = []
loop = tqdm(range(5001))
for i in loop:
    faces = data[rng.integers(0, data.size(0), batch_size)]
    reconstructs, loss_commit = model(faces)

    loss = F.mse_loss(reconstructs, faces) + loss_commit
    losses.append(loss.detach().cpu().item())
    loop.set_description(f"total loss = {losses[-1]:.8f}")
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

total loss = 0.00469922: 100%|██████████| 5001/5001 [01:36<00:00, 52.07it/s]


In [6]:
model_tran = FaceTokenModel(
    FaceTokenConfig(
        size_in=3,
        size_out=3,
        size_hidden=512,
        size_intermediate=512,
        num_encoder_layers=2,
        num_decoder_layers=2,
        codebook_dim=512,
        codebook_size=128,
        codebook_heads=4,
    )
).to("cuda:0")
optimizer_tran = torch.optim.Adam(model_tran.parameters(), lr=3e-4, weight_decay=1e-5)


batch_size = 32
losses = []
loop = tqdm(range(5001))
for i in loop:
    cs = centroids_unit[rng.integers(0, data.size(0), batch_size)]
    reconstructs, loss_commit = model_tran(cs)

    loss = F.mse_loss(reconstructs, cs) + loss_commit
    losses.append(loss.detach().cpu().item())
    loop.set_description(f"total loss = {losses[-1]:.8f}")
    loss.backward()
    optimizer_tran.step()
    optimizer_tran.zero_grad()

total loss = 0.00067574: 100%|██████████| 5001/5001 [01:25<00:00, 58.65it/s]


In [12]:
with torch.no_grad():
    reconstructs, _ = model(data)
    reconstructs = reconstructs.view(-1, 4, 3)
    vs = reconstructs[:, :3].view(-1, 3, 3)
    cs = reconstructs[:, 3].view(-1, 1, 3) * centroids_norm.view(-1, 1, 1)

In [13]:
import trimesh


def create_mesh_from_faces(face_vertices):
    """
    Create a trimesh.Trimesh object from face vertex data.
    
    Args:
        face_vertices: numpy array of shape (n, 3, 3) where:
                      - n: number of faces
                      - 3: vertices per face
                      - 3: xyz coordinates
    
    Returns:
        trimesh.Trimesh object
    """
    # Reshape to get all vertices (n*3, 3)
    all_vertices = face_vertices.reshape(-1, 3)
    
    # Find unique vertices and inverse indices
    vertices, indices = np.unique(all_vertices, 
                                axis=0, 
                                return_inverse=True)
    
    # Reshape indices to get faces (n, 3)
    faces = indices.reshape(-1, 3)
    
    # Create the mesh
    mesh = trimesh.Trimesh(vertices=vertices, 
                          faces=faces)
    
    return mesh


re_mesh = (vs * scales.view(-1, 1, 1) + cs * centroids_norm.view(-1, 1, 1)).cpu().numpy()
re_mesh = create_mesh_from_faces(re_mesh)
scene = trimesh.Scene([re_mesh])
scene.show()

In [24]:
reconstructs.view(-1, 3, 3) * scales.view(-1, 1, 1)

tensor([[[ 0.0699, -0.0459, -0.0474],
         [-0.0270,  0.0373,  0.0684],
         [-0.0436,  0.0089, -0.0223]],

        [[ 0.0043, -0.0325, -0.1277],
         [-0.0191,  0.0700, -0.0055],
         [ 0.0154, -0.0384,  0.1350]],

        [[-0.0170, -0.0548, -0.1258],
         [ 0.0286,  0.0918,  0.1896],
         [-0.0149, -0.0376, -0.0664]],

        ...,

        [[ 0.0515,  0.0807, -0.0547],
         [ 0.1030,  0.0276,  0.0126],
         [-0.1572, -0.1090,  0.0448]],

        [[ 0.0736, -0.0196, -0.0692],
         [-0.0968,  0.0387, -0.0452],
         [ 0.0236, -0.0192,  0.1145]],

        [[ 0.0130, -0.0649, -0.1088],
         [-0.0449,  0.0898, -0.0006],
         [ 0.0328, -0.0259,  0.1107]]], device='cuda:0')

torch.Size([5120, 1])