# Diffractive splitter

The diffractive splitter challenge entails designing a metasurface that evenly splits a normally-incident plane wave into a 7x7 array of beams. Light is incident from the ambient, with the substrate and the metasurface pattern being silicon oxide. The operating wavelength is 732.8 nm, and the unit cell pitch is 7.2 microns, corresponding to diffraction angles of ±15 degrees. The challenge is based on "[Design and rigorous analysis of a non-paraxial diffractive beamsplitter](https://www.lighttrans.com/fileadmin/shared/UseCases/Application_UC_Rigorous%20Analysis%20of%20Non-paraxial%20Diffractive%20Beam%20Splitter.pdf)" slide deck retrieved from the LightTrans web site.
    

## Simulating an existing design

We'll begin by loading, visualizing, and simulating existing designs extracted from LightTrans material (slide 12).

In [None]:
import matplotlib.pyplot as plt
import numpy as onp


def load_design(name):
    path = f"../../reference_designs/diffractive_splitter/{name}.csv"
    coarse_array = onp.genfromtxt(path, delimiter=",")
    return onp.kron(coarse_array, onp.ones((10, 10)))


names = ["device1", "device2", "device3"]
designs = [load_design(name) for name in names]

plt.figure(figsize=(8, 3))
for i, design in enumerate(designs):
    ax = plt.subplot(1, 4, i + 1)
    ax.imshow(design, cmap="gray")
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
from invrs_gym.challenges.diffract import splitter_challenge

challenge = splitter_challenge.diffractive_splitter()

While several challenges involve only the design of two-dimensional patterns (with a `Density2DArray` being the optimization variable), the diffractive splitter degrees of freedom include both the metasurface pattern and several film thicknesses, in the form of a `BoundedArray`.

In [None]:
import jax

params = challenge.component.init(jax.random.PRNGKey(0))
for key, value in params.items():
    print(f"Variable {key}: {type(value)}")

We'll simulate a reference design by overwriting the `density` entry in the `params` dict, leaving thicknesses unchanged. The default values match those from the LightTrans example. Then simulate using the `component.response` method.

In [None]:
import dataclasses

params["density"] = dataclasses.replace(params["density"], array=load_design("device1"))
response, aux = challenge.component.response(params)

Now let's plot the diffraction efficiency for each order. We use the `extract_orders_for_splitting` function, and get the efficiency for a 9x9 array of beams (even though this design is for a 7x7 splitter). This will let us see how the diffraction efficiency drops off for orders beyond those targeted by the design.

In [None]:
plt.figure(figsize=(4, 3))

splitting = splitter_challenge.extract_orders_for_splitting(
    response.transmission_efficiency,
    response.expansion,
    splitting=(9, 9),
    polarization="TM",
)

ax = plt.subplot(111)
im = plt.imshow(splitting * 100, cmap="coolwarm")
ax.set_xticks(onp.arange(9))
ax.set_yticks(onp.arange(9))
ax.set_xticklabels(range(-4, 5))
ax.set_yticklabels(range(-4, 5))
plt.colorbar(im)
im.set_clim([0, onp.amax(splitting * 100)])
ax.set_title("device1\nDiffraction efficiency (%)")
_ = ax.set_ylim(ax.get_ylim()[::-1])

This device is not a particularly good one, as most of the power ends up in the zeroth order. This is reported also in the LightTrans material, and seen in the metrics we can compute using the challenge `metrics` method.

In [None]:
print("Challenge metrics:")
for key, value in challenge.metrics(response, params=params, aux=aux).items():
    print(f"    {key} = {value:.4f}")

Let's take a look at the remaining devices, which have higher reported performance.

In [None]:
plt.figure(figsize=(8, 3))
for i, name in enumerate(["device2", "device3"]):
    params["density"] = dataclasses.replace(params["density"], array=load_design(name))
    response, aux = challenge.component.response(params)

    splitting = splitter_challenge.extract_orders_for_splitting(
        response.transmission_efficiency,
        response.expansion,
        splitting=(9, 9),
        polarization="TM",
    )

    ax = plt.subplot(1, 2, i + 1)
    im = plt.imshow(splitting * 100, cmap="coolwarm")
    ax.set_xticks(onp.arange(9))
    ax.set_yticks(onp.arange(9))
    ax.set_xticklabels(range(-4, 5))
    ax.set_yticklabels(range(-4, 5))
    plt.colorbar(im)
    im.set_clim([0, onp.amax(splitting * 100)])
    ax.set_title(f"{name}\nDiffraction efficiency (%)")
    ax.set_ylim(ax.get_ylim()[::-1])

## Diffractive splitter optimization

Now let's optimize a diffractive splitter. Again we obtain initial random parameters and define the loss function. The loss function will also return the response and the metrics, which will let us see how performance improves as we optimize.

In [None]:
params = challenge.component.init(jax.random.PRNGKey(0))


def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    metrics = challenge.metrics(response, params=params, aux=aux)
    return loss, (response, metrics)

