https://mitsuba.readthedocs.io/en/latest/src/inverse_rendering/caustics_optimization.html

In [1]:
import os
from os.path import realpath, join

import drjit as dr
import mitsuba as mi

import numpy as np

print(mi.variants())

['scalar_rgb', 'scalar_spectral', 'cuda_ad_rgb', 'llvm_ad_rgb']


In [2]:
#make sure you're not using a scalar varient, they only permit scalar types. this was causing me trouble for a bit
#read this for clarification: https://github.com/mitsuba-renderer/mitsuba3/issues/775
mi.set_variant('llvm_ad_rgb')

setup configs:

In [3]:
SCENE_DIR = realpath('../scenes')

CONFIGS = {
    'wave': {
        'emitter': 'gray',
        'reference': join(SCENE_DIR, 'references/wave-1024.jpg'),
    },
    'qureshi': {
        'emitter': 'gray',
        'reference': join(SCENE_DIR, 'references/f_qureshi-512.jpg'),
    },
    'green': {
        'emitter': 'gray',
        'reference': join(SCENE_DIR, 'references/m_green-512.jpg'),
    },
}

config_name = 'green'

config = CONFIGS[config_name]
print('selected:', config['reference'])
mi.Bitmap(config['reference'])

selected: C:\Users\Alex\Desktop\exploring-inverse-rendering\scenes\references/m_green-512.jpg


setup hparams:

In [4]:
config.update({
    'render_resolution': (128, 128),
    'heightmap_resolution': (512, 512),
    'n_upsampling_steps': 4,
    'spp': 32,
    'max_iterations': 1000,
    'learning_rate': 3e-5,
})

output_dir = realpath(join('.', '4_outputs', config_name))
os.makedirs(output_dir, exist_ok=True)
print('results will be saved to:', output_dir)

results will be saved to: C:\Users\Alex\Desktop\exploring-inverse-rendering\tutorials\4_outputs\green


function to generate the lens mesh:

In [5]:
def create_flat_lens_mesh(resolution):
    #generate uv coords (texture coordinates)
    U, V = dr.meshgrid(
        dr.linspace(mi.Float, 0, 1, resolution[0]),
        dr.linspace(mi.Float, 0, 1, resolution[1]),
        indexing = 'ij'
    )
    texcoords = mi.Vector2f(U, V)

    #generate vertex coords
    X = 2.0 * (U - 0.5)
    Y = 2.0 * (V - 0.5)
    vertices = mi.Vector3f(X, Y, 0.0)

    #triangulate faces
    faces_x, faces_y, faces_z = [], [], []
    for i in range(resolution[0]-1):
        for j in range(resolution[1]-1):
            v00 = i * resolution[1] + j
            v01 = v00 + 1
            v10 = (i+1) * resolution[1] + j
            v11 = v10 + 1

            faces_x.extend([v00, v01])
            faces_y.extend([v10, v10])
            faces_z.extend([v01, v11])
    
    #assemble face buffer
    faces = mi.Vector3u(faces_x, faces_y, faces_z)

    #make instance of mesh object
    mesh = mi.Mesh("lens-mesh", resolution[0] * resolution[1], len(faces_x), has_vertex_texcoords=True)

    #set mesh buffers
    mesh_params = mi.traverse(mesh)
    mesh_params['vertex_positions'] = dr.ravel(vertices)
    mesh_params['vertex_texcoords'] = dr.ravel(texcoords)
    mesh_params['faces'] = dr.ravel(faces)
    mesh_params.update()

    return mesh

In [6]:
lens_res = config.get('lens_res', config['heightmap_resolution'])
lens_fname = join(output_dir, 'lens_{}_{}.ply'.format(*lens_res))

if not os.path.isfile(lens_fname):
    m = create_flat_lens_mesh(lens_res)
    m.write_ply(lens_fname)
    print('wrote lens mesh ({}x{}) file to: {}'.format(*lens_res, lens_fname))

make emmiter:

In [7]:
emitter = {
    'type': 'directionalarea',
    'radiance': {
        'type': 'spectrum',
        'value': 0.8
    }
}

make integrator:

In [8]:
integrator = {
    'type': 'ptracer',
    'samples_per_pass': 256,
    'max_depth': 4,
    'hide_emitters': False
}

sensor setup:

