# Imports

In [None]:
from typing import Tuple, Dict, Callable
import yaml
import pickle
import copy
from datetime import datetime
import numpy as np
import jax.numpy as jnp
from jax import value_and_grad
import jax
import matplotlib.pyplot as plt

import mesher
import bound_cond
import material
import fea
import polygon
import tree
import mma
import utils
import plot_utils


jax.config.update("jax_debug_nans", True)
plt.rcParams.update(plot_utils.high_res_plot_settings)

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 100
_Ext = utils.Extent
poly_cmap = plot_utils.poly_cmap

# Configs

In [None]:
with open('settings.yaml', 'r') as file:
  config = yaml.safe_load(file)

cfg_bbox = config['BBOX']
cfg_mesh = config['MESH']
cfg_mat = config['MATERIAL']
cfg_bc = config['BOUNDARY_COND']
cfg_opt = config['OPTIMIZATION']
cfg_quad = config['QUADRIC']
cfg_tree = config['TREE']

# Mesh

In [None]:
bbox = mesher.BoundingBox(x=_Ext(cfg_bbox['x_min'], cfg_bbox['x_max']),
                          y=_Ext(cfg_bbox['y_min'], cfg_bbox['y_max']))
mesh = mesher.BilinearStructMesher(nelx=cfg_mesh['nelx'],
                                   nely=cfg_mesh['nely'],
                                   bounding_box=bbox)

# Material

In [None]:
mat = material.Material(youngs_modulus=cfg_mat['youngs_modulus'],
                        poissons_ratio=cfg_mat['poissons_ratio'],
                        )

# Boundary condition

In [None]:
bc = bound_cond.get_sample_struct_bc(mesh, bound_cond.SturctBCs[cfg_bc['example']])

# Solver

In [None]:
solver = fea.FEA(mesh, mat, bc)

# Init tree

In [None]:
num_levels = cfg_tree['num_levels'] # 6 -> corresponds to 2^6 = 64 base polygons
num_shapes = 2**num_levels
num_operators = 2**num_levels - 1

# Init poly params

In [None]:
poly_ext = polygon.PolygonExtents(
  num_planes_in_a_poly=6,
  num_polys=num_shapes,
  center_x=_Ext(cfg_quad['cx_min'], cfg_quad['cx_max']),
  center_y=_Ext(cfg_quad['cy_min'], cfg_quad['cy_max']),
  angle_offset=_Ext(cfg_quad['th_min'], cfg_quad['th_max']),
  face_offset=_Ext(cfg_quad['a_min'], cfg_quad['a_max']),
)

# Init design var

In [None]:
seed = 0
init_polys = polygon.init_random_polys(poly_ext, seed=seed)

init_prim_des_var = init_polys.to_normalized_array(poly_ext)

num_prim_des_var = init_prim_des_var.shape[0]

# init operators
np.random.seed(seed)
init_operations = np.random.uniform(low=0., high=1., size=(num_operators*4,))
num_oper_des_var = init_operations.shape[0]

init_design_var = np.hstack((init_prim_des_var, init_operations)).reshape((-1, 1))

# Optimizer

