In [10]:
import json

import numpy as np
import torch
from einops import rearrange
from torch.utils.data import DataLoader

from meshae.config import MeshAEModelConfig
from meshae.dataset import MeshAEDataset, MeshAECollateFn
from meshae.model import MeshAEModel
from meshae.utils import quantize, tensor_describe

In [2]:
with open("../configs/default/config.json") as fin:
    config = MeshAEModelConfig.from_dict(json.load(fin))

state_dict = torch.load("../checkpoints/default/ckpt-01-0100.pt").get("model_state_dict")
model = MeshAEModel(**config.to_dict())
model.load_state_dict(state_dict)
# model = model.to("cuda")

collate_fn = MeshAECollateFn()
dataset = MeshAEDataset("../data/objaverse/train", sort_by="zyx")
dataset.objects = dataset.objects[350 * 8:550*8]
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=1,
)

In [None]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        # batch = {k: v.to("cuda") for k, v in batch.items()}
        loss, (recon_loss, commit_loss), logits, coords = model(**batch)
        print(i, loss, recon_loss, commit_loss)
        break

0 tensor(3.9983) tensor(3.9872) tensor(0.0111)


In [13]:
tensor_describe(model.decoder.proj_refine[2].weight)

{'avg': -3.4784083254635334e-05,
 'std': 0.03655720874667168,
 'min': -0.1773606538772583,
 'max': 0.19272103905677795,
 'p25': -0.02468550205230713,
 'p50': -2.3601360226166435e-05,
 'p75': 0.024616852402687073,
 'norm': 56.149593353271484,
 'numel': 2359296,
 'numel_nonzero': 2359296,
 'numel_nonzero_pct': 100.0}

In [5]:
qogits = rearrange(logits, "b ... q -> b q (...)").log_softmax(1)
qoords = rearrange(
    quantize(coords, high_low=config.feature_configs["vrtx"].high_low, num_bins=64),
    "b ... -> b 1 (...)",
)
qoords = torch.zeros_like(qogits).scatter(1, qoords, 1.0)