In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import yaml
import re
from typing import Optional
from nerfstudio.cameras.rays import RaySamples, Frustums
from nerfstudio.cameras.cameras import Cameras, CameraType

from reni.configs.reni_config import RENIField
from reni.pipelines.reni_pipeline import RENIPipeline
from reni.field_components.field_heads import RENIFieldHeadNames
from reni.data.datamanagers.reni_datamanager import RENIDataManager
from reni.utils.utils import find_nerfstudio_project_root, rot_z, rot_y
from reni.utils.colourspace import linear_to_sRGB

In [2]:
# setup config
test_mode = 'val'
world_size = 1
local_rank = 0
device = 'cuda:0'

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

def load_model(load_dir: Path, load_step: Optional[int] = None):
    ckpt_dir = load_dir / 'nerfstudio_models'
    def clean_and_load_yaml(yaml_content):
        # Remove !!python related tags
        cleaned_content = re.sub(r'!!python[^\s]*', '', yaml_content)
        
        # Load the cleaned content
        return yaml.safe_load(cleaned_content)

    if load_step is None:
        load_step = sorted(int(x[x.find("-") + 1 : x.find(".")]) for x in os.listdir(ckpt_dir))[-1]
    
    ckpt = torch.load(ckpt_dir / f'step-{load_step:09d}.ckpt', map_location=device)
    reni_model_dict = {}
    for key in ckpt['pipeline'].keys():
        if key.startswith('_model.'):
            reni_model_dict[key[7:]] = ckpt['pipeline'][key]
    
    config_path = load_dir / 'config.yml'
    with open(config_path, 'r') as f:
        content = f.read()
        config = clean_and_load_yaml(content)
    
    reni_field_config = RENIField.config

    reni_field_config.pipeline.datamanager.dataparser.convert_to_ldr = config['pipeline']['datamanager']['dataparser']['convert_to_ldr']
    reni_field_config.pipeline.datamanager.dataparser.convert_to_log_domain = config['pipeline']['datamanager']['dataparser']['convert_to_log_domain']
    if config['pipeline']['datamanager']['dataparser']['min_max_normalize'].__class__ == list:
        reni_field_config.pipeline.datamanager.dataparser.min_max_normalize = tuple(config['pipeline']['datamanager']['dataparser']['min_max_normalize'])
    else:
        reni_field_config.pipeline.datamanager.dataparser.min_max_normalize = config['pipeline']['datamanager']['dataparser']['min_max_normalize']
    reni_field_config.pipeline.datamanager.dataparser.augment_with_mirror = config['pipeline']['datamanager']['dataparser']['augment_with_mirror']
    reni_field_config.pipeline.model.loss_inclusions = config['pipeline']['model']['loss_inclusions']
    reni_field_config.pipeline.model.field.conditioning = config['pipeline']['model']['field']['conditioning']
    reni_field_config.pipeline.model.field.invariant_function = config['pipeline']['model']['field']['invariant_function']
    reni_field_config.pipeline.model.field.equivariance = config['pipeline']['model']['field']['equivariance']
    reni_field_config.pipeline.model.field.axis_of_invariance = config['pipeline']['model']['field']['axis_of_invariance']
    reni_field_config.pipeline.model.field.positional_encoding = config['pipeline']['model']['field']['positional_encoding']
    reni_field_config.pipeline.model.field.encoded_input = config['pipeline']['model']['field']['encoded_input']
    reni_field_config.pipeline.model.field.latent_dim = config['pipeline']['model']['field']['latent_dim']
    reni_field_config.pipeline.model.field.hidden_features = config['pipeline']['model']['field']['hidden_features']
    reni_field_config.pipeline.model.field.hidden_layers = config['pipeline']['model']['field']['hidden_layers']
    reni_field_config.pipeline.model.field.mapping_layers = config['pipeline']['model']['field']['mapping_layers']
    reni_field_config.pipeline.model.field.mapping_features = config['pipeline']['model']['field']['mapping_features']
    reni_field_config.pipeline.model.field.num_attention_heads = config['pipeline']['model']['field']['num_attention_heads']
    reni_field_config.pipeline.model.field.num_attention_layers = config['pipeline']['model']['field']['num_attention_layers']
    reni_field_config.pipeline.model.field.output_activation = config['pipeline']['model']['field']['output_activation']
    reni_field_config.pipeline.model.field.last_layer_linear = config['pipeline']['model']['field']['last_layer_linear']
    reni_field_config.pipeline.model.field.trainable_scale = config['pipeline']['model']['field']['trainable_scale']
    reni_field_config.pipeline.model.field.old_implementation = config['pipeline']['model']['field']['old_implementation']
    reni_field_config.pipeline.model.loss_inclusions = config['pipeline']['model']['loss_inclusions']

    pipeline: RENIPipeline = reni_field_config.pipeline.setup(
      device=device,
      test_mode=test_mode,
      world_size=world_size,
      local_rank=local_rank,
      grad_scaler=None,
    )

    datamanager = pipeline.datamanager

    model = pipeline.model

    model.to(device)
    print(model.field.train_mu.shape)
    model.load_state_dict(reni_model_dict)
    model.eval()

    return pipeline, datamanager, model

