# Converting the old RENI models into the Nerfstudio format for comparisons.

In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import yaml
from nerfstudio.cameras.rays import RaySamples, Frustums

from reni.configs.reni_config import RENIField
from reni.field_components.field_heads import RENIFieldHeadNames
from reni.data.datamanagers.reni_datamanager import RENIDataManager
from reni.utils.utils import find_nerfstudio_project_root, rot_z
from reni.utils.colourspace import linear_to_sRGB

# setup config
test_mode = 'val'
world_size = 1
local_rank = 0
device = 'cuda:0'

project_root = find_nerfstudio_project_root(Path(os.getcwd()))
# set current working directory to nerfstudio project root
os.chdir(project_root)

In [2]:
ndims = 49

latest_version = f'/workspace/outputs/reni/old_reni_models/latent_dim_{ndims}/nerfstudio_models/step-000001000.ckpt'

example_ckpt = torch.load(latest_version, map_location='cpu')

In [3]:
path_to_old_models = Path('/workspace/outputs/old_RENI/RENIVariationalAutoDecoder/network_128_5')
config_path = path_to_old_models / f'ndims_{ndims}' / 'files' / 'config.yaml'
eval_reni_path = path_to_old_models / f'ndims_{ndims}' / 'files' / 'RENI_Latent.pt'

# load config
config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)

# load checkpoint
ckpt = torch.load(eval_reni_path, map_location='cpu')

In [4]:
print(f'Nerfstudio Model eval_mu shape = {example_ckpt["pipeline"]["_model.field.eval_mu"].shape}')
print(f'Old model eval_mu shape = {ckpt["mu"].shape}')
print(f'Nerfstudio Model eval_logvar shape = {example_ckpt["pipeline"]["_model.field.eval_logvar"].shape}')
print(f'Old model eval_logvar shape = {ckpt["log_var"].shape}')

Nerfstudio Model eval_mu shape = torch.Size([21, 49, 3])
Old model eval_mu shape = torch.Size([21, 49, 3])
Nerfstudio Model eval_logvar shape = torch.Size([21, 49, 3])
Old model eval_logvar shape = torch.Size([21, 49, 3])


In [5]:
# for all keys in exmple_ckpt['pipelines'] that start with _model.field.network
# print the key and the shape of the value
for key, value in example_ckpt['pipeline'].items():
    if key.startswith('_model.field.network'):
        print(f'{key} shape = {value.shape}')

for key, value in ckpt.items():
    if key.startswith('net'):
        print(f'{key} shape = {value.shape}')

_model.field.network.net.0.linear.weight shape = torch.Size([128, 2501])
_model.field.network.net.0.linear.bias shape = torch.Size([128])
_model.field.network.net.1.linear.weight shape = torch.Size([128, 128])
_model.field.network.net.1.linear.bias shape = torch.Size([128])
_model.field.network.net.2.linear.weight shape = torch.Size([128, 128])
_model.field.network.net.2.linear.bias shape = torch.Size([128])
_model.field.network.net.3.linear.weight shape = torch.Size([128, 128])
_model.field.network.net.3.linear.bias shape = torch.Size([128])
_model.field.network.net.4.linear.weight shape = torch.Size([128, 128])
_model.field.network.net.4.linear.bias shape = torch.Size([128])
_model.field.network.net.5.linear.weight shape = torch.Size([128, 128])
_model.field.network.net.5.linear.bias shape = torch.Size([128])
_model.field.network.net.6.linear.weight shape = torch.Size([3, 128])
_model.field.network.net.6.linear.bias shape = torch.Size([3])
net.0.linear.weight shape = torch.Size([128,

In [6]:
# update the nerfstudio model with the old model
example_ckpt['pipeline']['_model.field.eval_mu'] = ckpt['mu']
example_ckpt['pipeline']['_model.field.eval_logvar'] = ckpt['log_var']
example_ckpt['pipeline']['_model.field.train_mu'] = ckpt['mu']
# update the network weights
for key, value in ckpt.items():
    if key.startswith('net'):
        example_ckpt['pipeline']['_model.field.network.' + key] = value

In [8]:
# overwrite the nerfstudio model with the updated weights
torch.save(example_ckpt, latest_version)

In [7]:
example_ckpt['pipeline']['_model.field.train_mu'] = torch.zeros((1673, 49, 3)).type_as(example_ckpt['pipeline']['_model.field.train_mu'])
example_ckpt['pipeline']['_model.field.train_logvar'] = torch.zeros((1673, 49, 3)).type_as(example_ckpt['pipeline']['_model.field.train_logvar'])


In [35]:
example_ckpt['pipeline']['_model.field.train_logvar'].shape

torch.Size([1673, 36, 3])