In [None]:
import functools
import numpy as np
import jax.numpy as jnp
import jax
from jax.typing import ArrayLike
import matplotlib.pyplot as plt

import toflux.src.geometry as _geom
import toflux.src.utils as _utils
import toflux.src.mesher as _mesher
import toflux.src.material as _mat
import toflux.src.bc as _bc

import toflux.src.fe_thermal as _fea_thermal
import toflux.src.fe_fluid as _fea_fluid
import toflux.src.solver as _solv
import toflux.src.mma as _mma
import toflux.src.viz as _viz


_Ext = _utils.Extent
_FluidField = _fea_fluid.FluidField
_TempField = _fea_thermal.ThermalField
_cmap = _viz.fluid_cmap

jax.config.update("jax_enable_x64", True)


## Mesh

In [None]:
geom = _geom.BrepGeometry("toflux/brep/single_pipe.json")

fluid_mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=50,
  nely_desired=50,
  dofs_per_node=3,
  gauss_order=2,
)

thermal_mesh = _mesher.grid_mesh_brep(
  brep=geom,
  nelx_desired=50,
  nely_desired=50,
  dofs_per_node=1,
  gauss_order=2,
)

density = jnp.zeros(fluid_mesh.num_elems)
_viz.plot_grid_mesh(fluid_mesh, density)
_viz.plot_brep(geom)

In [None]:
dens_filter = _utils.create_density_filter(
  thermal_mesh.elem_centers,
  cutoff_distance=0.02 * thermal_mesh.bounding_box.diag_length,
  filter_type=_utils.Filters.CIRCULAR,
)

## Material

In [None]:
fluid_mat = _mat.FluidMaterial(
  mass_density=1000.0,
  dynamic_viscosity=6.6e-3,
  thermal_conductivity=1.0e-5,
  specific_heat=5.0e3,
)

thermal_mat = _mat.ThermalMaterial(
  thermal_conductivity=1.0e-4,
  specific_heat=5.0e3,
  mass_density=1000.0,
)

In [None]:
min_inv_permeability = _mat.brinkman_bound(
  fluid_mat.dynamic_viscosity, 100 * fluid_mesh.bounding_box.lx
)
max_inv_permeability = _mat.brinkman_bound(
  fluid_mat.dynamic_viscosity, 1.0e-2 * fluid_mesh.bounding_box.lx
)
init_inv_permeability = _mat.brinkman_bound(
  fluid_mat.dynamic_viscosity, 1.0e-1 * fluid_mesh.bounding_box.lx
)

inv_permeability_ext = _Ext(min=min_inv_permeability, max=max_inv_permeability)

In [None]:
desired_mat_frac = 0.6

brink_inter_factor = _mat.calculate_interpolation_factor(
  inv_permeability_ext, init_inv_permeability, desired_mat_frac
)


## Fluid boundary condition

In [None]:
inlet_fraction = 1.0 / 5.0
char_length = inlet_fraction * fluid_mesh.bounding_box.ly
reynolds_num = 0.9
char_velocity = reynolds_num * fluid_mat.kinematic_viscosity / char_length

In [None]:
face_tol = fluid_mesh.elem_size[0] * 0.5

inlet_faces = _bc.identify_faces(fluid_mesh, edges=[geom.edges[1]], tol=face_tol)
n = len(inlet_faces)
x_nodes = jnp.linspace(-1.0, 1.0, n + 1)
u_nodes = char_velocity * (1.0 - x_nodes**2)
face_node_vals = jnp.stack([u_nodes[1:], u_nodes[:-1]], axis=1)
u_vel = (_FluidField.U_VEL, face_node_vals)
v_vel = (_FluidField.V_VEL, jnp.zeros_like(face_node_vals))
inlet_face_val = [u_vel, v_vel]

# outlet condition
outlet_faces = _bc.identify_faces(fluid_mesh, edges=[geom.edges[5]], tol=face_tol)
n = len(outlet_faces)
v_vel = (_FluidField.V_VEL, jnp.zeros(n))
pres = (_FluidField.PRESSURE, jnp.zeros(n))
outlet_face_val = [v_vel, pres]

