# Gradient Computation

This example shows how we can calculate the gradient of a particular objective function.

Firstly, the packages required for this example are imported.

In [1]:
import jax
import jax.numpy as jnp

import jaxlayerlumos as jll
import jaxlayerlumos.utils_spectra as jll_utils_spectra
import jaxlayerlumos.utils_materials as jll_utils_materials
import jaxlayerlumos.utils_units as jll_utils_units

We define a material layout and an incidence angle.

In [2]:
materials = ["Air", "TiO2", "Ag", "TiO2", "FusedSilica"]
angles = jnp.array([0.0])

Frequencies and wavelengths corresponding to visible light values are retrieved along with ($n$, $k$) values for the different materials at these wavelengths.

In [3]:
frequencies = jll_utils_spectra.get_frequencies_visible_light()
wavelengths = jll_utils_spectra.convert_frequencies_to_wavelengths(frequencies)

n_k = jll_utils_materials.get_n_k(materials, frequencies)

Then, an objective function to compute gradient is defined. This objective function is to obtain the average of reflection spectrum over `thicknesses`. In this example, the gradient over the thicknesses of TiO$_2$, Ag, and TiO$_2$ is computed. The first and last layers are semi-infinite and defined with thicknesses of 0.

In [4]:
def objective(thicknesses):
    thicknesses = jnp.concatenate([
        jnp.array([0.0]),
        thicknesses,
        jnp.array([0.0])
    ], axis=0)
    thicknesses *= jll_utils_units.get_nano()

    R_TE, _, R_TM, _ = jll.stackrt(n_k, thicknesses, frequencies, thetas=angles)

    R_TE = R_TE[0]
    R_TM = R_TM[0]

    spectrum = (R_TE + R_TM) / 2
    avg_spectrum = jnp.mean(spectrum)

    return avg_spectrum

The next cell calculates the gradient over TiO$_2$ (100 nm), Ag (30 nm), and TiO$_2$ (40 nm).

In [5]:
grad_objective = jax.grad(objective)(jnp.array([100.0, 30.0, 40.0]))

print("Gradient over TiO2 (100 nm), Ag (30 nm), TiO2 (40 nm)")
print(grad_objective)

Gradient over TiO2 (100 nm), Ag (30 nm), TiO2 (40 nm)
[-0.00083524  0.01468192  0.00346658]


The next cell calculates the gradient over TiO$_2$ (10 nm), Ag (5 nm), and TiO$_2$ (12 nm).

In [6]:
grad_objective = jax.grad(objective)(jnp.array([10.0, 5.0, 12.0]))

print("Gradient over TiO2 (10 nm), Ag (5 nm), TiO2 (12 nm)")
print(grad_objective)

Gradient over TiO2 (10 nm), Ag (5 nm), TiO2 (12 nm)
[ 0.0019731  -0.00641389  0.00573745]
