In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
# needed for notebook to find other directories
module_path = os.path.abspath(os.path.join('../mipnerf'))
if module_path not in sys.path:
    sys.path.append(module_path)
# needed for notebook to find other directories

In [None]:
import collections
import flax
import jax
import jax.numpy as jnp
from jax import random
import gin
from absl import flags
import functools
from flax.training import checkpoints
from mipnerf.internal import utils
from mipnerf.internal import datasets
from mipnerf.internal import c2f_obb_dataset
from mipnerf.internal import obbpose_dataset
from mipnerf.internal import models
from mipnerf.internal import obbpose_model
from mipnerf.internal import d_models
from mipnerf.internal import d_models360

import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R

In [None]:
class Config:
  """Configuration flags for everything."""
  dataset_loader: str = 'waymo'  # The type of dataset loader to use.
  batching: str = 'timestep'  # Batch composition, [single_image, all_images].
  batch_size: int = 512  # The number of rays/pixels in each batch.
  factor: int = 4  # The downsample factor of images, 0 for no downsampling.
  spherify: bool = True  # Set to True for spherical 360 scenes.
  centering: bool = True  # this determines if poses are centered around zero or not
  render_path: bool = False  # If True, render a path. Used only by LLFF.
  llffhold: int = 11  # Use every Nth image for the test set. Used only by LLFF.
  timesteps: int = 5  # How many timesteps the current scene has (a bit of a crutch right now, integrate to dataset?)
  lr_init: float = 5e-4  # The initial learning rate.
  lr_final: float = 5e-6  # The final learning rate.
  lr_delay_steps: int = 2500  # The number of "warmup" learning steps.
  eps_delay_steps: int = 0  # The number of "warmup" learning steps.
  eps_init: int = 3  # Initial interval for near loss
  eps_final: int = 0.2  # Final interval for near loss
  l2_init: int = 1  # Initial value for l2reg for deformation model
  l2_final: int = 0  # Final value for l2reg for deformation model
  l2_delay_steps: int = 5000  # reduce l2reg after this many steps
  psreg_init: float = 10e5  # start value for pose regularization
  psreg_final: float = 10e-1  # end value for pose regularization
  psreg_delay_steps: int = 5000  # after this many steps, start decreasing psreg
  psreg_delay_mult: float = 1.0
  random_box: bool = False
  random_yaw: bool = False
  box_noise: float = 0.5
  yaw_noise: float = 5.  # rotational noise in degrees added to box yaw / heading angle
  c2f_steps: list = (5000, 10000, 15000)  # The number of steps after which rays of higher resolutions should be loaded
  lr_delay_mult: float = 0.01  # How much sever the "warmup" should be.
  grad_max_norm: float = 0.  # Gradient clipping magnitude, disabled if == 0.
  grad_max_val: float = 0.  # Gradient clipping value, disabled if == 0.
  max_steps: int = 200000  # The number of optimization steps.
  save_every: int = 50000  # The number of steps to save a checkpoint.
  print_every: int = 100  # The number of steps between reports to tensorboard.
  gc_every: int = 10000  # The number of steps between garbage collections.
  test_render_interval: int = 1  # The interval between images saved to disk.
  disable_multiscale_loss: bool = False  # If True, disable multiscale loss.
  randomized: bool = True  # Use randomized stratified sampling.
  near: float = 0.0  # Near plane distance.
  far: float = 40.  # Far plane distance.
  coarse_loss_mult: float = 0.1  # How much to downweight the coarse loss(es).
  weight_decay_mult: float = 0.  # The multiplier on weight decay.
  white_bkgd: bool = False  # If True, use white as the background (black o.w.).
  rand_bkgd: bool = False  # If True, use random color as background

In [None]:
import numpy as np
import cv2
from PIL import Image

def visualize_depth(depth, cmap=cv2.COLORMAP_TWILIGHT_SHIFTED):
    """
    depth: (H, W)
    """
    x = depth
    x = np.nan_to_num(x) # change nan to 0
    mi = np.min(x) # get minimum depth
    ma = np.max(x)
    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
    x = (255*x).astype(np.uint8)
    x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
    return x_

