In [None]:
import sys
sys.path.append('../')

import yaml
from typing import Union
import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure
import gifcm

import mesher
import utils
# import losses
import projections

import invrs_gym
from totypes import types
_Ext = utils.Extent

In [None]:
with open("./settings.yaml", "r") as file:
  config = yaml.safe_load(file)

mesh_cfg = config['DOM_MESH']
bbox_cfg = config['DOM_BBOX']
cons_cfg = config['CONSTRAINTS']

In [None]:
dom_bbox = mesher.BoundingBox(x=_Ext(bbox_cfg['x_min'], bbox_cfg['x_max']),
                                y=_Ext(bbox_cfg['y_min'], bbox_cfg['y_max'])
                                )
dom_mesh = mesher.Mesher(nelx=mesh_cfg['nelx'],
                           nely=mesh_cfg['nely'],
                           bounding_box=dom_bbox)

In [None]:
mesh_resolution_nm = dom_bbox.lx/dom_mesh.nelx

In [None]:
mesh_resolution_nm

In [None]:
challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend(
                                            resolution_nm=mesh_resolution_nm)

In [None]:
key = jax.random.PRNGKey(seed=27)

init_design = jax.random.uniform(key, (dom_mesh.num_elems,),
                                 minval=-0.001, maxval=0.0001)

init_density = jax.nn.sigmoid(init_design).reshape((dom_mesh.nelx, dom_mesh.nely))
dens_array = types.Density2DArray(array=init_density, lower_bound=0.,
                                  upper_bound=1.)

In [None]:
num_epochs, lr = 25, 2.e-1

In [None]:
def optimize_design(init_guess: jnp.ndarray,
                    num_epochs: int,
                    lr: float,
                    plot_interval: int = 5)->jnp.ndarray:

  optimizer = optax.adam(learning_rate=lr)
  opt_state = optimizer.init(init_guess)
  opt_params = init_guess

  epoch = 0
  comp_0 = 1.

  def loss_wrapper(opt_params: jnp.ndarray, epoch: int):

    density = jax.nn.sigmoid(opt_params)
    density = projections.threshold_filter(density).reshape((dom_mesh.nelx,
                                                             dom_mesh.nely))
    dens_array.array = density

    response, aux = challenge.component.response(dens_array)
    loss = challenge.loss(response)
    distance = challenge.distance_to_target(response)
    metrics = challenge.metrics(response, dens_array, aux)
    return loss, (density, response, distance, metrics, aux)
  
  convg_history = {'epoch': [], 'loss': [], 'density': [], 'aux': [],
                   'response': [], 'metrics': [], 'distance': []}
  for epoch in range(num_epochs):
    (loss, (density, response, distance, metrics, aux)), grad_loss = jax.value_and_grad(
                                                    loss_wrapper, has_aux=True)(
                                                        opt_params, epoch)

    updates, opt_state = optimizer.update(grad_loss, opt_state)
    opt_params = optax.apply_updates(opt_params, updates)

    convg_history['epoch'].append(epoch); convg_history['response'].append(response)
    convg_history['loss'].append(loss); convg_history['metrics'].append(metrics)
    convg_history['density'].append(density); convg_history['distance'].append(distance)
    convg_history['aux'].append(aux)

    if epoch == 1 or epoch == 10:
      loss_0 =  jax.lax.stop_gradient(loss)

    status = f'epoch {epoch}, J = {loss:.2E}'
    print(status)
    if epoch % plot_interval == 0:
      plt.figure()
      img = plt.imshow(density.reshape((dom_mesh.nelx, dom_mesh.nely)).T,
                       cmap='coolwarm',
                       origin='lower')
      plt.colorbar(img); plt.title(status); plt.show(); plt.pause(1.e-3)

  return density, convg_history

In [None]:
density, convg_history = optimize_design(init_design, num_epochs, lr)

# Plotting the S params

The S params if of shape (num_wvlenghts, num_input_ports, num_output_ports)

The numbers are complex numbers. This has to do with the fact that the wave in a complex number and the S param measures the strength of the wave

We want to get the strength of the wave... in other terms we first compute the
magnitude

the param is usually plotted in decibel scale ... we do this by doing 20xlog10()
of the magnitude of the obtained response


below we plot a gif (we collect the s param for each optimization epoch)

In [None]:
anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))

for (i, response) in zip(convg_history['epoch'],
                         convg_history['response']):
  with anim.frame():
    ax = plt.subplot(111)
    ax.plot(
            response.wavelengths_nm,
            20 * onp.log10(onp.abs(response.s_parameters[:,0,0])),
            "o-",
            label="$|S_{11}|^2$",
        )
    ax.plot(
            response.wavelengths_nm,
            20 * onp.log10(onp.abs(response.s_parameters[:,0,1])),
            "o-",
            label="$|S_{21}|^2$",
        )
    ax.legend()
    ax.set_xlabel('wavelength')
    ax.set_ylabel('scattering param')
    ax.set_xlim(onp.amin(response.wavelengths_nm), onp.amax(response.wavelengths_nm))
    ax.set_ylim([1., -40.])
    ax.set_title(f'epoch {i}')
    ax.invert_yaxis()

anim.save_gif("s_param.gif", duration=400)

In [None]:
anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))

for (i, rho, aux) in zip(convg_history['epoch'],
                         convg_history['density'],
                         convg_history['aux']):
  with anim.frame():
    # Plot fields, using some of the methods specific to the underlying ceviche model.
    density = challenge.component.ceviche_model.density(rho.reshape((dom_mesh.nelx,
                                                                     dom_mesh.nely)))

    ax = plt.subplot(121)
    img = ax.imshow(density, cmap="gray")
    plt.text(100, 90, f"step {i:02}", color="w", fontsize=20)
    ax.axis(False)
    plt.colorbar(img)
    ax.set_xlim(ax.get_xlim()[::-1])
    ax.set_ylim(ax.get_ylim()[::-1])

    # Plot the field, which is a part of the `aux` returned with the challenge response.
    # The field will be overlaid with contours of the binarized design.
    field = onp.real(aux["fields"])
    field = field[0, 0, :, :]  # First wavelength, first excitation port.
    contours = measure.find_contours(density)

    ax = plt.subplot(122)
    im = ax.imshow(field, cmap="bwr")
    im.set_clim([-onp.amax(field), onp.amax(field)])
    for c in contours:
        plt.plot(c[:, 1], c[:, 0], "k", lw=1)
    ax.axis(False)
    ax.set_xlim(ax.get_xlim()[::-1])
    ax.set_ylim(ax.get_ylim()[::-1])

anim.save_gif("waveguide_bend.gif", duration=200)