In [1]:
import os
import torch
import yaml
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import plotly.graph_objects as go
from torch.utils.data import Dataset

from nerfstudio.configs import base_config as cfg
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSR, NeRFOSRDataParserConfig
from nerfstudio.pipelines.base_pipeline import VanillaDataManager
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.cameras.rays import RayBundle, RaySamples, Frustums
from nerfstudio.utils.colormaps import apply_depth_colormap
from nerfstudio.field_components.encodings import SHEncoding, NeRFEncoding
import tinycudann as tcnn

from reni_neus.models.reni_neus_model import RENINeuSFactoModelConfig, RENINeuSFactoModel
from reni_neus.utils.utils import look_at_target, random_points_on_unit_sphere
from reni_neus.data.datamanagers.reni_neus_datamanager import RENINeuSDataManagerConfig, RENINeuSDataManager
from reni_neus.configs.ddf_config import DirectionalDistanceField
from reni_neus.configs.reni_neus_config import RENINeuS
from reni_neus.utils.utils import find_nerfstudio_project_root, rot_z

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

# setup config
test_mode = 'test'
world_size = 1
local_rank = 0
device = 'cuda:0'

reni_neus_config = RENINeuS
# reni_neus_config.config.pipeline.visibility_ckpt_path = Path('outputs/ddf/ddf/2023-08-31_101658/')
# reni_neus_config.config.pipeline.visibility_ckpt_step = 20000
# reni_neus_config.config.pipeline.reni_neus_ckpt_path = Path('outputs/reni-neus/reni-neus/2023-08-30_111340/')
# reni_neus_config.config.pipeline.reni_neus_ckpt_step = 100000
reni_neus_config.config.pipeline.model.use_visibility = True
reni_neus_config.config.pipeline.model.visibility_threshold = 0.1
reni_neus_config.config.pipeline.model.fit_visibility_field = False


pipeline = reni_neus_config.config.pipeline.setup(device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank)
datamanager = pipeline.datamanager
model = pipeline.model
model = model.eval()

Output()

Output()

Output()

Output()

In [2]:
ray_bundle, batch = datamanager.next_train(0)
# # need to update batch['indices'][:, 0] using indices_to_session which is a dict of keys: indices (ints) and values: session (ints)
# batch['indices'][:, 0] = torch.tensor([indices_to_session[i.item()] for i in batch['indices'][:, 0]]).type_as(batch['indices'][:, 0])

In [37]:
image_idx, camera_ray_bundle, batch = datamanager.next_eval_image(1)

In [38]:
camera_ray_bundle.camera_indices[:, :, 0]

tensor([[2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        ...,
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2]], device='cuda:0')

In [5]:
batch['image_idx']

0

In [31]:
camera_ray_bundle, batch = next(datamanager.iter_eval_dataloader)
batch

{'image_idx': 32,
 'image': tensor([[[0.8039, 0.8078, 0.8275],
          [0.8039, 0.8078, 0.8275],
          [0.8000, 0.8039, 0.8235],
          ...,
          [0.9255, 0.9216, 0.9451],
          [0.9333, 0.9216, 0.9490],
          [0.9333, 0.9216, 0.9490]],
 
         [[0.8039, 0.8000, 0.8235],
          [0.8039, 0.8000, 0.8235],
          [0.8000, 0.8039, 0.8235],
          ...,
          [0.9333, 0.9294, 0.9529],
          [0.9333, 0.9216, 0.9490],
          [0.9333, 0.9216, 0.9490]],
 
         [[0.8118, 0.8078, 0.8314],
          [0.8078, 0.8039, 0.8275],
          [0.8000, 0.8039, 0.8235],
          ...,
          [0.9294, 0.9255, 0.9490],
          [0.9333, 0.9216, 0.9490],
          [0.9333, 0.9216, 0.9490]],
 
         ...,
 
         [[0.5647, 0.5373, 0.4980],
          [0.5922, 0.5686, 0.5216],
          [0.6980, 0.6588, 0.6118],
          ...,
          [0.5647, 0.4784, 0.3882],
          [0.5373, 0.4588, 0.3608],
          [0.4980, 0.4196, 0.3216]],
 
         [[0.6784, 0.