In [None]:
def generate_rays_original(camtoworlds, dataset, img):
    """Generating rays for all images."""
    x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        np.arange(dataset.w[img], dtype=np.float32),  # X-Axis (columns)
        np.arange(dataset.h[img], dtype=np.float32),  # Y-Axis (rows)
        indexing='xy')
    camera_dirs = np.stack(
        [(x - dataset.w[img] * 0.5) / dataset.focal[img],
         -(y - dataset.h[img] * 0.5) / dataset.focal[img], -np.ones_like(x)],axis=-1)
    directions = ((camera_dirs[None, ..., None, :] *
                   camtoworlds[:3, :3]).sum(axis=-1))
    origins = np.broadcast_to(camtoworlds[:3, -1],directions.shape)
    viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)

    # Distance from each unit-norm direction vector to its x-axis neighbor.
    dx = np.sqrt(np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :]) ** 2, -1))
    dx = np.concatenate([dx, dx[:, -2:-1, :]], 1)
    # Cut the distance in half, and then round it out so that it's
    # halfway between inscribed by / circumscribed about the pixel.

    radii = dx[..., None] * 2 / np.sqrt(12)

    ones = np.ones_like(origins[..., :1]).squeeze()
    rays = BoxRays(
            origins=origins.squeeze(),
            directions=directions.squeeze(),
            viewdirs=viewdirs.squeeze(),
            radii=radii.squeeze(),
            lossmult=ones.squeeze(),
            near=ones * dataset.near,
            far=ones * dataset.far,)
    return rays

def generate_rays_waymo(camtoworlds, dataset, img, pp=0):
    """Generating rays for all images."""
    x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        np.arange(dataset.w[img], dtype=np.float32),  # X-Axis (columns)
        np.arange(dataset.h[img], dtype=np.float32),  # Y-Axis (rows)
        indexing='xy')

    pp_scale = pp / 4

    camera_dirs = np.stack(
        [(x - pp_scale[0] + 0.5) / dataset.focal[img],
         -(y - pp_scale[1] + 0.5) / dataset.focal[img], -np.ones_like(x)],axis=-1)

    directions = ((camera_dirs[None, ..., None, :] *
                   camtoworlds[:3, :3]).sum(axis=-1))
    origins = np.broadcast_to(camtoworlds[:3, -1],directions.shape)
    viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)

    # Distance from each unit-norm direction vector to its x-axis neighbor.
    dx = np.sqrt(np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :]) ** 2, -1))
    dx = np.concatenate([dx, dx[:, -2:-1, :]], 1)
    # Cut the distance in half, and then round it out so that it's
    # halfway between inscribed by / circumscribed about the pixel.

    radii = dx[..., None] * 2 / np.sqrt(12)

    ones = np.ones_like(origins[..., :1]).squeeze()
    rays = BoxRays(
            origins=origins.squeeze(),
            directions=directions.squeeze(),
            viewdirs=viewdirs.squeeze(),
            radii=radii.squeeze(),
            lossmult=ones.squeeze(),
            near=ones * dataset.near,
            far=ones * dataset.far,)
    return rays

In [None]:
def convert_to_ndc(origins, directions, focal, w, h, near=1.):
    """Convert a set of rays to NDC coordinates."""
    # Shift ray origins to near plane
    t = -(near + origins[..., 2]) / directions[..., 2]
    origins = origins + t[..., None] * directions

    dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
    ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))

    # Projection
    o0 = -((2 * focal) / w) * (ox / oz)
    o1 = -((2 * focal) / h) * (oy / oz)
    o2 = 1 + 2 * near / oz

    d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
    d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
    d2 = -2 * near / oz

    origins = np.stack([o0, o1, o2], -1)
    directions = np.stack([d0, d1, d2], -1)
    return origins, directions

