## Setup  

In this notebook we extend the **channel–flow** example to **density‑based topology optimisation (TO)** of a **bifurcated (Y‑shaped) artery** carrying shear‑thinning blood.  
For the forward Navier–Stokes–Brinkman implementation, refer to `solve_fluid.ipynb`.

$$
\begin{aligned}
\min_{\gamma}\;\;J &= 
\frac12 \int_{\Omega}
      \underbrace{\eta({\gamma})\,
      \bigl(\nabla\mathbf u+\nabla\mathbf u^{\mathsf T}\bigr)
      :\bigl(\nabla\mathbf u+\nabla\mathbf u^{\mathsf T}\bigr)}_{\text{viscous dissipation}}
      \;d\Omega  \\
&\quad + \frac12\int_{\Omega}\alpha(\gamma)\,\mathbf u\!\cdot\!\mathbf u \;d\Omega
\\[6pt]
\text{s.\,t.}\;& 
\mathbf R(\gamma,\mathbf u,p)=\mathbf 0 \qquad\quad\;\; \text{(steady NS–Brinkman)}\\[2pt]
& g(\gamma)=1-\dfrac{\displaystyle\sum_e\gamma_e\,v_e}{V^{\ast}}\le 0 
\quad\;\;\;\; \text{(solid / porosity constraint)}\\[4pt]
\end{aligned}
$$

### Rheology — Carreau–Yasuda model  

Blood is modelled as an incompressible non‑Newtonian fluid with viscosity  

$$
\eta({\gamma})=\eta_{\infty}+(\eta_0-\eta_{\infty})
\left[1+(\lambda{\gamma})^{a}\right]^{\tfrac{\,n-1\,}{a}},
$$  

where typical arterial‑blood parameters are  
$\eta_0=0.56$ Pa·s, $\eta_{\infty}=0.0345$ Pa·s,
$n=0.3568$, $\lambda=3.13\;\text{s}$ and $a=2.0$.

### Bifurcated‑Artery Boundary Conditions  

* **Inlet:** fully‑developed parabolic velocity profile with peak speed $U_{\text{max}}$ that sets the Reynolds number $Re=\rho U_{\text{max}}D/\eta_{\infty}$.  
* **Outlets:** zero‑gauge pressure ($p=0$).  
* **Walls:** no‑slip ($\mathbf u=\mathbf 0$).  
* Geometry follows a symmetric Y‑branch with half‑angle $\theta$; branch diameters equal the parent diameter for mass conservation (see `figures/artery.png`).  
<img src="../../figures/artery.png" alt="Channel" style="width: 500px;"/>

### Design parametrisation  

We adopt the standard **element-wise density method**: the design vector
$\gamma$ is mapped to the Brinkman damping coefficient
$\alpha(\gamma)$ via the RAMP scheme.

### Optimiser  

We employ the **Method of Moving Asymptotes (MMA)**.  
Gradients $\partial J/\partial\gamma$ and
$\partial g/\partial\gamma$ are produced automatically with JAX’s AD.

NOTE: We follow a convention similar to those found in the solid mechanics literature of Topology optimization where the solids are attributed with a pseudodensity of 1 and voids/fluids with a pseudodensity of 0. This is in contrast to literature in fluid mechanics where the vice-versa is often use.





In [None]:
import functools
import jax.numpy as jnp
import jax
from jax.typing import ArrayLike
import matplotlib.pyplot as plt
import numpy as np

import toflux.src.geometry as _geom
import toflux.src.mesher as _mesher
import toflux.src.material as _mat
import toflux.src.utils as _utils
import toflux.src.bc as _bc
import toflux.experiments.non_newtonian.fe_non_newtonian as _fea
import toflux.src.solver as _solv
import toflux.src.mma as _mma
import toflux.src.viz as _viz

jax.config.update("jax_enable_x64", True)
plt.rcParams.update(_viz.high_res_plot_settings)

_Field = _fea.FluidField
_cmap = _viz.fluid_cmap

# Define Geometry and Mesh 

Below we build the **artery** benchmark commonly used in fluid TO.

#### Geometry  
A simple artery is described in a json and passed as the geometry. This geometry is passed to the mesher.

#### Discretisation  
* The domain $(\Omega)$ is partitioned into **bilinear quadrilateral (Q1) elements**.  
* Both primary fields—velocity $(\mathbf u=(u_x,u_y))$ **and** pressure (p\)—are interpolated with the same Q1 shape functions (equal-order Q1/Q1 formulation).

