# Render @ pose

## Description

This notebook is an example of loading network parameter weights from `.msgpack` file saved from a pre-trained Instant-NGP model and using it to bootstrap a model in python as a more convenient/lazy interface for downstream tasks. In this notebook an example of rendering at a specific pose is implemented.

## Acknowledgement

This notebook is inspired by [test.ipynb](https://github.com/kwea123/ngp_pl/blob/master/test.ipynb) and this [answer](https://github.com/NVlabs/instant-ngp/discussions/522#discussioncomment-3211571) by [@kwea123](https://github.com/kwea123).

In [None]:
# Configure ipython to reload modules automatically when changed
# Avoids having to restart the kernel every time a module is modified
# https://stackoverflow.com/a/14390676
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import time
import os
import numpy as np
from models.networks import NGP
from models.rendering import render
from metrics import psnr
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from datasets import dataset_dict
from datasets.ray_utils import get_ray_directions, get_rays
from utils import load_ckpt
from train import depth2img
import imageio

import msgpack
from kornia.utils.grid import create_meshgrid3d

plt.rcParams['figure.figsize'] = [3, 3]
plt.rcParams['figure.dpi'] = 250

# Load pre-trained network

In [None]:
pretrained_data = None

# with open('/mnt/datasets/my_nerf_datasets/cuboids_and_joystick_360/transforms_base.msgpack', 'rb') as f:
# with open('/mnt/datasets/xin_rosbag/transforms_base_aabb_4.msgpack', 'rb') as f:
with open('/mnt/datasets/xin_rosbag/transforms_.msgpack', 'rb') as f:
    pretrained_data = msgpack.loads(f.read())

img_w, img_h = pretrained_data["snapshot"]["nerf"]["dataset"]["metadata"][0]["resolution"]

fx, fy = pretrained_data["snapshot"]["nerf"]["dataset"]["metadata"][0]["focal_length"]
cx, cy = pretrained_data["snapshot"]["nerf"]["dataset"]["metadata"][0]["principal_point"]

camera_intrinsics = torch.tensor([[fx,  0.0, cx * img_w],
                                  [0.0, fy,  cy * img_h],
                                  [0.0, 0.0, 1.0]])

directions = get_ray_directions(img_h, img_w, torch.FloatTensor(camera_intrinsics))

# Scratchpad

* `scale`=0.33
* `aabb_scale`=4

In [None]:
# print(model.density_grid.dtype)
# print(128**3)
# print(3*128**3//8)
# print("snapshot_density_grid_size", pretrained_data['snapshot']["density_grid_size"])
# print("3 * (128 ** 3) =", 3 * (128 ** 3))
# density_grid = np.frombuffer(pretrained_data['snapshot']['density_grid_binary'], dtype=np.float16).reshape((3, 128 ** 3))
# print("read data size as float32: ", density_grid.shape)

# print(pretrained_data.keys())

# print(pretrained_data["rgb_network"])
# print(pretrained_data["snapshot"].keys())
# print(pretrained_data["snapshot"]["nerf"].keys())
# print(pretrained_data["snapshot"]["nerf"]["dataset"]["metadata"][0]["resolution"])
# print(torch.FloatTensor([pretrained_data["snapshot"]["nerf"]["dataset"]["xforms"][0]["start"]]))

# print(np.frombuffer(pretrained_data['snapshot']['params_binary'], dtype=np.uint8).shape[0]/13085152)
# print(pretrained_data["encoding"])
# print(pretrained_data["snapshot"]["nerf"]["dataset"].keys())
# print(pretrained_data["snapshot"]["nerf"]["dataset"]["scale"]) # 0.33
# print(pretrained_data["snapshot"]["nerf"]["dataset"]["aabb_scale"]) # 4
# print(pretrained_data["snapshot"]["nerf"]["dataset"]["render_aabb"])
# print(pretrained_data["snapshot"]["nerf"]["dataset"]["render_aabb_to_local"])
# print(pretrained_data["snapshot"]["nerf"]["dataset"]["offset"])

print(pretrained_data.keys())
print(pretrained_data["encoding"])
print(pretrained_data["network"])
print(pretrained_data["rgb_network"]["activation"])
print(pretrained_data["rgb_network"]["output_activation"])

print(pretrained_data["snapshot"]["aabb"])

# Create model

In [None]:
scale = float(pretrained_data['snapshot']["nerf"]["dataset"]["scale"])
aabb_scale = float(pretrained_data['snapshot']["nerf"]["dataset"]["aabb_scale"])
offset = torch.tensor([pretrained_data["snapshot"]["nerf"]["dataset"]["offset"]])

# print(offset)
# print(scale)

model = NGP(scale = aabb_scale, rgb_act='None', offset = offset).cuda()

# Set network weights to the pre-trained weights from Instant-NGP
xyz_encoder_size = model.state_dict()["xyz_encoder.params"].shape[0]
network_params = np.frombuffer(pretrained_data['snapshot']['params_binary'], dtype=np.float16)

# print("n_params:", pretrained_data['snapshot']["n_params"])
# print("network_params before:", network_params.shape[0])

# print("xyz_encoder.params.shape: ", model.xyz_encoder.params.shape)
# print("rgb_net.params.shape: ", model.rgb_net.params.shape)
# print("sum: ", model.xyz_encoder.params.shape[0] + model.rgb_net.params.shape[0])

model.state_dict()["xyz_encoder.params"][:] = torch.from_numpy(network_params[:xyz_encoder_size]).cuda()
model.state_dict()["rgb_net.params"][:] = torch.from_numpy(network_params[xyz_encoder_size:]).cuda()

density_grid = np.frombuffer(pretrained_data['snapshot']['density_grid_binary'], dtype=np.float16).reshape((3, 128 ** 3))
model.register_buffer('density_grid', torch.from_numpy(density_grid).cuda())

model.register_buffer('grid_coords', create_meshgrid3d(128, 128, 128, False, dtype=torch.int32).reshape(-1, 3))

model.update_density_grid(0.01*1024/3**0.5)
model.update_density_grid(-0.01)
#

# model.center
# print(model.xyz_min)
# print(model.xyz_max)

# Rendering

In [None]:
# pose = torch.FloatTensor([pretrained_data["snapshot"]["nerf"]["dataset"]["xforms"][13]["start"]])

pose = torch.FloatTensor([[1.0, 0.0, 0.0, 0.5],
                         [0.0, -1.0, 0.0, 0.5],
                         [0.0, 0.0, -1.0, 0.5]])

rays_o, rays_d = get_rays(directions.cuda(), pose.cuda())

results = render(model, rays_o, rays_d,
                    **{'test_time': True,
                        'T_threshold': 1e-2,
                        'exp_step_factor': 1/256})

plt.subplots(figsize=(15, 12))
plt.tight_layout()
plt.subplot(221)
plt.title('pred')
pred = results['rgb'].reshape(img_h, img_w, 3).cpu().numpy()
plt.imshow(pred)
plt.subplot(222)
plt.title('depth')
depth = results['depth'].reshape(img_h, img_w).cpu().numpy()
depth_ = depth2img(depth)
plt.imshow(depth_)
plt.subplot(223)
plt.title('opacity')
plt.imshow(results['opacity'].reshape(img_h, img_w).cpu().numpy(), 'bone')
plt.show()