### Conjugate Heat Transfer (CHT) Setup
To illustrate CHT TO, we consider a **conjugate-heat-transfer (CHT) optimisation benchmark** in a 2D pipe. The entire rectangular domain, denoted by $\Omega$, serves as the design region.

* **$\Omega$ (the design region)** – A design variable field $\gamma(\mathbf x)\in[0,1]$ is optimised, where $\gamma=0$ represents fluid and $\gamma=1$ represents solid material.

The coupled steady-state incompressible Navier-Stokes and Energy equations are solved using a Brinkman formulation to obtain the fluid velocity $\mathbf u$, pressure $p$, and temperature $T$.

---

## Optimization problem

$$
\begin{aligned}
\min_{\gamma}\quad
& F
    =
     (1-W)\frac{J_f}{J_f^*} - W\frac{J_{th}}{J_{th}^*},
\\[6pt]
\text{s.t.}\quad
& \mathbf R\!\bigl(\gamma,\mathbf u,p,T\bigr)=\mathbf 0, \\[4pt]
& g(\gamma)=
  \Phi_{min} - \frac{\displaystyle\sum_{j} \gamma_j\,v_j}{V}  \;\le 0, \\[4pt]
& 0\le\gamma_j\le1\qquad\forall\,j .
\end{aligned}
$$

The formulation seeks to simultaneously **minimize fluid power dissipation** ($J_f$) and **maximize recoverable thermal power** ($J_{th}$). The negative sign on the thermal objective reframes its maximization as a minimization problem.

---

### Symbols

| Symbol | Description |
|---|---|
| $\gamma_j$ | Element-wise design variable (0 for fluid, 1 for solid) |
| $v_j$ | Element volume |
| $\mathbf u,\;p,\;T$ | Fluid velocity, pressure, and temperature|
| $F$ | Aggregated, normalized multi-objective function  |
| $J_f$ | Dissipated fluid power objective  |
| $J_{th}$ | Recoverable thermal power objective  |
| $W$ | Scalar weighting factor for the multi-objective function  |
| $\mathbf R$ | CHT residual (Navier-Stokes + Energy)  |
| $\Phi_{max}$ | Maximum permitted fluid volume fraction  |
| $V$ | Total volume of the domain |

---

### Numerical details

* **Element-based parametrisation** using a density-based approach with RAMP-type interpolation for the Brinkman friction term and thermal diffusivity.
* The **Method of Moving Asymptotes (MMA)** algorithm updates the design variable $\gamma$.
* Gradients (sensitivities) $\partial F/\partial\gamma$ are computed using the **Automatic Differentiation**.

**Boundary conditions** 

* **Inlet $\Gamma_{in}$**: A prescribed parabolic velocity profile and a constant temperature $T_{in}$.
* **Outlet $\Gamma_{out}$**: Zero pressure ($p=0$) and zero-gradient conditions for velocity and temperature.
* **Top and bottom walls $\Gamma_{w}$**: No-slip velocity and a constant wall temperature $T_{w}$.
* **Side walls $\Gamma_{ad}$**: Adiabatic and no-slip velocity conditions.

<img src="../../figures/cht_to_bc.png" alt="Channel" style="width: 500px;"/>

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import functools
from jax.typing import ArrayLike

import toflux.src.geometry as _geom
import toflux.src.utils as _utils
import toflux.src.mesher as _mesher
import toflux.src.material as _mat
import toflux.src.bc as _bc
import toflux.src.fe_thermal as _fea_thermal
import toflux.src.fe_fluid as _fea_fluid
import toflux.src.viz as _viz
import toflux.src.solver as _solv
import toflux.src.mma as _mma

import matplotlib.pyplot as plt


_Ext = _utils.Extent
_FluidField = _fea_fluid.FluidField
_TempField = _fea_thermal.ThermalField
_cmap = _viz.fluid_cmap


