In [None]:

from typing import Tuple
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import truss_examples as trex
import truss_structure as trs
import solver
import mma
import utils

In [None]:
truss_skel, bc = trex.get_sample_truss(trex.Samples.GRID_TRUSS)

In [None]:
init_truss_areas = 2e-3*jnp.ones((truss_skel.num_bars,))
youngs_modulus = 1e9*jnp.ones((truss_skel.num_bars,))

In [None]:
truss_solver = solver.TrussSolver(truss_skel, bc)

In [None]:
u = truss_solver.solve(youngs_modulus, init_truss_areas)
init_vol = truss_solver.get_volume(init_truss_areas)
fig, ax = plt.subplots(1, 1)
title_str = f'volume: {init_vol:.2E}'
trs.plot_truss(truss_skel, bc,
          area = init_truss_areas,
          node_displacements = u,
          ax = ax,
          title_str = title_str)

In [None]:
area_extent = utils.Extent(min=5e-8, max=5e-2)
init_des_var = utils.normalize(init_truss_areas, area_extent)

In [None]:
class Optimizer:
  def __init__(self, truss_solver: solver.TrussSolver,):
    self.truss_solver = truss_solver


  def constraint_fn(self, design_var: jnp.ndarray,
                        max_vol: float) -> Tuple[jnp.ndarray, jnp.ndarray]:

    def volume_constraint(design_var: jnp.ndarray) -> float:
      bar_area = utils.unnormalize(design_var, area_extent)
      occupied_volume = truss_solver.get_volume(bar_area)
      return (occupied_volume/max_vol) - 1., occupied_volume

    (vc, vol), dvc = jax.value_and_grad(volume_constraint, has_aux=True
                                        )(design_var.reshape((-1)))
    return jnp.array([vc]).reshape((-1,1)), dvc.reshape((1,-1)), vol


  def objective_fn(self, design_var: jnp.ndarray)-> Tuple[jnp.ndarray, jnp.ndarray]:
    
    def loss_wrapper(design_var):
      bar_area = utils.unnormalize(design_var, area_extent)
      u = truss_solver.solve(youngs_modulus, bar_area)
      return truss_solver.compute_compliance(u), u

    (obj, u), grad_obj = jax.value_and_grad(loss_wrapper, has_aux=True)(
                                                      design_var.reshape((-1)))
    return jnp.array([obj]), grad_obj.reshape((-1, 1)), u


  def optimize(self, init_geom,
               max_vol: float,
               max_iter: int,
               kkt_tol: float=1e-6,
               step_tol: float=1e-6,
               move_limit: float=5e-3,
               plot_interval: int=5):

    self.num_design_var = init_geom.shape[0]
    self.design_var = init_geom
    lower_bound = np.zeros((self.num_design_var, 1))
    upper_bound = np.ones((self.num_design_var, 1))
    mma_params = mma.MMAParams(
        max_iter=max_iter,
        kkt_tol=kkt_tol,
        step_tol=step_tol,
        move_limit=move_limit,
        num_design_var=self.num_design_var,
        num_cons=1,
        lower_bound=lower_bound,
        upper_bound=upper_bound,
    )
    mma_state = mma.init_mma(self.design_var, mma_params)

    def plotfun(areas, u, status = ''):
      fig, ax = plt.subplots(1, 1)
      trs.plot_truss(truss_skel, bc,
          area = areas,
          node_displacements = u,
          ax = ax,
          title_str = status)
      plt.show()
    history = {'obj':[], 'vol_cons':[]}
    # MMA Loop

    while not mma_state.is_converged:
      objective, grad_obj, u = self.objective_fn(mma_state.x)
      constr, grad_cons, vol = self.constraint_fn(mma_state.x, max_vol)
      mma_state.x = np.array(mma_state.x).reshape((-1,1))
      mma_state = mma.update_mma(
        mma_state, mma_params, objective, grad_obj, constr, grad_cons
      )
      status = (f'epoch {mma_state.epoch:d} obj {objective[0]:.2E} '
                f'vol {vol:.2E} ')
      history['obj'].append(objective)
      history['vol_cons'].append(constr[0,0])

      print(status)
      if mma_state.epoch%plot_interval == 0 or mma_state.epoch==1:
        bar_area = utils.unnormalize(mma_state.x, area_extent)
        plotfun(bar_area, u, status)

    return mma_state, history

In [None]:
opt = Optimizer(truss_solver)

In [None]:
mma_state, history = opt.optimize(init_geom=init_des_var.reshape((-1,1)),
                                  max_vol = init_vol,
                                  max_iter=400,
                                  plot_interval=30)