# wall condition
wall_faces = _bc.identify_faces(
  fluid_mesh,
  edges=[
    geom.edges[0],
    geom.edges[2],
    geom.edges[3],
    geom.edges[4],
    geom.edges[6],
    geom.edges[7],
  ],
  tol=face_tol,
)
n = len(wall_faces)
u_vel = (_FluidField.U_VEL, jnp.zeros(n))
v_vel = (_FluidField.V_VEL, jnp.zeros(n))
wall_face_val = [u_vel, v_vel]

inlet_bc = _bc.DirichletBC(elem_faces=inlet_faces, values=inlet_face_val, name="inlet")
outlet_bc = _bc.DirichletBC(
  elem_faces=outlet_faces, values=outlet_face_val, name="outlet"
)
wall_bc = _bc.DirichletBC(elem_faces=wall_faces, values=wall_face_val, name="wall")

fluid_bc_list = [inlet_bc, outlet_bc, wall_bc]
fluid_bc = _bc.process_boundary_conditions(fluid_bc_list, fluid_mesh)

_viz.plot_bc(fluid_bc_list, fluid_mesh)

## Thermal boundary condition

In [None]:
def inlet_cond(node_coords):
  left = jnp.all(node_coords[:, 0] <= dx / 2.0)
  start = jnp.all(node_coords[:, 1] >= 2.0 * thermal_mesh.bounding_box.ly / 5.0 - dy)
  end = jnp.all(node_coords[:, 1] <= 3.0 * thermal_mesh.bounding_box.ly / 5.0 + dy)
  return left and start and end


def top_face_cond(node_coords):
  top = jnp.all(node_coords[:, 1] >= thermal_mesh.bounding_box.ly - dy / 2.0)
  return top


def btm_face_cond(node_coords):
  btm = jnp.all(node_coords[:, 1] <= dy / 2.0)
  return btm

In [None]:
# inlet condition
inlet_faces = _bc.identify_faces(thermal_mesh, edges=[geom.edges[1]], tol=face_tol)
n = len(inlet_faces)
tv = (_TempField.TEMPERATURE, 273.0 * jnp.ones(n))
inlet_face_val = [tv]

# top condition
top_faces = _bc.identify_faces(
  thermal_mesh,
  edges=[
    geom.edges[3],
  ],
  tol=face_tol,
)
n = len(top_faces)
tv = (_TempField.TEMPERATURE, 283.0 * jnp.ones(n))
top_face_val = [tv]

# bottom condition
bottom_faces = _bc.identify_faces(  thermal_mesh,
  edges=[
    geom.edges[7],
  ],
  tol=face_tol,)
n = len(bottom_faces)
tv = (_TempField.TEMPERATURE, 283.0 * jnp.ones(n))
btm_face_val = [tv]

inlet_bc = _bc.DirichletBC(elem_faces=inlet_faces, values=inlet_face_val, name="Inlet")
top_bc = _bc.DirichletBC(elem_faces=top_faces, values=top_face_val, name="Top")
bottom_bc = _bc.DirichletBC(elem_faces=bottom_faces, values=btm_face_val, name="Btm")


thermal_bcs_list = [
  inlet_bc,
  top_bc,
  bottom_bc,
]
_viz.plot_bc(thermal_bcs_list, thermal_mesh)

thermal_bc = _bc.process_boundary_conditions(thermal_bcs_list, thermal_mesh)

In [None]:
in_box = _mesher.BoundingBox(
  x=_Ext(min=0.0, max=fluid_mesh.elem_size[0] ),
  y=_Ext(min=fluid_mesh.bounding_box.ly * 2.0 / 5.0, max=fluid_mesh.bounding_box.ly * 3.0 / 5.0),
)

