In [None]:
# Adapted from Mitsuba 3's Official Tutorials:
# https://mitsuba.readthedocs.io/en/latest/src/inverse_rendering/shape_optimization.html
# https://mitsuba.readthedocs.io/en/stable/src/how_to_guides/mesh_io_and_manipulation.html

import drjit as dr
import mitsuba as mi
import matplotlib.pyplot as plt
import os

mi.set_variant('llvm_ad_rgb')


In [None]:
# Dependencies:
# !pip install cholespy
# !pip install gpytoolbox

In [None]:
mi.xml_to_props("banana_updated.xml")

In [None]:
from mitsuba import ScalarTransform4f as T

active_indices = [0, 1, 2, 3, 4, 5]

to_world_list = [
    T([
        [-0.40831, -0.629649, -0.660927, 3.35112],
        [-0.752151, -0.178205, 0.634439, -0.77301],
        [-0.517254, 0.756165, -0.400828, -0.591697],
        [0, 0, 0, 1]
    ]), #2
    T([
        [0.313648, -0.580312, -0.751573, 4.23076],
        [-0.829941, -0.552095, 0.0799364, 1.81866],
        [-0.461328, 0.598689, -0.654788, 0.998348],
        [0, 0, 0, 1]
    ]), #4
    T([
        [0.999562, -0.0260534, 0.0140109, 0.290532],
        [-0.0194337, -0.935423, -0.352996, 3.80493],
        [0.0223028, 0.352569, -0.93552, 2.18121],
        [0, 0, 0, 1]
    ]), #6
    T([
        [0.7136, 0.377279, 0.590284, -2.54665],
        [0.570557, -0.801911, -0.177211, 2.72416],
        [0.406497, 0.463248, -0.787503, 1.21879],
        [0, 0, 0, 1]
    ]), #8
    T([
        [-0.886821, 0.229873, 0.400883, -1.07143],
        [0.397902, -0.0612994, 0.915378, -2.60014],
        [0.234995, 0.971288, -0.0371054, -2.38561],
        [0, 0, 0, 1]
    ]), #11
    T([
        [-0.994695, -0.00129575, 0.102857, 0.36554],
        [0.10286, -0.00344105, 0.99469, -2.86139],
        [-0.000934934, 0.999993, 0.00355608, -2.3446],
        [0, 0, 0, 1]
    ])  #13
]

sensor_count = len(active_indices)
sensors = []

for i in range(len(to_world_list)):
    sensors.append(mi.load_dict({
        'type': 'perspective',
        'fov_axis' : 'x',
        'fov': 64.022150,
        'to_world': to_world_list[i],
        'film': {
            'type': 'hdrfilm',
            'width': 600, 'height': 400,
            'filter': {'type': 'gaussian'},
            'sample_border': True
        },
        'sampler': {
            'type': 'independent',
            'sample_count': 128
        }
    }))


In [None]:
# If exporting PLY from Blender:
# Format        : Check "ASCII" (for later edits)
# Limit to      : Check "Selected Only"
# Scale         : 1.000
# Forward Axis  : -Z
# Up Axis       : Y
# Objects       : Check Apply Modifiers (if needed)
# Geometry      : Dont Check "UV Coordinates" (not needed)
#                 Don't Check "Vertex Normals" (let Mitsuba calculate)
# Vertex Colors : Select "Linear"
#                 Check "Triangulated Mesh"
# Then, manually edit the PLY file:
# 1. In the header:
#   change "uchar color" to "float color" where "color" is "red"/"green"/"blue"
#   remove the line "property uchar alpha"
# 2. After the header:
#   replace all (0-255, 0-255, 0-255, 0-255) rbga values with (0-1, 0-1, 0-1) rgb values
# 3. When loading the PLY file in Mitsuba:
#   need 'flip_normals' : True

mesh = mi.load_dict({
    "type": "ply",
    'filename': "./meshes/banana_init.ply",
    # 'flip_normals' : True,
    "bsdf": {
        "type": "diffuse",
        'reflectance': {
            'type': 'mesh_attribute',
            'name': 'vertex_color'
        }
        # 'reflectance': {
        #     'type': 'rgb',
        #     'value': [0.2, 0.2, 0.2]
        # }
    }
})

