<a href="https://colab.research.google.com/github/aadityacs/simple_TO/blob/main/structural/density_TO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/aadityacs/simple_TO
%cd simple_TO/structural/

In [None]:
from typing import Tuple
import functools
import jax
import numpy as np
import jax.numpy as jnp
from jax import value_and_grad

import utils
import mesher
from material import StructuralMaterial
import bcs
import FE_routines as fe
import mma
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 150

_Ext = utils.Extent

In [None]:
bounding_box = mesher.BoundingBox(x=_Ext(0., 60.),
                                  y= _Ext(0., 30.))
nelx, nely = 40, 20 # number of elems along X and Y axis
mesh = mesher.BilinearStructMesher(nelx = nelx, nely = nely,
                                   bounding_box=bounding_box)

In [None]:
init_design = np.random.uniform(low=0., high=1., size = (mesh.num_elems, 1))

In [None]:
material = StructuralMaterial(youngs_modulus=1., poisson_ratio=0.3, mass_density=1.)

In [None]:
bc = bcs.get_sample_struct_bc(mesh, bcs.SturctBCs.MID_CANT_BEAM)

In [None]:
fe_solver = fe.FEA(mesh, material, bc)
lossfn = fe_solver.loss_function  # takes in density of shape (num_elems,)

In [None]:
class Optimizer:
  def __init__(self, lossfn, mesh):
    self.lossfn, self.mesh = lossfn, mesh
  #-----------------------#
  def constraint_fn(self, density: jnp.ndarray,
                        max_vol_frac: float) -> Tuple[jnp.ndarray, jnp.ndarray]:

    def volume_constraint(density:jnp.ndarray)-> float:
      occupied_volume = jnp.einsum('i,i->i',self.mesh.elem_area, density).sum()
      volcons =  occupied_volume/(max_vol_frac*self.mesh.domain_volume) - 1.
      return volcons
    vc, dvc = value_and_grad(volume_constraint)(density.reshape((-1)))

    return jnp.array([vc]).reshape((-1,1)), dvc.reshape((1,-1))
  #-----------------------#
  @functools.partial(jax.jit, static_argnums=(0,))
  def objective_fn(self, density: jnp.ndarray)-> Tuple[jnp.ndarray, jnp.ndarray]:
    obj, grad_obj = value_and_grad(self.lossfn)(density.reshape((-1)))
    return jnp.array([obj]), grad_obj.reshape((-1, 1))
  #-----------------------#
  def optimize(self, init_geom, max_vol_frac: float,
     max_iter: int, kkt_tol: float=1e-6, step_tol: float=1e-6,
      move_limit: float=5e-2, plot_interval: int=5):

    self.num_design_var = init_geom.shape[0]
    self.design_var = init_geom
    lower_bound = np.zeros((self.num_design_var, 1))
    upper_bound = np.ones((self.num_design_var, 1))
    mma_params = mma.MMAParams(
        max_iter=max_iter,
        kkt_tol=kkt_tol,
        step_tol=step_tol,
        move_limit=move_limit,
        num_design_var=self.num_design_var,
        num_cons=1,
        lower_bound=lower_bound,
        upper_bound=upper_bound,
    )
    mma_state = mma.init_mma(self.design_var, mma_params)

    def plotfun(x, status = ''):
      plt.figure()
      plt.imshow(x.reshape((self.mesh.nelx, self.mesh.nely)).T, cmap='rainbow')
      plt.title(status)
      plt.show()
    history = {'obj':[], 'vol_cons':[]}
    # MMA Loop

    while not mma_state.is_converged:
      objective, grad_obj = self.objective_fn(mma_state.x)
      constr, grad_cons = self.constraint_fn(mma_state.x, max_vol_frac)
      mma_state = mma.update_mma(
        mma_state, mma_params, objective, grad_obj, constr, grad_cons
      )
      status = (f'epoch {mma_state.epoch:d} obj {objective[0]:.2E} '
                f'vol cons {constr[0,0]:.2E} ')
      history['obj'].append(objective)
      history['vol_cons'].append(constr[0,0])


      print(status)
      if mma_state.epoch%plot_interval == 0 or mma_state.epoch==1:
        plotfun(mma_state.x, status)

    return mma_state, history

In [None]:
optim = Optimizer(lossfn, mesh)

mma_state, history = optim.optimize(init_design, max_vol_frac=0.5,
                                    max_iter=151,
                                    plot_interval=10)

plt.show(block=True)

In [None]:
for k in history:
  plt.figure()
  plt.plot(history[k])
  plt.xlabel('iter')
  plt.ylabel(f'{k}')