In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import yaml
import re
from typing import Optional
from tqdm import tqdm
from nerfstudio.cameras.rays import RaySamples, Frustums
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.utils import colormaps, misc

from reni.configs.reni_config import RENIField
from reni.configs.sh_sg_envmap_configs import SHField, SGField
from reni.pipelines.reni_pipeline import RENIPipeline
from reni.field_components.field_heads import RENIFieldHeadNames
from reni.data.datamanagers.reni_datamanager import RENIDataManager
from reni.utils.utils import find_nerfstudio_project_root, rot_z, rot_y
from reni.utils.colourspace import linear_to_sRGB

# setup config
world_size = 1
local_rank = 0
device = 'cuda:0'

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

def load_model(load_dir: Path, load_step: Optional[int] = None):
    ckpt_dir = load_dir / 'nerfstudio_models'
    def clean_and_load_yaml(yaml_content):
        # Remove !!python related tags
        cleaned_content = re.sub(r'!!python[^\s]*', '', yaml_content)
        
        # Load the cleaned content
        return yaml.safe_load(cleaned_content)

    if load_step is None:
        load_step = sorted(int(x[x.find("-") + 1 : x.find(".")]) for x in os.listdir(ckpt_dir))[-1]
    
    ckpt = torch.load(ckpt_dir / f'step-{load_step:09d}.ckpt', map_location=device)
    reni_model_dict = {}
    for key in ckpt['pipeline'].keys():
        if key.startswith('_model.'):
            reni_model_dict[key[7:]] = ckpt['pipeline'][key]
    
    config_path = load_dir / 'config.yml'
    with open(config_path, 'r') as f:
        content = f.read()
        config = clean_and_load_yaml(content)
    
    if 'latent_dim' in config['pipeline']['model']['field'].keys():
        
        model_config = RENIField.config
        model_config.pipeline.datamanager.dataparser.convert_to_ldr = config['pipeline']['datamanager']['dataparser']['convert_to_ldr']
        model_config.pipeline.datamanager.dataparser.convert_to_log_domain = config['pipeline']['datamanager']['dataparser']['convert_to_log_domain']
        if config['pipeline']['datamanager']['dataparser']['eval_mask_path'] is not None:
            eval_mask_path = Path(os.path.join(*config['pipeline']['datamanager']['dataparser']['eval_mask_path']))
            model_config.pipeline.datamanager.dataparser.eval_mask_path = eval_mask_path
        else:
            model_config.pipeline.datamanager.dataparser.eval_mask_path = None
        if config['pipeline']['datamanager']['dataparser']['min_max_normalize'].__class__ == list:
            model_config.pipeline.datamanager.dataparser.min_max_normalize = tuple(config['pipeline']['datamanager']['dataparser']['min_max_normalize'])
        else:
            model_config.pipeline.datamanager.dataparser.min_max_normalize = config['pipeline']['datamanager']['dataparser']['min_max_normalize']
        model_config.pipeline.datamanager.dataparser.augment_with_mirror = config['pipeline']['datamanager']['dataparser']['augment_with_mirror']
        model_config.pipeline.model.loss_inclusions = config['pipeline']['model']['loss_inclusions']
        model_config.pipeline.model.field.conditioning = config['pipeline']['model']['field']['conditioning']
        model_config.pipeline.model.field.invariant_function = config['pipeline']['model']['field']['invariant_function']
        model_config.pipeline.model.field.equivariance = config['pipeline']['model']['field']['equivariance']
        model_config.pipeline.model.field.axis_of_invariance = config['pipeline']['model']['field']['axis_of_invariance']
        model_config.pipeline.model.field.positional_encoding = config['pipeline']['model']['field']['positional_encoding']
        model_config.pipeline.model.field.encoded_input = config['pipeline']['model']['field']['encoded_input']
        model_config.pipeline.model.field.latent_dim = config['pipeline']['model']['field']['latent_dim']
        model_config.pipeline.model.field.hidden_features = config['pipeline']['model']['field']['hidden_features']
        model_config.pipeline.model.field.hidden_layers = config['pipeline']['model']['field']['hidden_layers']
        model_config.pipeline.model.field.mapping_layers = config['pipeline']['model']['field']['mapping_layers']
        model_config.pipeline.model.field.mapping_features = config['pipeline']['model']['field']['mapping_features']
        model_config.pipeline.model.field.num_attention_heads = config['pipeline']['model']['field']['num_attention_heads']
        model_config.pipeline.model.field.num_attention_layers = config['pipeline']['model']['field']['num_attention_layers']
        model_config.pipeline.model.field.output_activation = config['pipeline']['model']['field']['output_activation']
        model_config.pipeline.model.field.last_layer_linear = config['pipeline']['model']['field']['last_layer_linear']
        model_config.pipeline.model.field.trainable_scale = config['pipeline']['model']['field']['trainable_scale']
        model_config.pipeline.model.field.old_implementation = config['pipeline']['model']['field']['old_implementation']
        model_config.pipeline.model.loss_inclusions = config['pipeline']['model']['loss_inclusions']
    elif 'spherical_harmonic_order' in config['pipeline']['model']['field'].keys():
        model_config = SHField.config
        model_config.pipeline.model.field.spherical_harmonic_order = config['pipeline']['model']['field']['spherical_harmonic_order']
    elif 'row_col_gaussian_dims' in config['pipeline']['model']['field'].keys():
        model_config = SGField.config
        model_config.pipeline.model.field.row_col_gaussian_dims = config['pipeline']['model']['field']['row_col_gaussian_dims']

    model_config.pipeline.test_mode = config['pipeline']['test_mode']
    test_mode = model_config.pipeline.test_mode

    pipeline: RENIPipeline = model_config.pipeline.setup(
      device=device,
      test_mode=test_mode,
      world_size=world_size,
      local_rank=local_rank,
      grad_scaler=None,
    )

    datamanager = pipeline.datamanager

    model = pipeline.model

    model.to(device)
    model.load_state_dict(reni_model_dict)
    model.eval()

    return pipeline, datamanager, model

