In [None]:
import sys
sys.path.append('../')

import pickle
import yaml
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import invrs_gym
from totypes import types
from skimage import measure
import gifcm


import mesher
import utils
import projections
import transforms
import sdf_ops
import opt
import network
import mma

import plot_utils

plt.rcParams.update(plot_utils.high_res_plot_settings)

_Ext = utils.Extent

In [None]:
with open("./settings.yaml", "r") as file:
  config = yaml.safe_load(file)

dom_mesh_cfg = config['DOM_MESH']
dom_bbox_cfg = config['DOM_BBOX']
nn_cfg = config['IMPLICIT_NN']
stamp_bbox_cfg = config['STAMP_BBOX']
stamp_mesh_cfg = config['STAMP_MESH']
cons_cfg = config['CONSTRAINTS']
topopt_cfg = config['TOPOPT']

# Define mesh

In [None]:
dom_bbox = mesher.BoundingBox(x=_Ext(dom_bbox_cfg['x_min'],
                                     dom_bbox_cfg['x_max']),
                              y=_Ext(dom_bbox_cfg['y_min'], 
                                     dom_bbox_cfg['y_max'])
                              )

dom_mesh = mesher.Mesher(nelx=dom_mesh_cfg['nelx'],
                         nely=dom_mesh_cfg['nely'],
                         bounding_box=dom_bbox)

stamp_bbox = mesher.BoundingBox(x=_Ext(stamp_bbox_cfg['x_min'],
                                       stamp_bbox_cfg['x_max']),
                                y=_Ext(stamp_bbox_cfg['y_min'],
                                       stamp_bbox_cfg['y_max'])
                                )
default_stamp_mesh = mesher.Mesher(nelx=stamp_mesh_cfg['nelx'],
                                   nely=stamp_mesh_cfg['nely'],
                                   bounding_box=stamp_bbox
                                   )

# Load stamps

In [None]:
stamp_bbox = mesher.BoundingBox(x=_Ext(stamp_bbox_cfg['x_min'],
                                       stamp_bbox_cfg['x_max']),
                                y=_Ext(stamp_bbox_cfg['y_min'],
                                       stamp_bbox_cfg['y_max'])
                                )
default_stamp_mesh = mesher.Mesher(nelx=stamp_mesh_cfg['nelx'],
                                   nely=stamp_mesh_cfg['nely'],
                                   bounding_box=stamp_bbox
                                   )

library_stamp_sdfs = np.load('../data/train_sdf_images.npy')

# Define the transform extents

In [None]:
scale_factor = 0.5
min_feature_size = cons_cfg['mfs']
lib_mfs = cons_cfg['library_mfs']

dom_bbox_padded = mesher.BoundingBox(
    x=_Ext(dom_bbox_cfg['x_min'], dom_bbox_cfg['x_max']).pad(pad_amount=-80),
    y=_Ext(dom_bbox_cfg['y_min'], dom_bbox_cfg['y_max']).pad(pad_amount=-80),
)

max_scale = scale_factor*(dom_bbox_padded.diag_length/stamp_bbox.diag_length)

transform_extent = transforms.TransformExtent(trans_x=dom_bbox_padded.x,
                                              trans_y=dom_bbox_padded.y,
                                              rot_rad=_Ext(0., 2*np.pi),
                                              scale=_Ext(min_feature_size/lib_mfs, max_scale))

# Load the vae network

In [None]:
latent_dim = nn_cfg['latent_dim']
implicit_hidden_dim = nn_cfg['hidden_dim']
implicit_num_layers = nn_cfg['num_layers']
implicit_siren_freq = nn_cfg['siren_freq']


sdf_net = network.ConvoImplicitAutoEncoder(latent_dim=latent_dim,
                                           implicit_hidden_dim=implicit_hidden_dim,
                                           implicit_num_layers=implicit_num_layers,
                                           implicit_siren_freq=implicit_siren_freq)

with open('../data/sdf_vae_net_weights.pkl', 'rb') as f:
  sdf_net_params = pickle.load(f)

## Get the encoded Zs

In [None]:
pred_enc_stamps, _,_, encoded_z = sdf_net.apply({'params': sdf_net_params},
                              library_stamp_sdfs,
                              default_stamp_mesh.elem_centers,
                              False)
min_encoded_coordn = jnp.amin(encoded_z, axis=0)
max_encoded_coordn = jnp.amax(encoded_z, axis=0)

## Load the history

