In [1]:
import os
import torch
import yaml
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import plotly.graph_objects as go
from torch.utils.data import Dataset

from nerfstudio.configs import base_config as cfg
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSR, NeRFOSRDataParserConfig
from nerfstudio.pipelines.base_pipeline import VanillaDataManager
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.cameras.rays import RayBundle, RaySamples, Frustums
from nerfstudio.utils.colormaps import apply_depth_colormap
from nerfstudio.field_components.encodings import SHEncoding, NeRFEncoding
import tinycudann as tcnn

from reni_neus.models.reni_neus_model import RENINeuSFactoModelConfig, RENINeuSFactoModel
from reni_neus.utils.utils import look_at_target, random_points_on_unit_sphere
from reni_neus.data.datamanagers.reni_neus_datamanager import RENINeuSDataManagerConfig, RENINeuSDataManager
from reni_neus.configs.ddf_config import DirectionalDistanceField
from reni_neus.configs.reni_neus_config import RENINeuS
from reni_neus.utils.utils import find_nerfstudio_project_root, rot_z

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

# setup config
test_mode = 'val'
world_size = 1
local_rank = 0
device = 'cuda:0'

reni_neus_config = RENINeuS
reni_neus_config.config.pipeline.visibility_ckpt_path = Path('outputs/ddf/ddf/2023-08-31_101658/')
reni_neus_config.config.pipeline.visibility_ckpt_step = 20000
reni_neus_config.config.pipeline.reni_neus_ckpt_path = Path('outputs/reni-neus/reni-neus/2023-08-30_111340/')
reni_neus_config.config.pipeline.reni_neus_ckpt_step = 100000
reni_neus_config.config.pipeline.model.use_visibility = True
reni_neus_config.config.pipeline.model.visibility_threshold = 0.1
reni_neus_config.config.pipeline.model.fit_visibility_field = False


pipeline = reni_neus_config.config.pipeline.setup(device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank)
datamanager = pipeline.datamanager
model = pipeline.model
model = model.eval()

Output()

Output()

Output()

Output()

In [2]:
batch = next(iter(datamanager.eval_session_holdout_dataloader))

In [4]:
session_image_idxs = datamanager.eval_dataset.metadata["session_eval_indices"]
# currently session_image_idxs is the image idxs relative to session
# but we want it to be relative to the whole dataset
image_idxs_holdout = []
for session_relative_idx, session in zip(session_image_idxs, datamanager.eval_dataset.metadata["session_to_indices"]):
  image_idxs_holdout.append(int(session[session_relative_idx]))

In [6]:
datamanager.eval_dataset.metadata["session_to_indices"]

{'01-08_07_30': [0],
 '08-08_16_00': [1],
 '28-07_10_00': [2],
 '29-07_12_00': [3],
 '29-07_20_30': [4]}

In [15]:
subset = torch.utils.data.Subset(datamanager.eval_dataset, [0, 1, 2])

In [20]:
datamanager.eval_dataset.metadata

{'semantics': Semantics(filenames=['data/NeRF-OSR/Data/lk2/final/validation/cityscapes_mask/01-08_07_30_IMG_6660.png', 'data/NeRF-OSR/Data/lk2/final/validation/cityscapes_mask/08-08_16_00_IMG_7850.png', 'data/NeRF-OSR/Data/lk2/final/validation/cityscapes_mask/28-07_10_00_IMG_5262.png', 'data/NeRF-OSR/Data/lk2/final/validation/cityscapes_mask/29-07_12_00_IMG_5512.png', 'data/NeRF-OSR/Data/lk2/final/validation/cityscapes_mask/29-07_20_30_IMG_5592.png'], classes=['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'], colors=tensor([[128,  64, 128],
         [244,  35, 232],
         [ 70,  70,  70],
         [102, 102, 156],
         [190, 153, 153],
         [153, 153, 153],
         [250, 170,  30],
         [220, 220,   0],
         [107, 142,  35],
         [152, 251, 152],
         [ 70, 130, 180],
         [220,  20,  60],
         [255,   0

In [17]:
batch

{'image_idx': 0,
 'image': tensor([[[0.8941, 0.8980, 0.9176],
          [0.8980, 0.9020, 0.9216],
          [0.8941, 0.8980, 0.9176],
          ...,
          [0.9255, 0.9216, 0.9529],
          [0.9255, 0.9216, 0.9529],
          [0.9255, 0.9216, 0.9529]],
 
         [[0.8980, 0.8941, 0.9137],
          [0.8980, 0.8941, 0.9137],
          [0.8941, 0.8980, 0.9176],
          ...,
          [0.9255, 0.9216, 0.9529],
          [0.9255, 0.9216, 0.9529],
          [0.9294, 0.9255, 0.9569]],
 
         [[0.8941, 0.8902, 0.9098],
          [0.8941, 0.8902, 0.9098],
          [0.8941, 0.8980, 0.9176],
          ...,
          [0.9294, 0.9255, 0.9569],
          [0.9294, 0.9255, 0.9569],
          [0.9294, 0.9255, 0.9569]],
 
         ...,
 
         [[0.5020, 0.4745, 0.4510],
          [0.4549, 0.4275, 0.4039],
          [0.4588, 0.4314, 0.4078],
          ...,
          [0.4745, 0.4667, 0.4196],
          [0.4784, 0.4627, 0.4157],
          [0.4941, 0.4824, 0.4235]],
 
         [[0.5255, 0.4