## Directional Distance Function Experiments
This notebooks contains experiments in using a directional distance field for visibility in RENI-NeuS

In [None]:
# set visible devices
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import os
os.chdir("/workspace/")
import sys
sys.path.append("/workspace/reni_neus")


import torch
import yaml
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import plotly.graph_objects as go
from torch.utils.data import Dataset

from nerfstudio.configs import base_config as cfg
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSR, NeRFOSRDataParserConfig
from nerfstudio.pipelines.base_pipeline import VanillaDataManager
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.utils.colormaps import apply_depth_colormap
from nerfstudio.field_components.encodings import SHEncoding, NeRFEncoding
import tinycudann as tcnn

from reni_neus.reni_neus_model import RENINeuSFactoModelConfig, RENINeuSFactoModel
from reni_neus.utils.utils import get_directions, get_sineweight
from reni_neus.illumination_fields.reni_field import RENIField
from reni_neus.data.reni_neus_datamanager import RENINeuSDataManagerConfig, RENINeuSDataManager
from reni_neus.reni_neus_config import RENINeuS as RENINeuSMethodSpecification

def make_ray_bundle_clone(ray_bundle):
    metadata_copy = {}
    for key, value in ray_bundle.metadata.items():
        metadata_copy[key] = value.detach().clone()

    new_ray_bundle = RayBundle(
      origins=ray_bundle.origins.detach().clone(),
      directions=ray_bundle.directions.detach().clone(),
      pixel_area=ray_bundle.pixel_area.detach().clone(),
      metadata=metadata_copy,
      camera_indices=ray_bundle.camera_indices.detach().clone(),
      nears=ray_bundle.nears.detach().clone() if ray_bundle.nears is not None else None,
      fars=ray_bundle.fars.detach().clone() if ray_bundle.fars is not None else None,
    )
    return new_ray_bundle

def make_batch_clone(batch):
    new_batch = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            new_batch[key] = value.detach().clone()
        else:
            new_batch[key] = value
    return new_batch       

def sRGB(imgs):
    # Add batch dimension if necessary
    if imgs.ndim == 3:
        imgs = imgs.unsqueeze(0)
    
    # Calculate the 98th percentile for each image
    q = torch.quantile(imgs.view(imgs.size(0), -1), 0.98, dim=1)
    
    # Normalize images by their 98th percentile
    imgs = imgs / q.view(-1, 1, 1, 1)
    
    # Clamp the pixel values between 0.0 and 1.0
    imgs = torch.clamp(imgs, 0.0, 1.0)
    
    # Convert linear RGB to sRGB using the sRGB conversion formula
    mask = imgs <= 0.0031308
    imgs_sRGB = imgs = torch.where(
        imgs <= 0.0031308,
        12.92 * imgs,
        1.055 * torch.pow(torch.abs(imgs), 1 / 2.4) - 0.055,
    )
    return imgs_sRGB

def rotation_matrix(axis, angle):
    """
    Return 3D rotation matrix for rotating around the given axis by the given angle.
    """
    axis = np.asarray(axis)
    axis = axis / np.sqrt(np.dot(axis, axis))
    a = np.cos(angle / 2.0)
    b, c, d = -axis * np.sin(angle / 2.0)
    aa, bb, cc, dd = a * a, b * b, c * c, d * d
    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
    return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
                     [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
                     [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])

# setup config
test_mode = 'val'
world_size = 1
local_rank = 0
device = 'cuda:0'

datamanager: RENINeuSDataManager = RENINeuSMethodSpecification.config.pipeline.datamanager.setup(
    device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, 
)
datamanager.to(device)

# includes num_eval_data as needed for reni latent code fitting.
model = RENINeuSMethodSpecification.config.pipeline.model.setup(
    scene_box=datamanager.train_dataset.scene_box,
    num_train_data=len(datamanager.train_dataset),
    num_eval_data=len(datamanager.eval_dataset),
    metadata=datamanager.train_dataset.metadata,
    world_size=world_size,
    local_rank=local_rank,
    eval_latent_optimisation_source=RENINeuSMethodSpecification.config.pipeline.eval_latent_optimisation_source,
)
model.to(device)

image_idx_original = 3
camera_ray_bundle_original, batch_original = datamanager.eval_dataloader.get_data_from_image_idx(image_idx_original)

True # printing to hide long cell output

In [2]:
camera_ray_bundle = make_ray_bundle_clone(camera_ray_bundle_original)
batch = make_batch_clone(batch_original)