In [None]:
class Optimizer:


  def __init__(self,
               lossfn: Callable,
               mesh: mesher.BilinearStructMesher):
    self.lossfn, self.mesh = lossfn, mesh

  def get_poly_params_and_operator_from_design_var(self,
                                                    design_var: jnp.ndarray,
                                                     softmax_temperature: float
                                                     )-> Tuple[polygon.ConvexPolys,
                                                                 jnp.ndarray]:
    
    """
      Computes polygon parameters and operations from the design variables.

      Args:
        design_var : Array representing the design variables.
        softmax_temperature : Temperature parameter for the softmax function that makes the operators more closer to 
                              a one-hot encoding.
      Returns:
        Tuple: A tuple containing the polygon parameters and the operations array.
      """
    poly_des_var = design_var[:num_prim_des_var]
    oper_des_var = design_var[num_prim_des_var:].reshape((num_operators, 4))

    poly_params = polygon.ConvexPolys.from_normalized_array(poly_des_var,
                                                            poly_extents=poly_ext)
    operations = tree.get_operators_from_design_var(oper_des_var,
                                                    t=softmax_temperature)

    return poly_params, operations
  def constraint_fn(self,
                    design_var: jnp.ndarray,
                    max_vol_frac: float
                    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
    
    """
      Computes the volume constraint function for optimization.

      Args:
        design_var: Array representing the design variables: primitive params and operators.
        max_vol_frac : Maximum volume fraction constraint.

      Returns:
        Tuple: A tuple containing the volume constraint value and its gradient.
      """


    def volume_constraint(design_var:jnp.ndarray)-> float:
      softmax_temperature = min(400., 10. + self.mma_state.epoch*5.)
      proj_sharpness = min(12., 1. + self.mma_state.epoch*0.06)

      poly_params, operations = self.get_poly_params_and_operator_from_design_var(
                                                design_var, softmax_temperature)
      poly_sdf = polygon.compute_poly_sdf(poly_params, self.mesh)
      poly_density = polygon.project_sdf_to_density(poly_sdf, self.mesh)
      proj_poly_density = material.projection_filter(poly_density,
                                                        beta=proj_sharpness)
      csg_root = tree.eval_binary_csg_tree(proj_poly_density, operations)
      design_density = material.projection_filter(
                               csg_root.value, beta=proj_sharpness)
      occupied_volume = jnp.einsum('i,i->i',self.mesh.elem_area, design_density).sum()
      volcons =  occupied_volume/(max_vol_frac*self.mesh.domain_volume) - 1.

      return volcons

    vc, dvc = value_and_grad(volume_constraint)(design_var.reshape((-1)))

    return jnp.array([vc]).reshape((-1,1)), dvc.reshape((1,-1))

  def objective_fn(self,
                   design_var: jnp.ndarray
                   )-> Tuple[jnp.ndarray, jnp.ndarray]:
    
    """
      Computes the objective function (complaince) and its gradient for the optimization.

      Args:
        design_var: Array representing the design variables: primitive params and operators.

      Returns:
        Tuple: A tuple containing the objective value (complaince), its gradient, and auxiliary data.
    """

    def obj_wrapper(design_var:jnp.ndarray)-> float:

      softmax_temperature = min(400., 10. + self.mma_state.epoch*5.)
      proj_sharpness = min(12., 1. + self.mma_state.epoch*0.06)
      poly_params, operations = self.get_poly_params_and_operator_from_design_var(
                                                design_var, softmax_temperature)
      poly_sdf = polygon.compute_poly_sdf(poly_params, self.mesh)
      poly_density = polygon.project_sdf_to_density(poly_sdf, self.mesh)
      proj_poly_density = material.projection_filter(poly_density,
                                                        beta=proj_sharpness)
      csg_root = tree.eval_binary_csg_tree(proj_poly_density, operations)
      design_density = material.projection_filter(
                               csg_root.value, beta=proj_sharpness) + 2.e-2
      obj = self.lossfn(design_density.reshape((-1)))
      return obj, (design_density, poly_params, operations, csg_root)

    ((obj,
      (density, poly_params, operations, csg_root)), 
      grad_obj) = value_and_grad(obj_wrapper, has_aux=True
                                              )(design_var.reshape((-1)))
    return jnp.array([obj]), grad_obj.reshape((-1, 1)), (density, poly_params,
                                                         operations, csg_root)


  def optimize(self,
               init_design_var: jnp.ndarray,
               max_vol_frac: float,
               max_iter: int,
               kkt_tol: float=1e-6,
               step_tol: float=1e-6,
               move_limit: float=5e-2,
               plot_interval: int=5,
               )->Tuple[mma.MMAState, Dict]:
    
    """
    Optimize the design using MMA.
    
    Args:
      init_guess: Normalized initial guess for the optimization with values in [0, 1].
      max_vol_frac: Allowable volume fraction.
      max_iter: Number of optimization epochs.
      kkt_tol : Tolerance for the Karush-Kuhn-Tucker condition.
      step_tol : Tolerance for the step size.
      move_limit: Learning rate for the optimization.
      plot_interval: Interval for plotting the design progress.
    
    Returns:
      Tuple containing the final mma state and  convergence history.
    """

    self.num_design_var = init_design_var.shape[0]
    self.design_var = init_design_var
    lower_bound = np.zeros((self.num_design_var, 1))
    upper_bound = np.ones((self.num_design_var, 1))
    mma_params = mma.MMAParams(
        max_iter=max_iter,
        kkt_tol=kkt_tol,
        step_tol=step_tol,
        move_limit=move_limit,
        num_design_var=self.num_design_var,
        num_cons=1,
        lower_bound=lower_bound,
        upper_bound=upper_bound,
    )
    self.mma_state = mma.init_mma(self.design_var, mma_params)

    def plotfun(x, status = ''):
      fig, ax = plt.subplots(1,1)
      plt.imshow(x.reshape((self.mesh.nelx, self.mesh.nely)).T,
                  cmap=poly_cmap, origin='lower')
      ax.set_xticks([])
      ax.set_yticks([])
      plt.title(status)
      plt.show()
    history = {'obj':[], 'vol_cons':[]}

    while not self.mma_state.is_converged:
      objective, grad_obj, (density, poly_params,
                    operations, csg_root) = self.objective_fn(self.mma_state.x)
      constr, grad_cons = self.constraint_fn(self.mma_state.x, max_vol_frac)

      self.mma_state = mma.update_mma(self.mma_state, mma_params,
                                 objective, grad_obj,
                                 constr, grad_cons
                                )
      status = (f'epoch {self.mma_state.epoch:d} obj {objective[0]:.2E} '
                f'vol cons {constr[0,0]:.2E} ')
      history['obj'].append(objective)
      history['vol_cons'].append(constr[0,0])

      print(status)
      if self.mma_state.epoch%plot_interval == 0 or self.mma_state.epoch==1:
        plotfun(density, status)

    return self.mma_state, history

In [None]:
optim = Optimizer(solver.loss_function, mesh)

In [None]:
(mma_state, history) = optim.optimize(
                                    np.array(init_design_var),
                                    max_vol_frac=cfg_opt['desired_vol_frac'],
                                    max_iter=cfg_opt['num_epochs'],
                                    move_limit=cfg_opt['learning_rate'],
                                    plot_interval=cfg_opt['plot_interval']
                                    )

# Dump/load files

In [None]:
def save():
  now = datetime.now()
  save_file = f"./results/treetop_{now:%Y-%m-%d-%H-%M}"

  with open(str(save_file) + "_hist.pkl", 'wb') as f:
    pickle.dump(history, f)
  np.save(str(save_file) + "_mma_state.npy", mma_state.to_array())

# save()

In [None]:
# load_file = "./results/treetop_2024-08-16-22-13"
# mma_state_np = np.load(str(load_file) + "_mma_state.npy")
# mma_state = mma.MMAState.from_array(mma_state_np,
#                                      num_design_var=init_design_var.shape[0])
# with open(str(load_file) + "_hist.pkl", "rb") as f:
#   history = pickle.load(f)

# High resolution plot

In [None]:
poly_params, operations = optim.get_poly_params_and_operator_from_design_var(mma_state.x, 
                                                                             1000.)

high_res = 3
high_res_mesh = mesher.BilinearStructMesher(nelx=high_res*mesh.nelx,
                                            nely=high_res*mesh.nely,
                                            bounding_box=bbox)
poly_sdf = polygon.compute_poly_sdf(poly_params, high_res_mesh)
poly_density = polygon.project_sdf_to_density(poly_sdf, high_res_mesh, sharpness=10.)
proj_poly_density = material.projection_filter(poly_density, beta=13.)
csg_root = tree.eval_binary_csg_tree(proj_poly_density, operations)
design_density = material.projection_filter(csg_root.value,
                                            beta = 13.)
print(design_density.shape)

In [None]:
plt.figure()
plt.imshow(np.round(design_density).reshape((high_res_mesh.nelx, high_res_mesh.nely)).T,
            cmap=poly_cmap, origin='lower')
plt.axis('off')
plt.savefig('mbb_validation.svg')
plt.show()

# Recover the poly params, operations and tree from the final state

In [None]:
pruned_csg_root = copy.deepcopy(csg_root)
tree.prune_tree(pruned_csg_root)

In [None]:
labels = [rf'X \cap Y', rf'X \cup Y', rf'X \backslash Y', rf'Y \backslash X']

plot_induvidual = False
plot_tree = False
for depth in range(1, num_levels+1):
  nodes = tree.breadth_first_search_at_depth(pruned_csg_root, depth)
  
  opers = []
  for node in nodes:
    if node.operation is not None:
      opers.append(labels[np.argmax(node.operation)])
    else:
      opers.append('\;') # handle leaf nodes
  if len(nodes) == 0:
    continue

  if plot_tree:
    fig, axes = plt.subplots(1, len(nodes), figsize=(len(nodes)*5, 5))
    for i, node in enumerate(nodes):
      if node.is_redundant:
        value = np.zeros_like(node.value)
      else:
        value = node.value
      axes[i].imshow(np.round(value).reshape((high_res_mesh.nelx, high_res_mesh.nely)).T,
                      cmap=poly_cmap, origin='lower', vmin=0., vmax=1.)
      axes[i].set_xticks([])
      axes[i].set_yticks([])
      axes[i].set_xlabel(f'${opers[i]}$', fontname='Times New Roman')
    fig.savefig(f'./results/depth_{depth}.pdf')
  if plot_induvidual:
    for i, node in enumerate(nodes):
      if node.is_redundant:
        value = np.zeros_like(node.value)
      else:
        value = node.value
      infig, inax = plt.subplots(1, 1, figsize=(5, 5))
      inax.imshow(np.round(value).reshape((high_res_mesh.nelx, high_res_mesh.nely)).T,
                  cmap=poly_cmap, origin='lower', vmin=0., vmax=1.)
      inax.set_xticks([])
      inax.set_yticks([])
      infig.savefig(f'./results/depth_{depth}_node_{i}.pdf', bbox_inches='tight',
                    pad_inches=0, transparent=True)