In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import yaml
import re
from typing import Optional
from nerfstudio.cameras.rays import RaySamples, Frustums
from nerfstudio.cameras.cameras import Cameras, CameraType

from reni.configs.reni_config import RENIField
from reni.configs.sh_sg_envmap_configs import SHField, SGField
from reni.pipelines.reni_pipeline import RENIPipeline
from reni.field_components.field_heads import RENIFieldHeadNames
from reni.data.datamanagers.reni_datamanager import RENIDataManager
from reni.utils.utils import find_nerfstudio_project_root, rot_z, rot_y
from reni.utils.colourspace import linear_to_sRGB

# setup config
test_mode = 'val'
world_size = 1
local_rank = 0
device = 'cuda:0'

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

def load_model(load_dir: Path, load_step: Optional[int] = None, eval_rotation: Optional[int] = None):
    ckpt_dir = load_dir / 'nerfstudio_models'
    def clean_and_load_yaml(yaml_content):
        # Remove !!python related tags
        cleaned_content = re.sub(r'!!python[^\s]*', '', yaml_content)
        
        # Load the cleaned content
        return yaml.safe_load(cleaned_content)

    if load_step is None:
        load_step = sorted(int(x[x.find("-") + 1 : x.find(".")]) for x in os.listdir(ckpt_dir))[-1]
    
    ckpt = torch.load(ckpt_dir / f'step-{load_step:09d}.ckpt', map_location=device)
    reni_model_dict = {}
    for key in ckpt['pipeline'].keys():
        if key.startswith('_model.'):
            reni_model_dict[key[7:]] = ckpt['pipeline'][key]
    
    config_path = load_dir / 'config.yml'
    with open(config_path, 'r') as f:
        content = f.read()
        config = clean_and_load_yaml(content)
    
    if 'latent_dim' in config['pipeline']['model']['field'].keys():
        
        model_config = RENIField.config
        model_config.pipeline.datamanager.dataparser.convert_to_ldr = config['pipeline']['datamanager']['dataparser']['convert_to_ldr']
        model_config.pipeline.datamanager.dataparser.convert_to_log_domain = config['pipeline']['datamanager']['dataparser']['convert_to_log_domain']
        if config['pipeline']['datamanager']['dataparser']['eval_mask_path'] is not None:
            eval_mask_path = Path(os.path.join(*config['pipeline']['datamanager']['dataparser']['eval_mask_path']))
            model_config.pipeline.datamanager.dataparser.eval_mask_path = eval_mask_path
        else:
            model_config.pipeline.datamanager.dataparser.eval_mask_path = None
        if config['pipeline']['datamanager']['dataparser']['min_max_normalize'].__class__ == list:
            model_config.pipeline.datamanager.dataparser.min_max_normalize = tuple(config['pipeline']['datamanager']['dataparser']['min_max_normalize'])
        else:
            model_config.pipeline.datamanager.dataparser.min_max_normalize = config['pipeline']['datamanager']['dataparser']['min_max_normalize']
        model_config.pipeline.datamanager.dataparser.augment_with_mirror = config['pipeline']['datamanager']['dataparser']['augment_with_mirror']
        model_config.pipeline.model.loss_inclusions = config['pipeline']['model']['loss_inclusions']
        model_config.pipeline.model.field.conditioning = config['pipeline']['model']['field']['conditioning']
        model_config.pipeline.model.field.invariant_function = config['pipeline']['model']['field']['invariant_function']
        model_config.pipeline.model.field.equivariance = config['pipeline']['model']['field']['equivariance']
        model_config.pipeline.model.field.axis_of_invariance = config['pipeline']['model']['field']['axis_of_invariance']
        model_config.pipeline.model.field.positional_encoding = config['pipeline']['model']['field']['positional_encoding']
        model_config.pipeline.model.field.encoded_input = config['pipeline']['model']['field']['encoded_input']
        model_config.pipeline.model.field.latent_dim = config['pipeline']['model']['field']['latent_dim']
        model_config.pipeline.model.field.hidden_features = config['pipeline']['model']['field']['hidden_features']
        model_config.pipeline.model.field.hidden_layers = config['pipeline']['model']['field']['hidden_layers']
        model_config.pipeline.model.field.mapping_layers = config['pipeline']['model']['field']['mapping_layers']
        model_config.pipeline.model.field.mapping_features = config['pipeline']['model']['field']['mapping_features']
        model_config.pipeline.model.field.num_attention_heads = config['pipeline']['model']['field']['num_attention_heads']
        model_config.pipeline.model.field.num_attention_layers = config['pipeline']['model']['field']['num_attention_layers']
        model_config.pipeline.model.field.output_activation = config['pipeline']['model']['field']['output_activation']
        model_config.pipeline.model.field.last_layer_linear = config['pipeline']['model']['field']['last_layer_linear']
        model_config.pipeline.model.field.trainable_scale = config['pipeline']['model']['field']['trainable_scale']
        model_config.pipeline.model.field.old_implementation = config['pipeline']['model']['field']['old_implementation']
        model_config.pipeline.model.loss_inclusions = config['pipeline']['model']['loss_inclusions']
    elif 'spherical_harmonic_order' in config['pipeline']['model']['field'].keys():
        model_config = SHField.config
        model_config.pipeline.model.field.spherical_harmonic_order = config['pipeline']['model']['field']['spherical_harmonic_order']
    elif 'row_col_gaussian_dims' in config['pipeline']['model']['field'].keys():
        model_config = SGField.config
        model_config.pipeline.model.field.row_col_gaussian_dims = config['pipeline']['model']['field']['row_col_gaussian_dims']

    if eval_rotation is not None:
        model_config.pipeline.datamanager.dataparser.apply_eval_rotation = eval_rotation
    else:
        model_config.pipeline.datamanager.dataparser.apply_eval_rotation = False

    pipeline: RENIPipeline = model_config.pipeline.setup(
      device=device,
      test_mode=test_mode,
      world_size=world_size,
      local_rank=local_rank,
      grad_scaler=None,
    )

    datamanager = pipeline.datamanager

    model = pipeline.model

    model.to(device)
    model.load_state_dict(reni_model_dict)
    model.eval()

    return pipeline, datamanager, model