def generate_images_from_models(image_indices, model_paths):
    all_model_outputs = {}
    
    for model_path in model_paths:
        model_name = model_path.split("/")[-1]
        pipeline, datamanager, model = load_model(Path(model_path))
        
        model_outputs = {}
        
        for idx in image_indices:
            # Your code to produce an image would go here.
            model.eval()
            _, ray_bundle, batch = datamanager.next_eval_image(idx)
            H, W = model.metadata["image_height"], model.metadata["image_width"]

            # High res image:
            H = 128
            W = H * 2
            cx = torch.tensor(W // 2, dtype=torch.float32).repeat(1)
            cy = torch.tensor(H // 2, dtype=torch.float32).repeat(1)
            fx = torch.tensor(H, dtype=torch.float32).repeat(1)
            fy = torch.tensor(H, dtype=torch.float32).repeat(1)

            c2w = torch.tensor([[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]]], dtype=torch.float32).repeat(1, 1, 1)

            cameras = Cameras(fx=fx, fy=fy, cx=cx, cy=cy, camera_to_worlds=c2w, camera_type=CameraType.EQUIRECTANGULAR)

            ray_bundle = cameras.generate_rays(0).flatten().to(device)
            ray_bundle.camera_indices = torch.ones_like(ray_bundle.camera_indices) * idx

            batch['image'] = batch['image'].to(device)

            if model.field.old_implementation:
                get_rotation = rot_y
            else:
                get_rotation = rot_z

            rotation = get_rotation(torch.tensor(np.deg2rad(0.0)).float())
            rotation = rotation.to(device)

            outputs = model.get_outputs_for_camera_ray_bundle(ray_bundle)
            outputs['rgb'] = outputs['rgb'].reshape(H, W, 3)
            pred_img = model.field.unnormalise(outputs['rgb'])
            pred_img = linear_to_sRGB(pred_img, use_quantile=True)
            model_outputs[idx] = pred_img
            
        all_model_outputs[model_name] = model_outputs
    
    return all_model_outputs

In [3]:
# Example usage
image_indices = [1, 2, 3]
model_paths = [
    '/workspace/outputs/reni/reni_plus_plus_models/latent_dim_9',
    '/workspace/outputs/reni/reni_plus_plus_models/latent_dim_100'
]

output_images = generate_images_from_models(image_indices, model_paths)


Output()

Output()

torch.Size([3346, 9, 3])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Output()

Output()

torch.Size([3346, 100, 3])


In [5]:
output_images['latent_dim_100']

{1: tensor([[[0.2142, 0.3746, 0.5365],
          [0.2142, 0.3746, 0.5364],
          [0.2142, 0.3746, 0.5364],
          ...,
          [0.2143, 0.3748, 0.5366],
          [0.2143, 0.3747, 0.5365],
          [0.2143, 0.3747, 0.5365]],
 
         [[0.2207, 0.3818, 0.5429],
          [0.2206, 0.3817, 0.5427],
          [0.2205, 0.3815, 0.5424],
          ...,
          [0.2210, 0.3823, 0.5434],
          [0.2209, 0.3822, 0.5432],
          [0.2208, 0.3820, 0.5430]],
 
         [[0.2207, 0.3826, 0.5432],
          [0.2207, 0.3825, 0.5432],
          [0.2207, 0.3824, 0.5431],
          ...,
          [0.2206, 0.3828, 0.5433],
          [0.2206, 0.3827, 0.5433],
          [0.2207, 0.3826, 0.5432]],
 
         ...,
 
         [[0.6215, 0.6048, 0.5447],
          [0.6212, 0.6045, 0.5444],
          [0.6209, 0.6041, 0.5441],
          ...,
          [0.6224, 0.6057, 0.5454],
          [0.6221, 0.6054, 0.5452],
          [0.6218, 0.6051, 0.5449]],
 
         [[0.6154, 0.5982, 0.5379],
         