attribute_size = mesh.vertex_count() * 3
mesh.add_attribute(
    "vertex_color", 3, [0.20] * attribute_size
)

mesh_params = mi.traverse(mesh)
print(mesh_params)
# mesh_params['vertex_normals'] *= -1
# mesh_params.update()

scene_dict = {
    'type': 'scene',
    'integrator': {
        'type': 'direct_projective',
        'sppi': 0, 
    },
    'emitter': {
        'type': 'constant',
        'radiance': 1.0,
    },
    'shape': mesh
}

scene = mi.load_dict(scene_dict)
params = mi.traverse(scene)
print(params)
print(list(params["shape.vertex_color"])[:9])
print(list(params["shape.vertex_normals"])[:9])


In [None]:
def plot_images(images):
    images_count = len(images)
    fig, axs = plt.subplots(1, images_count, figsize=(images_count*5, 5))
    if images_count == 1:
        axs.imshow(mi.util.convert_to_bitmap(images[i]))
        axs.axis('off')
    else:
        for i in range(images_count):
            axs[i].imshow(mi.util.convert_to_bitmap(images[i]))
            axs[i].axis('off')

def plot_images_each(images):
    for image in images:
        plt.imshow(mi.util.convert_to_bitmap(image))
        plt.axis("off")
        plt.show()


In [None]:
ref_images_all = []
ref_indices = [2, 4, 6, 8, 11, 13]
for i in ref_indices:
    ref_image = mi.Bitmap(f"./refs/banana/banana{i}_600x400.png")
    ref_image = ref_image.convert(
        pixel_format=mi.Bitmap.PixelFormat.RGB,
        component_format=mi.Struct.Type.Float32,
        srgb_gamma=False,
    )
    ref_images_all.append(ref_image)

ref_images = []
for i in active_indices:
    ref_images.append(ref_images_all[i])

plot_images(ref_images)
# plot_images_each(ref_images)


In [None]:
init_images = [mi.render(scene, sensor=sensors[i], spp=128) for i in range(sensor_count)]
plot_images(init_images)


In [None]:
plot_images_each(init_images)

In [None]:
lambda_ = 25
ls = mi.ad.LargeSteps(params['shape.vertex_positions'], params['shape.faces'], lambda_)

In [None]:
lr = 1e-1
opt = mi.ad.Adam(lr=lr, uniform=True)

In [None]:
opt['shape.vertex_positions'] = ls.to_differential(params['shape.vertex_positions'])
opt['shape.vertex_color'] = params['shape.vertex_color']

In [None]:
iterations = 100
loss_vec = []
for it in range(iterations):
    total_loss = mi.Float(0.0)

    for sensor_idx in range(sensor_count):
        params['shape.vertex_positions'] = ls.from_differential(opt['shape.vertex_positions'])
        params['shape.vertex_color'] = opt['shape.vertex_color']
        params.update()

        img = mi.render(scene, params, sensor=sensors[sensor_idx], seed=it)

        # L1 Loss
        loss = dr.mean(dr.abs(img - ref_images[sensor_idx]))

        dr.backward(loss)
        opt.step()
        params.update(opt)

        total_loss += loss

    loss_vec.append(total_loss)
    print(f"Iter: {1+it:03d}; Loss: {total_loss[0]}")

In [None]:
plt.plot(loss_vec) # somehow doesn't work if loss_vec.append(total_loss[0])
plt.show()


In [None]:
params['shape.vertex_positions'] = ls.from_differential(opt['shape.vertex_positions'])
params['shape.vertex_color'] = opt['shape.vertex_color']
params.update()

print(list(params["shape.vertex_normals"])[:9])


In [None]:
final_images = [mi.render(scene, sensor=sensors[i], spp=128) for i in range(sensor_count)]
plot_images(final_images)


In [None]:
plot_images_each(final_images)


In [None]:
mesh.write_ply(f"./outputs/banana_optimized_lr{str(lr)}_it{str(iterations)}.ply")