out_box = _mesher.BoundingBox(
  x=_Ext(min=fluid_mesh.bounding_box.lx - fluid_mesh.elem_size[0], max=fluid_mesh.bounding_box.lx),
  y=_Ext(min=fluid_mesh.bounding_box.ly * 2.0 / 5.0, max=fluid_mesh.bounding_box.ly * 3.0 / 5.0),
)

inlet_elems = _mesher.compute_point_indices_in_box(thermal_mesh.elem_centers, in_box)
outlet_elems = _mesher.compute_point_indices_in_box(thermal_mesh.elem_centers, out_box)


In [None]:
inlet_elems

In [None]:
density = np.zeros((thermal_mesh.num_elems,))
density[inlet_elems] = 1.0
density[outlet_elems] = 2.0
_viz.plot_grid_mesh(thermal_mesh, density.reshape(-1))


## Solver

In [None]:
solver_settings = {
  "linear": {
    "solver": _solv.LinearSolvers.SCIPY_SPARSE,
    "rtol": 1.0e-5,
    "petsc_solver": {},
  },
  "nonlinear": {"max_iter": 10, "threshold": 1.0e-12},
}
flow_solver = _fea_fluid.FluidSolver(
  fluid_mesh, fluid_bc, fluid_mat, solver_settings=solver_settings
)

thermal_solver = _fea_thermal.FEA(
  mesh=thermal_mesh,
  material=thermal_mat,
  bc=thermal_bc,
  solver_settings=solver_settings,
)

## Fluid objective extents

In [None]:
press_vel = jnp.zeros((fluid_mesh.num_dofs,))
press_vel = press_vel.at[fluid_bc["fixed_dofs"]].set(fluid_bc["dirichlet_values"])

density = desired_mat_frac * np.ones((fluid_mesh.num_elems,))
brinkman_penalty = _mat.compute_ramp_interpolation(
  prop=density,
  ramp_penalty=brink_inter_factor,
  prop_ext=inv_permeability_ext,
  mode="convex",
)

press_vel = _solv.modified_newton_raphson_solve(
  flow_solver, press_vel, brinkman_penalty
)

dissip_pow = jnp.sum(
  jax.vmap(flow_solver.compute_elem_dissipated_power)(
    brinkman_penalty, press_vel[fluid_mesh.elem_dof_mat], fluid_mesh.elem_node_coords
  )
)


fluid_obj_scale = 1.0e-9
print("dissip_pow", dissip_pow)

In [None]:
press_vel_elem = press_vel[fluid_mesh.elem_dof_mat]
press_elem = np.mean(press_vel_elem[:, 0::3], axis=1)
u_vel_elem = np.mean(press_vel_elem[:, 1::3], axis=1)
v_vel_elem = np.mean(press_vel_elem[:, 2::3], axis=1)
vel_elem_mag = np.sqrt(u_vel_elem**2 + v_vel_elem**2)
_, ax = plt.subplots(1, 1)
ax = _viz.plot_grid_mesh(mesh=fluid_mesh, field=vel_elem_mag, ax=ax, colorbar=True)
ax.set_title("Velocity Magnitude")

In [None]:
num_vel_dofs_per_elem = fluid_mesh.num_dim * fluid_mesh.elem_template.num_nodes
elem_vel = jnp.zeros((fluid_mesh.num_elems, num_vel_dofs_per_elem))
elem_vel = elem_vel.at[:, 0 :: fluid_mesh.num_dim].set(press_vel_elem[:, 1::3])
elem_vel = elem_vel.at[:, 1 :: fluid_mesh.num_dim].set(press_vel_elem[:, 0::3])

In [None]:
temp = jnp.zeros((thermal_mesh.num_dofs,))
temp = temp.at[thermal_bc["fixed_dofs"]].set(thermal_bc["dirichlet_values"])


diffusivity_ext = _utils.Extent(min=fluid_mat.diffusivity, max=thermal_mat.diffusivity)

eff_diffusivity = _mat.compute_ramp_interpolation(
  prop=density,
  ramp_penalty=1.0,
  prop_ext=diffusivity_ext,
  mode="concave",
)

temp = _solv.modified_newton_raphson_solve(
  thermal_solver, temp, elem_vel, eff_diffusivity
)