#### Numerical integration  
Element contributions (mass, convection, brinkman, viscous and stability terms) are evaluated with a $(2 \times 2)$ Gauss quadrature.

This mesh serves as the baseline for the finite-element Navier–Stokes solve and for any subsequent density-based topology-optimisation runs.

In [None]:
geom = _geom.BrepGeometry("toflux/brep/artery.json")

mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=140,
  nely_desired=100,
  dofs_per_node=3,
  gauss_order=2,
)

density = jnp.zeros(mesh.num_elems)
_viz.plot_grid_mesh(mesh, density)
_viz.plot_brep(geom)

### Material — Carreau–Yasuda blood model  

In this example we model blood as an incompressible, shear‑thinning fluid
whose viscosity obeys the Carreau–Yasuda law

$$
\mu(\dot{\varepsilon}) =
\mu_{\infty} + (\mu_0 - \mu_{\infty})
\left[1 + (\lambda\dot{\varepsilon})^{a}\right]^{\tfrac{\,n-1\,}{a}} .
$$

| Symbol | Meaning | Value (unit) |
|--------|---------|--------------|
| $\eta_{\infty}$ | Infinite‑shear viscosity | **$3.45\times10^{-2}$** Pa·s |
| $\eta_{0}$ | Zero‑shear viscosity | **$5.6\times10^{-1}$** Pa·s|
| $\lambda$ | Time constant | **$3.313\ \text{s}$** |
| $a$ | Yasuda exponent | **2.0** |
| $n$ | Power‑law index | **0.3568** |



In [None]:
blood_mat = _mat.CarreauYasudaNonNewtonianFluid(
  eta_inf=3.45e-2, eta_0=0.8, lam=3.313, a=2.0, n=0.3568
)

## Cross non-Newtonian material model

In [None]:
# blood_mat = _mat.CrossNonNewtonianFluid(
#   eta_inf=0.1, eta_0=0.01, lam=0.025, n=2.0
# )

## Other fluid material properties

In [None]:

mat_params = _mat.FluidMaterial(
  mass_density=1.058,
  dynamic_viscosity=blood_mat.eta_inf,
)

min_mat_frac = 0.76
inv_permeability_ext = _utils.Extent(min=0.0, max=2500.0)

# Boundray Conditions



| Region | Type | Imposed values |
|--------|------|----------------|
| **Inlet top and bottom(left vertical edge)** | Dirichlet | $(u = parabolic)$, $(v = 0)$ |
| **Outlet bottom  (right vertical edge)** | Dirichlet v, p | \(p = 0\), \(v = 0\) |
| **Top wall** | No-slip (Dirichlet) | \(u = 0\), \(v = 0\) |
| **Bottom wall** | No-slip (Dirichlet) | \(u = 0\), \(v = 0\) |

**Characteristic velocity**

The inlet velocity is chosen to give a Reynolds number of 7.68:

$$
U_c \;=\; \frac{\mathrm{Re}\,\nu}{H},
\qquad
\text{Re} = 7.68,\;
\quad
\nu = \frac{\mu}{\rho},
\quad
H = \text{characteristic length}.
$$


In [None]:
inlet_fraction = 1.0 / 6.0
char_length = inlet_fraction * mesh.bounding_box.ly
reynolds_num = 7.68
char_velocity = reynolds_num * mat_params.kinematic_viscosity / (char_length)

In [None]:
face_tol = mesh.elem_size[0] * 0.5

# inlet top condition
inlet_top_faces = _bc.identify_faces(mesh, edges=[geom.edges[6]], tol=face_tol)
n = len(inlet_top_faces)
x_nodes = jnp.linspace(-1.0, 1.0, n + 1)
u_profile = char_velocity * (1.0 - x_nodes**2)
face_node_vals = jnp.stack([u_profile[1:], u_profile[:-1]], axis=1)
u_vel = (_Field.U_VEL, face_node_vals)
v_vel = (_Field.V_VEL, jnp.zeros_like(face_node_vals))
inlet_top_face_val = [u_vel, v_vel]