model_path = Path('/workspace/neusky/ns_reni/models/reni_plus_plus_models/latent_dim_100')
pipeline, datamanager, model = load_model(model_path)

Output()

Output()

In [2]:
import random
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import random
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
import io
import itertools

def get_3d_vector_plot(latent_code, vectors_to_show, seed=42):
    """
    Plots a subset of 3D vectors from the latent code using Matplotlib and returns the plot as an image array.
    The quivers are colored with a repeatable set of colors.

    :param latent_code: torch.tensor -> [N, 3]
    :param vectors_to_show: int
    """

    random.seed(seed)

    # Define a list of colors
    colors = ['red', 'green', 'blue', 'purple', 'orange', 'cyan']
    color_cycle = itertools.cycle(colors)

    fixed_arrowhead_size = 0.2  # Adjust this value as needed

    # Ensure the latent code is a 2D tensor with shape [N, 3]
    if len(latent_code.shape) != 2 or latent_code.shape[1] != 3:
        raise ValueError("latent_code must be a 2D tensor with shape [N, 3]")

    # Ensure vectors_to_show is not greater than the number of vectors in latent_code
    vectors_to_show = min(vectors_to_show, latent_code.shape[0])

    # Select a random subset of vectors
    indices = random.sample(range(latent_code.shape[0]), vectors_to_show)
    vectors = latent_code[indices]

    # Initialize 3D plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Add vectors to the plot
    for vector in vectors:
        vector = vector.cpu().numpy()  # Convert tensor to numpy array
        length = np.linalg.norm(vector)  # Calculate the length of the vector

        if length > 2:
            # Normalise the vector to length 2
            vector = (vector / length) * 2
            length = 2

        # Adjust arrow_length_ratio to maintain consistent arrowhead size
        arrow_length_ratio = fixed_arrowhead_size / length

        # Plot the vector with a color from the color cycle
        color = next(color_cycle)
        ax.quiver(0, 0, 0, vector[0], vector[1], vector[2], color=color, length=length, 
                  normalize=True, arrow_length_ratio=arrow_length_ratio)

    # Setting the ticks on each axis
    ax.set_xticks([-2, 0, 2])
    ax.set_yticks([-2, 0, 2])
    ax.set_zticks([-2, 0, 2])

    # remove grid
    ax.grid(False)

    # Save the plot to a buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    img = Image.open(buf)
    img_array = np.array(img)
    buf.close()

    plt.close(fig)  # Close the figure to free up memory

    img_array = img_array[30:395, 20:395, :3]

    return img_array