Before optimizing, let's investigate the gradient of the loss with respect to the optimizable parameters. For many optimization methods, these must be comparable in magnitude if the optimization is to be successful. In particular, this is true for the L-BFGS-B scheme used in other challenges.

In [None]:
grad, _ = jax.grad(loss_fn, has_aux=True)(params)

plt.figure(figsize=(4, 3))
ax = plt.subplot(111)
im = ax.imshow(grad["density"].array)
plt.colorbar(im)
ax.set_title("Gradient wrt density")

print(f"Gradient wrt grating thickness: {grad['thickness_grating'].array}")

The gradient with respect to density and the grating thickness differ by roughly a factor of 1e4. To bring them to the same scale, we can change the scale of the `density`. By default, a value of `0` corresponds to the absence of material, and `1` corresponds to the presence of material, as stored in the bounds attributes of the `Density2DArray`.

In [None]:
print(
    f"Original density: "
    f"lower bound (void) = {params['density'].lower_bound:.4f}, "
    f"upper bound (solid) = {params['density'].upper_bound:.4f}, "
    f"mean value = {onp.mean(params['density'].array):.4f}"
)

To change the density scale, we'll simply modify its upper and lower bounds, and also rescale the `array` attribute accordingly. Note that regardless of the scale we choose, the simulation will map the upper bound value to material presence, and the lower bound value to material absence.

In [None]:
def rescale_density(density, scale):
    rescaled_array = density.array - density.lower_bound
    rescaled_array /= (density.upper_bound - density.lower_bound) / scale
    return dataclasses.replace(
        density,
        array=rescaled_array,
        lower_bound=0,
        upper_bound=(density.upper_bound - density.lower_bound) * scale,
    )


params["density"] = rescale_density(params["density"], scale=1e-2)
print(
    f"Rescaled density: "
    f"lower bound (void) = {params['density'].lower_bound:.4e}, "
    f"upper bound (solid) = {params['density'].upper_bound:.4e}, "
    f"mean value = {onp.mean(params['density'].array):.4e}"
)

Then, we can take another look at the gradients.

In [None]:
grad, _ = jax.grad(loss_fn, has_aux=True)(params)

plt.figure(figsize=(4, 3))
ax = plt.subplot(111)
im = ax.imshow(grad["density"].array)
plt.colorbar(im)
ax.set_title("Gradient wrt density")

print(f"Gradient wrt grating thickness: {grad['thickness_grating'].array}")

The values are now a bit more comparable. Note that a the choice of scale can dramatically impact the optimization result, and it may be worthwhile to experiment with several values to find one that works well.

To design the diffractive splitter we'll use the `density_lbfgsb` optimizer from the [invrs-opt](https://github.com/invrs-io/opt) package. Initialize the optimizer state, and then define the `step_fn` which is called at each optimization step, and then simply call it repeatedly to obtain an optimized design.

In [None]:
import invrs_opt

opt = invrs_opt.density_lbfgsb(beta=4)
state = opt.init(params)  # Initialize optimizer state using the initial parameters.


@jax.jit
def step_fn(state):
    params = opt.params(state)
    (value, (response, metrics)), grad = jax.value_and_grad(loss_fn, has_aux=True)(
        params
    )
    state = opt.update(grad=grad, value=value, params=params, state=state)
    return state, (params, value, response, metrics)


# Call `step_fn` repeatedly to optimize, and store the results of each evaluation.
metrics_values = []
for _ in range(60):
    state, (params, value, response, metrics) = step_fn(state)
    metrics_values.append(metrics)

Now let's visualize the optimization trajectory, the final design, and its performance.

In [None]:
from skimage import measure

plt.figure(figsize=(12, 3))

ax = plt.subplot(131)
ax.plot(
    [m["total_efficiency"] * 100 for m in metrics_values], label="Total efficiency (%)"
)
ax.plot(
    [m["uniformity_error"] * 100 for m in metrics_values], label="Uniformity error (%)"
)
ax.set_xlabel("Optimization step")
ax.legend()

ax = plt.subplot(132)
im = plt.imshow(params["density"].array, cmap="gray")
im.set_clim([params["density"].lower_bound, params["density"].upper_bound])

contours = measure.find_contours(onp.asarray(params["density"].array))
for c in contours:
    ax.plot(c[:, 1], c[:, 0], "r")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Optimized design")

splitting = splitter_challenge.extract_orders_for_splitting(
    response.transmission_efficiency,
    response.expansion,
    splitting=(9, 9),
    polarization="TM",
)

ax = plt.subplot(133)
im = plt.imshow(splitting * 100, cmap="coolwarm")
ax.set_xticks(onp.arange(9))
ax.set_yticks(onp.arange(9))
ax.set_xticklabels(range(-4, 5))
ax.set_yticklabels(range(-4, 5))
plt.colorbar(im)
im.set_clim([0, onp.amax(splitting * 100)])
ax.set_title("Diffraction efficiency (%)")
_ = ax.set_ylim(ax.get_ylim()[::-1])