In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from jax.experimental.optimizers import adam, clip_grads
import haiku as hk
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(font_scale=0.5)

from sdrf import (
    Siren, IGR
)
from util import get_ray_bundle, look_at

In [None]:
def create_sphere(pt, origin=jnp.array([0.0, 0.0, 0.0]), radius=0.5):
    return jnp.linalg.norm(pt - origin, ord=2) - radius

def get_on_surface_points(num_pts, radius=0.5):
    phi = jnp.linspace(0, jnp.pi, num_pts)
    theta = jnp.linspace(0, 2 * jnp.pi, num_pts)
    return jnp.stack((radius * jnp.sin(phi) * jnp.cos(theta), 
                      radius * jnp.sin(phi) * jnp.sin(theta),
                      radius * jnp.cos(phi)
                     ), axis=-1)

In [None]:
%matplotlib inline

num_epochs = 1000
validation_skip = 100
batch_size = 2 ** 14
#batch_size = 2 ** 5
#batch_size = 1

rng = jax.random.PRNGKey(42)
model_fn = hk.transform(lambda x: Siren(3, 1, 4, 256, True)(x))
#model_fn = hk.transform(lambda x: IGR([32, 32, 32, 32, 32, 32])(x))
params = model_fn.init(rng, jnp.ones([3,]))
model_fn = hk.without_apply_rng(model_fn)

optimizer = init_adam, update, get_params = adam(lambda _: 1e-4)

optimizer_state = init_adam((params))

model_jit_fn = jit(lambda params, pts: vmap(lambda pt: model_fn.apply(params, pt)[0])(pts))
grad_model_fn = grad(lambda params, pt: model_fn.apply(params, pt)[0], argnums=(1,))

def compute_loss(params, pts):    
    def loss_fn(pt):
        dist = create_sphere(pt)
        model_output = model_fn.apply(params, pt)
        
        grad_output = grad_model_fn(params, pt)
        
        reconstruction_loss = ((model_output - dist) ** 2).sum()
        eikonal_loss = (1.0 - jnp.linalg.norm(grad_output)) ** 2.0
        inter_loss = jnp.exp(-1e2 * jnp.abs(model_output)).sum()
        
        return jnp.array([reconstruction_loss, eikonal_loss, inter_loss])

    losses = jnp.mean(vmap(loss_fn)(pts), axis=0)
    return losses[0] * 3e3 + losses[1] * 5e1 + losses[2] * 1e2, losses

value_fn = jit(compute_loss)
value_and_grad_loss_fn = jit(grad(compute_loss, argnums=(0,), has_aux=True))
jnp.set_printoptions(precision=4, suppress=True)