# inlet bottom condition
inlet_btm_faces = _bc.identify_faces(mesh, edges=[geom.edges[2]], tol=face_tol)
n = len(inlet_btm_faces)
x_nodes = jnp.linspace(-1.0, 1.0, n + 1)
u_profile = char_velocity * (1.0 - x_nodes**2)
face_node_vals = jnp.stack([u_profile[1:], u_profile[:-1]], axis=1)
u_vel = (_Field.U_VEL, face_node_vals)
v_vel = (_Field.V_VEL, jnp.zeros_like(face_node_vals))
inlet_btm_face_val = [u_vel, v_vel]

# wall condition
wall_edge_nos = [0, 4, 8, 1, 3, 5, 7, 11, 13, 14, 10, 9, 15]
wall_edges = [geom.edges[i] for i in wall_edge_nos]
wall_faces = _bc.identify_faces(mesh, edges=wall_edges, tol=face_tol)
n = len(wall_faces)
u_vel = (_Field.U_VEL, jnp.zeros(n))
v_vel = (_Field.V_VEL, jnp.zeros(n))
wall_face_val = [u_vel, v_vel]


# outlet condition
outlet_faces = _bc.identify_faces(mesh, edges=[geom.edges[12]], tol=face_tol)
n = len(outlet_faces)
v_vel = (_Field.V_VEL, jnp.zeros(n))
pres = (_Field.PRESSURE, jnp.zeros(n))
outlet_face_val = [v_vel, pres]


inlet_top_bc = _bc.DirichletBC(
  elem_faces=inlet_top_faces, values=inlet_top_face_val, name="inlet_top"
)
inlet_btm_bc = _bc.DirichletBC(
  elem_faces=inlet_btm_faces, values=inlet_btm_face_val, name="inlet_btm"
)
outlet_bc = _bc.DirichletBC(
  elem_faces=outlet_faces, values=outlet_face_val, name="outlet"
)

wall_bc = _bc.DirichletBC(
  elem_faces=wall_faces, values=wall_face_val, name="wall"
)

bcs_list = [
  wall_bc,
  inlet_top_bc,
  inlet_btm_bc,
  outlet_bc,
]

bc = _bc.process_boundary_conditions(
  bcs_list,
  mesh,
)

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
_viz.plot_grid_mesh(mesh, ax=ax, colorbar=False)
_viz.plot_bc(bcs_list, mesh, ax=ax)
plt.show()

# Solver

The incompressible Navier–Stokes–Brinkman system is **intrinsically non-linear** because of the convective term and the design-dependent viscosity .
Our solver therefore employs a *damped/modified Newton–Raphson* loop. For more details see 'solver.py'.

In [None]:
solver_settings = {
  "linear": {
    "solver": _solv.LinearSolvers.PETSC,
    "petsc_solver": {},
  },
  "nonlinear": {"max_iter": 10, "threshold": 1.0e-8},
}
flow_solver = _fea.FluidSolver(
  mesh=mesh,
  bc=bc,
  material=mat_params,
  non_newtonian_mat=blood_mat,
  solver_settings=solver_settings,
)

# Initialize

We initialize the FEA solver with the mesh, material, boundary conditions and solver settings. We provide an inital guess of the solution with the Dirichlet condition enforced.

In [None]:
init_press_vel = jnp.zeros((mesh.num_dofs,))
init_press_vel = init_press_vel.at[bc["fixed_dofs"]].set(bc["dirichlet_values"])

In [None]:
u_velocity = init_press_vel[1 : mesh.num_dofs : 3]
u_vel_elem = np.mean(u_velocity[mesh.elem_nodes], axis=1)
_viz.plot_grid_mesh(mesh, u_vel_elem)

# Optimize

Here, we define the optimization loop. We use the MMA optimizer. We begin the optimization with a uniform design corresponding to the minimum allowed material/solid fraction. The optimization flowchart is illustrated below:
<img src="../../figures/flowchart_non_newtonian.png" alt="flowchart" style="width: 800px;"/>

