In [29]:
import numpy as np 
import meshplot as mp
import mitsuba as mi
import drjit as dr

In [30]:
mi.set_variant("cuda_ad_rgb")

In [31]:
scene_description = {
    'type': 'scene',
    'heightfield': {
        'type': 'heightfield',
        'filename': 'data/depth.bmp',
        'max_height': 1.0,
        'bsdf': {
            'type': 'diffuse',
            'reflectance': {
                'type': 'rgb',
                'value': [0.5, 0.5, 0.5]
            }
        }
    },
    'sensor' :{
        'type': 'perspective',
        'to_world': mi.ScalarTransform4f.look_at(
            #origin=[0, 0, 1], target=[0, 0, 0], up=[0, 1, 0]
            origin=[0, -5, 1], target=[0, 0, 0], up=[0, 0, 1]
            #origin=[0, 0, 1], target=[0, 0, 0], up=[0, 1, 0]
        ),
#         'near_clip': 10,
#         'far_clip':2800,
        'film': {
            'type': 'hdrfilm',
            'width': 500,
            'height': 500,
            'sample_border': True
        },
        'sampler': {
            'type': 'independent',
            'sample_count': 2048,
        }
    },
    'integrator': {
        'type': 'direct_reparam',
        'reparam_rays': 128,
        'reparam_antithetic': True,
        'reparam_kappa': 10**6,
    },
    
        
       'sphere_2': {
        'type': 'sphere',
        'center': [0, -7, 1],
        'radius': 1,
       'emitter': {
            'type': 'area',
            'radiance': {
                'type': 'rgb',
                'value': 1.0,
            }
        }
    }    
  
#     'emitter': {
#         'type': 'constant',
#     }
}

In [32]:
scene = mi.load_dict(scene_description)
params = mi.traverse(scene)

In [33]:
heightfield = {
    'res_x': params['heightfield.res_x'],
    'res_y': params['heightfield.res_y'],
    'heightfield': params['heightfield.heightfield'],
    'to_world': params['heightfield.to_world'],
    'max_height': params['heightfield.max_height'],
    'vertex_normals': params['heightfield.per_vertex_normals']
}

In [34]:
def query_normal(idx, heightfield):
    return dr.gather(mi.Normal3f, heightfield['vertex_normals'], idx)

In [35]:
"""
 Given 3 triangle vertices a, b and c in world-space, and point p_x, this function determines the barycentric
 UV coordinates (u,v,w) of point p_x, w.r.t (a,b,c)
"""
def compute_barycentric(a, b, c, p_x):
    v0 = b - a 
    v1 = c - a
    v2 = p_x - a

    d00 = dr.dot(v0, v0)
    d01 = dr.dot(v0, v1)
    d11 = dr.dot(v1, v1)
    d20 = dr.dot(v2, v0)
    d21 = dr.dot(v2, v1)
    denom = d00 * d11 - d01 * d01

    v = (d11 * d20 - d01 * d21) / denom
    w = (d00 * d21 - d01 * d20) / denom
    u = 1.0 - v - w

    return u,v,w

