## Setup

In this notebook, we continue along the L-bracket example for optimization. For details on structural solver see `solve_struct.ipynb`

\begin{gather*}
\underset{\rho}{\text{minimize}} \quad J = \mathbf{u}^\mathsf{T} \mathbf{K}(\rho)\mathbf{u} \\
\text{subject to} \quad \mathbf{R}(\rho, \mathbf{u}) = \mathbf{0} \\
g \equiv \frac{\sum_e \rho_e v_e}{V^*} - 1 \leq 0
\end{gather*}


Where $\rho$ are the element densitites, where $\boldsymbol{u}$ is the displacement, $\boldsymbol{R}$ is the residual equation corresponding to the structural finite element analysis, $v_e$ are the element volumes (areas in 2D) and $V^*$ is the maximum allowed volume.

We consider a simple compliance minimization problem under volume constraint. We employ the standard element based design paramterization for illustration. 


We describe in detail the setup of the FEA problem in the previous notebook. Here we simply extend the FEA towards optimization.

We use the MMA optimizer. The optimizer depends on the gradients of the objectives and the constraints with respect to the design variables. These are computed automatically in JAX.

In [None]:
import functools
import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

import toflux.src.utils as _utils
import toflux.src.geometry as _geom
import toflux.src.mesher as _mesher
import toflux.src.material as _mat
import toflux.src.bc as _bc

import toflux.src.fe_struct as _fea
import toflux.src.solver as _solv
import toflux.src.mma as _mma
import toflux.src.viz as _viz

_Disp = _fea.DisplacementField

jax.config.update("jax_enable_x64", True)

## Mesh

In [None]:
geom = _geom.BrepGeometry("toflux/brep/Lbeam.json")

mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=80,
  nely_desired=80,
  dofs_per_node=2,
  gauss_order=3,
)

density = jnp.zeros(mesh.num_elems)
_viz.plot_grid_mesh(mesh, density)
_viz.plot_brep(geom)

## Material

In [None]:
material_params = _mat.StructuralMaterial(
  youngs_modulus=1.0,
  poissons_ratio=0.30,
)

deformation_model = _fea.DeformationModel.SMALL

## Boundary conditions



In [None]:
hang_faces = _bc.identify_faces(mesh, edges=[geom.edges[2]])
n = len(hang_faces)
hang_face_val = [(_Disp.V, -5.e-4 * jnp.ones(n))]


top_faces = _bc.identify_faces(mesh, edges=[geom.edges[5]])
n = len(top_faces)
xv = (_Disp.U, jnp.zeros(n))
yv = (_Disp.V, jnp.zeros(n))

top_face_val = [xv, yv]

fixed_bc = _bc.DirichletBC(elem_faces=top_faces, values=top_face_val, name="fix")
load_bc = _bc.NeumannBC(elem_faces=hang_faces, values=hang_face_val, name="load")

bc_list = [fixed_bc, load_bc]
bc = _bc.process_boundary_conditions(bc_list, mesh)

_viz.plot_bc(bc_list, mesh)

## Solver

In [None]:
solver_settings = {
  "linear": {"solver": _solv.LinearSolvers.SCIPY_SPARSE, "scipy_solver": {}},
  "nonlinear": {"max_iter": 1, "threshold": 1.0e-4},
}

In [None]:
fea = _fea.FEA(
  mesh=mesh,
  material=material_params,
  deformation_model=deformation_model,
  bc=bc,
  solver_settings=solver_settings,
)

## Filter and projection

A common issue in density based TO is the presence of checkerboard like design and the non convergence of the design towards binary values. We overcome these by filtering the densities and then projecting them. Here, we employ a simple radial filter.

In [None]:
dens_filter = _utils.create_density_filter(
  mesh.elem_centers,
  cutoff_distance=0.02 * mesh.bounding_box.diag_length,
  filter_type=_utils.Filters.CIRCULAR,
)

## Objective