In [None]:
@functools.partial(jax.jit, static_argnames=("flow_solver",))
def objective_function(
  mat_frac: jnp.ndarray,
  flow_solver: _fea.FluidSolver,
  ramp_penalty: float,
  thresh_beta:float,
  init_press_vel=None,
):
  """Computes the objective function and its gradient.

  This function computes the objective as the total dissipated power in the flow field,
    which is the sum of the dissipated power in each element. The dissipated power is 
    computed using Brinkman penalty and the viscosity of the fluid. The Brinkman 
    penalty is computed using a convex ramp function based on the material fraction.
    The viscosity of the fluid is computed using the shear rate in each element using the 
    Non-Newtonian fluid model.
  Args:
    mat_frac: Material fraction array of shape (num_elems,). The array should contain 
      values between 0 and 1, representing the material fraction at each element. Where 
      0 indicates fluid and 1 indicates material. The material fraction can assume 
      intermediate values between 0 and 1 during optimization.
    flow_solver: Fluid solver instance.
    ramp_penalty: Penalty factor for the Brinkman interpolation using a convex ramp 
      function. The ramp penalty is  updated using a continuation scheme during
      optimization. In the beginning, the ramp penalty is set to a large value (which 
      makes the mat frac vs Brinkman penalty convex which can be determined by 
      brink_iter_factor) to allow fluid flow through the entire domain. As the 
      optimization progresses, the ramp penalty is reduced to allow the material 
      fraction to converge to material or fluid.
    init_press_vel: Initial pressure-velocity field of shape (num_dofs,). This is used 
      as the initial guess for the pressure-velocity field to Newton-Raphson
      iterations. It contains the Dirichlet boundary conditions applied to the
      pressure-velocity field.

  Returns:
    A tuple containing the objective value and a tuple of pressure-velocity field and 
      material fraction.
  """
  def objective_wrapper(mat_frac):
    mat_frac_thresh = _utils.threshold_filter(mat_frac, beta=thresh_beta)
    brinkman_penalty = _mat.compute_ramp_interpolation(
      prop=mat_frac_thresh,
      ramp_penalty=ramp_penalty,
      prop_ext=inv_permeability_ext,
      mode="convex",
    )
    press_vel = _solv.modified_newton_raphson_solve(
      flow_solver, init_press_vel, brinkman_penalty
    )
    obj_args = (
      brinkman_penalty,
      press_vel[mesh.elem_dof_mat],
      mesh.elem_node_coords,
    )
    elem_obj = jax.vmap(flow_solver.compute_elem_dissipated_power)(*obj_args)
    obj = jnp.sum(elem_obj)
    return obj, (press_vel, mat_frac_thresh)

  (obj, (press_vel, mat_frac_thresh)), d_obj = jax.value_and_grad(
    objective_wrapper, has_aux=True
  )(mat_frac)
  return obj, d_obj.reshape((-1, 1)), press_vel, mat_frac_thresh

# Constraint

We define a simple material constraint optimization problem where we restrict the minimum material/solid the design can assume. With the user defined `min_mat_frac` we compute the material constraint. Once again, the gradients are computed automatically via jax's `value_and_grad`.

In [None]:
def constraint_function(mat_frac: ArrayLike, max_vol_frac: float, thresh_beta: float):
  """Computes the constraint for the material fraction.
    This constraint ensures that the mean material fraction is above a minimum threshold.

    Args:
      mat_frac: Material fraction array of shape (num_elems,). The array should contain 
        values between 0 and 1, representing the material fraction at each element. Where 
        0 indicates fluid and 1 indicates material. The material fraction can assume 
        intermediate values between 0 and 1 during optimization.
      min_mat_frac: Minimum material fraction threshold ensures material presence in the 
        design.

    Returns:
      A scalar value representing the constraint violation.
    """
  def constraint_wrapper(mat_frac):
    mat_frac = _utils.threshold_filter(mat_frac, beta=thresh_beta)
    vol_frac = jnp.mean(mat_frac)
    return 1.0 - (vol_frac / max_vol_frac)

  vol_cons, d_vol_cons = jax.value_and_grad(constraint_wrapper)(mat_frac)
  return jnp.array([vol_cons]), d_vol_cons.T

# Optimize
Here, we define the optimization loop. We use the MMA optimizer. We begin the optimization with a uniform design corresponding to the minimum allowed material/solid fraction. 