jax.config.update("jax_enable_x64", True)
plt.rcParams.update(_viz.high_res_plot_settings)

## Define Geometry and Mesh

### Mesh

Below we build the **conjugate heat transfer (CHT)** pipe optimization benchmark.

#### Geometry  
The pipe geometry (defined in JSON) is passed to the both meshers (fluid and thermal).

#### Discretisation  
* The domain is partitioned into **bilinear quadrilateral (Q1) elements** with same numbe rof elements for both fluid and CHT solver.  

#### Numerical integration  
Element contributions are evaluated with **$2\times2$** Gauss quadrature:

This mesh serves as the baseline for the finite-element **CHT** solve (advection–diffusion fluid region, diffusion in solid region) and for any subsequent **density-based topology-optimisation** runs involving fluid and thermal objectives.


In [None]:
geom = _geom.BrepGeometry("toflux/brep/single_pipe.json")

nelx_desired, nely_desired, gauss_order = 120, 120, 2
fluid_mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=nelx_desired,
  nely_desired=nely_desired,
  dofs_per_node=3,
  gauss_order=gauss_order,
)

thermal_mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=nelx_desired,
  nely_desired=nely_desired,
  dofs_per_node=1,
  gauss_order=gauss_order,
)

_viz.plot_brep(geom)

## Material

For this CHT setup, we define distinct material properties for the **fluid** and **solid** phases, which are used in the interpolation scheme for topology optimization.

---

#### Fluid Phase Properties
| Symbol | Meaning | Value |
|---|---|---|
| $ \rho_f $ | Mass density | 1000 |
| $ \mu_f $ | Dynamic viscosity | $6.6\times10^{-3}$ |
| $ k_f $ | Thermal conductivity | 0.05 |
| $ c_{p,f} $ | Specific heat | 5000 |

---

#### Solid Phase Properties
| Symbol | Meaning | Value |
|---|---|---|
| $ \rho_s $ | Mass density | 1000 |
| $ k_s $ | Thermal conductivity | 0.5 |
| $ c_{p,s} $ | Specific heat | 5000 |

- **Material Interpolation:** In topology optimization, the material properties within the design domain are interpolated based on the design variable $\gamma \in [0,1]$. 
  - The effective thermal conductivity **$k(\gamma)$** and specific heat **$c_p(\gamma)$** are interpolated between their solid and fluid values.
  - A **Brinkman penalization** term, **$\alpha(\gamma)$**, is interpolated from a very large value (in solid regions, where $\gamma=1$) to zero (in fluid regions, where $\gamma=0$) to damp velocity and recover the pure Navier-Stokes equations in the fluid.

For the nondimensional energy equation, the Péclet number is $ \mathrm{Pe}=\dfrac{U\,L}{\kappa} $, where the thermal diffusivity is $\kappa = k / (\rho c_p)$.


In [None]:
desired_mat_frac = 0.6

In [None]:
fluid_mat_prop = _mat.FluidMaterial(
  mass_density=1e3,
  dynamic_viscosity=6.6e-3,
  thermal_conductivity=0.05,
  specific_heat=5e3,
)

solid_mat_prop = _mat.FluidMaterial(
  mass_density=1e3, thermal_conductivity=0.5, specific_heat=5e3
)

In [None]:
min_inv_permeability = _mat.brinkman_bound(
  fluid_mat_prop.dynamic_viscosity, 100 * fluid_mesh.bounding_box.lx
)
max_inv_permeability = _mat.brinkman_bound(
  fluid_mat_prop.dynamic_viscosity, 1.0e-2 * fluid_mesh.bounding_box.lx
)
init_inv_permeability = _mat.brinkman_bound(
  fluid_mat_prop.dynamic_viscosity, 1.0e-1 * fluid_mesh.bounding_box.lx
)

inv_permeability_ext = _Ext(min=min_inv_permeability, max=max_inv_permeability)

