### Double Pipe Boundary Conditions
To illustrate fluid TO, we consider a simple double pipe problem, as shown in the figure below:

<img src="../../../figures/double_pipe.png" alt="Channel" style="width: 500px;"/>


In [None]:
import functools
import jax.numpy as jnp
import jax
from jax.typing import ArrayLike
from flax import nnx
import optax
import matplotlib.pyplot as plt

import toflux.src.geometry as _geom
import toflux.src.mesher as _mesher
import toflux.src.material as _mat
import toflux.src.utils as _utils
import toflux.src.bc as _bc
import toflux.experiments.fluid.tounn.network as _network
import toflux.src.fe_fluid as _fea
import toflux.src.solver as _solv
import toflux.src.constrained_loss as _cons_loss
import toflux.src.viz as _viz

jax.config.update("jax_enable_x64", True)
plt.rcParams.update(_viz.high_res_plot_settings)

_Ext = _utils.Extent
_Field = _fea.FluidField
_cmap = _viz.fluid_cmap

## Mesh

In [None]:
geom = _geom.BrepGeometry(brep_file="toflux/brep/double_pipe_long.json")

mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=150,
  nely_desired=100,
  dofs_per_node=3,
  gauss_order=2,
)

_viz.plot_brep(geom)

## Material

In [None]:
mat_params = _mat.FluidMaterial(
  mass_density=1.0,
  dynamic_viscosity=1.0,
)
min_mat_frac = 0.666
mat_frac = min_mat_frac * jnp.ones((mesh.num_elems,))

min_inv_permeability = _mat.brinkman_bound(
  mat_params.dynamic_viscosity, 100.0 * mesh.bounding_box.lx
)
max_inv_permeability = _mat.brinkman_bound(
  mat_params.dynamic_viscosity, 1.0e-2 * mesh.bounding_box.lx
)
init_inv_permeability = _mat.brinkman_bound(
  mat_params.dynamic_viscosity, 1.0e-1 * mesh.bounding_box.lx
)


inv_permeability_ext = _utils.Extent(min=min_inv_permeability, max=max_inv_permeability)
brink_inter_factor = _mat.calculate_interpolation_factor(
  inv_permeability_ext, init_inv_permeability, min_mat_frac
)

## Boundary conditions

In [None]:
inlet_fraction = 1.0 / 6.0
char_length = inlet_fraction * mesh.bounding_box.ly
reynolds_num = 0.1667
char_velocity = reynolds_num * mat_params.kinematic_viscosity / (char_length)

In [None]:
face_tol = mesh.elem_size[0] * 0.5

# inlet-top condition
inlet_top_faces = _bc.identify_faces(mesh, edges=[geom.edges[3]], tol=face_tol)
n_faces = len(inlet_top_faces)
x_nodes = jnp.linspace(-1.0, 1.0, n_faces + 1)
u_nodes = char_velocity * (1.0 - x_nodes**2)
face_node_vals = jnp.stack([u_nodes[1:], u_nodes[:-1]], axis=1)
u_vel = (_Field.U_VEL, face_node_vals)
v_vel = (_Field.V_VEL, jnp.zeros_like(face_node_vals))
inlet_top_face_val = [u_vel, v_vel]

# inlet bottom condition
inlet_btm_faces = _bc.identify_faces(mesh, edges=[geom.edges[1]], tol=face_tol)
n_faces = len(inlet_btm_faces)
x_nodes = jnp.linspace(-1.0, 1.0, n_faces + 1)
u_nodes = char_velocity * (1.0 - x_nodes**2)
face_node_vals = jnp.stack([u_nodes[1:], u_nodes[:-1]], axis=1)
u_vel = (_Field.U_VEL, face_node_vals)
v_vel = (_Field.V_VEL, jnp.zeros_like(face_node_vals))
inlet_btm_face_val = [u_vel, v_vel]

# wall condition
wall_faces = _bc.identify_faces(
  mesh,
  edges=[
    geom.edges[0],
    geom.edges[2],
    geom.edges[4],
    geom.edges[6],
    geom.edges[8],
    geom.edges[10],
    geom.edges[5],
    geom.edges[11],
  ],
  tol=face_tol,
)
n = len(wall_faces)
u_vel = (_Field.U_VEL, jnp.zeros(n))
v_vel = (_Field.V_VEL, jnp.zeros(n))
wall_face_val = [u_vel, v_vel]

# out flow condition
outflow_faces = _bc.identify_faces(
  mesh, edges=[geom.edges[7], geom.edges[9]], tol=face_tol
)
n = len(outflow_faces)
v_vel = (_Field.V_VEL, jnp.zeros(n))
pres = (_Field.PRESSURE, jnp.zeros(n))
outflow_face_val = [v_vel, pres]


inlet_top_bc = _bc.DirichletBC(
  elem_faces=inlet_top_faces, values=inlet_top_face_val, name="Inlet Top"
)
inlet_btm_bc = _bc.DirichletBC(
  elem_faces=inlet_btm_faces, values=inlet_btm_face_val, name="Inlet Bottom"
)
outflow_bc = _bc.DirichletBC(
  elem_faces=outflow_faces, values=outflow_face_val, name="Outflow"
)
wall_bc = _bc.DirichletBC(elem_faces=wall_faces, values=wall_face_val, name="Wall")