def generate_rotation_animation(image_idx, frames, model, datamanager, filename='animation.m4v', fps=24, include_field_diagram=False):
    model_outputs = {}
        
    for i in tqdm(range(frames)):
        rotation_angle = i * 360 / frames
        # Your code to produce an image would go here.
        model.eval()
        _, ray_bundle, batch = datamanager.next_eval_image(image_idx)
        H, W = model.metadata["image_height"], model.metadata["image_width"]

        # High res image:
        H = 256
        W = H * 2
        cx = torch.tensor(W // 2, dtype=torch.float32).repeat(1)
        cy = torch.tensor(H // 2, dtype=torch.float32).repeat(1)
        fx = torch.tensor(H, dtype=torch.float32).repeat(1)
        fy = torch.tensor(H, dtype=torch.float32).repeat(1)

        c2w = torch.tensor([[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]]], dtype=torch.float32).repeat(1, 1, 1)

        cameras = Cameras(fx=fx, fy=fy, cx=cx, cy=cy, camera_to_worlds=c2w, camera_type=CameraType.EQUIRECTANGULAR)

        ray_bundle = cameras.generate_rays(0).flatten().to(device)
        ray_bundle.camera_indices = torch.ones_like(ray_bundle.camera_indices) * image_idx

        batch['image'] = batch['image'].to(device)

        # check if the model has attribute old_implementation
        if hasattr(model.field, 'old_implementation'):
            if model.field.old_implementation:
                get_rotation = rot_y
            else:
                get_rotation = rot_z
        else:
            get_rotation = rot_z

        rotation = get_rotation(torch.tensor(np.deg2rad(rotation_angle)).float())
        rotation = rotation.to(device)

        outputs = model.get_outputs_for_camera_ray_bundle(ray_bundle, rotation)

        pred_img = model.field.unnormalise(outputs['rgb'])

        pred_img = pred_img.view(H, W, 3)

        pred_img = linear_to_sRGB(pred_img, use_quantile=True) # [H, W, 3]

        latent_code, _, _ = model.field.sample_latent(image_idx)

        rotation = get_rotation(torch.tensor(np.deg2rad(-rotation_angle)).float())
        rotation = rotation.to(device)

        latent_code = torch.matmul(latent_code, rotation)

        plot = get_3d_vector_plot(latent_code.cpu().detach(), 50) # H, W, 3

        # Convert numpy array to PIL Image
        img = Image.fromarray(plot)

        # Calculate the new width to maintain aspect ratio
        original_width, original_height = img.size
        aspect_ratio = original_width / original_height
        new_width = int(H * aspect_ratio)

        # Resize the image
        resized_img = img.resize((new_width, H), Image.BICUBIC) # [H, W, 3]

        # Convert back to numpy array
        plot = np.array(resized_img, dtype=np.float32) / 255.0

        if include_field_diagram:
            diagram_path = '/workspace/neusky/ns_reni/publication/figures/neural_field.png'
            diagram = Image.open(diagram_path)
            # Calculate the new width to maintain aspect ratio
            original_width, original_height = diagram.size
            aspect_ratio = original_width / original_height
            new_width = int(H * aspect_ratio)

            # Resize the image
            resized_diagram = diagram.resize((new_width, H), Image.BICUBIC)

            # Convert back to numpy array
            diagram = np.array(resized_diagram, dtype=np.float32) / 255.0

            plot = np.concatenate((plot, diagram), axis=1)

        # attach plot to left hand side of image
        plot = np.concatenate((plot, pred_img.cpu().detach().numpy()), axis=1)

        model_outputs[i] = {'plot': plot}

    height, width, _ = model_outputs[0]['plot'].shape
    size = (width, height)

    # Create a VideoWriter object
    out = cv2.VideoWriter(filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)

    for i in range(len(model_outputs)):
        # Convert the plot to the correct color format
        img = cv2.cvtColor(model_outputs[i]['plot'], cv2.COLOR_RGB2BGR)
        img = (img * 255).astype(np.uint8)
        out.write(img)

    # Release the VideoWriter
    out.release()

In [68]:
model.field.config.view_train_latents = True
path = '/workspace/neusky/ns_reni/publication/figures/reni_plus_plus_teaser.mp4'
generate_rotation_animation(96, 96, model, datamanager, filename=path, fps=24, include_field_diagram=True) # 96, 136, 219, 252
model.field.config.view_train_latents = False

100%|██████████| 96/96 [00:29<00:00,  3.23it/s]