diffusivity_ext = _Ext(min=fluid_mat_prop.diffusivity, max=solid_mat_prop.diffusivity)

In [None]:
init_ramp_penalty = _mat.calculate_initial_ramp_penalty(
  inv_permeability_ext, init_inv_permeability, desired_mat_frac
)

## Fluid Boundary Conditions

| Region | Type | Imposed values |
|---|---|---|
| **Inlet** | Dirichlet | Parabolic profile $(u=U_c(1-y^2), v=0)$ |
| **Outlet** | Dirichlet (p, v) | $p = 0$, $v = 0$ |
| **All Walls** | No-slip (Dirichlet) | $u = 0$, $v = 0$ |

**Characteristic Velocity**

The inlet velocity is set to achieve a target Reynolds number:

$$
U_c \;=\; \frac{\mathrm{Re}\,\nu}{L_{char}},
\qquad
\text{Re} = 3.0,\;
\quad
\nu = \frac{\mu}{\rho},
\quad
L_{char} = \text{inlet width}.
$$

---
## Thermal Boundary Conditions

| Region | Type | Imposed values |
|---|---|---|
| **Inlet** | Dirichlet | $T = 0$ |
| **Outlet** | Neumann (default) | $\nabla T \cdot \mathbf{n} = 0$ |
| **Heated Walls** | Dirichlet | $T = 10$ |
| **Adiabatic Walls** | Neumann (default) | $\nabla T \cdot \mathbf{n} = 0$ |

In [None]:
reynolds_num = 3.0
inlet_channel_width = 1.0 / 5.0
char_length = inlet_channel_width * fluid_mesh.bounding_box.ly
char_velocity = reynolds_num * fluid_mat_prop.kinematic_viscosity / char_length


face_tol = fluid_mesh.elem_size[0] * 0.5

# inlet condition
inlet_faces = _bc.identify_faces(fluid_mesh, edges=[geom.edges[1]], tol=face_tol)
n = len(inlet_faces)
x_nodes = jnp.linspace(-1.0, 1.0, n + 1)
u_nodes = char_velocity * (1.0 - x_nodes**2)
face_node_vals = jnp.stack([u_nodes[1:], u_nodes[:-1]], axis=1)
u_vel = (_FluidField.U_VEL, face_node_vals)
v_vel = (_FluidField.V_VEL, jnp.zeros_like(face_node_vals))
inlet_face_val = [u_vel, v_vel]

# outlet condition
outlet_faces = _bc.identify_faces(fluid_mesh, edges=[geom.edges[5]], tol=face_tol)
n = len(outlet_faces)
v_vel = (_FluidField.V_VEL, jnp.zeros(n))
pres = (_FluidField.PRESSURE, jnp.zeros(n))
outlet_face_val = [v_vel, pres]

# wall condition
wall_faces = _bc.identify_faces(
  fluid_mesh,
  edges=[
    geom.edges[0],
    geom.edges[2],
    geom.edges[3],
    geom.edges[4],
    geom.edges[6],
    geom.edges[7],
  ],
  tol=face_tol,
)
n = len(wall_faces)
u_vel = (_FluidField.U_VEL, jnp.zeros(n))
v_vel = (_FluidField.V_VEL, jnp.zeros(n))
wall_face_val = [u_vel, v_vel]

inlet_bc = _bc.DirichletBC(elem_faces=inlet_faces, values=inlet_face_val, name="Inlet")
outlet_bc = _bc.DirichletBC(
  elem_faces=outlet_faces, values=outlet_face_val, name="Outlet"
)
wall_bc = _bc.DirichletBC(elem_faces=wall_faces, values=wall_face_val, name="Wall")

fluid_bc_list = [wall_bc, inlet_bc, outlet_bc]
fluid_bc = _bc.process_boundary_conditions(fluid_bc_list, fluid_mesh)

_viz.plot_bc(fluid_bc_list, fluid_mesh)