In [9]:
sensor_to_world = mi.ScalarTransform4f.look_at(
    target=[0, -20, 0],
    origin = [0, -4.65, 0],
    up = [0, 0, 1]
)
resx, resy = config['render_resolution']
sensor = {
    'type': 'perspective',
    'near_clip': 1,
    'far_clip': 1000,
    'fov': 45,
    'to_world': sensor_to_world,

    'sampler': {
        'type': 'independent',
        'sample_count': 512  # Not really used
    },
    'film': {
        'type': 'hdrfilm',
        'width': resx,
        'height': resy,
        'pixel_format': 'rgb',
        'rfilter': {
            # Important: smooth reconstruction filter with a footprint larger than 1 pixel.
            'type': 'gaussian'
        }
    }
}

assemble scene:

In [10]:
scene = {
    'type': 'scene',
    'sensor': sensor,
    'integrator': integrator,
    # Glass BSDF
    'simple-glass': {
        'type': 'dielectric',
        'id': 'simple-glass-bsdf',
        'ext_ior': 'air',
        'int_ior': 1.5,
        'specular_reflectance': { 'type': 'spectrum', 'value': 0 },
    },
    'white-bsdf': {
        'type': 'diffuse',
        'id': 'white-bsdf',
        'reflectance': { 'type': 'rgb', 'value': (1, 1, 1) },
    },
    'black-bsdf': {
        'type': 'diffuse',
        'id': 'black-bsdf',
        'reflectance': { 'type': 'spectrum', 'value': 0 },
    },
    # Receiving plane
    'receiving-plane': {
        'type': 'obj',
        'id': 'receiving-plane',
        'filename': '../scenes/meshes/rectangle.obj',
        'to_world': \
            mi.ScalarTransform4f.look_at(
                target=[0, 1, 0],
                origin=[0, -7, 0],
                up=[0, 0, 1]
            ).scale((5, 5, 5)),
        'bsdf': {'type': 'ref', 'id': 'white-bsdf'},
    },
    # Glass slab, excluding the 'exit' face (added separately below)
    'slab': {
        'type': 'obj',
        'id': 'slab',
        'filename': '../scenes/meshes/slab.obj',
        'to_world': mi.ScalarTransform4f.rotate(axis=(1, 0, 0), angle=90),
        'bsdf': {'type': 'ref', 'id': 'simple-glass'},
    },
    # Glass rectangle, to be optimized
    'lens': {
        'type': 'ply',
        'id': 'lens',
        'filename': lens_fname,
        'to_world': mi.ScalarTransform4f.rotate(axis=(1, 0, 0), angle=90),
        'bsdf': {'type': 'ref', 'id': 'simple-glass'},
    },

    # Directional area emitter placed behind the glass slab
    'focused-emitter-shape': {
        'type': 'obj',
        'filename': '../scenes/meshes/rectangle.obj',
        'to_world': mi.ScalarTransform4f.look_at(
            target=[0, 0, 0],
            origin=[0, 5, 0],
            up=[0, 0, 1]
        ),
        'bsdf': {'type': 'ref', 'id': 'black-bsdf'},
        'focused-emitter': emitter,
    },
}

In [11]:
scene = mi.load_dict(scene)

load reference image

In [12]:
def load_ref_image(config, resolution, output_dir):
    b = mi.Bitmap(config['reference'])
    b = b.convert(mi.Bitmap.PixelFormat.RGB, mi.Bitmap.Float32, False)
    if b.size() != resolution:
        b = b.resample(resolution)

        mi.util.write_bitmap(join(output_dir, 'out_ref.exr'), b)

        print('loaded reference image from:', config['reference'])
        return mi.TensorXf(b)

sensor = scene.sensors()[0]
crop_size = sensor.film().crop_size()
image_ref = load_ref_image(config, crop_size, output_dir=output_dir)

loaded reference image from: C:\Users\Alex\Desktop\exploring-inverse-rendering\scenes\references/m_green-512.jpg


create displacement texture

In [13]:
initial_heightmap_resolution = [r // (2 ** config['n_upsampling_steps'])
                                for r in config['heightmap_resolution']]
upsampling_steps = dr.sqr(dr.linspace(mi.Float, 0, 1, config['n_upsampling_steps']+1, endpoint=False).numpy()[1:])
upsampling_steps = (config['max_iterations'] * upsampling_steps).astype(int)
print('The resolution of the heightfield will be doubled at iterations:', upsampling_steps)

