In [1]:
import json

import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

from meshae.config import MeshAEModelConfig
from meshae.dataset import MeshAEDataset, MeshAECollateFn
from meshae.model import MeshAEModel

In [11]:
with open("../configs/default/config.json") as fin:
    config = MeshAEModelConfig.from_dict(json.load(fin))



state_dict = torch.load("../checkpoints/default/ckpt-01-0100.pt").get("model_state_dict")
model = MeshAEModel(**config.to_dict())
model.load_state_dict(state_dict)
model = model.to("cuda")

collate_fn = MeshAECollateFn()
dataset = MeshAEDataset("../data/objaverse/train", sort_by="zyx")
dataset.objects = dataset.objects[350 * 8:550*8]
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=1,
)

In [12]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        batch = {k: v.to("cuda") for k, v in batch.items()}
        loss, (recon_loss, commit_loss), logits, _ = model(**batch)
        print(i, loss, recon_loss, commit_loss)

        if np.isnan(loss.cpu().numpy()):
            break

0 tensor(4.5410, device='cuda:0') tensor(4.5409, device='cuda:0') tensor(9.3491e-05, device='cuda:0')
1 tensor(4.6029, device='cuda:0') tensor(4.6029, device='cuda:0') tensor(6.7874e-05, device='cuda:0')
2 tensor(4.6786, device='cuda:0') tensor(4.6785, device='cuda:0') tensor(4.6270e-05, device='cuda:0')
3 tensor(4.5315, device='cuda:0') tensor(4.5314, device='cuda:0') tensor(0.0001, device='cuda:0')
4 tensor(4.6723, device='cuda:0') tensor(4.6721, device='cuda:0') tensor(0.0002, device='cuda:0')
5 tensor(4.6653, device='cuda:0') tensor(4.6653, device='cuda:0') tensor(6.0893e-05, device='cuda:0')
6 tensor(4.5286, device='cuda:0') tensor(4.5283, device='cuda:0') tensor(0.0002, device='cuda:0')
7 tensor(4.6491, device='cuda:0') tensor(4.6489, device='cuda:0') tensor(0.0002, device='cuda:0')
8 tensor(4.5871, device='cuda:0') tensor(4.5871, device='cuda:0') tensor(7.1031e-05, device='cuda:0')
9 tensor(4.8269, device='cuda:0') tensor(4.8268, device='cuda:0') tensor(9.1633e-05, device='cuda:

In [4]:
batch

{'vertices': tensor([[[-1.3152e-01, -4.0179e-01, -1.2543e-01],
          [-1.3940e-01, -4.1330e-01, -1.3330e-01],
          [-1.3940e-01, -4.1330e-01,  1.3330e-01],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
         [[-6.5738e-02,  6.5738e-02, -2.6858e-01],
          [-5.4697e-02,  1.9721e-01, -1.9721e-01],
          [-1.6409e-01,  1.6409e-01, -1.6409e-01],
          ...,
          [-1.9721e-01, -1.9721e-01,  5.4697e-02],
          [-2.0308e-01, -1.9870e-01,  5.3460e-02],
          [-2.0308e-01, -1.9870e-01, -5.3461e-02]],
 
         [[-2.3676e-01, -1.4593e-01,  3.8506e-01],
          [-2.3676e-01, -1.4593e-01,  3.6286e-01],
          [-2.1518e-01, -1.4593e-01,  3.6286e-01],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
         ...,
 