In [None]:
def generate_ndc_rays(camtoworlds, dataset):
    """Generate normalized device coordinate rays for llff."""
    rays = generate_rays_original(camtoworlds, dataset)
    ndc_origins, ndc_directions = convert_to_ndc(rays.origins,
                                                 rays.directions,
                                                 dataset.focal[5], dataset.w[5], dataset.h[5])

    mat = ndc_origins
    # Distance from each unit-norm direction vector to its x-axis neighbor.
    dx = np.sqrt(np.sum((mat[:, :-1, :, :] - mat[:, 1:, :, :]) ** 2, -1))
    dx = np.concatenate([dx, dx[:, -2:-1, :]], 1)

    dy = np.sqrt(np.sum((mat[:, :, :-1, :] - mat[:, :, 1:, :]) ** 2, -1))
    dy = np.concatenate([dy, dy[:, :, -2:-1]], 2)
    # Cut the distance in half, and then round it out so that it's
    # halfway between inscribed by / circumscribed about the pixel.
    radii = (0.5 * (dx + dy))[..., None] * 2 / np.sqrt(12)
    
    ones = np.ones_like(ndc_origins[..., :1].squeeze())
    ndcrays = Rays(
                origins=ndc_origins.squeeze(),
                directions=ndc_directions.squeeze(),
                viewdirs=rays.directions.squeeze(),
                radii=radii.squeeze(),
                lossmult=ones,
                near=ones * dataset.near,
                far=ones * dataset.far)
    
    return ndcrays

In [None]:
def _recenter_poses(poses):
    """Recenter poses according to the original NeRF code."""
    poses_ = poses.copy()
    bottom = np.reshape([0, 0, 0, 1.], [1, 4])
    c2w = _poses_avg(poses)
    c2w = np.concatenate([c2w[:3, :4], bottom], -2)
    bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
    poses = np.concatenate([poses[:, :3, :4], bottom], -2)
    poses = np.linalg.inv(c2w) @ poses
    poses_[:, :3, :4] = poses[:, :3, :4]
    poses = poses_
    return poses, c2w

def _poses_avg(poses):
    """Average poses according to the original NeRF code."""
    hwf = poses[0, :3, -1:]
    center = poses[:, :3, 3].mean(0)
    vec2 = _normalize(poses[:, :3, 2].sum(0))
    up = poses[:, :3, 1].sum(0)
    c2w = np.concatenate([_viewmatrix(vec2, up, center), hwf], 1)
    return c2w
    
def _normalize(x):
    """Normalization helper function."""
    return x / np.linalg.norm(x)
    
def _viewmatrix(z, up, pos):
    """Construct lookat view matrix."""
    vec2 = _normalize(z)
    vec1_avg = up
    vec0 = _normalize(np.cross(vec1_avg, vec2))
    vec1 = _normalize(np.cross(vec2, vec0))
    m = np.stack([vec0, vec1, vec2, pos], 1)
    return m

poses_arr = np.load('/home/tristram/data/waymo/seg1_5_center/poses_bounds.npy')
#poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
#bds = poses_arr[:, -2:].transpose([1, 0])
poses = poses_arr[:, :-4].reshape([-1, 3, 5]).transpose([1, 2, 0])
bds = poses_arr[:, -4:-2].transpose([1, 0])
princip_point = poses_arr[:, -2:]