heightmap_texture = mi.load_dict({
    'type': 'bitmap',
    'id': 'heightmap_texture',
    'bitmap': mi.Bitmap(dr.zeros(mi.TensorXf, initial_heightmap_resolution)),
    'raw': True,
})

params = mi.traverse(heightmap_texture)
params.keep(['data'])
opt = mi.ad.Adam(lr=config['learning_rate'], params=params)

The resolution of the heightfield will be doubled at iterations: [ 40 160 360 640]


applying displacement texture

In [14]:
params_scene = mi.traverse(scene)

#always displace along the original normals
positions_initial = dr.unravel(mi.Vector3f, params_scene['lens.vertex_positions'])
normals_initial   = dr.unravel(mi.Vector3f, params_scene['lens.vertex_normals'])

lens_si = dr.zeros(mi.SurfaceInteraction3f, dr.width(positions_initial))
lens_si.uv = dr.unravel(type(lens_si.uv), params_scene['lens.vertex_texcoords'])

def apply_displacement(amplitude = 1.):
    # Enforce reasonable range. For reference, the receiving plane
    # is 7 scene units away from the lens.
    vmax = 1 / 100.
    params['data'] = dr.clamp(params['data'], -vmax, vmax)
    dr.enable_grad(params['data'])

    height_values = heightmap_texture.eval_1(lens_si)
    new_positions = (height_values * normals_initial * amplitude + positions_initial)
    params_scene['lens.vertex_positions'] = dr.ravel(new_positions)
    params_scene.update()

running the optimization:

In [15]:
def scale_independent_loss(image, ref):
    scaled_image = image / dr.mean(dr.detach(image))
    scaled_ref = ref / dr.mean(ref)
    return dr.mean(dr.sqr(scaled_image - scaled_ref))

In [16]:
import time
start_time = time.time()
mi.set_log_level(mi.LogLevel.Warn)
iterations = config['max_iterations']
loss_values = []
spp = config['spp']

for it in range(iterations):
    t0 = time.time()

    # Apply displacement and update the scene BHV accordingly
    apply_displacement()

    # Perform a differentiable rendering of the scene
    image = mi.render(scene, params, seed=it, spp=2 * spp, spp_grad=spp)

    #save (to see the optimizer working)
    mi.util.write_bitmap(f"/4_outputs/images/{it}-{config_name}.png", image)

    # Scale-independent L2 function
    loss = scale_independent_loss(image, image_ref)

    # Back-propagate errors to input parameters and take an optimizer step
    dr.backward(loss)

    # Take a gradient step
    opt.step()

    # Increase resolution of the heightmap
    if it in upsampling_steps:
        opt['data'] = dr.upsample(opt['data'], scale_factor=(2, 2, 1))

    # Carry over the update to our "latent variable" (the heightmap values)
    params.update(opt)

    # Log progress
    elapsed_ms = 1000. * (time.time() - t0)
    current_loss = loss[0]
    loss_values.append(current_loss)
    mi.Thread.thread().logger().log_progress(
        it / (iterations-1),
        f'Iteration {it:03d}: loss={current_loss:g} (took {elapsed_ms:.0f}ms)',
        'Caustic Optimization', '')


    # Increase rendering quality toward the end of the optimization
    if it in (int(0.7 * iterations), int(0.9 * iterations)):
        spp *= 2
        opt.set_learning_rate(0.5 * opt.lr['data'])


end_time = time.time()
print(((end_time - start_time) * 1000) / iterations, ' ms per iteration on average')
mi.set_log_level(mi.LogLevel.Info)

VBox(children=(HTML(value=''), FloatProgress(value=0.0, bar_style='info', layout=Layout(width='100%'), max=1.0…

1890.5548701286316  ms per iteration on average


In [17]:
fname = join(output_dir, 'heightmap_final.exr')
mi.util.write_bitmap(fname, params['data'])
print('[+] Saved final heightmap state to:', fname)

fname = join(output_dir, 'lens_displaced.ply')
apply_displacement()
lens_mesh = [m for m in scene.shapes() if m.id() == 'lens'][0]
lens_mesh.write_ply(fname)
print('[+] Saved displaced lens to:', fname)

[+] Saved final heightmap state to: C:\Users\Alex\Desktop\exploring-inverse-rendering\tutorials\4_outputs\green\heightmap_final.exr


[+] Saved displaced lens to: C:\Users\Alex\Desktop\exploring-inverse-rendering\tutorials\4_outputs\green\lens_displaced.ply