bcs_list = [
  wall_bc,
  inlet_top_bc,
  inlet_btm_bc,
  outflow_bc,
]

bc = _bc.process_boundary_conditions(bcs_list, mesh)

## Solver and loss

In [None]:
solver_settings = {
  "linear": {
    "solver": _solv.LinearSolvers.PETSC,
    "rtol": 1.0e-5,
    "petsc_solver": {},
  },
  "nonlinear": {"max_iter": 10, "threshold": 1.0e-6},
}
flow_solver = _fea.FluidSolver(mesh, bc, mat_params, solver_settings=solver_settings)

loss_params = _cons_loss.LogBarrierParams(t0=3.0, mu=1.04)

## Network

In [None]:
symm_map = _network.Symmetry(sym_xz_coord=mesh.bounding_box.y.center)


fourier_proj = _network.FourierProjection(
  num_input_dim=2,
  num_terms=100,
  min_radius=0.25,
  max_radius=5.0,
)

topnet = _network.TopNet(
  num_neurons=[2 * fourier_proj.num_terms, 20, 1],
  rngs=nnx.Rngs(0),
  use_batch_norm=False,
)

# Optimize
Here, we define the optimization loop. We use the MMA optimizer. We begin the optimization with a uniform design corresponding to the minimum allowed material/solid fraction. The optimization flowchart is illustrated below:
<img src="../../../figures/flowchart_tounn.png" alt="flowchart" style="width: 800px;"/>

In [None]:
@functools.partial(nnx.jit)
def loss_fn(
  net: _network.TopNet,
  mesh_xy: jnp.ndarray,
  epoch: int,
  max_vol_frac: float,
  obj_0: float,
  ramp_penalty: float,
  press_vel_guess: ArrayLike,
):
  mat_frac = jax.nn.sigmoid(net(mesh_xy)).reshape((-1,))
  brinkman_penalty = _mat.compute_ramp_interpolation(
    prop=mat_frac,
    ramp_penalty=ramp_penalty,
    prop_ext=inv_permeability_ext,
    mode="convex"
  )
  press_vel = _solv.modified_newton_raphson_solve(
    flow_solver, press_vel_guess, brinkman_penalty
  )
  obj_args = (
    brinkman_penalty,
    press_vel[mesh.elem_dof_mat],
    mesh.elem_node_coords,
  )
  elem_diss_pow = jax.vmap(flow_solver.compute_elem_dissipated_power)(*obj_args)
  dissipated_power = jnp.sum(elem_diss_pow)

  volcons = jnp.array([1.0 - (mat_frac.mean() / max_vol_frac)])

  loss = _cons_loss.combined_loss(
    dissipated_power / obj_0,
    volcons,
    [_cons_loss.ConstraintTypes.INEQUALITY],
    [loss_params],
    epoch,
  )

  return loss, (
    mat_frac,
    press_vel,
    dissipated_power,
    volcons,
  )

In [None]:
def optimize_design(
  net: _network.TopNet,
  max_vol_frac: float = 0.6,
  lr: float = 1.0e-2,
  max_iter: int = 100,
):
  mesh_xy = symm_map.apply(jnp.array(mesh.elem_centers))
  mesh_xy = fourier_proj.apply(mesh_xy)

  opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(lr),
  )
  optimizer = nnx.Optimizer(topnet, opt)

  iter = 0
  obj_0 = 1.0
  convg_hist = {"diss_pow": [], "vol_frac": []}

  press_vel = jnp.zeros((mesh.num_dofs,))
  press_vel = press_vel.at[bc["fixed_dofs"]].set(bc["dirichlet_values"])

  for iter in range(max_iter):
    cont_param = min(20.0, 1.0 + 19.0 * (iter / 75.0) ** 2)
    ramp_penalty = brink_inter_factor / cont_param
    (loss, (density, press_vel, diss_pow, vol_frac)), grad_loss = nnx.value_and_grad(
      loss_fn, has_aux=True
    )(net, mesh_xy, iter, max_vol_frac, obj_0, ramp_penalty, press_vel)

    press_vel = jax.lax.stop_gradient(press_vel)
    press_vel = press_vel.at[bc["fixed_dofs"]].set(bc["dirichlet_values"])

    optimizer.update(grad_loss)

    # save and print
    print(f"iter {iter}, loss {loss:.2E}")
    print(f"diss_pow {diss_pow:.2E}, vol_frac {vol_frac[0]:.2F}")

    convg_hist["diss_pow"].append(diss_pow)
    convg_hist["vol_frac"].append(vol_frac)

    # renormalization and penalty factor
    if iter == 0 or iter == 10:
      obj_0 = jax.lax.stop_gradient(diss_pow)

    # plotting
    if iter % 10 == 0:
      _, ax = plt.subplots(1, 1)
      ax = _viz.plot_grid_mesh(
        mesh=mesh,
        field=density.reshape(-1),
        ax=ax,
        colorbar=False,
        cmap=_cmap,
        val_range=(0.0, 1.0),
      )

      ax.set_aspect("equal")
      ax.spines[["top", "right", "left", "bottom"]].set_visible(False)
      ax.set_xticks([])
      ax.set_yticks([])
      plt.show()
      plt.pause(1e-6)

  return convg_hist

In [None]:
convg_hist = optimize_design(topnet, max_vol_frac=0.67)