In [None]:
temp_inlet = 0.0
temp_wall = 10.0
temp_ext = _Ext(min=temp_inlet, max=temp_wall)
# inlet condition
inlet_faces = _bc.identify_faces(thermal_mesh, edges=[geom.edges[1]], tol=face_tol)
n = len(inlet_faces)
tv = (_TempField.TEMPERATURE, (0.0 / temp_ext.range) * jnp.ones(n))
inlet_face_val = [tv]

# wall condition
wall_faces = _bc.identify_faces(
  thermal_mesh, edges=[geom.edges[3], geom.edges[7]], tol=face_tol
)
n = len(wall_faces)
tv = (_TempField.TEMPERATURE, (10.0 / temp_ext.range) * jnp.ones(n))
wall_face_val = [tv]

inlet_bc = _bc.DirichletBC(elem_faces=inlet_faces, values=inlet_face_val, name="Inlet")
wall_bc = _bc.DirichletBC(elem_faces=wall_faces, values=wall_face_val, name="Wall")


thermal_bcs_list = [inlet_bc, wall_bc]
thermal_bc = _bc.process_boundary_conditions(thermal_bcs_list, thermal_mesh)

_viz.plot_bc(thermal_bcs_list, thermal_mesh)

# Domain for inlet and outlets

Here we select the inlet and outlet elements to comput the theraml power entering and exiting.

In [None]:
in_box = _mesher.BoundingBox(
  x=_Ext(min=-face_tol, max=face_tol),
  y=_Ext(
    min=fluid_mesh.bounding_box.ly * 2.0 / 5.0 - 0.5 * face_tol,
    max=fluid_mesh.bounding_box.ly * 3.0 / 5.0 + 0.5 * face_tol,
  ),
)

out_box = _mesher.BoundingBox(
  x=_Ext(
    min=fluid_mesh.bounding_box.lx - face_tol,
    max=fluid_mesh.bounding_box.lx + face_tol,
  ),
  y=_Ext(
    min=fluid_mesh.bounding_box.ly * 2.0 / 5.0 - 0.5 * face_tol,
    max=fluid_mesh.bounding_box.ly * 3.0 / 5.0 + 0.5 * face_tol,
  ),
)

inlet_elems = _mesher.compute_point_indices_in_box(thermal_mesh.elem_centers, in_box)
outlet_elems = _mesher.compute_point_indices_in_box(thermal_mesh.elem_centers, out_box)


# Visualize the inlet and outlet domains

In [None]:
density_test = np.zeros((thermal_mesh.num_elems,))
density_test[inlet_elems] = 1.0
density_test[outlet_elems] = 2.0
_viz.plot_grid_mesh(thermal_mesh, density_test.reshape(-1))

# Solver

The incompressible Navier–Stokes–Brinkman system is **intrinsically non-linear** because of the convective term .
Our solver therefore employs a *damped/modified Newton–Raphson* loop. For more details see 'solver.py'. The energy equations are linear in temperature.  By default, all our FEA solvers are programmed as a subclass of nonlinear problems. Hence, the solver expects the settings for both the linear and nonlinear solvers.
Even when the physics is linear, the solver still performs a Newton–Raphson step; for a linear model this step converges in one iteration. 


In [None]:
solver_settings = {
  "linear": {
    "solver": _solv.LinearSolvers.PETSC,
    "rtol": 1.0e-6,
    "petsc_solver": {},
  },
  "nonlinear": {"max_iter": 15, "threshold": 1.0e-9},
}
flow_solver = _fea_fluid.FluidSolver(
  fluid_mesh, fluid_bc, fluid_mat_prop, solver_settings=solver_settings
)


thermal_solve = _fea_thermal.FEA(
  mesh=thermal_mesh,
  material=solid_mat_prop,
  bc=thermal_bc,
  solver_settings=solver_settings,
)

# Solve Thermal System