def generate_samples(rng, num_pts=batch_size):
    off_surface_pts = jax.random.uniform(rng, (batch_size // 2, 3), minval=-1.0, maxval=1.0)
    on_surface_pts = get_on_surface_points(batch_size // 2)
    pts = jnp.concatenate((on_surface_pts, off_surface_pts), axis=0)
    return pts

for epoch in range(num_epochs):    
    params = get_params(optimizer_state)
    
    rng, subrng = jax.random.split(rng)
    pts = generate_samples(subrng)
    
    #with jax.disable_jit():
    gradient, losses = value_and_grad_loss_fn(params, pts)
    losses = tuple(np.array(loss) for loss in losses)
    gradient = clip_grads(gradient, 1.0)
    print(f"epoch {epoch}: loss {losses}")
    
    if epoch % validation_skip == 0:
        pts = jnp.stack(jnp.meshgrid(jnp.linspace(-1.0, 1.0, 8), 
                                     jnp.linspace(-1.0, 1.0, 8), 
                                     jnp.linspace(-1.0, 1.0, 8)),
                        axis=-1)
        grid = pts[:, :, 4, :].reshape(8, 8, 3)
        dists = model_jit_fn(params, grid.reshape(-1, 3)).reshape(8, 8)
        #dists = vmap(create_sphere)(grid.reshape(-1, 3)).reshape(8, 8)
        
        heat_fig, heat_ax = plt.subplots()
        heat_ax = sns.heatmap(
            np.array(dists),
            annot=True,
            fmt=".1f",
            vmin=-0.8,
            vmax=0.8,
            center=0,
            cmap="RdBu_r",
        )
        heat_ax.set_aspect('equal')
        
        contour_fig, contour_ax = plt.subplots()
        contour_ax.set_aspect('equal')
        cs = contour_ax.contour(grid[:, :, 0], grid[:, :, 1], dists)
        plt.show()
    
    optimizer_state = update(epoch, gradient[0], optimizer_state)

In [None]:
from jax import jit, vmap, grad

from sdrf import (
    additive_integrator,
    sphere_trace_naive,
    sphere_trace,
    render,
    render_img,
    gaussian_pdf,
    ExponentialSampler,
    GaussianSampler,
    LinearSampler,
    StratifiedSampler,
)
from util import get_ray_bundle, look_at

params = get_params(optimizer_state)
#params = (jnp.array([0.0, 0.0, 0.0]), jnp.array([0.5]))

view_matrix = jnp.array(
    np.linalg.inv(
        np.array(
            look_at(
                jnp.array([-2.0, 0.0, 0.0]),
                jnp.array([0.0, 0.0, 0.0]),
                jnp.array([0.0, 1.0, 0.0]),
            )
        )
    )
)

# height, width, chunk_size = 32, 32, 8
height, width, chunk_size = 256, 256, 32
ro, rd = get_ray_bundle(height, width, 300.0, view_matrix)

rng = jax.random.PRNGKey(42)

# sigma used for importance sampling
importance_sigma = 1e-1
phi_sigma = 1e-1

num_samples = 8

geometry = lambda x, params: model_fn.apply(params, x)
#geometry = lambda x, params: create_sphere(x, *params)

# surface is solid white
# appearance = lambda pt, rd: jnp.array([1.0, 1.0, 1.0])

# Some Lambertian lighting
light_pos = jnp.array([-8.0, -4.0, 0.0])
normalize = lambda vec: vec / jnp.linalg.norm(vec, ord=2)
normals = lambda pt: grad(lambda pt, params: geometry(pt, params)[0], argnums=(0,))(
    pt, params
)[0]
distance = lambda pt: jnp.square(jnp.linalg.norm(light_pos - pt, ord=2))
light_dir = lambda pt: normalize(light_pos - pt)
diffuse_power = 30.0
specular_power = 50.0
specular_hardness = 16
diffuse = lambda pt: jnp.broadcast_to(
    jnp.clip(
        jnp.dot(light_dir(pt), normals(pt)) * diffuse_power / distance(pt), 0.0, 1.0
    ),
    (3,),
)
h = lambda pt, rd: normalize(light_dir(pt) - rd)
ndoth = lambda pt, rd: jnp.clip(jnp.dot(normals(pt), h(pt, rd)), 0.0, 1.0)
specular = (
    lambda pt, rd: jnp.power(ndoth(pt, rd), specular_hardness)
    * specular_power
    / distance(pt)
)
appearance = lambda pt, rd: diffuse(pt) + specular(pt, rd)

"""phi = lambda dist: gaussian_pdf(
    jnp.maximum(dist, jnp.zeros_like(dist)), 0.0, phi_sigma
)"""
phi = lambda dist: gaussian_pdf(
    jnp.maximum(dist, jnp.zeros_like(dist)), 0.0, phi_sigma
)

# sampler = GaussianSampler(importance_sigma)
# sampler = ExponentialSampler(importance_sigma)
sampler = LinearSampler(importance_sigma)
# sampler = StratifiedSampler(importance_sigma)
print(sampler.sample(None, num_samples))

render_fn = lambda ro, rd, rng: render(
    sampler,
    geometry,
    appearance,
    ro,
    rd,
    params,
    rng,
    phi,
    num_samples,
    True,
)

# with jax.disable_jit():
(rgb, depth), rng = jit(render_img, static_argnums=(0, 3))(
    render_fn, rng, (ro, rd), chunk_size
)


In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
plt.grid(b=None)
plt.imshow(np.array(rgb))
#plt.imshow(np.array(depth))