In [None]:
import os
import torch
import yaml
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import plotly.graph_objects as go
from torch.utils.data import Dataset
import pyexr

from nerfstudio.configs import base_config as cfg
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSR, NeRFOSRDataParserConfig
from nerfstudio.pipelines.base_pipeline import VanillaDataManager
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.cameras.rays import RayBundle, RaySamples, Frustums
from nerfstudio.utils.colormaps import apply_depth_colormap
from nerfstudio.field_components.encodings import SHEncoding, NeRFEncoding
from nerfstudio.viewer.server import viewer_utils
import tinycudann as tcnn

from reni_neus.models.reni_neus_model import RENINeuSFactoModelConfig, RENINeuSFactoModel
from reni_neus.utils.utils import look_at_target, random_points_on_unit_sphere
from reni_neus.data.datamanagers.reni_neus_datamanager import RENINeuSDataManagerConfig, RENINeuSDataManager
from reni_neus.configs.ddf_config import DirectionalDistanceField
from reni_neus.configs.reni_neus_config import RENINeuS
from reni_neus.utils.utils import find_nerfstudio_project_root, rot_z

from reni.illumination_fields.environment_map_field import EnvironmentMapFieldConfig

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

# setup config
test_mode = 'test'
world_size = 1
local_rank = 0
device = 'cuda:0'

scene = 'site1'

reni_neus_config = RENINeuS
reni_neus_ckpt_path = '/workspace/reni_neus/models/site1' # model without vis
step = 100000
reni_neus_ckpt = torch.load(reni_neus_ckpt_path + '/nerfstudio_models' + f'/step-{step:09d}.ckpt', map_location=device)
reni_neus_model_dict = {}
for key in reni_neus_ckpt['pipeline'].keys():
    if key.startswith('_model.'):
        reni_neus_model_dict[key[7:]] = reni_neus_ckpt['pipeline'][key]

if scene == 'site1':
    reni_neus_config.config.pipeline.datamanager.dataparser.session_holdout_indices=[0, 0, 0, 0, 0]
elif scene == 'site2':
    reni_neus_config.config.pipeline.datamanager.dataparser.session_holdout_indices=[1, 2, 2, 7, 9]
elif scene == 'site3':
    reni_neus_config.config.pipeline.datamanager.dataparser.session_holdout_indices=[0, 6, 6, 2, 11]
elif scene == 'stjacob':
    reni_neus_config.config.pipeline.datamanager.dataparser.session_holdout_indices=[0, 0, 0]

datamanager: RENINeuSDataManager = RENINeuS.config.pipeline.datamanager.setup(
    device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, 
)
datamanager.to(device)

# instantiate model with config with vis
model = RENINeuS.config.pipeline.model.setup(
    scene_box=datamanager.train_dataset.scene_box,
    num_train_data=len(datamanager.train_dataset),
    num_val_data=datamanager.num_val,
    num_test_data=datamanager.num_test,
    test_mode=test_mode,
)

model.to(device)
model.load_state_dict(reni_neus_model_dict)
model.eval()




In [2]:
import os
import torch
import yaml
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import plotly.graph_objects as go
from torch.utils.data import Dataset
import pyexr

from nerfstudio.configs import base_config as cfg
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSR, NeRFOSRDataParserConfig
from nerfstudio.pipelines.base_pipeline import VanillaDataManager
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.cameras.rays import RayBundle, RaySamples, Frustums
from nerfstudio.utils.colormaps import apply_depth_colormap
from nerfstudio.field_components.encodings import SHEncoding, NeRFEncoding
from nerfstudio.viewer.server import viewer_utils
import tinycudann as tcnn

