In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
%env JAX_DEBUG_NANS=False

from collections import namedtuple

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, vmap, grad, value_and_grad
from jax.experimental.optimizers import adam, clip_grads

import matplotlib.pyplot as plt

from sdrf import (
    render,
    render_img,
    run_one_iter_of_sdrf,
    eikonal_loss,
    manifold_loss,
    SDRF,
    SDRFParams
)
from util import get_ray_bundle, look_at

In [None]:
def analytic_geometry(x, params):
    origin, radius = params
    dumb_norm = lambda a, b: ((a - b) ** 2).sum() ** (1 / 2)
    #return (jnp.linalg.norm(x - origin, ord=2) - radius).sum()
    #return (dumb_norm(x, origin) - radius).sum()
    return (dumb_norm(x, origin) - radius)

In [None]:
def analytic_appearance(pt, rd, params):
    '''origin, radius = params
    light_pos = jnp.array([-8.0, -4.0, 0.0])
    normalize = lambda vec: vec / jnp.linalg.norm(vec, ord=2)
    #normals = lambda pt: grad(lambda pt, params: analytic_geometry(pt, params).sum(), argnums=(0,))(
    #    pt, (origin, radius)
    #)[0]
    
    error = jnp.abs(analytic_geometry(pt, params))
    valid_mask = error < 1e-2
    
    masked_sdf = lambda pt, params: jax.lax.cond(valid_mask.sum(), 
                                                 (pt, params), 
                                                 lambda args: jnp.array(analytic_geometry(*args)).reshape(1),
                                                 jnp.zeros(1,), 
                                                 lambda x: x)
    normals = lambda pt: grad(lambda pt, params: masked_sdf(pt, params).sum(), argnums=(0,))(
        pt, (origin, radius)
    )[0]
    
    distance = lambda pt: jnp.square(jnp.linalg.norm(light_pos - pt, ord=2))
    light_dir = lambda pt: normalize(light_pos - pt)
    diffuse_power = 20.0
    specular_power = 40.0
    specular_hardness = 16
    diffuse = lambda pt: jnp.broadcast_to(
        jnp.clip(
            jnp.dot(light_dir(pt), normals(pt)) * diffuse_power / distance(pt), 0.0, 1.0
        ),
        (3,),
    )
    h = lambda pt, rd: normalize(light_dir(pt) - rd)
    ndoth = lambda pt, rd: jnp.clip(jnp.dot(normals(pt), h(pt, rd)), 0.0, 1.0)
    specular = (
        lambda pt, rd: jnp.power(ndoth(pt, rd), specular_hardness)
        * specular_power
        / distance(pt)
    )
    
    return diffuse(pt) + specular(pt, rd)'''
    return jnp.array([1.0, 1.0, 1.0])

In [None]:
RenderOptions = namedtuple(
    "RenderOptions", 
    ["phi", "num_samples", "truncation_distance", "additive", "chunksize", "debug"]
)
PhiOptions = namedtuple("PhiOptions", ["initial_sigma", "lr_decay", "lr_decay_factor"])
SDRFOptions = namedtuple("SDRFOptions", ["eikonal", "manifold", "sampler", "render"])
AuxLossOptions = namedtuple("AuxLossOptions", ["scale", "num_samples"])
SamplerOptions = namedtuple("SamplerOptions", ["kind", "gaussian"])
#SamplerOptions = namedtuple("SamplerOptions", ["kind", "linear"])
GaussianSamplerOptions = namedtuple("GaussianSamplerOptions", ["sigma"])
LinearSamplerOptions = namedtuple("LinearSamplerOptions", ["support"])

