In [1]:
import ceviche_challenges
from ceviche_challenges import units as u
from ceviche_challenges.model_base import _wavelengths_nm_to_omegas

from ceviche import viz, fdfd_ez
from ceviche import jacobian

import autograd
import autograd.numpy as npa
import jax
import jax.numpy as jnp

import numpy as np
import matplotlib.pyplot as plt

In [2]:
from inverse_design.brushes import notched_square_brush, circular_brush
from inverse_design.conditional_generator import (
    new_latent_design, transform
)
from tqdm.notebook import trange

from javiche import jaxit

from inverse_design.local_generator import generate_feasible_design_mask
from jax.example_libraries.optimizers import adam

# Define the problem using ceviche_challenges

In [3]:
spec = ceviche_challenges.waveguide_bend.prefabs.waveguide_bend_2umx2um_spec(
    wg_width=400*u.nm, variable_region_size=(1600*u.nm, 1600*u.nm), cladding_permittivity=2.25
)
params = ceviche_challenges.waveguide_bend.prefabs.waveguide_bend_sim_params(resolution = 25 * u.nm,
                                                                             wavelengths=u.Array([1270], u.nm))
model = ceviche_challenges.waveguide_bend.model.WaveguideBendModel(params, spec)

# Define the optimization function

In [4]:
def db_to_au(x):
    return npa.power(10, x / 10)

S_cutoff_dB = npa.array([-20., -0.5])
S_cutoff = db_to_au(S_cutoff_dB)
g = npa.array([-1 if x < S_cutoff.max() else +1 for x in S_cutoff])
w_valid = npa.array([x if x < S_cutoff.max() else 1-x for x in S_cutoff])

In [5]:
print(S_cutoff)
print(g)
print(w_valid)

[0.01       0.89125094]
[-1  1]
[0.01       0.10874906]


In [6]:
np.random.seed(42)
S = np.random.rand(2)
S = S / (1.1 * np.sum(S))

print(jnp.abs(S) ** 2, jnp.abs(S_cutoff) ** 2)
print(jnp.abs(S) ** 2 - jnp.abs(S_cutoff) ** 2)
print(g * (jnp.abs(S) ** 2 - jnp.abs(S_cutoff) ** 2))
print(g * (jnp.abs(S) ** 2 - jnp.abs(S_cutoff) ** 2) / jnp.min(w_valid))
print(jax.nn.softplus(g * (jnp.abs(S) ** 2 - jnp.abs(S_cutoff) ** 2) / jnp.min(w_valid)))

[0.06601047 0.42532036] [9.9999997e-05 7.9432815e-01]
[ 0.06591047 -0.3690078 ]
[-0.06591047 -0.3690078 ]
[ -6.591047 -36.90078 ]
[1.3716612e-03 9.4231264e-17]


In [7]:
min_w = np.min(w_valid)

# @jaxit()
def objective_S(rho):
    s_params, _ = model.simulate(rho)

    s11 = npa.abs(s_params[:, 0, 0])
    s21 = npa.abs(s_params[:, 0, 1])

    return s11, s21

def loss_function(S):
    temp = g * (jnp.abs(S) ** 2 - jnp.abs(S_cutoff) ** 2) / jnp.min(w_valid)
    return jnp.linalg.norm(jax.nn.softplus(temp)) ** 2

def ad_softplus(x):
    return npa.log(1 + npa.exp(x))

def ad_loss_function(S):
    temp = g * (npa.square(npa.abs(S)) - npa.square(npa.abs(S_cutoff))) / min_w
    return npa.linalg.norm(ad_softplus(temp)) ** 2

@jaxit()
def ad_jaxed(rho):
    S = objective_S(rho)
    return ad_loss_function(S)

In [8]:
# objective_S(latent)
autograd.elementwise_grad(objective_S)(latent).max()
# autograd.value_and_grad(objective_S, latent)

NameError: name 'latent' is not defined

# Then, the inverse design aka the conditional generator part

In [None]:
def forward(latent_weights, brush):
    latent_t = transform(latent_weights, brush) #.reshape((Nx, Ny))
    design_mask = generate_feasible_design_mask(latent_t, 
      brush, verbose=False)
    design = (design_mask+1.0)/2.0
    return design

In [None]:
brush = circular_brush(5)
latent = new_latent_design(model.design_variable_shape, bias=0.1, r=1, r_scale=1e-3)

In [None]:
def loss_fn(latent):
    design = forward(latent, brush)
    # S = objective_S(design)
    # return loss_function(S)
    loss = ad_jaxed(design)
    return loss

# Now we define the optimization

In [None]:
# Number of epochs in the optimization
Nsteps = 10
# Parameters for the Adam optimizer
step_size = 0.01
beta1 = 0.667
beta2 = 0.9

In [None]:
grad_fn = jax.grad(loss_fn)

init_fn, update_fn, params_fn = adam(step_size)
state = init_fn(latent) #.flatten()
#value_and_grad seems to have a problem. Figure out why!

def step_fn(step, state):
    latent = params_fn(state) # we need autograd arrays here...
    grads = grad_fn(latent)
    loss = loss_fn(latent)
    #loss = loss_fn(latent)

    optim_state = update_fn(step, grads, state)
    # optim_latent = params_fn(optim_state)
    # optim_latent = optim_latent/optim_latent.std()

    # visualize_debug()
    # visualize_all(latent)
    return loss, optim_state

In [None]:
init_fn, update_fn, params_fn = adam(step_size)
state = init_fn(npa.array(latent))

latent = params_fn(state)
grads = grad_fn(latent)

#|eval:false
range_ = trange(Nsteps)
losses = np.ndarray(Nsteps)
for step in range_:
    print(step)
    loss, state = step_fn(step, state)
    losses[step] = loss
    range_.set_postfix(loss=float(loss))