In [1]:
import os
os.chdir("/workspace/")
import sys
sys.path.append("/workspace/reni_neus")


import torch
import yaml
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import plotly.graph_objects as go
from torch.utils.data import Dataset
import imageio

from nerfstudio.configs import base_config as cfg
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSR, NeRFOSRDataParserConfig
from nerfstudio.pipelines.base_pipeline import VanillaDataManager
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.utils.colormaps import apply_depth_colormap
from nerfstudio.field_components.encodings import SHEncoding, NeRFEncoding
import tinycudann as tcnn

from reni.reni_config import RENIField
from reni.field_components.field_heads import RENIFieldHeadNames
from reni.data.reni_datamanager import RENIDataManagerConfig, RENIDataManager

def rotation_matrix(axis: np.ndarray, angle: float) -> np.ndarray:
    """
    Return 3D rotation matrix for rotating around the given axis by the given angle.
    """
    axis = np.asarray(axis)
    axis = axis / np.sqrt(np.dot(axis, axis))
    a = np.cos(angle / 2.0)
    b, c, d = -axis * np.sin(angle / 2.0)
    aa, bb, cc, dd = a * a, b * b, c * c, d * d
    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
    rotation = np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
                     [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
                     [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
    # convert to pytorch
    rotation = torch.from_numpy(rotation).float()
    return rotation

# setup config
test_mode = 'val'
world_size = 1
local_rank = 0
device = 'cuda:0'

reni_ckpt_path = '/workspace/outputs/unnamed/reni/2023-07-24_145239/' # model without vis
step = 50000

ckpt = torch.load(reni_ckpt_path + '/nerfstudio_models' + f'/step-{step:09d}.ckpt', map_location=device)
reni_model_dict = {}
for key in ckpt['pipeline'].keys():
    if key.startswith('_model.'):
        reni_model_dict[key[7:]] = ckpt['pipeline'][key]

datamanager: RENIDataManager = RENIField.config.pipeline.datamanager.setup(
    device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, 
)
datamanager.to(device)

num_train_data = datamanager.num_train
num_eval_data = datamanager.num_eval

# instantiate model with config with vis
model = RENIField.config.pipeline.model.setup(
    scene_box=datamanager.train_dataset.scene_box,
    num_train_data=num_train_data,
    num_eval_data=num_eval_data,
    metadata=datamanager.train_dataset.metadata,
)

model.to(device)
model.load_state_dict(reni_model_dict)
model.eval()

print('Model loaded')

Computing min and max values of the dataset...
Min and max values of the dataset are (-8.033231735229492, 5.505331516265869).


Output()

Output()

Model loaded


In [2]:
i = -1

In [29]:
i = i + 1
ray_bundle, batch = datamanager.fixed_indices_eval_dataloader.get_data_from_image_idx(0)
H, W = ray_bundle.shape
ray_bundle = ray_bundle.reshape(-1)

rotation = rotation_matrix(np.array([0, 1, 0]), np.deg2rad(0))
rotation = rotation.to(device)

# field_outputs = model.field.forward(ray_bundle, rotation=rotation, latent_codes=torch.randn(1, 36, 3).to(device))
field_outputs = model.field.forward(ray_bundle, rotation=rotation, latent_codes=torch.zeros(1, 36, 3).to(device))

outputs = {
    "rgb": field_outputs[RENIFieldHeadNames.RGB],
    "mu": field_outputs[RENIFieldHeadNames.MU],
    "log_var": field_outputs[RENIFieldHeadNames.LOG_VAR],
}

outputs['rgb'] = outputs['rgb'].reshape(H, W, 3)

metrics_dict, image_dict = model.get_image_metrics_and_images(outputs, batch)

plt.imshow(image_dict['img'].cpu().detach().numpy())

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
