In [1]:
from types import SimpleNamespace   

from kornia.utils.grid import create_meshgrid3d
import torch

from models.networks import NGPGv2
from models.rendering_NGPA import render
from utils.utils import slim_ckpt, load_ckpt
from datasets import dataset_dict

In [2]:
hparams = {
    "root_dir": "dataset/WAT/breville",
    "dataset_name": "colmap_ngpa_render",
    "exp_name": "breville",
    "downsample": 1.0,
    "num_epochs": 20,
    "batch_size": 8192,
    "lr": 1e-2,
    "eval_lpips": True,
    "task_curr": 4,
    "task_number": 5,
    "dim_a": 48,
    "dim_g": 16,
    "scale": 8.0,
    "vocab_size": 5,
    "weight_path": "ckpts/NGPGv2/colmap_ngpa/breville/epoch=19-v6.ckpt",
    "render_fname": "UB",
    "val_only": True,
    "use_exposure": False,
}
hparams = SimpleNamespace(**hparams)

### Setup

In [4]:
dataset = dataset_dict[hparams.dataset_name]
kwargs = {'root_dir': hparams.root_dir,
        'downsample': hparams.downsample}
test_dataset = dataset(split='test', **kwargs)

self.img_wh = (1920, 1440)
[test] near_far = 0.2990965247154236/60.96563720703125, scale = 7.620704650878906
Loading 34 test images ...


100%|██████████| 34/34 [00:02<00:00, 12.23it/s]


self.poses_interpolate = torch.Size([2934, 3, 4]), self.ts_interpolate = torch.Size([2934]), self.task_ids_interpolate = 2934


In [None]:
test_dataset[0]

In [None]:
rgb_act = 'None' if hparams.use_exposure else 'Sigmoid'
model = NGPGv2(scale=hparams.scale, vocab_size=hparams.task_curr+1, rgb_act=rgb_act, dim_a=hparams.dim_a, dim_g=hparams.dim_g)
G = model.grid_size
model.register_buffer('density_grid', torch.zeros(model.cascades, G**3))
model.register_buffer('grid_coords', create_meshgrid3d(G, G, G, False, dtype=torch.int32).reshape(-1, 3))

In [None]:
load_ckpt(model, hparams.weight_path)

In [None]:
model

In [None]:
list(model.named_buffers())

In [None]:
test_dataset.directions