def get_latents(model_path, rotations, eval_latent_batch_type):
    model_latents = {}
    model_metrics = {}
    
    for rotation in rotations:
        pipeline, datamanager, model = load_model(Path(model_path), eval_rotation=rotation)
        model.config.eval_latent_batch_type = eval_latent_batch_type

        metrics_dict = pipeline.get_average_eval_image_metrics(optimise_latents=True)
        latents = model.field.eval_mu

        model_latents[rotation] = latents
        model_metrics[rotation] = metrics_dict
    
    return model_latents, model_metrics

In [2]:
model_path = '/workspace/outputs/reni/reni_plus_plus_models/latent_dim_100'
rotations = [0, 5, 20, 45, 90, 180, 270]

model_latents, model_metrics = get_latents(model_path, rotations, eval_latent_batch_type='full_image')

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

In [3]:
import roma

import torch
from torch import cos, sin

def compute_relative_error_per_latent(Z1, Z2, gt_rotation):
    Z1 = Z1.float()
    Z2 = Z2.float()

    M = torch.linalg.lstsq(Z1, Z2).solution
    R_est = roma.special_procrustes(M)
    E = (torch.norm(Z1 @ R_est - Z2) / torch.norm(Z2)).item()

    # Compute the ground truth rotation matrix
    R_gt = rot_z(gt_rotation).type_as(R_est)

    # Compute the Frobenius norm of the difference between R_gt and R_est
    R_diff_norm = torch.norm(R_gt - R_est, p='fro').item()

    return E, R_diff_norm

def compute_mean_error_for_rotations(reference_latents, latent_dict):
    errors = {}
    for rotation_degrees, latents in latent_dict.items():
        if rotation_degrees == 0:  # Skip the reference rotation
            continue
        # Convert degrees to radians for the rotation
        gamma = torch.deg2rad(torch.tensor([rotation_degrees], dtype=latents.dtype))
        errors_for_rotation = []
        rotation_diff_norms = []
        for i in range(latents.size(0)):
            error, rotation_diff_norm = compute_relative_error_per_latent(reference_latents[i], latents[i], gamma)
            errors_for_rotation.append(error)
            rotation_diff_norms.append(rotation_diff_norm)
        mean_error = torch.tensor(errors_for_rotation).mean().item()
        std_error = torch.tensor(errors_for_rotation).std().item()
        mean_rotation_diff_norm = torch.tensor(rotation_diff_norms).mean().item()
        std_rotation_diff_norm = torch.tensor(rotation_diff_norms).std().item()
        errors[rotation_degrees] = [mean_error, std_error, mean_rotation_diff_norm, std_rotation_diff_norm]
    return errors



