In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
# needed for notebook to find other directories
module_path = os.path.abspath(os.path.join('../mipnerf'))
if module_path not in sys.path:
    sys.path.append(module_path)
# needed for notebook to find other directories

In [None]:
import collections
import flax
import jax
import jax.numpy as jnp
from jax import random
import gin
from absl import flags
import functools
from flax.training import checkpoints
from mipnerf.internal import utils
from mipnerf.internal import datasets
from mipnerf.internal import c2f_obb_dataset
from mipnerf.internal import obbpose_dataset
from mipnerf.internal import models
from mipnerf.internal import obbpose_model
from mipnerf.internal import d_models
from mipnerf.internal import d_models360
from mipnerf.internal import math

import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R

import lpips
import torch
from torchvision import transforms as T

In [None]:
import numpy as np
import cv2
from PIL import Image

def visualize_depth(depth, cmap=cv2.COLORMAP_TWILIGHT_SHIFTED):
    """
    depth: (H, W)
    """
    x = depth
    x = np.nan_to_num(x) # change nan to 0
    mi = np.min(x) # get minimum depth
    ma = np.max(x)
    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
    x = (255*x).astype(np.uint8)
    x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
    return x_

In [None]:
class Config:
  """Configuration flags for everything."""
  dataset_loader: str = 'waymo'  # The type of dataset loader to use.
  batching: str = 'timestep'  # Batch composition, [single_image, all_images].
  batch_size: int = 512  # The number of rays/pixels in each batch.
  factor: int = 4  # The downsample factor of images, 0 for no downsampling.
  spherify: bool = True  # Set to True for spherical 360 scenes.
  centering: bool = True  # this determines if poses are centered around zero or not
  render_path: bool = False  # If True, render a path. Used only by LLFF.
  llffhold: int = 11  # Use every Nth image for the test set. Used only by LLFF.
  timesteps: int = 5  # How many timesteps the current scene has (a bit of a crutch right now, integrate to dataset?)
  lr_init: float = 5e-4  # The initial learning rate.
  lr_final: float = 5e-6  # The final learning rate.
  lr_delay_steps: int = 2500  # The number of "warmup" learning steps.
  eps_delay_steps: int = 0  # The number of "warmup" learning steps.
  eps_init: int = 3  # Initial interval for near loss
  eps_final: int = 0.2  # Final interval for near loss
  l2_init: int = 1  # Initial value for l2reg for deformation model
  l2_final: int = 0  # Final value for l2reg for deformation model
  l2_delay_steps: int = 5000  # reduce l2reg after this many steps
  psreg_init: float = 10e5  # start value for pose regularization
  psreg_final: float = 10e-1  # end value for pose regularization
  psreg_delay_steps: int = 5000  # after this many steps, start decreasing psreg
  psreg_delay_mult: float = 1.0
  random_box: bool = False
  random_yaw: bool = False
  box_noise: float = 0.5
  yaw_noise: float = 5.  # rotational noise in degrees added to box yaw / heading angle
  c2f_steps: list = (5000, 10000, 15000)  # The number of steps after which rays of higher resolutions should be loaded
  lr_delay_mult: float = 0.01  # How much sever the "warmup" should be.
  grad_max_norm: float = 0.  # Gradient clipping magnitude, disabled if == 0.
  grad_max_val: float = 0.  # Gradient clipping value, disabled if == 0.
  max_steps: int = 200000  # The number of optimization steps.
  save_every: int = 50000  # The number of steps to save a checkpoint.
  print_every: int = 100  # The number of steps between reports to tensorboard.
  gc_every: int = 10000  # The number of steps between garbage collections.
  test_render_interval: int = 1  # The interval between images saved to disk.
  disable_multiscale_loss: bool = False  # If True, disable multiscale loss.
  randomized: bool = True  # Use randomized stratified sampling.
  near: float = 0.0  # Near plane distance.
  far: float = 40.  # Far plane distance.
  coarse_loss_mult: float = 0.1  # How much to downweight the coarse loss(es).
  weight_decay_mult: float = 0.  # The multiplier on weight decay.
  white_bkgd: bool = False  # If True, use white as the background (black o.w.).
  rand_bkgd: bool = False  # If True, use random color as background

