In [None]:
import sys
sys.path.append('../')

import pickle
import yaml
import numpy as np
import jax
import jax.numpy as jnp
import mma
from datetime import datetime

import mesher
import utils
import network
import transforms
import opt
import projections

import plot_utils
import invrs_gym
from totypes import types
from skimage import measure

import matplotlib.pyplot as plt

_Ext = utils.Extent

jax.config.update("jax_debug_nans", True)
plt.rcParams.update(plot_utils.high_res_plot_settings)

# Settings

In [None]:
with open("./settings.yaml", "r") as file:
  config = yaml.safe_load(file)

dom_mesh_cfg = config['DOM_MESH']
dom_bbox_cfg = config['DOM_BBOX']

stamp_mesh_cfg = config['STAMP_MESH']
stamp_bbox_cfg = config['STAMP_BBOX']

nn_cfg = config['IMPLICIT_NN']

cons_cfg = config['CONSTRAINTS']
topopt_cfg = config['TOPOPT']

# Load the vae network

In [None]:
latent_dim = nn_cfg['latent_dim']
implicit_hidden_dim = nn_cfg['hidden_dim']
implicit_num_layers = nn_cfg['num_layers']
implicit_siren_freq = nn_cfg['siren_freq']


sdf_net = network.ConvoImplicitAutoEncoder(latent_dim=latent_dim,
                                           implicit_hidden_dim=implicit_hidden_dim,
                                           implicit_num_layers=implicit_num_layers,
                                           implicit_siren_freq=implicit_siren_freq)

with open('../data/sdf_vae_net_weights.pkl', 'rb') as f:
  sdf_net_params = pickle.load(f)

# define the stamp parameters

In [None]:
stamp_bbox = mesher.BoundingBox(x=_Ext(stamp_bbox_cfg['x_min'],
                                       stamp_bbox_cfg['x_max']),
                                y=_Ext(stamp_bbox_cfg['y_min'],
                                       stamp_bbox_cfg['y_max'])
                                )
default_stamp_mesh = mesher.Mesher(nelx=stamp_mesh_cfg['nelx'],
                                   nely=stamp_mesh_cfg['nely'],
                                   bounding_box=stamp_bbox
                                   )

library_stamp_sdfs = np.load('../data/train_sdf_images.npy')

## Get the encoded Zs

In [None]:
pred_enc_stamps, _,_, encoded_z = sdf_net.apply({'params': sdf_net_params},
                              library_stamp_sdfs,
                              default_stamp_mesh.elem_centers,
                              False)

In [None]:
num_stamps_library, latent_dim = encoded_z.shape

In [None]:
min_encoded_coordn = jnp.amin(encoded_z, axis=0)
max_encoded_coordn = jnp.amax(encoded_z, axis=0)

In [None]:
dom_bbox = mesher.BoundingBox(x=_Ext(dom_bbox_cfg['x_min'],
                                     dom_bbox_cfg['x_max']),
                              y=_Ext(dom_bbox_cfg['y_min'], 
                                     dom_bbox_cfg['y_max'])
                              )

dom_mesh = mesher.Mesher(nelx=dom_mesh_cfg['nelx'],
                         nely=dom_mesh_cfg['nely'],
                         bounding_box=dom_bbox)

# Define the transform extents

In [None]:
scale_factor = 0.5
min_feature_size = cons_cfg['mfs']
lib_mfs = cons_cfg['library_mfs']

dom_bbox_padded = mesher.BoundingBox(
    x=_Ext(dom_bbox_cfg['x_min'], dom_bbox_cfg['x_max']).pad(pad_amount=-80),
    y=_Ext(dom_bbox_cfg['y_min'], dom_bbox_cfg['y_max']).pad(pad_amount=-80),
)

max_scale = scale_factor*(dom_bbox_padded.diag_length/stamp_bbox.diag_length)

transform_extent = transforms.TransformExtent(trans_x=dom_bbox_padded.x,
                                              trans_y=dom_bbox_padded.y,
                                              rot_rad=_Ext(0., 2*np.pi),
                                              scale=_Ext(min_feature_size/lib_mfs, max_scale))

# Init opt design var

## Init transform

In [None]:
num_stamp_x, num_stamp_y = topopt_cfg['num_stamps_x'], topopt_cfg['num_stamps_y']
num_stamps = num_stamp_x * num_stamp_y
init_transform = transforms.init_grid_transforms(num_stamp_x,
                                                 num_stamp_y,
                                                 transform_extent)

## Initialize latent

In [None]:
np_rng = np.random.default_rng(0)

# We just feed the normalized array here
init_latent_var = np_rng.normal(scale=2., size=(num_stamps,latent_dim))
init_latent_guess = opt.normalize_latent_coordns(init_latent_var,
                                              min_encoded_coordn,
                                             max_encoded_coordn)

In [None]:
init_guess = jnp.concatenate((init_latent_guess.reshape((-1)),
                          init_transform.to_normalized_array(transform_extent)),
                              axis=-1)

# Optimization loss parameters

In [None]:
min_separation = cons_cfg['msd']

# Define the solver and challenge

In [None]:
mesh_resolution_nm = dom_bbox.lx/dom_mesh.nelx

challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend(
                                            resolution_nm=mesh_resolution_nm)

# mode convertor challenge
# challenge = invrs_gym.challenges.ceviche_lightweight_mode_converter(
#                                              resolution_nm=mesh_resolution_nm)

In [None]:
key = jax.random.PRNGKey(seed=1)
init_density = jax.random.uniform(key, (dom_mesh.nelx, dom_mesh.nely),
                                 minval=0., maxval=1.)
dens_array = types.Density2DArray(array=init_density, lower_bound=0.,
                                  upper_bound=1.)

# Optimization parameters