mean_errors = compute_mean_error_for_rotations(model_latents[0], model_latents)
# Begin LaTeX table code
# Begin LaTeX table code
latex_output = "\\begin{table}[h]\n"
latex_output += "\\begin{center}\n"
latex_output += "\\caption{Mean relative error and rotation matrix discrepancy for optimised latent codes across various rotation angles of the test set.}\n"
latex_output += "\\label{tab:rotation_error_comparison}\n"
latex_output += "\\begin{tabular}{| c | c | c |}\n"
latex_output += "\\hline\n"
latex_output += "\\makecell{Rotation Angle \\\\ (Degrees)} & \\makecell{Relative Error} & \\makecell{Ground Truth Rotation \\\\ Matrix Error} \\\\\n"
latex_output += "\\hline\n"

# Here you would insert your data row

# Generate LaTeX table rows with math mode for the plus-minus symbol and the rotation matrix discrepancy
for rotation, errors in mean_errors.items():
    latex_output += f"{rotation} & ${errors[0]:.3f} \\pm {errors[1]:.3f}$ & ${errors[2]:.3f} \\pm {errors[3]:.3f}$ \\\\\n"
    latex_output += "\\hline\n"

# End LaTeX table code
latex_output += "\\end{tabular}\n"
latex_output += "\\end{center}\n"
latex_output += "\\end{table}\n"

# Print the LaTeX table code
print(latex_output)



\begin{table}[h]
\begin{center}
\caption{Mean relative error and rotation matrix discrepancy for optimised latent codes across various rotation angles of the test set.}
\label{tab:rotation_error_comparison}
\begin{tabular}{| c | c | c |}
\hline
\makecell{Rotation Angle \\ (Degrees)} & \makecell{Relative Error} & \makecell{Ground Truth Rotation \\ Matrix Error} \\
\hline
5 & $0.582 \pm 0.140$ & $0.150 \pm 0.108$ \\
\hline
20 & $0.762 \pm 0.140$ & $0.243 \pm 0.123$ \\
\hline
45 & $0.838 \pm 0.098$ & $0.416 \pm 0.291$ \\
\hline
90 & $0.049 \pm 0.049$ & $0.008 \pm 0.009$ \\
\hline
180 & $0.050 \pm 0.045$ & $0.007 \pm 0.005$ \\
\hline
270 & $0.044 \pm 0.048$ & $0.006 \pm 0.005$ \\
\hline
\end{tabular}
\end{center}
\end{table}



In [8]:
pipeline, datamanager, model = load_model(Path(model_path), eval_rotation=None)# Pseudo-Function Definitions
# Assuming you have all the necessary imports and model setup done

