In [None]:
# Adapted from Mitsuba 3's Official Tutorials:
# https://mitsuba.readthedocs.io/en/latest/src/inverse_rendering/shape_optimization.html
# https://mitsuba.readthedocs.io/en/stable/src/how_to_guides/mesh_io_and_manipulation.html

import drjit as dr
import mitsuba as mi
import matplotlib.pyplot as plt
import os

mi.set_variant('llvm_ad_rgb')


In [None]:
# Dependencies:
# !pip install cholespy
# !pip install gpytoolbox

In [None]:
from mitsuba import ScalarTransform4f as T

active_indices = [1, 3]

origins_all = [
    [-0.08, 0.02, -1.25],
    [-0.1, 0.02, 1.33],
    [1.35, 0.01, -0.05],
    [-1.35, -0.05, -0.01]
]

targets_all = [
    [-0.08, 0.02, 0],
    [-0.1, 0.02, 0],
    [0, 0.01, -0.05],
    [0, -0.05, -0.01]
]

origins = []
targets = []
for i in active_indices:
    origins.append(origins_all[i])
    targets.append(targets_all[i])

sensor_count = len(active_indices)
sensors = []

for i in range(sensor_count):
    sensors.append(mi.load_dict({
        'type': 'perspective',
        'fov': 45,
        'to_world': T().look_at(target=targets[i], origin=origins[i], up=[0, 1, 0]),
        'film': {
            'type': 'hdrfilm',
            'width': 400, 'height': 400,
            'filter': {'type': 'gaussian'},
            'sample_border': True
        },
        'sampler': {
            'type': 'independent',
            'sample_count': 128
        }
    }))


In [None]:
# If exporting PLY from Blender:
# Format        : Check "ASCII" (for later edits)
# Limit to      : Check "Selected Only"
# Scale         : 1.000
# Forward Axis  : -Z
# Up Axis       : Y
# Objects       : Check Apply Modifiers (if needed)
# Geometry      : Dont Check "UV Coordinates" (not needed)
#                 Don't Check "Vertex Normals" (let Mitsuba calculate)
# Vertex Colors : Select "Linear"
#                 Check "Triangulated Mesh"
# Then, manually edit the PLY file:
# 1. In the header:
#   change "uchar color" to "float color" where "color" is "red"/"green"/"blue"
#   remove the line "property uchar alpha"
# 2. After the header:
#   replace all (0-255, 0-255, 0-255, 0-255) rbga values with (0-1, 0-1, 0-1) rgb values
# 3. When loading the PLY file in Mitsuba:
#   need 'flip_normals' : True

mesh = mi.load_dict({
    "type": "ply",
    'filename': "./meshes/HandsomeDan_Updated.ply",
    'flip_normals' : True,
    "bsdf": {
        "type": "diffuse",
        'reflectance': {
            'type': 'mesh_attribute',
            'name': 'vertex_color'
        }
        # 'reflectance': {
        #     'type': 'rgb',
        #     'value': [0.2, 0.2, 0.2]
        # }
    }
})

mesh_params = mi.traverse(mesh)
print(mesh_params)
# mesh_params['vertex_normals'] *= -1
# mesh_params.update()

scene_dict = {
    'type': 'scene',
    'integrator': {
        'type': 'direct_projective',
        'sppi': 0, 
    },
    'emitter': {
        'type': 'constant',
        'radiance': 1.0,
    },
    'shape': mesh
}

scene = mi.load_dict(scene_dict)
params = mi.traverse(scene)
print(params)
print(list(params["shape.vertex_color"])[:9])
print(list(params["shape.vertex_normals"])[:9])


In [None]:
def plot_images(images):
    images_count = len(images)
    fig, axs = plt.subplots(1, images_count, figsize=(images_count*5, 5))
    if images_count == 1:
        axs.imshow(mi.util.convert_to_bitmap(images[i]))
        axs.axis('off')
    else:
        for i in range(images_count):
            axs[i].imshow(mi.util.convert_to_bitmap(images[i]))
            axs[i].axis('off')


In [None]:
ref_images_all = []
for i in range(1, 5):
    ref_image = mi.Bitmap(f"./refs/HandsomeDan{i}_400px.png")
    ref_image = ref_image.convert(
        pixel_format=mi.Bitmap.PixelFormat.RGB,
        component_format=mi.Struct.Type.Float32,
        srgb_gamma=False,
    )
    ref_images_all.append(ref_image)

ref_images = []
for i in active_indices:
    ref_images.append(ref_images_all[i])

plot_images(ref_images)


In [None]:
init_imgs = [mi.render(scene, sensor=sensors[i], spp=128) for i in range(sensor_count)]
plot_images(init_imgs)

In [None]:
lambda_ = 25
ls = mi.ad.LargeSteps(params['shape.vertex_positions'], params['shape.faces'], lambda_)

In [None]:
lr = 1e-1
opt = mi.ad.Adam(lr=lr, uniform=True)

In [None]:
opt['shape.vertex_positions'] = ls.to_differential(params['shape.vertex_positions'])
opt['shape.vertex_color'] = params['shape.vertex_color']

In [None]:
iterations = 100
loss_vec = []
for it in range(iterations):
    total_loss = mi.Float(0.0)

    for sensor_idx in range(sensor_count):
        params['shape.vertex_positions'] = ls.from_differential(opt['shape.vertex_positions'])
        params['shape.vertex_color'] = opt['shape.vertex_color']
        params.update()

        img = mi.render(scene, params, sensor=sensors[sensor_idx], seed=it)

        # L1 Loss
        loss = dr.mean(dr.abs(img - ref_images[sensor_idx]))

        dr.backward(loss)
        opt.step()
        params.update(opt)

        total_loss += loss

    loss_vec.append(total_loss)
    print(f"Iter: {1+it:03d}; Loss: {total_loss[0]}")

In [None]:
plt.plot(loss_vec) # somehow doesn't work if loss_vec.append(total_loss[0])
plt.show()


In [None]:
params['shape.vertex_positions'] = ls.from_differential(opt['shape.vertex_positions'])
params['shape.vertex_color'] = opt['shape.vertex_color']
params.update()

print(list(params["shape.vertex_normals"])[:9])


In [None]:
final_imgs = [mi.render(scene, sensor=sensors[i], spp=128) for i in range(sensor_count)]
plot_images(final_imgs)


In [None]:
mesh.write_ply(f"./outputs/HandsomeDan_Optimized_lr{str(lr)}.ply")