In [None]:
save_file = f"../results/convergence_mode_convertor_2024-07-01-18-02.pkl"
with open(save_file, 'rb') as f:
  convg_history = pickle.load(f)

mma_state_array = np.load('../results/mma_mode_convertor_2024-07-01-18-02.npy')

# TODO: Save the num_design_var to yaml in vae_topopt and read it from there here
num_stamp_x, num_stamp_y = topopt_cfg['num_stamps_x'], topopt_cfg['num_stamps_y']
num_stamps = num_stamp_x * num_stamp_y
num_latent_params = latent_dim*num_stamps
mma_state = mma.MMAState.from_array(mma_state_array, num_design_var=topopt_cfg['num_design_var'])


## Convergence

In [None]:
# Create the figure and primary axis
fig, ax1 = plt.subplots()

# Plot field_1 on the primary y-axis
ax1.plot(convg_history['epoch'],
          convg_history['objective'],
          'k-', label='Objective')
ax1.set_ylabel('Objective')

# Create a twin axis for field2 and field3
ax2 = ax1.twinx()

# Plot field2 on the secondary y-axis
ax2.plot(convg_history['epoch'],
         convg_history['sep_cons'],
         'g:.', label='Seperation cons')
ax2.set_ylabel('Constraints', color='red')

# Plot field3 on the secondary y-axis with a different linestyle
ax2.plot(convg_history['epoch'],
         convg_history['lat_cons'],
         'r--', label='Latent cons')
ax2.tick_params(axis='y', labelcolor='red')

# Set labels and title
ax1.set_xlabel('Epoch')


# plt.title('Two Y-Axis Plot')

# Add legend
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
plt.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

plt.show()

## Get the S-params at certain iterations for more wvlengths

In [None]:
mesh_resolution_nm = dom_bbox.lx/dom_mesh.nelx

challenge = invrs_gym.challenges.ceviche_waveguide_bend(
                                      resolution_nm=mesh_resolution_nm)

# mode convertor
# challenge = invrs_gym.challenges.ceviche_mode_converter(
#                                               resolution_nm=mesh_resolution_nm)


# density array
key = jax.random.PRNGKey(seed=1)
init_density = jax.random.uniform(key, (dom_mesh.nelx, dom_mesh.nely),
                                 minval=0., maxval=1.)
dens_array = types.Density2DArray(array=init_density, lower_bound=0.,
                                  upper_bound=1.)

In [None]:
for epoch in [0, 20, 110]:
  curr_dens = convg_history['density'][epoch]
  dens_array.array = curr_dens
  response, aux = challenge.component.response(dens_array)

  fig, ax = plt.subplots(1,1)
  ax.plot(
            response.wavelengths_nm,
            20 * np.log10(np.abs(response.s_parameters[:,0,0])),
            "r-",
            label="$|S_{11}|^2$",
        )
  ax.plot(
          response.wavelengths_nm,
          20 * np.log10(np.abs(response.s_parameters[:,0,1])),
          "b--",
          label="$|S_{21}|^2$",
      )
  ax.legend()
  ax.set_xlabel('wavelength')
  ax.set_ylabel('scattering param')
  ax.set_xlim(np.amin(response.wavelengths_nm), np.amax(response.wavelengths_nm))
  ax.set_ylim([1., -40.])
  ax.set_title(f'epoch {epoch}')
  ax.invert_yaxis()


  full_density = challenge.component.ceviche_model.density(curr_dens)

  field = np.real(convg_history['aux'][epoch]["fields"])
  field = field[0, 0, :, :]  # First wavelength, first excitation port.
  contours = measure.find_contours(full_density)

  fig, ax = plt.subplots(1,1)
  im = ax.imshow(field, cmap="bwr")
  im.set_clim([-np.amax(field), np.amax(field)])
  for c in contours:
    plt.plot(c[:, 1], c[:, 0], "k", lw=1)
  ax.axis(False)
  ax.set_xlim(ax.get_xlim()[::-1])
  ax.set_ylim(ax.get_ylim()[::-1])
  plt.show()

## High resolution plot

In [None]:
high_res_dom_mesh = mesher.Mesher(nelx=5*dom_mesh_cfg['nelx'],
                         nely=5*dom_mesh_cfg['nely'],
                         bounding_box=dom_bbox)

In [None]:
opt_params = mma_state.x


