In [1]:
import json

import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

from meshae.config import MeshAEModelConfig
from meshae.dataset import MeshAEDataset, MeshAECollateFn
from meshae.model import MeshAEModel

In [None]:
with open("../configs/default/config.json") as fin:
    config = MeshAEModelConfig.from_dict(json.load(fin))



state_dict = torch.load("../checkpoints/default/ckpt-01-0100.pt").get("model_state_dict")
model = MeshAEModel(**config.to_dict())
model.load_state_dict(state_dict)
model = model.to("cuda")

collate_fn = MeshAECollateFn()
dataset = MeshAEDataset("../data/objaverse/train", sort_by="zyx")
dataset.objects = dataset.objects[350 * 8:550*8]
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=1,
)

In [10]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        batch = {k: v.to("cuda") for k, v in batch.items()}
        loss, (recon_loss, commit_loss), logits, _ = model(**batch)
        print(i, loss, recon_loss, commit_loss)

        if np.isnan(loss.cpu().numpy()):
            break

0 tensor(4.7432, device='cuda:0') tensor(4.7430, device='cuda:0') tensor(0.0002, device='cuda:0')
1 tensor(4.4589, device='cuda:0') tensor(4.4588, device='cuda:0') tensor(5.2520e-05, device='cuda:0')
2 tensor(4.7622, device='cuda:0') tensor(4.7622, device='cuda:0') tensor(8.1368e-05, device='cuda:0')
3 tensor(4.6708, device='cuda:0') tensor(4.6568, device='cuda:0') tensor(0.0140, device='cuda:0')
4 tensor(4.4988, device='cuda:0') tensor(4.4986, device='cuda:0') tensor(0.0002, device='cuda:0')
5 tensor(4.5389, device='cuda:0') tensor(4.5387, device='cuda:0') tensor(0.0001, device='cuda:0')
6 tensor(4.8936, device='cuda:0') tensor(4.8935, device='cuda:0') tensor(6.6920e-05, device='cuda:0')
7 tensor(4.6487, device='cuda:0') tensor(4.6487, device='cuda:0') tensor(6.9527e-05, device='cuda:0')
8 tensor(4.5516, device='cuda:0') tensor(4.5515, device='cuda:0') tensor(7.9763e-05, device='cuda:0')
9 tensor(4.7768, device='cuda:0') tensor(4.7767, device='cuda:0') tensor(0.0001, device='cuda:0')


In [4]:
batch

{'vertices': tensor([[[-1.3152e-01, -4.0179e-01, -1.2543e-01],
          [-1.3940e-01, -4.1330e-01, -1.3330e-01],
          [-1.3940e-01, -4.1330e-01,  1.3330e-01],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
         [[-6.5738e-02,  6.5738e-02, -2.6858e-01],
          [-5.4697e-02,  1.9721e-01, -1.9721e-01],
          [-1.6409e-01,  1.6409e-01, -1.6409e-01],
          ...,
          [-1.9721e-01, -1.9721e-01,  5.4697e-02],
          [-2.0308e-01, -1.9870e-01,  5.3460e-02],
          [-2.0308e-01, -1.9870e-01, -5.3461e-02]],
 
         [[-2.3676e-01, -1.4593e-01,  3.8506e-01],
          [-2.3676e-01, -1.4593e-01,  3.6286e-01],
          [-2.1518e-01, -1.4593e-01,  3.6286e-01],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
         ...,
 