In [None]:
def optimize_design(
  fe: _fea.FluidSolver,
  min_mat_frac: float,
  max_iter: int,
  move_limit: float = 1e-2,
  kkt_tol: float = 1e-3,
  step_tol: float = 1e-3,
  plot_interval: int = 5,
):
  """Optimizes the design using the Method of Moving Asymptotes (MMA) optimization.

  The function initializes the design variables with the minimum material fraction and 
    then iteratively updates them via the MMA approach. At each iteration, the objective 
    function and its gradient are computed using JAX `value_and_grad` function based 
    on the current material distribution. Additionally, a material constraint is enforced
    to ensure the design maintains at least the specified minimum material fraction. 

  Args:
    fe: Fluid solver instance containing mesh, boundary conditions, and material 
      properties.
    min_mat_frac: Minimum allowed material fraction to enforce material presence in the 
      design.
    max_iter: Maximum number of iterations for the MMA optimization loop.
    move_limit: Maximum allowable change in the design variable per iteration.
    kkt_tol: Tolerance for the Karush-Kuhn-Tucker (KKT) optimality criteria.
    step_tol: Tolerance for the optimization step size.
    plot_interval: Number of iterations between plot updates.

  Returns:
    mma_state: The final state of the MMA optimization, including the optimized design 
      variables, iteration count, and convergence flag.
    history: A dictionary recording the history of the objective function values and 
      constraint violations with keys:
      'obj'      - list of objective function values,
      'vol_cons' - list of volume constraint values.
  """
  design_var = min_mat_frac * np.ones((fe.mesh.num_elems, 1))
  num_design_var = design_var.shape[0]
  num_cons = 1
  lower_bound = np.zeros((num_design_var, 1))
  upper_bound = np.ones((num_design_var, 1))
  history = {"obj": [], "vol_cons": []}

  mma_params = _mma.MMAParams(
    max_iter=max_iter,
    kkt_tol=kkt_tol,
    step_tol=step_tol,
    move_limit=move_limit,
    num_design_var=num_design_var,
    num_cons=num_cons,
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  mma_state = _mma.init_mma(design_var, mma_params)

  press_vel = jnp.zeros((fe.mesh.num_dofs,))
  press_vel = press_vel.at[fe.bc["fixed_dofs"]].set(fe.bc["dirichlet_values"])

  while not mma_state.is_converged:
    ramp_penalty = 50. # max(10., 100.0 - 0.1 * mma_state.epoch)
    thresh_beta = 4. # min(64., 1. + mma_state.epoch*0.3)
    print("thresh_beta", thresh_beta)
    objective, grad_obj, press_vel, mat_frac = objective_function(
      mma_state.x, fe, ramp_penalty, thresh_beta, press_vel
    )
    constr, grad_cons = constraint_function(mma_state.x, min_mat_frac, thresh_beta)

    press_vel = jax.lax.stop_gradient(press_vel)
    press_vel = press_vel.at[fe.bc["fixed_dofs"]].set(fe.bc["dirichlet_values"])

    mma_state = _mma.update_mma(
      mma_state, mma_params, objective, grad_obj, constr, grad_cons
    )

    status = f"epoch {mma_state.epoch} J {objective:.2E} mc {constr[0]:.2F}"
    print(status)

    if mma_state.epoch % plot_interval == 0 or mma_state.epoch == 1:
      _, ax = plt.subplots(1, 1)
      ax = _viz.plot_grid_mesh(
        mesh=mesh,
        field=np.round(mat_frac).reshape(-1),
        ax=ax,
        colorbar=False,
        cmap=_cmap,
        val_range=(0.0, 1.0),
      )
      ax.set_xticks([])
      ax.set_yticks([])
      for spine in ax.spines.values():
        spine.set_visible(False)

      press_vel_elem = press_vel[mesh.elem_dof_mat]
      press_elem = np.mean(press_vel_elem[:, 0::3], axis=1)
      u_vel_elem = np.mean(press_vel_elem[:, 1::3], axis=1)
      v_vel_elem = np.mean(press_vel_elem[:, 2::3], axis=1)
      vel_elem_mag = np.sqrt(u_vel_elem**2 + v_vel_elem**2)
      _, ax = plt.subplots(1, 1)
      ax = _viz.plot_grid_mesh(mesh=mesh, field=vel_elem_mag, ax=ax, colorbar=True)
      ax.set_title("Velocity Magnitude")

      _, ax = plt.subplots(1, 1)
      ax = _viz.plot_grid_mesh(mesh=mesh, field=press_elem, ax=ax, colorbar=True)
      ax.set_title("Pressure")
      plt.show()

  return mma_state, history

In [None]:
mma_state, u = optimize_design(
  flow_solver, min_mat_frac=min_mat_frac, max_iter=300, move_limit=3e-2
)