(pred_stamp_latent_coordns, 
  shape_transforms) = opt.compute_transforms_and_latent_coordn_from_opt_params(
                                                      opt_params.flatten(),
                                                      num_stamps,
                                                      transform_extent,
                                                      num_latent_params,
                                                      latent_dim,
                                                      min_encoded_coordn,
                                                      max_encoded_coordn
                                                      )
shape_sdfs = opt.compute_shape_sdfs(sdf_net,
                              sdf_net_params,
                              high_res_dom_mesh,
                              shape_transforms,
                              pred_stamp_latent_coordns,
                              stamp_bbox)
shape_densities = sdf_ops.project_sdf_to_density(shape_sdfs, high_res_dom_mesh)
density = sdf_ops.compute_union_density_fields(shape_densities).reshape((
                                                              high_res_dom_mesh.nelx,
                                                              high_res_dom_mesh.nely))
density = projections.threshold_filter(density)
dens_array.array = density


plt.figure()
img = plt.imshow(density.reshape((high_res_dom_mesh.nelx, 
                                  high_res_dom_mesh.nely)).T,
                                  cmap='coolwarm',
                                  origin='lower')
plt.colorbar(img)

# Colored shapes plot

In [None]:
delta_des_avail_latent = (pred_stamp_latent_coordns[:, np.newaxis, :] - 
                            encoded_z[np.newaxis, :, :])#{dal}
  
dist_des_avail_latent = jnp.linalg.norm(delta_des_avail_latent, axis=-1) #{da}
nearest_shape_idx = np.argmin(dist_des_avail_latent, axis=1)

shape_colors = plot_utils.shape_lib_color_palette[nearest_shape_idx, :]
shape_densities_filtered = projections.threshold_filter(shape_densities, 100)
colored_shapes =   1. - np.einsum('sc, sp -> pc', shape_colors, shape_densities_filtered)

In [None]:
transpose_image = np.transpose(colored_shapes.reshape(high_res_dom_mesh.nely, high_res_dom_mesh.nelx, 3), (1, 0, 2))
plt.imshow(transpose_image, origin='lower')
plt.show()

# Scattering param Convergence GIF

In [None]:
anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))

for (i, response) in zip(convg_history['epoch'],
                         convg_history['response']):
  with anim.frame():
    ax = plt.subplot(111)
    ax.plot(
            response.wavelengths_nm,
            20 * np.log10(np.abs(response.s_parameters[:,0,0])),
            "o-",
            label="$|S_{11}|^2$",
        )
    ax.plot(
            response.wavelengths_nm,
            20 * np.log10(np.abs(response.s_parameters[:,0,1])),
            "o-",
            label="$|S_{21}|^2$",
        )
    ax.legend()
    ax.set_xlabel('wavelength')
    ax.set_ylabel('scattering param')
    ax.set_xlim(np.amin(response.wavelengths_nm), np.amax(response.wavelengths_nm))
    ax.set_ylim([1., -40.])
    ax.set_title(f'epoch {i}')
    ax.invert_yaxis()

anim.save_gif("s_param.gif", duration=400)

# Design Convergence GIF

In [None]:
def animate():
  anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))

  for (i, rho, aux) in zip(convg_history['epoch'],
                           convg_history['density'],
                           convg_history['aux']):
    with anim.frame():
      # Plot fields, using some of the methods specific to the underlying ceviche model.
      density = challenge.component.ceviche_model.density(rho.reshape((dom_mesh.nelx,
                                                                      dom_mesh.nely)))

      ax = plt.subplot(121)
      img = ax.imshow(density, cmap="gray")
      plt.text(100, 90, f"step {i:02}", color="w", fontsize=20)
      ax.axis(False)
      plt.colorbar(img)
      ax.set_xlim(ax.get_xlim()[::-1])
      ax.set_ylim(ax.get_ylim()[::-1])

      # Plot the field, which is a part of the `aux` returned with the challenge response.
      # The field will be overlaid with contours of the binarized design.
      field = np.real(aux["fields"])
      field = field[0, 0, :, :]  # First wavelength, first excitation port.
      contours = measure.find_contours(density)

      ax = plt.subplot(122)
      im = ax.imshow(field, cmap="bwr")
      im.set_clim([-np.amax(field), np.amax(field)])
      for c in contours:
          plt.plot(c[:, 1], c[:, 0], "k", lw=1)
      ax.axis(False)
      ax.set_xlim(ax.get_xlim()[::-1])
      ax.set_ylim(ax.get_ylim()[::-1])

  anim.save_gif("waveguide_bend.gif", duration=200)

animate()