# Update poses according to downsampling.
poses[:2, 4, :] = np.array([320,480]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1. / 4.0

poses = np.moveaxis(poses, -1, 0).astype(np.float32)
poses, c2w = _recenter_poses(poses)
poses[:, :3, 3] /= 5.0


box_dict = np.load('/home/tristram/data/waymo/seg1_5_center/3D_boxes.npy', allow_pickle=True).item()
box_pose = []
box_ext = []
for key in box_dict:
    if 'center' in key:
        box_pose.append(box_dict[key])
    elif 'ext' in key:
        box_ext.append(box_dict[key])
box_pose = np.array(box_pose)
box_ext = np.array(box_ext)

box_pose = np.linalg.inv(c2w) @ box_pose
box_pose[:, :3, 3] /= 5.0 #* 2.0
box_ext /= 5.0 * 2.0

yaw = R.from_matrix(np.linalg.inv(box_pose[:, :3, :3]))  # take inverse of rotation matrix to go from world to object
yaw = np.array(yaw.as_rotvec())   

obbpose = np.concatenate([box_pose[:, :3, 3], yaw], axis=-1)
rel_pose = {}
cars = []
bpose = [k for k in box_dict if 'center' in k]
for i, key in enumerate(bpose):
    if '1_' in key and 'center' in key:
        can_pose = box_pose[i]
        ts, car, _ = key.split('_')
        cars.append(int(car))
        rel_pose[ts + '_' + car + '_rel'] = np.eye(4)
        box_dict[key] = obbpose[i]
        box_dict[ts + '_' + car + '_ext'] = box_ext[i]
    else:
        ts, car, _ = key.split('_')
        cars.append(int(car))
        rel_pose[ts+'_'+car+'_rel'] = np.matmul(can_pose, np.linalg.inv(box_pose[i]))
        box_dict[key] = obbpose[i]
        box_dict[ts + '_' + car + '_ext'] = box_ext[i]
        
cars = np.unique(np.array(cars))

In [None]:
BoxRays = collections.namedtuple(
    'BoxRays',
    ('origins', 'directions',
     'viewdirs', 'radii', 'lossmult',
     'near', 'far'))

config = Config()

dataset = obbpose_dataset.get_dataset('render', '/home/tristram/data/waymo/seg1_5_center', config)

#example_batch = dataset.peek()
#example_batch['box'] = example_batch['box'][None, ...]

model, init_variables = obbpose_model.construct_mipnerf(
      random.PRNGKey(20200823), dataset.peek())
optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
state = utils.TrainState(optimizer=optimizer)
del optimizer, init_variables

# Because this is only used for test set rendering, we disable randomization.
def render_eval_fn(variables, _, batch):
    return jax.lax.all_gather(
        model.apply(
                variables,
                random.PRNGKey(0),  # Unused.
                batch['rays'],
                batch['init'],
                batch['ext'],
                batch['ts'],
                randomized=False,
                white_bkgd=config.white_bkgd,
                rand_bkgd=False,
                alpha=batch['alpha']),
            axis_name='batch')

render_eval_pfn = jax.pmap(
    render_eval_fn,
    in_axes=(None, None, 0),  # Only distribute the data input.
    donate_argnums=(2,),
    axis_name='batch',
)
exp_name = 'Waymo_dend_cntr02_0_40_8_256_10fipe_noopt_noboxcntr'
state = checkpoints.restore_checkpoint('/home/tristram/nerf_results/'+exp_name+'/seg1_5_center', 
                                      state)

In [None]:
    #bottom = np.reshape([0, 0, 0, 1.], [1, 4])
    #pose = poses[5]
    #pose = np.concatenate([pose[:3, :4], bottom], -2)
    #c2w = render_poses[0]
    traj = np.load('/home/tristram/exp_results/giftest/seg1_5_traj_opt.npz', allow_pickle=True)['arr_0']
    save_dir = '/home/tristram/Pictures/giftemp4/'
    
    images = []
    for i, elem in enumerate(traj):
        if i <= len(traj):
            c2w = elem[0]
            ts = elem[1]
            rays = generate_rays_waymo(c2w, dataset, img=2, pp=princip_point[2])
            #rays = generate_rays_original(c2w, dataset, img=20)

            init = []
            for j in range(int(poses.shape[0]/5)):
                init.append(np.array(
                    [np.concatenate(box_dict[str(j + 1) + '_' + str(c) + '_center'][:, None], axis=0) for c in
                    cars]).reshape(-1, 6))
            init = np.array(init).reshape(int(poses.shape[0]/5), -1, 6)
            box = np.array([np.concatenate(box_dict[str(1) + '_' + str(c) + '_center'][:, None], axis=0) for c in
                     cars]).reshape(-1, 6)
            ext = np.array([np.concatenate(box_dict[str(1) + '_' + str(c) + '_ext'][..., None], axis=0) for c in
                     cars]).reshape(-1, 3)
            rel = np.array([np.concatenate(rel_pose[str(1) + '_' + str(c) + '_rel'], axis=0) for c in
                     cars]).reshape(-1, 4, 4)

            pred_color, pred_distance, pred_acc = obbpose_model.render_image(
                    functools.partial(render_eval_pfn, state.optimizer.target),
                    rays,
                    init,
                    ext,
                    ts,
                    None,
                    alpha=10,
                    chunk=2048)
            print(i+1, '/'+str(len(traj)))
            
            plt.figure(2, figsize=(20,6))
            plt.imshow(pred_color)
            plt.show()
    
            img_d = visualize_depth(pred_distance)
            plt.figure(2, figsize=(20,6))
            plt.imshow(img_d)
            plt.show()
    
            imgsave = np.clip(pred_color, 0, 1)
            plt.imsave(save_dir+exp_name+'_rgb_ts_'+str(i)+'.png', imgsave)
            img_d.save(save_dir+exp_name+'_depth_ts_'+str(i)+'.png')
            
            images.append(pred_color)
        
        #plt.figure(2, figsize=(20,6))
        #plt.imshow(pred_color)
        #plt.show()

        #img_d = visualize_depth(pred_distance)
        #plt.figure(2, figsize=(20,6))
        #plt.imshow(img_d)
        #plt.show()"""
    import imageio
    imageio.mimsave('/home/tristram/exp_results/giftest/test_carla.gif', images, fps=10)

In [None]:
%matplotlib inline
from ipywidgets import interactive, widgets
import matplotlib.pyplot as plt
import numpy as np

save_dir = '/home/tristram/Pictures/eval_traj/'

traj = []
def f(x, y, z, rx, ry, rz, ts):
    
    c2w = np.array([
        [1,0,0,x],
        [0,1,0,y],
        [0,0,1,z],
        [0,0,0,1]
    ])
    
    R_X = np.array([
      [ 1.0,  0.0, 0.0, 0.0],
      [ 0.0,  np.cos(rx*3.1415/180), -np.sin(rx*3.1415/180), 0.0],
      [ 0.0,  np.sin(rx*3.1415/180), np.cos(rx*3.1415/180), 0.0],
      [ 0.0,  0.0, 0.0, 1.0]])
    
    R_Y = np.array([
      [ np.cos(ry*3.1415/180),  0.0, np.sin(ry*3.1415/180), 0.0],
      [ 0.0,  1.0, 0.0, 0.0],
      [ -np.sin(ry*3.1415/180),  0.0, np.cos(ry*3.1415/180), 0.0],
      [ 0.0,  0.0, 0.0, 1.0]])
    
    R_Z = np.array([
      [ np.cos(rz*3.1415/180),  -np.sin(rz*3.1415/180), 0.0, 0.0],
      [ np.sin(rz*3.1415/180),  np.cos(rz*3.1415/180), 0.0, 0.0],
      [ 0.0,  0.0, 1.0, 0.0],
      [ 0.0,  0.0, 0.0, 1.0]])
    """
    ot = np.array([
        [1,0,0,ox],
        [0,1,0,oy],
        [0,0,1,oz],
        [0,0,0,1]
    ])
    
    OR_X = np.array([
      [ 1.0,  0.0, 0.0, 0.0],
      [ 0.0,  np.cos(orx*3.1415/180), -np.sin(orx*3.1415/180), 0.0],
      [ 0.0,  np.sin(orx*3.1415/180), np.cos(orx*3.1415/180), 0.0],
      [ 0.0,  0.0, 0.0, 1.0]])
    
    OR_Y = np.array([
      [ np.cos(ory*3.1415/180),  0.0, np.sin(ory*3.1415/180), 0.0],
      [ 0.0,  1.0, 0.0, 0.0],
      [ -np.sin(ory*3.1415/180),  0.0, np.cos(ory*3.1415/180), 0.0],
      [ 0.0,  0.0, 0.0, 1.0]])
    
    OR_Z = np.array([
      [ np.cos(orz*3.1415/180),  -np.sin(orz*3.1415/180), 0.0, 0.0],
      [ np.sin(orz*3.1415/180),  np.cos(orz*3.1415/180), 0.0, 0.0],
      [ 0.0,  0.0, 1.0, 0.0],
      [ 0.0,  0.0, 0.0, 1.0]])"""
    
    img = 20
    bottom = np.reshape([0, 0, 0, 1.], [1, 4])
    pose = poses[img,:3,:4]
    print(pose)
    pose = np.concatenate([pose[:3, :4], bottom], -2)
    c2w = (R_X@R_Y@R_Z@c2w@pose)[:3,:]
    
    #rays = generate_rays_waymo(c2w, dataset, img=img, pp=princip_point[img])
    rays = generate_rays_original(c2w, dataset, img=img)
    
    init = []
    for i in range(int(poses.shape[0]/5)):
        init.append(np.array(
            [np.concatenate(box_dict[str(i + 1) + '_' + str(c) + '_center'][:, None], axis=0) for c in
                cars]).reshape(-1, 6))
    init = np.array(init).reshape(int(poses.shape[0]/5), -1, 6)
    box = np.array([np.concatenate(box_dict[str(1) + '_' + str(c) + '_center'][:, None], axis=0) for c in
                 cars]).reshape(-1, 6)
    ext = np.array([np.concatenate(box_dict[str(1) + '_' + str(c) + '_ext'][..., None], axis=0) for c in
                 cars]).reshape(-1, 3)
    rel = np.array([np.concatenate(rel_pose[str(1) + '_' + str(c) + '_rel'], axis=0) for c in
                 cars]).reshape(-1, 4, 4)
    #box_mvmt = OR_X@OR_Y@OR_Z@ot
    #box[obj] = box_mvmt@box[obj]
    #rel[obj] = np.linalg.inv(box_mvmt)@rel[obj]
    
    save = [c2w, ts]
    traj.append(save)
    """
    pred_color, pred_distance, pred_acc = obbpose_model.render_image(
                functools.partial(render_eval_pfn, state.optimizer.target),
                rays,
                init,
                ext,
                ts,
                None,
                alpha=10,
                chunk=2048)
    
    plt.figure(2, figsize=(20,6))
    plt.imshow(pred_color)
    plt.show()
    
    img_d = visualize_depth(pred_distance)
    plt.figure(2, figsize=(20,6))
    plt.imshow(img_d)
    plt.show()
    
    imgsave = np.clip(pred_color, 0, 1)
    plt.imsave(save_dir+exp_name+'_rgb_ts_'+str(ts)+'.png', imgsave)
    img_d.save(save_dir+exp_name+'_depth_ts_'+str(ts)+'.png')"""
    

sldr = lambda : widgets.FloatSlider(
    value=0.,
    min=-1.,
    max=1.,
    step=.1,
)

names = ['x', 'y', 'z', 'rx', 'ry', 'rz', 'ts']
    
interactive_plot = interactive(f, 
                               x=(-1.0,1.0, 0.025),
                               y=(-1.0,1.0, 0.025), 
                               z=(-1.0,1.0, 0.025), 
                               rx=(-20.0,20.0, 1), 
                               ry=(-90.0,90.0, 2),
                               rz=(-20.0,20.0, 1),
                               ts=(-4, 4, 1))
interactive_plot

In [None]:
print(len(traj))
np.savez('/home/tristram/exp_results/giftest/carla_traj.npz', traj)

In [None]:
Rays = collections.namedtuple('Rays',
('origins', 'directions', 'viewdirs', 'radii', 'lossmult', 'near', 'far'))

ckpt_path = '/home/tristram/nerf_results/KITTI10_ds01_combinedD/nerf10/'
eval_folder = '/home/tristram/exp_results/KITTI10_ds01_combinedD/'
ref_folder = '/home/tristram/data/nerf10'
imgs = render_traj(ckpt_path, eval_folder, ref_folder)

In [None]:
def render_traj(ckpt_path, eval_folder, ref_folder):
    
    config = Config()

    dataset = datasets.get_dataset('test', ref_folder, config)

    model, init_variables = models.construct_mipnerf(
          random.PRNGKey(20200823), dataset.peek())
    optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, init_variables
    
    state = checkpoints.restore_checkpoint(ckpt_path, state)
    
    def render_eval_fn(variables, _, rays):
        return jax.lax.all_gather(
            model.apply(
                variables,
                random.PRNGKey(0),  # Unused.
                rays,
                randomized=False,
                white_bkgd=config.white_bkgd),
            axis_name='batch')

    # pmap over only the data input.
    render_eval_pfn = jax.pmap(
          render_eval_fn,
          in_axes=(None, None, 0),
          donate_argnums=2,
          axis_name='batch',
      )
    
    if 'eval_images' not in os.listdir(eval_folder):
        os.mkdir(eval_folder+'eval_images')
    
    eval_img = []
    eval_traj = np.load(eval_folder + 'eval_traj.npy')
    i = 0
    for elem in eval_traj:
        rays = generate_ndc_rays(elem, dataset)

        pred_color, pred_distance, pred_acc = models.render_image(
              functools.partial(render_eval_pfn, state.optimizer.target),
              rays,
              None,
              chunk=8192)
        print(str(i)+'/'+str(eval_traj.shape[0]))
        plt.imsave(eval_folder + 'eval_images/' + str(i) + '.png', np.clip(pred_color, 0, 1))
        i += 1
    return eval_img

In [None]:
bottom = np.reshape([0, 0, 0, 1.], [1, 4])
pose = poses[0,:3,:4]
print(pose)
pose = np.concatenate([pose[:3, :4], bottom], -2)

rays = generate_rays_original(pose, dataset)
    
pred_color, pred_distance, pred_acc = models.render_image(
          functools.partial(render_eval_pfn, state.optimizer.target),
          rays,
          None,
          chunk=4096)
    
gt_depth = np.load('/home/tristram/data/waymo/seg1_5/depth_images.npz', allow_pickle=True)['arr_0']
print(gt_depth[4].shape)    
    
# save colored pointcloud of predicted depth
h_space = np.linspace(0, 320, num=320, dtype=np.int32)
w_space = np.linspace(0, 480, num=480, dtype=np.int32)

img_d = np.meshgrid(w_space, h_space)
pts = np.zeros((320, 480, 3))
pts[:,:,0] = img_d[0]
pts[:,:,1] = img_d[1]
#depth = 1 / (1 - pred_distance)
#depth = gt_depth[0]
depth = pred_distance
#depth[depth==0.0] = -1000.0
print(depth.shape)
pts[:,:,2] = depth.reshape(320,480)
pts = pts.reshape(-1,3)

pts_color = pred_color.reshape(-1, 3) # (H, W, 3) --> (H*W, 3)

test = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pts))
#test.colors = o3d.utility.Vector3dVector(pts_color)