from reni_neus.models.reni_neus_model import RENINeuSFactoModelConfig, RENINeuSFactoModel
from reni_neus.utils.utils import look_at_target, random_points_on_unit_sphere
from reni_neus.data.datamanagers.reni_neus_datamanager import RENINeuSDataManagerConfig, RENINeuSDataManager
from reni_neus.configs.ddf_config import DirectionalDistanceField
from reni_neus.configs.reni_neus_config import RENINeuS
from reni_neus.utils.utils import find_nerfstudio_project_root, rot_z

from reni.illumination_fields.environment_map_field import EnvironmentMapFieldConfig

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

# setup config
test_mode = 'test'
world_size = 1
local_rank = 0
device = 'cuda:0'

scene = 'site1'

ddf_config = DirectionalDistanceField
ddf_config.config.pipeline.reni_neus_ckpt_path = Path('/workspace/reni_neus/models/site1')

trainer = ddf_config.config.setup(local_rank=local_rank, world_size=world_size)
trainer.setup(test_mode=test_mode)
pipeline = trainer.pipeline
datamanager = pipeline.datamanager
model = pipeline.model
model = model.eval()

reni_neus_ckpt_path = '/workspace/reni_neus/models/site1' # model without vis
step = 100000
reni_neus_ckpt = torch.load(reni_neus_ckpt_path + '/nerfstudio_models' + f'/step-{step:09d}.ckpt', map_location=device)
ddf_model_dict = {}
for key in reni_neus_ckpt['pipeline'].keys():
    if key.startswith('_model.visibility_field.'):
        ddf_model_dict[key[24:]] = reni_neus_ckpt['pipeline'][key]


model.load_state_dict(ddf_model_dict)

Output()

Output()

Output()

<All keys matched successfully>

In [5]:
from nerfstudio.utils.io import load_from_json
from nerfstudio.cameras.cameras import Cameras, CameraType
from datetime import datetime
import math
from nerfstudio.utils import colormaps
import os
import cv2
import imageio
import matplotlib.cm as cm

camera_poses_path = f'/workspace/reni_neus/publication/ddf_camera_path.json'
meta = load_from_json(Path(camera_poses_path))
fps = meta['fps']

# create folder in /workspace/reni_neus/publication/animations/{scene}_datetime
# save all rendered images in this folder
datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_path = f'/workspace/reni_neus/publication/animations/ddf_{datetime_str}'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

render_height = meta['render_height']
render_width = meta['render_width']
render_height = 1080 # 1080, 144
render_width = 1920 # 1920, 256
cx = render_width / 2.0
cy = render_height / 2.0
fov = meta['keyframes'][0]['fov']
aspect = render_width / render_height
fx = render_width / (2 * math.tan(math.radians(fov) / 2))
fy = fx
c2w = torch.eye(4)[:3, :4]

camera = Cameras(camera_to_worlds=c2w,
                 fy=fy,
                 fx=fx,
                 cx=cx,
                 cy=cy,
                 camera_type=CameraType.PERSPECTIVE)

base_ray_bundle = datamanager.train_dataset.cameras[0].generate_rays(0)
base_ray_bundle = base_ray_bundle.to(device)


def save_model_output(model_output, frame_num, path):
    rendered_image = model_output['expected_termination_dist']

    # def to_rgb_tensor(gray_tensor, cmap="viridis"):
    #         # Ensure the tensor is in the range [0, 1]
    #         normalized_tensor = (gray_tensor - 0.0) / (2.0 - 0.0)

    #         # Convert to numpy and use colormap to get RGB values
    #         cmapped = cm.get_cmap(cmap)(normalized_tensor.cpu().numpy())

    #         # Convert back to tensor and take only RGB channels (discard alpha)
    #         rgb_tensor = torch.tensor(cmapped[..., :3])

    #         return rgb_tensor

    # rendered_image = to_rgb_tensor(rendered_image).cpu().detach().numpy()

    rendered_image = colormaps.apply_depth_colormap(rendered_image, near_plane=0.0, far_plane=2.0).cpu().detach().numpy()

    plt.imsave(f'{path}/frame_{str(frame_num).zfill(3)}_render.png', rendered_image)