In [None]:
@functools.partial(jax.jit, static_argnames=("fe",))
def objective_function(density: jnp.ndarray, fe: _fea.FEA, penal: float, u0=None):
  def objective_wrapper(density):
    # filter and project the density
    filt_dens = dens_filter @ density
    density = _utils.threshold_filter(filt_dens, beta=4.0)
    penal_dens = density**penal + 1e-2

    # material parameters
    lam, mu = material_params.lame_parameters
    lam, mu = lam * penal_dens, mu * penal_dens

    # solve the FE problem
    u = _solv.newton_raphson_solve(fe, u0, lam, mu)

    # compute the compliance
    force = jnp.zeros((mesh.num_dofs,))
    force = force.at[mesh.elem_dof_mat].add(bc["elem_forces"])
    compliance = jnp.einsum("i, i -> ", u, force)
    return compliance, (u, penal_dens)

  (compliance, (displacements, penal_dens)), d_compliance = jax.value_and_grad(
    objective_wrapper, has_aux=True
  )(density)
  return compliance, d_compliance.reshape((-1, 1)), displacements, penal_dens

## Constraint


We define a simple volume constraint optimization problem where we restrict the maximum volume the design can assume. With the user defined `max_vol_frac` we compute the volume constraint. Once again, the gradients are computed automatically via jax's  `value_and_grad`.



In [None]:
def constraint_function(density: jnp.ndarray, max_vol_frac: float):
  def constraint_wrapper(density):
    filt_dens = dens_filter @ density
    density = _utils.threshold_filter(filt_dens, beta=4.0)

    vol_frac = jnp.mean(density)
    return (vol_frac / max_vol_frac) - 1.0

  vol_cons, d_vol_cons = jax.value_and_grad(constraint_wrapper)(density)
  return jnp.array([vol_cons]), d_vol_cons.T

## Optimize

Here, we define the optimization loop. We use the MMA optimizer. We begin the optimization with a uniform design corresponding to the maximum allowed material fraction. The optimization steps are illustrated in following flowchart:

<img src="../../figures/struct_flowchart.png" alt="flowchart" style="width: 800px;"/>

In [None]:
def optimize_design(
  fe: _fea.FEA,
  max_vol_frac: float,
  max_iter: int,
  move_limit: float = 1e-2,
  kkt_tol: float = 1e-5,
  step_tol: float = 1e-5,
  plot_interval: int = 5,
):

  design_var = np.ones((fe.mesh.num_elems, 1))
  num_design_var = design_var.shape[0]
  num_cons = 1
  lower_bound = np.zeros((num_design_var, 1))
  upper_bound = np.ones((num_design_var, 1))
  mma_params = _mma.MMAParams(
    max_iter=max_iter,
    kkt_tol=kkt_tol,
    step_tol=step_tol,
    move_limit=move_limit,
    num_design_var=num_design_var,
    num_cons=num_cons,
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  mma_state = _mma.init_mma(design_var, mma_params)

  u = jnp.zeros((fe.mesh.num_dofs,)) + 1.0e-8
  u = u.at[fe.bc["fixed_dofs"]].set(fe.bc["dirichlet_values"])

  while not mma_state.is_converged:
    print("mma_state.epoch", mma_state.epoch)
    penal = min(8.0, 1.0 + 0.05 * mma_state.epoch)

    objective, grad_obj, u, _ = objective_function(mma_state.x, fe, penal, u)
    u = jax.lax.stop_gradient(u)
    u = u.at[fe.bc["fixed_dofs"]].set(fe.bc["dirichlet_values"])

    constr, grad_cons = constraint_function(mma_state.x, max_vol_frac)

    status = f"epoch {mma_state.epoch} J {objective:.2E} mc {constr[0]:.2F}"
    print(status)

    if mma_state.epoch % plot_interval == 0:
      _, ax = plt.subplots(1, 1)
      node_deformation = np.stack((u[0::2], u[1::2]), axis=1)
      deformed_mesh = _mesher.deform_mesh(fe.mesh, node_deformation)
      ax = _viz.plot_grid_mesh(
        deformed_mesh,
        mma_state.x.reshape(-1),
        ax=ax,
        val_range=(0.0, 1.0),
        colorbar=False,
      )
      ax.set_aspect("equal")
      ax.spines[["top", "right", "left", "bottom"]].set_visible(False)
      ax.set_xticks([])
      ax.set_yticks([])
      plt.show()
      plt.pause(1e-6)

    mma_state = _mma.update_mma(
      mma_state, mma_params, objective, grad_obj, constr, grad_cons
    )

  return mma_state, u

In [None]:
mma_state, u = optimize_design(fea, max_vol_frac=0.25, max_iter=250, move_limit=2.0e-1)