In [36]:
"""
    Given the UV coordinates of a point and the UV-space vertices of the 
    tile for which we are checking, this function returns the UV-space vertices of the triangle in which the point
    is located.
   0-------3
    | \pos|
    |  \  |    Outcome mapping
    |neg\ |
   1-------2
"""
def determine_containing_triangle(uv, tile_vertices, tile_vertices_world, tile_normals):
    # We can determine this in 2D, only x and y matters 
    # Sign of determinant: positive: right, negative: left 
    side = dr.sign((tile_vertices[2][0] - tile_vertices[0][0]) * (uv[1] - tile_vertices[0][1]) - (tile_vertices[2][1] - tile_vertices[0][1]) * (uv[0] - tile_vertices[0][0]))
    
    # UV-space vertices
    hit_tri_uv_0_0 = dr.select(side < 0, tile_vertices[0][0], tile_vertices[3][0])
    hit_tri_uv_0_1 = dr.select(side < 0, tile_vertices[0][1], tile_vertices[3][1])

    hit_tri_uv_1_0 = dr.select(side < 0, tile_vertices[1][0], tile_vertices[0][0])
    hit_tri_uv_1_1 = dr.select(side < 0, tile_vertices[1][1], tile_vertices[0][1])
    
    hit_tri_uv_2 = tile_vertices[2]
    
    # World-space vertices
    hit_tri_world_0_0 = dr.select(side < 0, tile_vertices_world[0][0], tile_vertices_world[3][0])
    hit_tri_world_0_1 = dr.select(side < 0, tile_vertices_world[0][1], tile_vertices_world[3][1])
    hit_tri_world_0_2 = dr.select(side < 0, tile_vertices_world[0][2], tile_vertices_world[3][2])

    hit_tri_world_1_0 = dr.select(side < 0, tile_vertices_world[1][0], tile_vertices_world[0][0])
    hit_tri_world_1_1 = dr.select(side < 0, tile_vertices_world[1][1], tile_vertices_world[0][1])
    hit_tri_world_1_2 = dr.select(side < 0, tile_vertices_world[1][2], tile_vertices_world[0][2])
    
    hit_tri_world_2 = tile_vertices_world[2]
    
    # Per-vertex normals
    hit_tri_normals_0_0 = dr.select(side < 0, tile_normals[0][0], tile_normals[3][0])
    hit_tri_normals_0_1 = dr.select(side < 0, tile_normals[0][1], tile_normals[3][1])
    hit_tri_normals_0_2 = dr.select(side < 0, tile_normals[0][2], tile_normals[3][2])

    hit_tri_normals_1_0 = dr.select(side < 0, tile_normals[1][0], tile_normals[0][0])
    hit_tri_normals_1_1 = dr.select(side < 0, tile_normals[1][1], tile_normals[0][1])
    hit_tri_normals_1_2 = dr.select(side < 0, tile_normals[1][2], tile_normals[0][2])
    
    hit_tri_normals_2 = tile_normals[2]
    
    
    return (mi.Vector2f(hit_tri_uv_0_0, hit_tri_uv_0_1), mi.Vector2f(hit_tri_uv_1_0, hit_tri_uv_1_1), mi.Vector2f(hit_tri_uv_2)), \
    (mi.Vector3f(hit_tri_world_0_0, hit_tri_world_0_1, hit_tri_world_0_2), mi.Vector3f(hit_tri_world_1_0, hit_tri_world_1_1, hit_tri_world_1_2), mi.Vector3f(hit_tri_world_2)), \
    (mi.Vector3f(hit_tri_normals_0_0, hit_tri_normals_0_1, hit_tri_normals_0_2), mi.Vector3f(hit_tri_normals_1_0, hit_tri_normals_1_1, hit_tri_normals_1_2), mi.Vector3f(hit_tri_normals_2))


