# Detailed look at `Optimizer`

For the sake of this tutorial, let's set an AD-compatible variant and load a scene.

In [1]:
import mitsuba as mi
import drjit as dr 

mi.set_variant('llvm_ad_rgb')

## The basics

To perform [gradient-based optimization][1], Mitsuba ships with standard optimizers including *Stochastic Gradient Descent* ([<code>SGD</code>][2]) with and without momentum, as well as [<code>Adam</code>][3]. Those both inherit from the [<code>Optimizer</code>][4] base class and can be found in the `mitsuba.ad` submodule.

The `Optimizer` class behaves like a Python `dict` with some extra logic and methods. This how-to guide will take you through its API and highlight the best-practices and pittfalls related to this class.

Let's first construct a simple `SGD` optimizer with a learning rate of `0.25`. 

[1]: https://en.wikipedia.org/wiki/Gradient_descent
[2]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer
[3]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.SGD
[4]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Adam

In [2]:
opt = mi.ad.SGD(lr=0.25)

We can now specify a variable to be optimized. The `Optimizer` will automatically enable gradient computation on the stored variable as it is necessary for further computation to produce any gradients.

In [3]:
opt['x'] = mi.Float(1.0)
opt

SGD[
  variables = ['x'],
  lr = {'default': 0.25},
  momentum = 0
]

It also provides a similar API to perform basic dictionary manipulations.

In [4]:
for k, v in opt.items():
    print(f"{k}: {v}")

x: [1.0]


⚠️ It is important to note that a copy of the variable is made when assigned to the Optimizer via the `__setitem__` method. For instance in the following code, the original variable won't be attach to the AD graph, and its value will remain unchanged.

In [5]:
y = mi.Float(2.0)
opt['y'] = y
opt['y'] += 1.0

print(f"original:  {y}, grad_enabled={dr.grad_enabled(y)}")
print(f"optimizer: {opt['y']}: grad_enabled={dr.grad_enabled(opt['y'])}")

original:  [2.0], grad_enabled=False
optimizer: [3.0]: grad_enabled=True


It is therefore crucial to use the variable held by the optimizer to perform the differentiable computation in order to produce the proper gradients. Here is a simple example where `x` and `y` are used in some calculation for which we then request the gradients to be backpropagated. We can validate that the gradients are adequately propagated to the optimizer variables.

In [6]:
z = opt['x'] + 2.0 * opt['y']
dr.backward(z)

print(f"x grad={dr.grad(opt['x'])}")
print(f"y grad={dr.grad(opt['y'])}")

x grad=[1.0]
y grad=[2.0]


During the optimization, the role of the optimizer will be to take a gradient step according to its update rule. In the case of a simple SGD optimizer with no momentum, the update rule is:
$$
x_{i+1} = x_i - \texttt{grad}(x_i) \times \texttt{lr} 
$$

The `Optimizer` method [<code>step()</code>][1] will apply its update rule to all the variables.

[1]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer.step

In [7]:
print(f"Before the gradient step: x={opt['x']}, y={opt['y']}")
opt.step()
print(f"After the gradient step:  x={opt['x']}, y={opt['y']}")

Before the gradient step: x=[1.0], y=[3.0]
After the gradient step:  x=[0.75], y=[2.5]


After performing the update rule, the `Optimizer` also resets the gradient values of its variables to `0.0` and ensures gradient computations are still enabled on all its variables. This guarantees that everything is ready for the next iteration of the optimization loop.

## Optimize scene parameters

In the context of differentiable rendering, we are interested in optimizing scene parameters exposed via the [<code>traverse()</code>][1] mechanism. While this could be done using the interface showcased above, the [<code>Optimizer</code>][2] class was specifically designed to better handle the sharp bits that arise in this configuration. For instance, it is important to always notify the scene of changing parameters as it might need to update its internal state (e.g. BVH rebuilt, shading normals recomputation, ...). 

For the sake of this How-to Guide, let's load a simple scene and perform the traversal to get access to the scene parameters.

[1]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.traverse
[2]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer

In [8]:
scene = mi.load_file('../scenes/cbox.xml')
params = mi.traverse(scene)

At construction, the `Optimizer` can be passed a `SceneParameters` object to work with.

In [9]:
opt = mi.ad.SGD(lr=0.25, params=params)
opt

SGD[
  variables = [],
  lr = {'default': 0.25},
  momentum = 0
]

At this point, the optimizer is empty as we didn't indicate which of the scene parameters we would like to optimize. This can be done with the [<code>load()</code>][1] method, which takes a string and/or a regex in order to load many params at once.

[1]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.Optimizer.load

In [10]:
opt.load('red.reflectance.value')

# Or using a regex to load all reflectance values from the different BSDFs
opt.load(r'.+\.reflectance\.value')

opt

SGD[
  variables = ['red.reflectance.value', 'gray.reflectance.value', 'white.reflectance.value', 'green.reflectance.value'],
  lr = {'default': 0.25},
  momentum = 0
]

As explained above, the loaded parameters will be copied internally, so any attempt to change their value in the optimizer will not directly be reflected in `params`.

In [11]:
opt['red.reflectance.value'] *= 0.5

print(f"params:   {params['red.reflectance.value']}")
print(f"optimize: {opt['red.reflectance.value']}")

params:   [[0.5700680017471313, 0.043013498187065125, 0.04437059909105301]]
optimize: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]


In order to propagate those changes to `params` (and to the `Scene` itself), the `Optimizer` class provides the [<code>update()</code>][1] method.

[1]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.Optimizer.update

In [12]:
opt.update()

print(f"params:   {params['red.reflectance.value']}")
print(f"optimize: {opt['red.reflectance.value']}")

