In [1]:
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

import trimesh
import torch
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce
from scipy.spatial.transform import Rotation
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchtyping import TensorType

from meshae.utils import compute_face_edges, compute_normalized_mesh, compute_sorted_faces


def collate_fn(data, padding_idx: int = -1):

    def _collate(key, create_mask=False):
        batch, s_len = [], []
        for datum in data:
            batch.append(datum[key])
            s_len.append(datum[key].size(0))

        batch = pad_sequence(batch, batch_first=True, padding_value=padding_idx)
        if create_mask:
            array = torch.arange(max(s_len), device=batch.device).unsqueeze(0)
            masks = torch.cat(
                [array < len_ for len_ in s_len],
                dim=0,
            )
            return batch, masks

        return batch

    faces, face_masks = _collate("faces", create_mask=True)
    edges, edge_masks = _collate("edges", create_mask=True)
    batch = {
        "vertices": _collate("vertices"),
        "faces": faces,
        "edges": edges,
        "face_masks": face_masks,
        "edge_masks": edge_masks,
    }

    return batch


class MeshAEDataset(Dataset):
    r"""
    """

    def __init__(
        self,
        path: Path,
        *,
        sort_face_by: str = "zxy",
        neighbor_if_share_one_vertex: bool = False,
        include_self: bool = False,
    ) -> None:
        super().__init__()

        self.path = Path(path)
        self.objects = list(self.path.glob("*.glb"))
        if len(self.objects) == 0:
            msg = f"No '.glb' object found under directory <{str(path)}>."
            raise RuntimeError(msg)

        self.sort_face_by = sort_face_by
        self.neighbor_if_share_one_vertex = neighbor_if_share_one_vertex
        self.include_self = include_self

    def __len__(self) -> int:
        return len(self.objects)

    def __getitem__(self, idx: int) -> dict:
        return self.load_and_process(self.objects[idx])

    def load_and_process(self, path: Path) -> dict[str, TensorType]:
        r"""Load and normalize a single mesh object from given path.

        Parameters
        ----------
        path : pathlib.Path
            Path to the source mesh object.

        Returns
        -------
        faces : TensorType["n_face", 3, int]
            Face tensor.
        edges : TensorType["n_edge", 2, int]
            Face edge tensor.
        vertices : TensorType["n_vrtx", 3, 3, float]
            Vertex tensor.
        """
        mesh = trimesh.load(path, file_type="glb", force="mesh", process=False)
        mesh, _, _ = compute_normalized_mesh(mesh)

        vertices = torch.from_numpy(mesh.vertices)
        faces = compute_sorted_faces(mesh, by=self.sort_face_by, return_tensor=True)
        edges = compute_face_edges(
            faces,
            neighbor_if_share_one_vertex=self.neighbor_if_share_one_vertex,
            include_self=self.include_self,
        )
        return {"faces": faces, "edges": edges, "vertices": vertices}

In [2]:
from meshae.config import MeshAEFeatEmbedConfig
from meshae.model import MeshAEModel


feat_configs = {
    "vrtx": MeshAEFeatEmbedConfig(high_low=(1.0, -1.0)),
    "acos": MeshAEFeatEmbedConfig(high_low=(np.pi, 0.0)),
    "norm": MeshAEFeatEmbedConfig(high_low=(1.0, -1.0)),
    "area": MeshAEFeatEmbedConfig(high_low=(1.0, 0.0)),
}
model = MeshAEModel(feat_configs, num_sageconv_layers=0, num_quantizers=2)

dataset = MeshAEDataset("../data/objaverse/train/")
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
for batch in dataloader:
    model(**batch)
    break

RuntimeError: Sizes of tensors must match except in dimension 3. Expected size 3 but got size 1 for tensor number 1 in the list.

In [3]:
vertices = batch["vertices"]
faces = batch["faces"]
face_masks = batch["face_masks"]

faces = faces.masked_fill(~face_masks.unsqueeze(-1), 0)
index = torch.arange(vertices.size(0), device=vertices.device)[:, None, None]

coords = vertices[index, faces]

In [10]:
features = model.embedding.extract_features(coords)
features["acos"].shape

embeds = torch.cat(
    [
        model.embedding.embeddings[name](indices)
        for name, indices in model.embedding.extract_features(coords).items()
    ],
    dim=-1,
)

RuntimeError: Sizes of tensors must match except in dimension 3. Expected size 3 but got size 1 for tensor number 1 in the list.

In [24]:
embeds = [
    model.embedding.embeddings[name](indices).flatten(-2)
    for name, indices in model.embedding.extract_features(coords).items()
]

In [25]:
embeds[3].shape

torch.Size([2, 480, 384])

In [15]:
list({"x": 12}.items())

[('x', 12)]