In [None]:
fluid_obj_ext = _Ext(min=1.9e-10, max=5e-09)
thermal_obj_ext = _Ext(min=500.0, max=6000.0)
obj_weight = 0.5

In [None]:
@functools.partial(jax.jit, static_argnames=("flow_solver", "thermal_solver"))
def objective_function(
  mat_frac: jnp.ndarray,
  flow_solver: _fea_fluid.FluidSolver,
  thermal_solver: _fea_thermal.FEA,
  fluid_ramp_penalty: float,
  thermal_ramp_penalty: float,
  press_vel_guess=None,
  temp_guess=None,
):
  """Computes the objective function and its gradient.

  This function computes the weighted objective which is a combination of the
  normalized dissipated power in the fluid and the normalized recoverable thermal
  power in the thermal field. The dissipated power is computed using the Brinkman
  interpolation of the material fraction, and the thermal power is computed using the
  Peclet number.

  Args:
    mat_frac: Material fraction array of shape (num_elems,). The array should contain
      values between 0 and 1, representing the material fraction at each element. Where
      0 indicates fluid and 1 indicates material. The material fraction can assume
      intermediate values between 0 and 1 during optimization.
    flow_solver: Fluid solver instance.
    thermal_solver: Thermal solver instance.
    fluid_ramp_penalty: Penalty factor for the Brinkman interpolation using a convex ramp
      function. The ramp penalty is  updated using a continuation scheme during
      optimization. In the beginning, the ramp penalty is set to a large value (which
      makes the mat frac vs Brinkman penalty convex which can be determined by
      brink_iter_factor) to allow fluid flow through the entire domain. As the
      optimization progresses, the ramp penalty is reduced to allow the material
      fraction to converge to material or fluid.
    press_vel_guess: Initial pressure-velocity field of shape (num_dofs,). This is used
      as the initial guess for the pressure-velocity field to Newton-Raphson
      iterations. It contains the Dirichlet boundary conditions applied to the
      pressure-velocity field.
    temp_guess: Initial temperature field of shape (num_dofs,). This is used as the
      initial guess for the temperature field to Newton-Raphson iterations. It contains
      the Dirichlet boundary conditions applied to the temperature field.

  Returns:
    A tuple containing the objective value and a tuple of element-wise
    pressure-velocity, temperature, pressure-velocity vector, and temperature vector.
  """

  def objective_wrapper(mat_frac):
    # solve the flow problem
    brinkman_penalty = _mat.compute_ramp_interpolation(
      prop=mat_frac,
      ramp_penalty=fluid_ramp_penalty,
      prop_ext=inv_permeability_ext,
      mode="convex",
    )
    press_vel = _solv.modified_newton_raphson_solve(
      flow_solver, press_vel_guess, brinkman_penalty
    )

    # get elem velocities
    elem_press_vel = press_vel[fluid_mesh.elem_dof_mat]  # (elems, num_dofs_per_elem)
    elem_u_vel = elem_press_vel[:, 1 :: fluid_mesh.nodes.dof_per_node]
    elem_v_vel = elem_press_vel[:, 2 :: fluid_mesh.nodes.dof_per_node]

    num_vel_dofs_per_elem = fluid_mesh.num_dim * fluid_mesh.elem_template.num_nodes
    elem_vel = jnp.zeros((fluid_mesh.num_elems, num_vel_dofs_per_elem))
    elem_vel = elem_vel.at[:, 0 :: fluid_mesh.num_dim].set(elem_u_vel)
    elem_vel = elem_vel.at[:, 1 :: fluid_mesh.num_dim].set(elem_v_vel)

    conv_velocity = elem_vel / char_velocity

    # solve the thermal problem
    eff_diffusivity = _mat.compute_ramp_interpolation(
      prop=mat_frac,
      ramp_penalty=thermal_ramp_penalty,
      prop_ext=diffusivity_ext,
      mode="concave",
    )
    peclet_number = char_length * char_velocity / eff_diffusivity
    temperature_non_dim = _solv.modified_newton_raphson_solve(
      thermal_solver,
      temp_guess,
      conv_velocity,
      peclet_number,
    )
    temperature_elem = (
      temperature_non_dim[thermal_mesh.elem_dof_mat].mean(axis=1) * temp_ext.range
      + temp_ext.min
    )
    therm_power = thermal_solver.thermal_power(
      temperature_elem, elem_u_vel, inlet_elems, outlet_elems
    )

    normalized_compliance = thermal_obj_ext.normalize_array(therm_power)

    elem_dissipated_power = jax.vmap(flow_solver.compute_elem_dissipated_power)(
      brinkman_penalty, elem_press_vel, fluid_mesh.elem_node_coords
    )
    dissipated_power = jnp.sum(elem_dissipated_power)
    normalized_dissipated_power = fluid_obj_ext.normalize_array(dissipated_power)
    weighted_obj = (
      1.0 - obj_weight
    ) * normalized_dissipated_power - obj_weight * normalized_compliance

    return weighted_obj, (
      elem_press_vel,
      temperature_elem,
      press_vel,
      temperature_non_dim,
    )

  (obj, (elem_press_vel, temperature_elem, press_vel, temp)), d_obj = (
    jax.value_and_grad(objective_wrapper, has_aux=True)(mat_frac)
  )
  return obj, d_obj.reshape((-1, 1)), elem_press_vel, temperature_elem, press_vel, temp

