In [None]:
import numpy as np
import torch
import tqdm
import trimesh
from torch.utils.data import DataLoader

from meshae.config import MeshAEFeatEmbedConfig
from meshae.data import MeshAEDataset
from meshae.model import MeshAEModel


feat_configs = {
    "vrtx": MeshAEFeatEmbedConfig(high_low=(1.0, -1.0)),
    "acos": MeshAEFeatEmbedConfig(high_low=(np.pi, 0.0)),
    "norm": MeshAEFeatEmbedConfig(high_low=(1.0, -1.0), num_bins=512),
    "area": MeshAEFeatEmbedConfig(high_low=(1.0, 0.0)),
}
model = MeshAEModel(
    feat_configs,
    num_sageconv_layers=1,
    num_quantizers=2,
    num_encoder_layers=2,
    num_decoder_layers=2,
    num_refiner_layers=2,
).to("cuda:0")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)

In [2]:
dataset = MeshAEDataset("../data/objaverse/train/")
dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

loop = tqdm.tqdm(dataloader)
losses = []
for i, batch in enumerate(loop):
    batch = {k: v.to("cuda:0") for k, v in batch.items()}
    l, b, p, c = model(**batch)
    losses.append(l.detach().cpu().item())
    loop.set_description(f"total loss = {losses[-1]:.8f}")
    l.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i > 16:
        break

  0%|          | 0/132 [00:00<?, ?it/s]

total loss = 4.15646648:  13%|█▎        | 17/132 [01:13<08:16,  4.32s/it]


In [3]:
from meshae.utils import dequantize


cfg = feat_configs["vrtx"]
faces = dequantize(
    p.argmax(-1).detach(),
    high_low=cfg.high_low,
    num_bins=cfg.num_bins,
).reshape(p.size(0), -1, 3, 3)

In [4]:
face = faces[0][batch["face_masks"][0].cpu()].cpu().numpy()
face.shape

(489, 3, 3)

In [5]:
def create_mesh_from_faces(face_vertices):
    """
    Create a trimesh.Trimesh object from face vertex data.
    
    Args:
        face_vertices: numpy array of shape (n, 3, 3) where:
                      - n: number of faces
                      - 3: vertices per face
                      - 3: xyz coordinates
    
    Returns:
        trimesh.Trimesh object
    """
    # Reshape to get all vertices (n*3, 3)
    all_vertices = face_vertices.reshape(-1, 3)
    
    # Find unique vertices and inverse indices
    vertices, indices = np.unique(all_vertices, 
                                axis=0, 
                                return_inverse=True)
    
    # Reshape indices to get faces (n, 3)
    faces = indices.reshape(-1, 3)
    
    # Create the mesh
    mesh = trimesh.Trimesh(vertices=vertices, 
                          faces=faces)
    
    return mesh


mesh = create_mesh_from_faces(face)
mesh.show()

NameError: name 'trimesh' is not defined