params:   [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]
optimize: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]


Note that it is also possible to perform partial updates by giving a list of strings and/or a regex to specify which parameters should be updated.

It is also possible to create an optimizer variable with a scene parameter's name as the `Optimizer` key. In this case, during the call to `update()`, the `Optimizer` will recognize that this variable's value should be propagated to the `params` data structure.

In [13]:
opt['light.emitter.radiance.value'] = mi.Color3f(3.0)

print('Before the update:')
print(f"  params:   {params['light.emitter.radiance.value']}")
print(f"  optimize: {opt['light.emitter.radiance.value']}")
opt.update()
print('After the update:')
print(f"  params:   {params['light.emitter.radiance.value']}")
print(f"  optimize: {opt['light.emitter.radiance.value']}")

Before the update:
  params:   [[18.386999130249023, 13.987299919128418, 6.753570079803467]]
  optimize: [[3.0, 3.0, 3.0]]
After the update:
  params:   [[3.0, 3.0, 3.0]]
  optimize: [[3.0, 3.0, 3.0]]


## Optimize latent variables

In more complex optimization scenarios, scene parameters might be described as a **function** of some other parameters. In such a scenario, we would be interested in optimizing those other parameters instead of the scene parameters directly. For example, this is desirable when generating the vertex positions of a mesh using a neural network. We would want to optimize the weights of the neural network, not the vertex positions themselves. Another example could be a procedurally generated texture, maybe from a physically-based model that can be tweaked with a few parameters.

Those external parameter can be called latent variables, and it is at the core of the [<code>Optimizer</code>][1]'s design to support them.

For a simpler example, let's consider the case where we are aiming at optimizing the translation vector of a 3D mesh object in our scene. Even from an convexity standpoint, optimizing those three translation values will be much easier that having to optimize all the vertex positions simultaneously and hope for the best.

Let's initialize our optimizer one more time, still passing the `params` object as argument.

[1]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer

In [14]:
opt = mi.ad.SGD(lr=0.25, params=params)

We can then append a latent variable to the optimizer, similar do what we did in the first section of this How-to Guide.

In [15]:
opt['trans'] = mi.Vector3f(0, 0, 0)

In this scenario, it will be the user's responsability to propagate the changes of the latent variable to the scene parameters. For this we like to define a specific update function in the following cell. Note that the vertex positions on a mesh are stored in a linear fashion in Mitsuba, it is therefore important to transform them into 3D points to apply the translation. They also need to be converted back into a linear array before the assignment to the scene parameters. For those operations, we use the `dr.unravel` and `dr.ravel` functions of DrJIT.

In [16]:
# Copy or our vertex positions (and convert them to 3D points)
initial_vertex_pos = dr.unravel(mi.Point3f, params['redwall.vertex_positions'])

# Now we define the update rule
def update_vertex_pos():
    # Create the translation transformation
    T = mi.Transform4f.translate(opt['trans'])
    
    # Apply the transformation to the vertex positions
    new_vertex_pos = T @ initial_vertex_pos

    # Flatten the vertex position array before assigning it to `params`
    params['redwall.vertex_positions'] = dr.ravel(new_vertex_pos)

With this function implemented, all we need to do is to make sure to call it before every call to `opt.update()` in order to properly 
propagate the new translation value to the scene.

In [17]:
update_vertex_pos()
opt.update()

<div class="admonition important alert alert-block alert-info">
💭 Of course in this simple example we could directly call the update method on <tt>params</tt> rather than on <tt>opt</tt>, but we need to consider the case where we are optimizing other scene parameters at the same time.
</div>

<div class="admonition important alert alert-block alert-info">
💭 On top of propagating the changes of values, performing this update also builds the computational graph necessary to the automatic differentiation layer to later compute the gradients of the scene parameters with respect to the latent variable.
</div>

## Optimizer state

On top of carrying variables to optimize, the [<code>Optimizer</code>][1] also holds an internal state used in its update rule. For instance, the momentum-based [</code>SGD</code>][2] optimizer tracks the velocity of the previous iteration to apply the momentum in its [<code>step()</code>][3] method. In most cases this state is stored on a per-parameter basis.

In some cases, it is useful to reset the state of an `Optimizer` (e.g. optimization scheduling that resizes the optimized volume grid). For this, we can use the [<code>reset()</code>][4] method which will zero-initialize the internal state associated with a specific parameter.

[1]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer
[2]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.SGD
[3]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer.step
[4]: https://mitsuba3.readthedocs.io/en/latest/src/api_reference.html#mitsuba.ad.Optimizer.reset

In [18]:
opt.reset('trans')

Another useful feature of the Mitsuba `Optimizer` is its ability to mask the state update depending on the presence of gradients in the variables. This can be useful with Monte Carlo simulations where some optimized parameters might not receive any gradients at this iteration. We found that in this situation, updating the state of the optimizer for those parameters will degrade the optimization as this iteration should rather be discarded for them instead. When constructing the `Optimizer`, it is possible to specific whether to discard such updates when the gradients are zero using the `mask_updates` argument (default is `False`).

In [19]:
opt = mi.ad.SGD(lr=0.25, mask_updates=True)

Finally, the Mitsuba implementation of the `SGD` optimizers also supports per-parameter learning rate. This is useful to control the magnitude of the gradient step taken on individual parameters. This can be achieved using the `set_learning_rate()` method, which takes a optional `key` argument to specify for which parameter to set the learning rate.

In [20]:
opt['x'] = mi.Float(1.0)
opt['y'] = mi.Float(1.0)

opt.set_learning_rate({'x': 0.125, 'y': 0.25})