In [37]:
"""
 Given a uv point coordinate (e.g. global heightfield UV of an intersection point), this function returns the
 vertices of the heightfield tile in which the point is located in UV-space, the tile vertices in world
 space, as well the tile normals. 
"""
def uv_to_tile(uv, res_x, res_y, heightfield):
    cell_size = (2.0 / (res_x - 1), 2.0 / (res_y - 1))
    cell_size_uv_space = (1.0 / (res_x - 1), 1.0 / (res_y - 1))
    amount_rows = res_x - 1
    amount_tiles_per_row = res_y - 1
    
    tile_x = mi.UInt32(dr.clamp(dr.floor(uv[0] / cell_size_uv_space[0]), 0, amount_tiles_per_row - 1))
    tile_y = mi.UInt32(dr.clamp(dr.floor((1 - uv[1]) / cell_size_uv_space[1]), 0 , amount_rows - 1))
    
    lt_idx = tile_y * res_x + tile_x
    rt_idx = tile_y * res_x + tile_x + 1
    lb_idx = (tile_y + 1) * res_x + tile_x
    rb_idx = (tile_y + 1) * res_x + tile_x + 1
    
    local_min_bounds = mi.Point2f(-1.0 + tile_x * cell_size[0], 1.0 - (tile_y + 1) * cell_size[1])
    local_max_bounds = mi.Point2f(-1.0 + (tile_x + 1) * cell_size[0], 1.0 - tile_y * cell_size[1])
    
    local_min_bounds_uv_space = mi.Point2f(tile_x * cell_size_uv_space[0], 1.0 - (tile_y + 1) * cell_size_uv_space[1])
    local_max_bounds_uv_space = mi.Point2f((tile_x + 1) * cell_size_uv_space[0], 1.0 - tile_y * cell_size_uv_space[1])
    
    v0 = (local_min_bounds_uv_space[0], local_max_bounds_uv_space[1])
    v1 = (local_min_bounds_uv_space[0], local_min_bounds_uv_space[1])
    v2 = (local_max_bounds_uv_space[0], local_min_bounds_uv_space[1])
    v3 = (local_max_bounds_uv_space[0], local_max_bounds_uv_space[1])

    v0_world = heightfield['to_world'].transform_affine(mi.Point3f(local_min_bounds[0], local_max_bounds[1], dr.gather(mi.Float, heightfield['heightfield'].array, lt_idx) * heightfield['max_height']))
    v1_world = heightfield['to_world'].transform_affine(mi.Point3f(local_min_bounds[0], local_min_bounds[1], dr.gather(mi.Float, heightfield['heightfield'].array, lb_idx) * heightfield['max_height']))
    v2_world = heightfield['to_world'].transform_affine(mi.Point3f(local_max_bounds[0], local_min_bounds[1], dr.gather(mi.Float, heightfield['heightfield'].array, rb_idx) * heightfield['max_height']))
    v3_world = heightfield['to_world'].transform_affine(mi.Point3f(local_max_bounds[0], local_max_bounds[1], dr.gather(mi.Float, heightfield['heightfield'].array, rt_idx) * heightfield['max_height']))
  
    n0 = query_normal(lt_idx, heightfield)
    n1 = query_normal(lb_idx, heightfield)
    n2 = query_normal(rb_idx, heightfield)
    n3 = query_normal(rt_idx, heightfield)

    # print(f"uv: {uv} tile_vs: {(v0, v1, v2, v3)}")

    #    - tile vertices (UV) -     ---- tile vertices (world) -----     -- tile normals --
    return (v0, v1, v2, v3),    (v0_world, v1_world, v2_world, v3_world), (n0, n1, n2, n3)

In [38]:
def project_uv_on_heightfield(uv, heightfield):
    tile_vertices, tile_vertices_world, tile_normals = uv_to_tile(uv, heightfield['res_x'], heightfield['res_y'], heightfield)
    tri_vertices, tri_vertices_world, tri_normals = determine_containing_triangle(uv, tile_vertices, tile_vertices_world, tile_normals)
    b0, b1, b2 = compute_barycentric(tri_vertices[0], tri_vertices[1], tri_vertices[2], uv)

    p_x = b0 * tri_vertices_world[0] + b1 * tri_vertices_world[1] + b2 * tri_vertices_world[2]
    n_x = b0 * tri_normals[0] + b1 * tri_normals[1] + b2 * tri_normals[2]
    return p_x, n_x

In [39]:
def normal_component_flipped_sign(n1, n2):
    result = mi.Mask(False)
    result |= ((dr.sign(n1[0]) != dr.sign(n2[0])) | (dr.sign(n1[1]) != dr.sign(n2[1])) | (dr.sign(n1[2]) != dr.sign(n2[2])))
    return result 

In [40]:
def compute_G(T0, d, t, heightfield):
    UV_p = T0 + t * d
    UV_p = dr.clamp(UV_p, mi.Point2f(0.0, 0.0), mi.Point2f(1.0, 1.0))
    P, _ = project_uv_on_heightfield(UV_p, heightfield) 
    return P.z

In [41]:
def compute_F(T0, d, t, view_O, heightfield):
    UV0 = T0 + t * d
    UV0 = dr.clamp(UV0, mi.Point2f(0.0, 0.0), mi.Point2f(1.0, 1.0))
    P, N = project_uv_on_heightfield(UV0, heightfield) 
    V = dr.normalize(view_O - P)
    F = dr.dot(N, V)
    return F