In [None]:
BoxRays = collections.namedtuple(
    'BoxRays',
    ('origins', 'directions',
     'viewdirs', 'radii', 'lossmult',
     'near', 'far'))

config = Config()

dataset = obbpose_dataset.get_dataset('test', '/home/tristram/data/waymo/seg1_5_center', config)

#example_batch = dataset.peek()
#example_batch['box'] = example_batch['box'][None, ...]

model, init_variables = obbpose_model.construct_mipnerf(
      random.PRNGKey(20200823), dataset.peek())
optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
state = utils.TrainState(optimizer=optimizer)
del optimizer, init_variables

# Because this is only used for test set rendering, we disable randomization.
def render_eval_fn(variables, _, batch):
    return jax.lax.all_gather(
        model.apply(
                variables,
                random.PRNGKey(0),  # Unused.
                batch['rays'],
                batch['init'],
                batch['ext'],
                batch['ts'],
                randomized=False,
                white_bkgd=config.white_bkgd,
                rand_bkgd=False,
                alpha=batch['alpha']),
            axis_name='batch')

render_eval_pfn = jax.pmap(
    render_eval_fn,
    in_axes=(None, None, 0),  # Only distribute the data input.
    donate_argnums=(2,),
    axis_name='batch',
)
exp_name = 'Waymo_den_cntr02_0_200_8_256_10fipe_gtbox'
state = checkpoints.restore_checkpoint('/home/tristram/nerf_results/'+exp_name+'/seg1_5_center', 
                                      state)

In [None]:
ssim_fn = jax.jit(functools.partial(math.compute_ssim, max_val=1.))
loss_fn_alex = lpips.LPIPS(net='alex')
psnr = []
ssim = []
lpips_score = []
for step, batch in zip(range(dataset.size), dataset):
    pred_color, pred_distance, pred_acc = obbpose_model.render_image(
                functools.partial(render_eval_pfn, state.optimizer.target),
                batch['rays'],
                batch['init'],
                batch['ext'],
                batch['ts'],
                None,
                alpha=10,
                chunk=2048)
    
    #plt.imsave('/home/tristram/Pictures/eval_imgs/'+exp_name+'depth'+str(step)+'.png', pred_color)
    plt.figure()
    plt.imshow(pred_color)
    plt.axis('off')
    plt.show()
    
    pred_depth = visualize_depth(pred_distance)
    gt_depth = visualize_depth(batch['depth'])
    
    plt.figure()
    plt.imshow(pred_depth)
    plt.axis('off')
    plt.show()
    
    imgsave = np.clip(pred_color, 0, 1)
    plt.imsave('/home/tristram/Pictures/eval_imgs/'+exp_name+'_rgb_'+str(step)+'.png', imgsave)
    pred_depth.save('/home/tristram/Pictures/eval_imgs/'+exp_name+'_depth_'+str(step)+'.png')
    
    psnr.append(math.mse_to_psnr(((pred_color - batch['pixels']) ** 2).mean()))
    ssim.append(ssim_fn(pred_color, batch['pixels']))

    img_0 = torch.from_numpy(np.asarray(pred_color)).moveaxis(-1, 0)
    img_gt = torch.from_numpy(np.asarray(batch['pixels'])).moveaxis(-1, 0)
    score = loss_fn_alex(img_0.unsqueeze(0),
                         img_gt.unsqueeze(0))
    lpips_score.append(score.detach().numpy().squeeze())
    
psnr = np.array(psnr)
ssim = np.array(ssim)
lpips_score = np.array(lpips_score)
print(psnr.mean())
print(ssim.mean())
print(lpips_score.mean())