In [5]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='7'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

device = 'cuda:0'

In [6]:
import trimesh
import jax
from jax import jit, vmap, numpy as np
import os
import numpy as onp

In [7]:
def as_mesh(scene_or_mesh):
    """
    Convert a possible scene to a mesh.

    If conversion occurs, the returned mesh has only vertex and face data.
    """
    if isinstance(scene_or_mesh, trimesh.Scene):
        if len(scene_or_mesh.geometry) == 0:
            mesh = None  # empty scene
        else:
            # we lose texture information here
            mesh = trimesh.util.concatenate(
                tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                    for g in scene_or_mesh.geometry.values()))
    else:
        assert(isinstance(scene_or_mesh, trimesh.Trimesh))
        mesh = scene_or_mesh
    return mesh


def recenter_mesh(mesh):
    mesh.vertices -= mesh.vertices.mean(0)
    mesh.vertices /= np.max(np.abs(mesh.vertices))
    mesh.vertices = .5 * (mesh.vertices + 1.)

@jit
def make_normals(rays, depth_map):
  rays_o, rays_d = rays
  pts = rays_o + rays_d * depth_map[...,None]
  dx = pts - np.roll(pts, -1, axis=0)
  dy = pts - np.roll(pts, -1, axis=1)
  normal_map = np.cross(dx, dy)
  normal_map = normal_map / np.maximum(np.linalg.norm(normal_map, axis=-1, keepdims=True), 1e-5)
  return normal_map


def render_mesh_normals(mesh, rays):
  origins, dirs = rays.reshape([2,-1,3])
  origins = origins * .5 + .5
  dirs = dirs * .5
  z = mesh.ray.intersects_first(origins, dirs)
  pic = onp.zeros([origins.shape[0],3]) 
  pic[z!=-1] = mesh.face_normals[z[z!=-1]]
  pic = np.reshape(pic, rays.shape[1:])
  return pic

def uniform_bary(u):

    su0 = np.sqrt(u[..., 0])
    b0 = 1. - su0
    b1 = u[..., 1] * su0
    return np.stack([b0, b1, 1. - b0 - b1], -1)

def get_normal_batch(mesh, bsize):

    batch_face_inds = np.array(onp.random.randint(0, mesh.faces.shape[0], [bsize]))
    batch_barys = np.array(uniform_bary(onp.random.uniform(size=[bsize, 2])))
    batch_faces = mesh.faces[batch_face_inds]
    batch_normals = mesh.face_normals[batch_face_inds]
    batch_pts = np.sum(mesh.vertices[batch_faces] * batch_barys[...,None], 1)

    return batch_pts, batch_normals

def make_test_pts(mesh, corners, test_size=2**18):
  c0, c1 = corners
  test_easy = onp.random.uniform(size=[test_size, 3]) * (c1-c0) + c0
  batch_pts, batch_normals = get_normal_batch(mesh, test_size)
  test_hard = batch_pts + onp.random.normal(size=[test_size,3]) * .01
  return test_easy, test_hard


def load_mesh(mesh_name, logdir, verbose=True):

    mesh = trimesh.load(mesh_name)
    mesh = as_mesh(mesh)
    if verbose: 
        print(mesh.vertices.shape)
    recenter_mesh(mesh)

    c0, c1 = mesh.vertices.min(0) - 1e-3, mesh.vertices.max(0) + 1e-3
    corners = [c0, c1]
    if verbose:
        print(c0, c1)
        print(c1-c0)
        print(np.prod(c1-c0))
        print(.5 * (c0+c1) * 2 - 1)


    test_pt_file = os.path.join(logdir, mesh_name + '_test_pts.npy')
    if not os.path.exists(test_pt_file):
        if verbose: print('regen pts')
        test_pts = np.array([make_test_pts(mesh, corners), make_test_pts(mesh, corners)])
        np.save(test_pt_file, test_pts)
    else:
        if verbose: print('load pts')
        test_pts = np.load(test_pt_file)

    if verbose: print(test_pts.shape)

    return mesh, corners, test_pts



trans_t = lambda t : np.array([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
], dtype=np.float32)

rot_phi = lambda phi : np.array([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1],
], dtype=np.float32)

rot_theta = lambda th : np.array([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1],
], dtype=np.float32)


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    # c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w
    return c2w



def get_rays(H, W, focal, c2w):
    i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
    dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = np.broadcast_to(c2w[:3,-1], rays_d.shape)
    return np.stack([rays_o, rays_d], 0)

get_rays = jit(get_rays, static_argnums=(0, 1, 2,))



In [8]:
R = 2.
c2w = pose_spherical(90. + 10 + 45, -30., R)

N_samples = 64
N_samples_2 = 64
H = 180
W = H
focal = H * .9
rays = get_rays(H, W, focal, c2w[:3,:4])

# low res rendering specification

#ray definition
# rays, corners, near, far, N_samples, N_samples_2, clip
render_args_lr = [get_rays(H, W, focal, c2w[:3,:4]), None, R-1, R+1, N_samples, N_samples_2, True]
  
N_samples = 256
N_samples_2 = 256
H = 512
W = H
focal = H * .9
rays = get_rays(H, W, focal, c2w[:3,:4])

# high res rendering specification

render_args_hr = [get_rays(H, W, focal, c2w[:3,:4]), None, R-1, R+1, N_samples, N_samples_2, True]

In [None]:
def render_rays_native_hier(params, ab, rays, corners, near, far, N_samples, N_samples_2, clip):
    rays_o, rays_d = rays[0], rays[1]
    c0, c1 = corners

    th = .5
    
    # Compute 3D query points
    z_vals = np.linspace(near, far, N_samples) 
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    # Run network
    alpha = jax.nn.sigmoid(np.squeeze(apply_fn(params, input_encoder(.5 * (pts + 1), *ab))))
    if clip:
      mask = np.logical_or(np.any(.5 * (pts + 1) < c0, -1), np.any(.5 * (pts + 1) > c1, -1))
      alpha = np.where(mask, 0., alpha)

    alpha = np.where(alpha > th, 1., 0)

    trans = 1.-alpha + 1e-10
    trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1)  
    weights = alpha * np.cumprod(trans, -1)
    
    depth_map = np.sum(weights * z_vals, -1) 
    acc_map = np.sum(weights, -1)

    # Second pass to refine isosurface

    z_vals = np.linspace(-1., 1., N_samples_2) * .01 + depth_map[...,None]
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

    # Run network
    alpha = jax.nn.sigmoid(np.squeeze(apply_fn(params, input_encoder(.5 * (pts + 1), *ab))))
    if clip:
      # alpha = np.where(np.any(np.abs(pts) > 1, -1), 0., alpha)
      mask = np.logical_or(np.any(.5 * (pts + 1) < c0, -1), np.any(.5 * (pts + 1) > c1, -1))
      alpha = np.where(mask, 0., alpha)

    alpha = np.where(alpha > th, 1., 0)

    trans = 1.-alpha + 1e-10
    trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1)  
    weights = alpha * np.cumprod(trans, -1)
    
    depth_map = np.sum(weights * z_vals, -1) 
    acc_map = np.sum(weights, -1)

    return depth_map, acc_map

render_rays = jit(render_rays_native_hier, static_argnums=(3,4,5,6,7,8))

gt_fn = lambda queries, mesh : mesh.ray.contains_points(queries.reshape([-1,3])).reshape(queries.shape[:-1])