In [42]:
def F_predicate(T0, d, t, view_O, heightfield):
    val = mi.Mask(True)
    normal_view_dot = compute_F(T0, d, t, view_O, heightfield)
    val &= normal_view_dot >= 0
    return val

In [43]:
def G_grad(func, T0, d, t, heightfield):
    dr.enable_grad(t)
    val = func(T0, d, t, heightfield)
    dr.forward_from(t)
    grad = dr.grad(val)
    dr.disable_grad(t)
    return grad

In [44]:
def find_nearest_ridge(T0, d, heightfield, active, max_nr_steps = 5, step_size = 0.002):
    active_ridge_finder = mi.Mask(active)
    found_ridge = active & False
    i = mi.UInt32(0)

    gradient_prev = G_grad(compute_G, T0, d, i*step_size, heightfield)
    loop = mi.Loop("ridge_finder", lambda: (i,found_ridge, gradient_prev, active_ridge_finder))
    while loop(active_ridge_finder):
        uv_curr = T0 + (i * step_size * d)
        gradient_curr = G_grad(compute_G, T0, d, i * step_size, heightfield)

        found_ridge |= active_ridge_finder & (gradient_prev * gradient_curr < 0)
        gradient_prev = gradient_curr
        i += 1

        active_ridge_finder &= (~found_ridge) & (i < max_nr_steps)

    return (i-1) * step_size, found_ridge

In [45]:
# See mitsuba/core/math.h
def bisection(a, b, max_iter, accuracy, pred, T0, d, view_O, heightfield):
    active_bisection = mi.Mask(True)
    midpoint = (a + b) / 2

    i = mi.UInt32(0)
    loop = mi.Loop("bisection", lambda: (i, a, b, active_bisection, midpoint))
    while loop(active_bisection):
        predicate_val = pred(T0, d, midpoint, view_O, heightfield)
        a = dr.select(active_bisection & predicate_val, midpoint, a)
        b = dr.select(active_bisection & predicate_val, b, midpoint)
        midpoint = (a + b) / 2

        i += 1
        active_bisection &= i < max_iter & (a < midpoint) & (midpoint < b)

    found_F_value = compute_F(T0, d, a, view_O, heightfield)
    root_found = dr.abs(found_F_value) < accuracy

    return a, root_found

In [46]:
def get_colormap():
    import matplotlib
    cmap_base = matplotlib.colormaps['Spectral']
    N = 10
    t1 = 0.42
    cmap_new = cmap_base(np.arange(0, t1+(t1-0)/N, (t1-0)/N))
    cmap_new = np.concatenate((cmap_new, [[0.85, 0.85, 0.85, 1]]), axis=0)  # Gray
    t2 = 0.79;  # t2 = 0.85;
    cmap_new = np.concatenate((cmap_new, cmap_base(np.arange(t2, 1+(1-t2)/N, (1-t2)/N)) ), axis=0)  # Gray
    cmap_colors = np.flipud(cmap_new)
    return matplotlib.colors.LinearSegmentedColormap.from_list("custom_spectral", cmap_colors)