In [None]:
num_epochs, lr = topopt_cfg['num_epochs'], topopt_cfg['lr']

In [None]:
num_latent_params = latent_dim*num_stamps
num_transform_params = init_transform.num_transform_parameters
num_opt_params = (num_transform_params + 
                    num_latent_params)

In [None]:
def optimize_design(init_guess: jnp.ndarray,
                    min_separation: float,
                    dom_mesh: mesher.Mesher,
                    num_stamps: int,
                    transform_extent: transforms.TransformExtent,
                    num_epochs: int,
                    lr: float,
                    mma_state_array: mma.MMAState = None,
                    plot_interval: int = 1,
                    checkpoint_interval: int = 10)->jnp.ndarray:
  """
  Optimize the design using MMA.
  
  Args:
    init_guess: Normalized initial guess for the optimization with values in [0, 1].
    min_separation: Minimum separation constraint for the design.
    dom_mesh: Mesh object containing mesh information.
    num_stamps: Number of stamps for the optimization.
    transform_extent: Transformation extent object.
    num_epochs: Number of optimization epochs.
    lr: Learning rate for the optimization.
    mma_state_array: Optional initial state for MMA.
    plot_interval: Interval for plotting the design progress.
    checkpoint_interval: Interval for saving checkpoints.
  
  Returns:
    Tuple containing the density, optimized parameters, convergence history, and final MMA state.
  """
  num_design_var = init_guess.shape[0]
  lower_bound = np.zeros((num_design_var, 1))
  upper_bound = np.ones((num_design_var, 1))
  mma_params = mma.MMAParams(
                              max_iter=num_epochs,
                              kkt_tol=1.e-6,
                              step_tol=1.e-6,
                              move_limit=lr,
                              num_design_var=num_design_var,
                              num_cons=2,
                              lower_bound=lower_bound,
                              upper_bound=upper_bound,
                          )
  if mma_state_array is None:
    mma_state = mma.init_mma(init_guess.reshape((-1, 1)), mma_params)
  else:
    mma_state = mma.MMAState.from_array(mma_state_array, num_design_var)
    mma_state.is_converged = False

  convg_history = {'epoch': [], 'response': [], 'objective': [],
                   'sep_cons': [], 'lat_cons': [], 'density': [], 'aux': []}
  while not mma_state.is_converged:
    obj ,grad_obj, auxs = opt.compute_objective(mma_state.x,
                                          transform_extent,
                                          challenge,
                                          dom_mesh,
                                          sdf_net,
                                          sdf_net_params,
                                          stamp_bbox,
                                          dens_array,
                                          num_stamps,
                                          num_latent_params,
                                          latent_dim,
                                          min_encoded_coordn,
                                          max_encoded_coordn)
    cons, grad_cons = opt.compute_constraint(mma_state.x,
                                         min_separation,
                                         transform_extent,
                                         dom_mesh,
                                         sdf_net,
                                         sdf_net_params,
                                         stamp_bbox,
                                         num_latent_params,
                                         latent_dim,
                                         min_encoded_coordn,
                                         max_encoded_coordn,
                                         encoded_z,
                                         num_stamps,
                                         mma_state.epoch)
    
    opt_params = mma_state.x

    mma_state.x = np.array(mma_state.x)
    mma_state = mma.update_mma(
                              mma_state,
                              mma_params,
                              np.array(obj),
                              np.array(grad_obj),
                              np.array(cons),
                              np.array(grad_cons)
                            )
    
    (shape_transforms, pred_stamp_latent_coordns, 
     density, response, distance, metrics, aux) = auxs

    convg_history['epoch'].append(mma_state.epoch)
    convg_history['response'].append(response)
    convg_history['objective'].append(obj)
    convg_history['sep_cons'].append(cons[0])
    convg_history['lat_cons'].append(cons[1])
    convg_history['density'].append(density)
    convg_history['aux'].append(aux)


    status = (f'epoch {mma_state.epoch}, J = {obj[0]:.2E}, '
              f'sep_cons {cons[0,0]:.2F}, lat_cons {cons[1,0]:.2F}')
    print(status)
    if mma_state.epoch % plot_interval == 0:
      plt.figure()
      img = plt.imshow(density.reshape((dom_mesh.nelx, dom_mesh.nely)).T,
                       cmap='coolwarm',
                       origin='lower')
      plt.colorbar(img); plt.title(status); plt.show(); plt.pause(1.e-3)

      plt.figure()
      plt.scatter(pred_stamp_latent_coordns[:,0], pred_stamp_latent_coordns[:,1],
                  c='red', marker='*')
      plt.scatter(encoded_z[:,0], encoded_z[:,1],
                  c='blue', marker='o')
      plt.show(); plt.pause(1e-6)
    
    if mma_state.epoch%checkpoint_interval == 0:
      np.save(f"../results/mma_state_{mma_state.epoch}.npy", mma_state.to_array())
  return density, opt_params, convg_history, mma_state

In [None]:
# load mma state if available
mma_state_array = np.load('../results/mma_state_20.npy')
# mma_state_array = None

In [None]:
density, opt_params, convg_history, mma_state = optimize_design(
                                init_guess,
                                min_separation,
                                dom_mesh,
                                num_stamps,
                                transform_extent,
                                num_epochs,
                                lr,
                                mma_state_array)

In [None]:
# save the data
now = datetime.now()
save_file = f"../results/convergence_waveguide_bend_{now:%Y-%m-%d-%H-%M}.pkl"
with open(save_file, 'wb') as f:
  pickle.dump(convg_history, f)

# save the mma state
mma_save_file = f"../results/mma_waveguide_bend_{now:%Y-%m-%d-%H-%M}.npy"
np.save(mma_save_file, mma_state.to_array())