# Add a parameter for an optional output folder
def process_scene(camera_poses_path):
    meta = load_from_json(Path(camera_poses_path))

    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    for frame_idx, frame in enumerate(meta['camera_path']):       
        camera_to_world = torch.from_numpy(np.array(frame['camera_to_world']).reshape((4, 4))).to(torch.float32)
        # ensure the position is normalised onto the unit sphere
        camera_to_world[:3, 3] = camera_to_world[:3, 3] / torch.norm(camera_to_world[:3, 3])
        camera.camera_to_worlds = camera_to_world[:3, :4]
        ray_bundle = camera.generate_rays(0)
        ray_bundle = ray_bundle.to(device)
        ray_bundle.camera_indices = torch.ones_like(ray_bundle.camera_indices)
        print(f'Rendering frame_idx: {frame_idx}')
        model_output = model.get_outputs_for_camera_ray_bundle(camera_ray_bundle=ray_bundle)
        save_model_output(model_output, frame_idx, folder_path)

# Call with an optional output folder
process_scene(camera_poses_path)

def create_animation(folder, image_type, fps, format='gif'):
    """
    Creates an animation from images of a specific type in the given folder.

    :param folder: Path to the folder containing the images.
    :param image_type: Type of the image (e.g., 'render', 'albedo', 'normal').
    :param fps: Frames per second for the output animation.
    :param format: Output format of the animation ('gif' or 'mp4').
    """
    # List all files in the folder
    files = sorted([f for f in os.listdir(folder) if f.endswith(f"{image_type}.png")])

    if not files:
        print("No images found for the specified type.")
        return

    # Read images and store them in a list
    images = []
    for file in files:
        img_path = os.path.join(folder, file)
        img = cv2.imread(img_path)
        if img is not None:
            # Convert from BGR to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append(img)
        else:
            print(f"Failed to read image: {img_path}")

    # Create animation
    output_path = os.path.join(folder, f"animation_{image_type}.{format}")
    if format == 'gif':
        imageio.mimsave(output_path, images, fps=fps)
    elif format == 'mp4':
        writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
        for img in images:
            writer.append_data(img)
        writer.close()
    else:
        print("Unsupported format. Please choose 'gif' or 'mp4'.")

    print(f"Animation saved at {output_path}")

create_animation(folder_path, 'render', 24, 'mp4')

Rendering frame_idx: 0
Rendering frame_idx: 1
Rendering frame_idx: 2
Rendering frame_idx: 3
Rendering frame_idx: 4
Rendering frame_idx: 5
Rendering frame_idx: 6
Rendering frame_idx: 7
Rendering frame_idx: 8
Rendering frame_idx: 9
Rendering frame_idx: 10
Rendering frame_idx: 11
Rendering frame_idx: 12
Rendering frame_idx: 13
Rendering frame_idx: 14
Rendering frame_idx: 15
Rendering frame_idx: 16
Rendering frame_idx: 17
Rendering frame_idx: 18
Rendering frame_idx: 19
Rendering frame_idx: 20
Rendering frame_idx: 21
Rendering frame_idx: 22
Rendering frame_idx: 23
Rendering frame_idx: 24
Rendering frame_idx: 25
Rendering frame_idx: 26
Rendering frame_idx: 27
Rendering frame_idx: 28
Rendering frame_idx: 29
Rendering frame_idx: 30
Rendering frame_idx: 31
Rendering frame_idx: 32
Rendering frame_idx: 33
Rendering frame_idx: 34
Rendering frame_idx: 35
Rendering frame_idx: 36
Rendering frame_idx: 37
Rendering frame_idx: 38
Rendering frame_idx: 39
Rendering frame_idx: 40
Rendering frame_idx: 41
Re



Animation saved at /workspace/reni_neus/publication/animations/ddf_20231123_144608/animation_render.mp4