temp_elem = jnp.mean(temp[thermal_mesh.elem_dof_mat], axis=1)
elem_inlet_temp = temp_elem[inlet_elems]
elem_outlet_temp = temp_elem[outlet_elems]

elem_inlet_vel = u_vel_elem[inlet_elems]
elem_outlet_vel = u_vel_elem[outlet_elems]

therm_power_in = elem_inlet_temp * elem_inlet_vel
therm_power_out = elem_outlet_temp * elem_outlet_vel

therm_power = (
  jnp.sum(therm_power_in - therm_power_out)
  * fluid_mat.specific_heat
  * fluid_mat.mass_density
)
print("Thermal power", therm_power)


_, ax = plt.subplots(1, 1)
ax = _viz.plot_grid_mesh(mesh=thermal_mesh, field=temp_elem, ax=ax, colorbar=True)
ax.set_title("Temperature")

In [None]:
@functools.partial(jax.jit, static_argnames=("flow_solver", "thermal_solver"))
def objective_function(
  mat_frac: jnp.ndarray,
  flow_solver: _fea_fluid.FluidSolver,
  thermal_solver: _fea_thermal.FEA,
  fluid_ramp_penalty: float,
  thermal_ramp_penalty: float,
  press_vel_guess=None,
  temp_guess=None,
):
  def objective_wrapper(mat_frac):
    # filter the density
    # filt_mat_frac = dens_filter @ mat_frac
    # mat_frac = _utils.threshold_filter(filt_mat_frac, beta=4.0)

    # solve the flow problem
    brinkman_penalty = _mat.compute_ramp_interpolation_convex(
      property=mat_frac,
      ramp_penalty=fluid_ramp_penalty,
      property_ext=inv_permeability_ext,
    )
    press_vel = _solv.modified_newton_raphson_solve(
      flow_solver, press_vel_guess, brinkman_penalty
    )

    # get elem velocities
    elem_press_vel = press_vel[fluid_mesh.elem_dof_mat]  # (elems, num_dofs_per_elem)
    elem_u_vel = elem_press_vel[:, 1 :: fluid_mesh.nodes.dof_per_node]
    elem_v_vel = elem_press_vel[:, 2 :: fluid_mesh.nodes.dof_per_node]

    elem_vel = jnp.zeros((fluid_mesh.num_elems, num_vel_dofs_per_elem))
    elem_vel = elem_vel.at[:, 0 :: fluid_mesh.num_dim].set(elem_u_vel)
    elem_vel = elem_vel.at[:, 1 :: fluid_mesh.num_dim].set(elem_v_vel)

    # solve the thermal problem
    eff_diffusivity = _mat.compute_ramp_interpolation_concave(
      property=mat_frac,
      ramp_penalty=thermal_ramp_penalty,
      property_ext=diffusivity_ext,
    )

    temp = _solv.modified_newton_raphson_solve(
      thermal_solver, temp_guess, elem_vel, eff_diffusivity
    )

    # compute the weighted objective
    dissip_pow = jnp.sum(
      jax.vmap(flow_solver.compute_elem_dissipated_power)(
        brinkman_penalty,
        press_vel[fluid_mesh.elem_dof_mat],
        fluid_mesh.elem_node_coords,
      )
    )
    # jax.debug.print("Dissipated power: {dissip_pow}", dissip_pow=dissip_pow)

    temp_elem = temp[thermal_mesh.elem_dof_mat]
    elem_inlet_temp = jnp.mean(temp_elem[inlet_elems], axis=1)
    elem_outlet_temp = jnp.mean(temp_elem[outlet_elems], axis=1)

    elem_inlet_vel = jnp.mean(elem_u_vel[inlet_elems], axis=1)
    elem_outlet_vel = jnp.mean(elem_u_vel[outlet_elems], axis=1)

    therm_power_in = elem_inlet_temp * elem_inlet_vel
    therm_power_out = elem_outlet_temp * elem_outlet_vel

    therm_power = (
      jnp.sum(therm_power_out - therm_power_in)
      * fluid_mat.specific_heat
      * fluid_mat.mass_density
    )
    # jax.debug.print("Thermal power: {therm_power}", therm_power=therm_power)
    w = 0.4
    scaled_therm_power = therm_power / thermal_obj_scale
    scaled_dissip_pow = dissip_pow / fluid_obj_scale
    obj = -(w * scaled_therm_power) + ((1.0 - w) * scaled_dissip_pow)
    jax.debug.print("scaled_therm_power: {x}", x=scaled_therm_power)
    jax.debug.print("scaled_dissip_pow: {x}", x=scaled_dissip_pow)

    return obj, (press_vel, temp, mat_frac)

  (obj, (press_vel, temp, mat_frac)), d_obj = jax.value_and_grad(
    objective_wrapper, has_aux=True
  )(mat_frac)
  return obj, d_obj.reshape((-1, 1)), press_vel, temp, mat_frac