# Constraint

We define a simple material constraint optimization problem where we restrict the maximum material/solid the design can assume. With the user defined `min_vol_frac` we compute the material constraint. Once again, the gradients are computed automatically via jax's `value_and_grad`.

In [None]:
def constraint_function(mat_frac: ArrayLike, min_mat_frac: float):
  """Computes the constraint for the material fraction.
  This constraint ensures that the mean material fraction is above a minimum threshold.

  Args:
    mat_frac: Material fraction array of shape (num_elems,). The array should contain
      values between 0 and 1, representing the material fraction at each element. Where
      0 indicates fluid and 1 indicates material. The material fraction can assume
      intermediate values between 0 and 1 during optimization.
    min_mat_frac: Minimum material fraction threshold ensures material presence in the
      design.

  Returns:
    A scalar value representing the constraint violation.
  """

  def constraint_wrapper(mat_frac):
    mean_mat_frac = jnp.mean(mat_frac)
    return 1.0 - (mean_mat_frac / min_mat_frac)

  mat_cons, d_mat_cons = jax.value_and_grad(constraint_wrapper)(mat_frac)
  return jnp.array([mat_cons]), d_mat_cons.T

In [None]:
def optimize_design(
  fluid_fe: _fea_fluid.FluidSolver,
  thermal_fe: _fea_thermal.FEA,
  min_mat_frac: float,
  max_iter: int,
  move_limit: float = 1e-2,
  kkt_tol: float = 1e-5,
  step_tol: float = 1e-5,
  plot_interval: int = 5,
):
  """Optimizes the design using the Method of Moving Asymptotes (MMA) optimization.

  The function initializes the design variables with the minimum material fraction and
    then iteratively updates them via the MMA approach. At each iteration, the objective
    function and its gradient are computed using JAX `value_and_grad` function based
    on the current material distribution. Additionally, a material constraint is enforced
    to ensure the design maintains at least the specified minimum material fraction.

  Args:
    fluid_fe: Fluid solver instance containing mesh, boundary conditions, and material
      properties.
    thermal_fe: Thermal solver instance containing mesh, boundary conditions, and material
      properties.
    min_mat_frac: Minimum allowed material fraction to enforce material presence in the
      design.
    max_iter: Maximum number of iterations for the MMA optimization loop.
    move_limit: Maximum allowable change in the design variable per iteration.
    kkt_tol: Tolerance for the Karush-Kuhn-Tucker (KKT) optimality criteria.
    step_tol: Tolerance for the optimization step size.
    plot_interval: Number of iterations between plot updates.

  Returns:
    mma_state: The final state of the MMA optimization, including the optimized design
      variables, iteration count, and convergence flag.
    history: A dictionary recording the history of the objective function values and
      constraint violations with keys:
      'obj'      - list of objective function values,
      'vol_cons' - list of volume constraint values.
  """
  design_var = min_mat_frac * np.ones((thermal_fe.mesh.num_elems, 1))
  num_design_var = design_var.shape[0]
  num_cons = 1
  lower_bound = np.zeros((num_design_var, 1))
  upper_bound = np.ones((num_design_var, 1))
  mma_params = _mma.MMAParams(
    max_iter=max_iter,
    kkt_tol=kkt_tol,
    step_tol=step_tol,
    move_limit=move_limit,
    num_design_var=num_design_var,
    num_cons=num_cons,
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  mma_state = _mma.init_mma(design_var, mma_params)

  history = {"obj": [], "vol_cons": []}

  press_vel = jnp.zeros((fluid_fe.mesh.num_dofs,))
  press_vel = press_vel.at[fluid_fe.bc["fixed_dofs"]].set(
    fluid_fe.bc["dirichlet_values"]
  )

  temp = jnp.zeros((thermal_fe.mesh.num_dofs,))
  temp = temp.at[thermal_fe.bc["fixed_dofs"]].set(thermal_fe.bc["dirichlet_values"])

  while not mma_state.is_converged:
    cont_param = min(20.0, 1.0 + 19.0 * (mma_state.epoch / 75.0) ** 2)
    fluid_ramp_penalty = init_ramp_penalty / cont_param
    thermal_ramp_penalty = min(0.3, 0.01 + mma_state.epoch * 8e-3)

    (
      objective,
      grad_obj,
      elem_press_vel,
      temperature_elem,
      press_vel_solved,
      temp_solved,
    ) = objective_function(
      mma_state.x,
      fluid_fe,
      thermal_fe,
      fluid_ramp_penalty,
      thermal_ramp_penalty,
      press_vel,
      temp,
    )
    constr, grad_cons = constraint_function(mma_state.x, min_mat_frac)
    press_vel = jax.lax.stop_gradient(press_vel_solved)
    press_vel = press_vel.at[fluid_fe.bc["fixed_dofs"]].set(
      fluid_fe.bc["dirichlet_values"]
    )

    temp = jax.lax.stop_gradient(temp_solved)
    temp = temp.at[thermal_fe.bc["fixed_dofs"]].set(thermal_fe.bc["dirichlet_values"])

    mma_state = _mma.update_mma(
      mma_state, mma_params, objective, grad_obj, constr, grad_cons
    )

    u_vel_elem = np.mean(elem_press_vel[:, 1::3], axis=1)
    v_vel_elem = np.mean(elem_press_vel[:, 2::3], axis=1)
    vel_elem_mag = np.sqrt(u_vel_elem**2 + v_vel_elem**2)

    status = f"epoch {mma_state.epoch} J {objective:.2E} mc {constr[0]:.2E}"

    print(status)
    if mma_state.epoch % plot_interval == 0 or mma_state.epoch == 1:
      _viz.plot_grid_mesh(
        mesh=thermal_fe.mesh,
        field=mma_state.x.reshape(-1),
        cmap=_cmap,
      )

      _, ax = plt.subplots(1, 1)
      ax = _viz.plot_grid_mesh(
        mesh=fluid_mesh, field=vel_elem_mag, ax=ax, colorbar=True
      )
      ax.set_title("Velocity Magnitude")
      plt.show()

  return mma_state, history

In [None]:
mma_state, u = optimize_design(
  flow_solver,
  thermal_solve,
  min_mat_frac=desired_mat_frac,
  max_iter=100,
  move_limit=1.5e-1,
)