In [None]:
%matplotlib inline
def optimization_experiment(appearance_fn, geometry_fn):
    rng = jax.random.PRNGKey(42)
    
    #origins, radius = jnp.zeros((3,)), jnp.array([2.0])
    origins, radius = jnp.zeros((3,)), 1.0

    sdrf = SDRF(geometry=geometry_fn, appearance=appearance_fn)
    sdrf_params = SDRFParams(geometry=(origins, radius), appearance=(origins, radius))

    init_adam, update, get_params = adam(
        lambda iteration: 1e-3
    )
    optimizer_state = init_adam(sdrf_params)
    
    view_matrix = jnp.array(
        np.linalg.inv(
            np.array(
                look_at(
                    jnp.array([-4.0, 0.0, 0.0]),
                    jnp.array([0.0, 0.0, 0.0]),
                    jnp.array([0.0, 1.0, 0.0]),
                )
            )
        )
    )

    height, width, chunk_size = 256, 256, 2 ** 14
    ro, rd = get_ray_bundle(height, width, 100.0, view_matrix)
    
    phi_options = PhiOptions(initial_sigma=1e-1, lr_decay_factor=0.1, lr_decay=250)
    render_options = RenderOptions(
        phi=phi_options,
        num_samples=8, 
        truncation_distance=1.0,
        additive=False,
        chunksize=16192,
        debug=True
    )
    sampler_options = SamplerOptions(kind="gaussian", gaussian=GaussianSamplerOptions(sigma=1e-2))
    #ampler_options = SamplerOptions(kind="linear", linear=LinearSamplerOptions(support=1e-1))
    options = SDRFOptions(
        eikonal=AuxLossOptions(scale=4.0, num_samples=32768),
        manifold=AuxLossOptions(scale=4.0, num_samples=32768),
        render=render_options,
        sampler=sampler_options
    )
    
    def get_target_img(target_params, subrng):
        target_s, _, _ = run_one_iter_of_sdrf(sdrf, target_params, ro, rd, 1, options, subrng)
        
        return target_s
    
    rng, subrng = jax.random.split(rng)
    target_geometry_params = (origins, jnp.array([2.0]))
    #with jax.disable_jit():
    target_s = get_target_img(SDRFParams(geometry=target_geometry_params, appearance=target_geometry_params), subrng)
    
    #print(target_s.flatten().max())
    
    plt.imshow(np.array(target_s).reshape(height, width, 3))
    plt.show()

    def loss_fn(params, iteration, subrng):
        rgb, depth, debug = run_one_iter_of_sdrf(sdrf, params, ro, rd, iteration, options, subrng[0])
        
        eikonal_samples = (
            jax.random.uniform(subrng[1], (options.eikonal.num_samples, 3))
            * options.eikonal.scale
        )

        manifold_samples = (
            jax.random.uniform(subrng[2], (options.manifold.num_samples, 3))
            * options.manifold.scale
        )
        
        rgb_loss = jnp.mean(((target_s - rgb) ** 2.0).flatten())
        
        e_loss, m_loss = (
            eikonal_loss(sdrf.geometry, eikonal_samples, params.geometry),
            manifold_loss(sdrf.geometry, manifold_samples, params.geometry).sum(),
        )
        
        losses = jnp.array([rgb_loss, e_loss, m_loss])
        loss_weights = jnp.array([3e3, 5e1, 1e2])
        #loss_weights = jnp.array([1.0, 1.0, 1.0])
        #loss_weights = jnp.array([1.0, 1e-9, 1e-9])
        
        return jnp.dot(losses, loss_weights), (losses, rgb, depth, debug)
    
    value_and_grad_fn = jit(value_and_grad(loss_fn, argnums=(0,), has_aux=True))
    
    for i in range(2500):
        rng, *subrng = jax.random.split(rng, 5)
        params = get_params(optimizer_state)
        
        #with jax.disable_jit():
        (loss, (losses, rgb, depth, debug)), (grads,) = value_and_grad_fn(
             params, i, subrng
        )
        grads = clip_grads(grads, 1.0)
        
        optimizer_state = update(i, grads, optimizer_state)
        
        if i % 100 == 0:
            plt.imshow(rgb.reshape(height, width, 3))
            plt.show()
            for j in range(8):
                plt.imshow(debug.reshape(height, width, 8)[:, :, j], cmap='gray')
                plt.show()
        
        print(f"Iteration {i}: Loss: {loss}\n\tRGB loss: {losses[0]}\n\tEikonal loss: {losses[1]}\n\tManifold loss: {losses[2]}")
        print(f"\tRadius: {params.geometry[1]}\n")
    
optimization_experiment(analytic_appearance, analytic_geometry)