In [None]:
def constraint_function(mat_frac: ArrayLike, max_vol_frac: float):
  def constraint_wrapper(mat_frac):
    # filt_mat_frac = dens_filter @ mat_frac
    # mat_frac = _utils.threshold_filter(filt_mat_frac, beta=4.0)
    vol_frac = jnp.mean(mat_frac)
    return 1.0 - (vol_frac / max_vol_frac)

  vol_cons, d_vol_cons = jax.value_and_grad(constraint_wrapper)(mat_frac)
  return jnp.array([vol_cons]), d_vol_cons.T

In [None]:
def optimize_design(
  fluid_fe: _fea_fluid.FluidSolver,
  thermal_fe: _fea_thermal.FEA,
  min_mat_frac: float,
  max_iter: int,
  move_limit: float = 1e-2,
  kkt_tol: float = 1e-5,
  step_tol: float = 1e-5,
  plot_interval: int = 5,
):
  design_var = min_mat_frac * np.ones((thermal_fe.mesh.num_elems, 1))
  num_design_var = design_var.shape[0]
  num_cons = 1
  lower_bound = np.zeros((num_design_var, 1))
  upper_bound = np.ones((num_design_var, 1))
  mma_params = _mma.MMAParams(
    max_iter=max_iter,
    kkt_tol=kkt_tol,
    step_tol=step_tol,
    move_limit=move_limit,
    num_design_var=num_design_var,
    num_cons=num_cons,
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  mma_state = _mma.init_mma(design_var, mma_params)

  def plotfun(x, status=""):
    plt.figure()
    plt.imshow(x.reshape((fluid_mesh.nelx, fluid_mesh.nely)).T, cmap=_cmap)
    plt.title(status)
    plt.colorbar()
    plt.show()

  history = {"obj": [], "vol_cons": []}

  press_vel = jnp.zeros((fluid_fe.mesh.num_dofs,))
  press_vel = press_vel.at[fluid_fe.bc["fixed_dofs"]].set(
    fluid_fe.bc["dirichlet_values"]
  )

  temp = jnp.zeros((thermal_fe.mesh.num_dofs,))
  temp = temp.at[thermal_fe.bc["fixed_dofs"]].set(thermal_fe.bc["dirichlet_values"])

  while not mma_state.is_converged:
    cont_param = min(20.0, 1.0 + 19.0 * (mma_state.epoch / 100.0) ** 2)
    fluid_ramp_penalty = brink_inter_factor / cont_param
    thermal_ramp_penalty = min(1.0, 0.01 + mma_state.epoch * 0.02)

    objective, grad_obj, press_vel, temp, mat_frac = objective_function(
      mma_state.x,
      fluid_fe,
      thermal_fe,
      fluid_ramp_penalty,
      thermal_ramp_penalty,
      press_vel,
      temp,
    )
    press_vel = jax.lax.stop_gradient(press_vel)
    press_vel = press_vel.at[fluid_fe.bc["fixed_dofs"]].set(
      fluid_fe.bc["dirichlet_values"]
    )

    temp = jax.lax.stop_gradient(temp)
    temp = temp.at[thermal_fe.bc["fixed_dofs"]].set(thermal_fe.bc["dirichlet_values"])

    constr, grad_cons = constraint_function(mma_state.x, min_mat_frac)
    mma_state = _mma.update_mma(
      mma_state, mma_params, objective, grad_obj, constr, grad_cons
    )

    status = f"epoch {mma_state.epoch} J {objective:.2E} mc {constr[0]:.2F}"
    print(status)
    if mma_state.epoch % plot_interval == 0 or mma_state.epoch == 1:
      plotfun(mat_frac, status)

  return mma_state, history

In [None]:
mma_state, u = optimize_design(
  flow_solver,
  thermal_solver,
  min_mat_frac=desired_mat_frac,
  max_iter=100,
  move_limit=4.0e-1,
)

In [None]:
node_id_jac = np.stack((thermal_mesh.iK, thermal_mesh.jK)).astype(np.int32).T
shp_fn = jax.vmap(thermal_mesh.elem_template.shape_functions)(thermal_mesh.gauss_pts)
def _compute_elem_stabilization(
  velocity: jnp.ndarray,
  diffusivity: jnp.ndarray,
  elem_char_length: jnp.ndarray,
) -> jnp.ndarray:
  """Returns the stabilization parameter for SUPG term.

    This function computes the stabilization parameter (τ) used in
      convection-diffusion problems for each element.The stabilization
      parameter (τ) is computed using an approximate minimum function considering
      two limiting cases:

    - τ₁: Convective limit
    - τ₃: Diffusive limit

    The stabilization parameter is assumed constant within each element, and
    τ₁ is computed based on the velocity components evaluated at the element centroid.

    Stabilization parameter (τ) is computed as:

    τ = ( τ₁⁻² + τ₃⁻² )^(-1/2)

    Where:
    τ₁ = h / ( 2√(uᵢ uᵢ) )   # Convective limit
    τ₃ = h² / (12 α)         # Diffusive limit
    τ_2 is the transient limit and ignored

    Variables:
    - h: Element characteristic length
    - uᵢ: Velocity components
    - α: Thermal diffusivity

    For details see [Alexandreasen 2023 SMO], Appendix B, eq 49
    https://link.springer.com/article/10.1007/s00158-022-03420-9

  Args:
    velocity: An array of shape (num_velocity_dofs_per_elem,) of the velocity field at
      the element dofs. The velocity is assumed to be in the order of
      [u₁, v₁, w₁, u₂, v₂, w₂, ...] for 2D and 3D problems.
    diffusivity: A scalar array containing the thermal diffusivity.
    elem_char_length: A scalar array containing the characteristic length of the
      element.

  Returns: The stabilization parameter.
  """
  gp_center = jnp.zeros((thermal_mesh.num_dim,))
  shp_fn = thermal_mesh.elem_template.shape_functions(gp_center)

  # (d)(i)m, (g)auss, (n)odes_per_elem
  u0 = jnp.einsum("n, nd -> nd", shp_fn, velocity.reshape(-1, thermal_mesh.num_dim))
  ue = jnp.einsum("nd, nd -> ", u0, u0)

  inv_sq_tau1 = (4 * ue) / elem_char_length**2
  tau_3 = (elem_char_length**2) / (12 * diffusivity)

  return (inv_sq_tau1 + tau_3 ** (-2)) ** (-1 / 2)

def _compute_elem_residual(
  temperature: jnp.ndarray,
  velocity: jnp.ndarray,
  diffusivity: jnp.ndarray,
  node_coords: jnp.ndarray,
  elem_char_length: float,
) -> jnp.ndarray:
  """Computes the elemental residual of the thermal stiffness matrix.

  The weak form of the steady-state energy equation (dimensional form) with SUPG
  stabilisation can be written as:

      ∫_Ω_e  w * u_j * ∂T/∂x_j dΩ_e                   (convection)
    + ∫_Ω_e  (∂w/∂x_j) * α * ∂T/∂x_j  dΩ_e            (diffusion)
    + ∫_Ω_e  τ_T * u_j * ∂w/∂x_j * R_T(u, T) dΩ_e    (SUPG)
    = 0

  where:
        Ω_e      : entire analysis domain
        u_j      : velocity component in direction x_j
        T        : temperature field
        w        : weight / test function
        α        : thermal diffusivity
        τ_T      : SUPG stabilisation parameter
        R_T(u,T) : strong-form residual of the energy equation

  Where:
        R_T = u_j * (∂T/∂x_j)  - α* (∂²T/∂x_j∂x_j)


  For more details
    1. see eq (2d) in:
      Subramaniam, V.,etal. "Topology optimization of conjugate heat transfer systems:
      A competition between heat transfer enhancement and pressure drop reduction."
      International Journal of Heat and Fluid Flow 75 (2019): 165-184.
    2. eq (A.44) in:
      Alexandersen, Joe "Topology optimisation for coupled convection problems" (2013)

  NOTE: This implementation assumes there are no externally applied surface heat flux
    or volumetric heat sources. This simplification is valid only for the opimization
    problems considered herein. For problems such as heat sinks, with heat generation
    these terms need to be added to the residual.

  Args:
    temperature: Array of (num_dofs_per_elem,) containing the temperature of the nodes
      of an element.
    velocity: Array of (num_nodes_per_elem * num_dim,) containing the velocity at the
      nodes of an element. The velocity  are assumed to be ordered as
      (u1, v1, w2 u2, v2, w2...) etc. The velocity is part of the convective heat
      transfer.
    diffusivity: Scalar value of the element's thermal diffusivity.
    node_coords: Array of (num_nodes_per_elem, num_dims) containing the coordinates of
      the nodes of an element.
    elem_char_length: Scalar value of the diagonal length of the element.

  Returns: Array of (num_dofs_per_elem,) containing the residual of the element. The
    resiudal's ordered is assumed as (t1, t2, t3,...) of the temperature at the nodes.
  """
  # (d)(i)m, (g)auss, (n)odes_per_elem = (t)emp_dofs_per_elem, (v)el_dofs_per_elem
  grad_shp_fn = jax.vmap(
    thermal_mesh.elem_template.get_gradient_shape_function_physical, in_axes=(0, None)
  )(thermal_mesh.gauss_pts, node_coords)  # (g,n,d)

  _, det_jac = jax.vmap(
    thermal_mesh.elem_template.compute_jacobian_and_determinant, in_axes=(0, None)
  )(thermal_mesh.gauss_pts, node_coords)

  stab_param = _compute_elem_stabilization(
    velocity, diffusivity, elem_char_length
  )
  vel_gauss = jnp.einsum(
    "gn, nd -> gd", shp_fn, velocity.reshape(-1, thermal_mesh.num_dim)
  )
  dtemp_xy = jnp.einsum("gnd, n -> gd", grad_shp_fn, temperature)

  res_conv = jnp.einsum("gn, gd, gd -> gn", shp_fn, vel_gauss, dtemp_xy)
  res_diff = diffusivity * jnp.einsum("gnd, gd -> gn", grad_shp_fn, dtemp_xy)

  res_strong_form = jnp.einsum("gd, gd -> g", vel_gauss, dtemp_xy)
  u_dot_grad_w = jnp.einsum("gd, gnd -> gn", vel_gauss, grad_shp_fn)
  res_conv_supg = stab_param * jnp.einsum("gn, g-> gn", u_dot_grad_w, res_strong_form)

  net_res = res_diff + res_conv + res_conv_supg
  return jnp.einsum("gn, g, g -> n", net_res, thermal_mesh.gauss_weights, det_jac)

temperature = jnp.array([2.5, 3.5, 6.5, 7.5])
velocity = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
diffusivity = jnp.array([1.0])
node_coords = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
elem_char_length = 5.0
_compute_elem_residual(
  temperature,
  velocity,
  diffusivity,
  node_coords,
  elem_char_length,
)