o3d.io.write_point_cloud("/home/tristram/nerf_results/WAYMO1_5_dsen_3_200_lin_8x256_8x256_256s_norand_bias01_cone/gt_depth0.ply", test)

In [None]:
import jax.numpy as jnp

space = jnp.logspace(-0.633333333, 0., 128 + 1, endpoint=False)
print(space)

In [None]:
import matplotlib.pyplot as plt
print(pred_color.shape)
print(batch['pixels'].shape)
plt.imshow(batch['pixels'])
plt.show

In [None]:
plt.imshow(pred_color)
plt.show

In [None]:
for i in range(len(traj)):
    rays_o, rays_d = get_rays(directions, traj[i])
    near, far = 0, 1
    rays_o, rays_d = get_ndc_rays(int(h), int(w), focal, 1.0, rays_o, rays_d)
    
    rays = torch.cat([rays_o, rays_d,
                       near*torch.ones_like(rays_o[:, :1]),
                       far*torch.ones_like(rays_o[:, :1])],1) # (h*w, 8)
    
    results = inference(rays.to('cuda'))
    img_pred = np.clip(results['rgb_fine'].view(h, w, 3).cpu().numpy(), 0, 1)
    print(i)
    plt.imsave('/home/tristram/Pictures/nerf_img/'+str(i)+'.png', img_pred)
    #plt.figure(2, figsize=(20,6))
    #plt.imshow(img_pred)
    #plt.show()
    #cv2.imwrite('/home/tristram/Pictures/nerf_img/'+str(i)+'.png', (img_pred*255).astype(np.int32))