In [47]:
def visualize_found_roots(roots, vp=None):
    # Surface
    N_res = 900
    x, y = np.meshgrid(np.linspace(0, 1, N_res), np.linspace(0, 1, N_res))
    x = x.flatten()
    y = y.flatten()
    pts, normals = project_uv_on_heightfield(mi.Point2f(x,y), heightfield)
    
    # Strips: densely sampled points along x and y dimension in the unit square to form a grid
    # We have `N_strip` strips along the x-axis, each strip has y-coord [1/(N_strip+1), 2/(N_strip+1), ..., N_strip/(N_strip+1)]
    # We have `N_strip` strips along the y-axis, each strip has x-coord [1/(N_strip+1), 2/(N_strip+1), ..., N_strip/(N_strip+1)]
    # Each strip has `N_res * dense` points
    N_strip = 30
    dense = 5
    t = np.linspace(0, 1, N_res * dense)
    M = N_strip * N_res * dense
    x = np.zeros((2 * M, ))
    y = np.zeros((2 * M, ))
    x[:M] = np.tile(t, N_strip)
    y[M:] = np.tile(t, N_strip)
    y[:M] = np.repeat(np.linspace(1 / (N_strip + 1), N_strip / (N_strip + 1), N_strip), N_res * dense)
    x[M:] = np.repeat(np.linspace(1 / (N_strip + 1), N_strip / (N_strip + 1), N_strip), N_res * dense)
    pts_strip, normals_strip = project_uv_on_heightfield(mi.Point2f(x,y), heightfield)

    v_box = np.array([[-1, -1, 0], [1, -1, 0], [1, 1, 0], [-1, 1, 0],
                    [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1.]])
    f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], 
                    [7, 4], [0, 4], [1, 5], [2, 6], [7, 3]], dtype=int)

    plt = mp.plot(v_box, f_box, shading={"background": "#606060"})
    plt.add_points(pts.numpy(), shading={"point_size": 0.02, "point_color": '#808080'})
    plt.add_points(pts_strip.numpy(), shading={"point_size": 0.018, "point_color": '#303060'})

    # Plot roots
    if roots is not None:
        roots[2] += 0.015
        plt.add_points(roots.numpy(), shading={"point_size": 0.15, "point_color": 'blue'})

    if vp is not None:
        # Compute normal
        viewdir = dr.normalize(vp - pts)

        if False:  # plot normal
            normals_np = normals.numpy()
            plt.add_lines(pts.numpy(), pts.numpy() + normals_np * 0.1, shading={"line_color": "black", "line_width": 0.1})

        # Plot F function (boundary test)
        normal_viewdir_dot = dr.dot(normals, viewdir).numpy()
        normal_viewdir_dot = (normal_viewdir_dot + 1) / 2
        pts_xy = pts.numpy().copy()
        cmap = get_colormap()
        plt.add_points(pts_xy, shading={"point_size": 0.04,}, c=cmap(normal_viewdir_dot)[:, :3])


        # Plot viewpoint and viewing direction 
        plt.add_points(vp.numpy(), shading={"point_size": 0.2, "point_color": "red"},)
        center = np.array([0.0, 0.0, 0.0])
        to_center = center - vp
        to_center /= np.linalg.norm(to_center)
        v0 = vp + to_center * 0.0
        v1 = vp + to_center * 0.25
        plt.add_lines(v0.numpy(), v1.numpy(), shading={"line_color": "black", "line_width": 4})

    return plt

In [71]:
V0 = mi.Point3f(-2, -2, 0.5)
sensor = mi.load_dict({
    "type": "perspective",
    "to_world": mi.ScalarTransform4f.look_at(
            origin=V0.numpy(), target=[0, 0, 0], up=[0, 0, 1]),
    'sampler': {
            'type': 'independent',
            'sample_count': 2048,
        },
        # 'fovAxis': "smaller",
        'fov': 90.0,
})

In [72]:
sampler = sensor.sampler()
sampler.seed(7, 100000)

In [73]:
ray, _ = sensor.sample_ray(0, sampler.next_1d(), sampler.next_2d(), sampler.next_2d())
its = scene.ray_intersect(ray)

In [74]:
idx = dr.compress(its.is_valid())
its_valid = dr.gather(mi.SurfaceInteraction3f, its, idx)
d_3d = its_valid.p - V0 
d = dr.normalize(mi.Vector2f(d_3d.x, d_3d.y))

# Find roots
t_ridge, found_ridge = find_nearest_ridge(its_valid.uv, d, heightfield, mi.Mask(True), 10, 0.03)
t_silhouette, root_found = bisection(mi.Float(0.0), mi.Float(t_ridge), 30, 1e-3, F_predicate, its_valid.uv, d, V0, heightfield)
found_uv = dr.select(root_found, its_valid.uv + t_silhouette * d, mi.Point2f(0.0, 0.0))
root_p, _ = project_uv_on_heightfield(found_uv, heightfield)


In [75]:
starting_p_count = dr.width(idx)
r_found_count = dr.count(root_found)
print(f"Found {r_found_count} roots, starting from {starting_p_count} points.")

Found 3305 roots, starting from 5440 points.


In [76]:
plt = visualize_found_roots(root_p , V0)
plt.add_points(its_valid.p.numpy(), shading={"point_size": 0.08})

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

7