def get_model_output(latent):
    model.eval()
    idx = 0
    _, ray_bundle, batch = datamanager.next_eval_image(idx)
    H, W = model.metadata["image_height"], model.metadata["image_width"]

    # High res image:
    H = 64
    W = H * 2
    cx = torch.tensor(W // 2, dtype=torch.float32).repeat(1)
    cy = torch.tensor(H // 2, dtype=torch.float32).repeat(1)
    fx = torch.tensor(H, dtype=torch.float32).repeat(1)
    fy = torch.tensor(H, dtype=torch.float32).repeat(1)

    c2w = torch.tensor([[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]]], dtype=torch.float32).repeat(1, 1, 1)

    cameras = Cameras(fx=fx, fy=fy, cx=cx, cy=cy, camera_to_worlds=c2w, camera_type=CameraType.EQUIRECTANGULAR)

    ray_bundle = cameras.generate_rays(0).flatten().to(device)
    ray_bundle.camera_indices = torch.ones_like(ray_bundle.camera_indices) * idx

    batch['image'] = batch['image'].to(device)

    # check if the model has attribute old_implementation
    if hasattr(model.field, 'old_implementation'):
        if model.field.old_implementation:
            get_rotation = rot_y
        else:
            get_rotation = rot_z

    ray_samples = model.create_ray_samples(ray_bundle.origins, ray_bundle.directions, ray_bundle.camera_indices)

    latent = latent.repeat(ray_samples.shape[0], 1, 1)

    field_outputs = model.field.forward(ray_samples, latent_codes=latent)

    pred_img = model.field.unnormalise(field_outputs[RENIFieldHeadNames.RGB])
    pred_img = pred_img.reshape(H, W, 3)

    pred_img = linear_to_sRGB(pred_img, use_quantile=True)
    return pred_img


# Pseudo-Function Definitions
def compute_mse(output1, output2):
    return ((output1 - output2) ** 2).mean()

def inverse_rotation_matrix(angle):
    # Assuming `rot_z` generates a rotation matrix for rotating by `angle` radians,
    # its inverse would be a rotation by `-angle`.
    return rot_z(-angle)

# Step 1: Generate reference outputs for each reference latent
reference_outputs = []
for ref_latent in model_latents[0]:  # Assuming this is a batch of reference latents
    ref_output = get_model_output(ref_latent.unsqueeze(0))  # Forward pass for each individual latent
    reference_outputs.append(ref_output)

# Step 2 & 3: Rotate latents, generate outputs, and compute MSE for each latent
mse_errors = {}
for rotation_degrees, latents in model_latents.items():
    if rotation_degrees == 0:
        continue
    # Convert degrees to radians
    radians = np.deg2rad(rotation_degrees)
    # Get the inverse rotation matrix
    inverse_rotation = inverse_rotation_matrix(torch.tensor(radians))
    inverse_rotation = inverse_rotation.to(device)
    
    errors_for_rotation = []
    for i, latent in enumerate(latents):
        # Apply inverse rotation to each latent
        inverse_rotation = inverse_rotation.type_as(latent)
        unrotated_latent = latent @ inverse_rotation.T  # Matrix multiplication to unrotate

        # Generate model output for the unrotated latent
        unrotated_output = get_model_output(unrotated_latent.unsqueeze(0))

        # Compute MSE with the corresponding reference output
        mse = compute_mse(unrotated_output, reference_outputs[i])
        errors_for_rotation.append(mse.item())
    
    # Store the mean MSE for this rotation
    mean_mse_for_rotation = sum(errors_for_rotation) / len(errors_for_rotation)
    mse_errors[rotation_degrees] = mean_mse_for_rotation

    # Print the mean MSE for the current rotation
    print(f'Mean MSE for {rotation_degrees} degrees rotation: {mean_mse_for_rotation}')

Output()

Output()

Mean MSE for 5 degrees rotation: 0.004223449937334018
Mean MSE for 20 degrees rotation: 0.018485892730365907
Mean MSE for 45 degrees rotation: 0.035348255098575636
Mean MSE for 90 degrees rotation: 0.03950929202671562
Mean MSE for 180 degrees rotation: 0.00010352867079312827
Mean MSE for 270 degrees rotation: 0.039573391899466515


In [35]:
model.eval()
idx = 0
_, ray_bundle, batch = datamanager.next_eval_image(idx)
H, W = model.metadata["image_height"], model.metadata["image_width"]

# High res image:
H = 64
W = H * 2
cx = torch.tensor(W // 2, dtype=torch.float32).repeat(1)
cy = torch.tensor(H // 2, dtype=torch.float32).repeat(1)
fx = torch.tensor(H, dtype=torch.float32).repeat(1)
fy = torch.tensor(H, dtype=torch.float32).repeat(1)

c2w = torch.tensor([[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]]], dtype=torch.float32).repeat(1, 1, 1)

cameras = Cameras(fx=fx, fy=fy, cx=cx, cy=cy, camera_to_worlds=c2w, camera_type=CameraType.EQUIRECTANGULAR)

ray_bundle = cameras.generate_rays(0).flatten().to(device)
ray_bundle.camera_indices = torch.ones_like(ray_bundle.camera_indices) * idx

batch['image'] = batch['image'].to(device)

# check if the model has attribute old_implementation
if hasattr(model.field, 'old_implementation'):
    if model.field.old_implementation:
        get_rotation = rot_y
    else:
        get_rotation = rot_z

rotation = get_rotation(torch.tensor(np.deg2rad(45.0)).float())
rotation = rotation.to(device)

ray_samples = model.create_ray_samples(ray_bundle.origins, ray_bundle.directions, ray_bundle.camera_indices)

field_outputs = model.field.forward(ray_samples, rotation=rotation)

pred_img = model.field.unnormalise(field_outputs[RENIFieldHeadNames.RGB])
gt_image = model.field.unnormalise(batch['image'])

# rehsape to H, W, C
gt_image = gt_image.reshape(H, W, 3)
pred_img = pred_img.reshape(H, W, 3)

pred_img = linear_to_sRGB(pred_img, use_quantile=True)
gt_img = linear_to_sRGB(gt_image, use_quantile=True)
