# Title

Description of the notebook, download link, vertical or horizontal method of lines?!

In [3]:
## example 17: navier stokes with heat conduction and convection and temperature-dependent viscosity
# FIXME: ist der Druck richtig? Starke Oszillationen, aber stabil...

# Tutorial Example: Navier–Stokes with Heat Conduction, Convection, and Temperature-Dependent Viscosity

In [4]:
import os

import jax
import jax.numpy as jnp
import meshio
from flax.core import FrozenDict

from autopdex import utility, spaces, seeder, models, dae

jax.config.update("jax_enable_x64", True)


SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 1550-1551: malformed \N character escape (dae.py, line 25)

Load the mesh that was generated with GMSH

In [None]:
# Load meshes generated with gmsh
mesh = meshio.read(r"C:\JAX\autopdex_related\autopdex\examples\meshes\navier_stokes_mesh_v.msh")
coords_v = jnp.asarray(mesh.points[:, :2])
cells_v = jnp.asarray(mesh.cells_dict["triangle6"])

mesh = meshio.read(r"C:\JAX\autopdex_related\autopdex\examples\meshes\navier_stokes_mesh_p.msh")
coords_p = jnp.asarray(mesh.points[:, :2])
cells_p = jnp.asarray(mesh.cells_dict["triangle"])

coords_T = coords_v
cells_T = cells_v

# Extend node coordinates to include the temperature field (using the velocity nodes)
node_coordinates = {
  '1velocity': coords_v,
  '2pressure': coords_p,
  '3temperature': coords_T,
}






## Boundary conditions:
 - Parabolic inflow at x = 0,
 - p = 0 at x = L (outflow),
 - No-slip on the cylinder and top/bottom walls.

 - Inflow (x = 0): T = 0 (Dirichlet)
 - Cylinder: T ramps from 0 to 20 over the first 3.5 seconds (Dirichlet)
 - Top and bottom walls: insulated (Neumann, natural condition)

In [None]:
def on_left(x):
  return jnp.isclose(x[:, 0], 0.0)

selection_v_left = utility.dof_select(on_left(coords_v), jnp.asarray([True, True]))

def on_top_bottom_circle(x):
  # For velocity: top/bottom walls and cylinder (no-slip)
  return (jnp.isclose(x[:, 1], 0.0) + jnp.isclose(x[:, 1], 0.41) + jnp.isclose(
      (x[:, 0] - 0.2)**2 + (x[:, 1] - 0.2)**2, 0.05**2))

selection_v_remaining = utility.dof_select(on_top_bottom_circle(coords_v), jnp.asarray([True, True]))

def on_right(x):
  # return jnp.isclose(x[:, 0], 2.2)
  return jnp.isclose(x, jnp.array([2.2, 0.])).all(axis=1)

selection_p = utility.dof_select(on_right(coords_p), jnp.asarray([True]))

def on_left_temp(x):
  return jnp.isclose(x[:, 0], 0.0)

selection_temp_left = utility.dof_select(on_left_temp(coords_T), jnp.asarray([True]))

def on_cylinder(x):
  # Cylinder boundary (tolerance may be added if needed)
  return jnp.isclose((x[:, 0] - 0.2)**2 + (x[:, 1] - 0.2)**2, 0.05**2)

selection_temp_cylinder = utility.dof_select(on_cylinder(coords_T), jnp.asarray([True]))

# Combine Dirichlet selections for temperature
selection_temp = selection_temp_left + selection_temp_cylinder

# Combine Dirichlet dofs for all fields
dirichlet_dofs = {
  '1velocity': selection_v_left + selection_v_remaining,
  '2pressure': selection_p,
  '3temperature': selection_temp,
}

# Set initial Dirichlet conditions
dirichlet_conditions = utility.dict_zeros_like(dirichlet_dofs, dtype=jnp.float64)


In [None]:

# Time-dependent Dirichlet condition for the velocity inflow
def get_dirichlet_conditions(t):

  def condition(y_coor):
    u_max = 1.5
    u = 4 * u_max * y_coor * (0.41 - y_coor) / 0.41**2

    # 3.5 seconds sinusoidal ramp-up, then constant
    cond_list = [t <= 3.5, t > 3.5]
    func_list = [u * jnp.sin(jnp.pi / 7 * t), u]
    return jnp.asarray([jnp.piecewise(t, cond_list, func_list), 0.])

  new_values = jax.vmap(condition)(coords_v[:, 1])
  return jnp.where(selection_v_left, new_values, 0.0)

def pre_step_updates(t, settings):
  settings['current time'] = t

  # Update velocity boundary conditions at time t
  new_dirichlet = get_dirichlet_conditions(t)
  settings['dirichlet conditions']['1velocity'] = new_dirichlet

  # Update temperature boundary conditions at time t
  new_temp_bc = jnp.zeros_like(dirichlet_dofs['3temperature'], dtype=jnp.float64)
  # Cylinder temperature: ramp from 0 to 20 within the first 3.5 seconds
  cylinder_temp = jnp.where(t <= 3.5, 20 * jnp.sin(jnp.pi / 7 * t), 20.0)
  new_temp_bc = new_temp_bc.at[selection_temp_cylinder].set(cylinder_temp)
  settings['dirichlet conditions']['3temperature'] = new_temp_bc

  return settings


### Weak formulation: Coupled equations (Navier–Stokes and Energy)


In [None]:
def weak_form_integrand(x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, set):
  """
  Local integrand of the coupled equations:

  Navier–Stokes:
    ∂ₜ u + (u · ∇) u - ν Δ u + ∇ p = f,   and   div(u) = 0

  Energy (Temperature) equation:
    ∂ₜ T + u · ∇ T - α Δ T = 0

  Coupling is achieved through the convection term (u · ∇ T) and a temperature-dependent viscosity.
  """
  t = settings['current time']

  # Pressure (p) and test function (q)
  p = trial_ansatz['2pressure'](x_int, t)[0]
  q = test_ansatz['2pressure'](x_int)[0]

  # Velocity (u) and its derivatives
  u = trial_ansatz['1velocity'](x_int, t)
  grad_u, du_dt = jax.jacfwd(trial_ansatz['1velocity'], (0, 1))(x_int, t)
  du_dt = du_dt.flatten()  # flatten to obtain a vector

  # Test function (v) and its derivatives
  v = test_ansatz['1velocity'](x_int)
  grad_v = jax.jacfwd(test_ansatz['1velocity'])(x_int)

  # Temperature (T) and test function (φ)
  T_val = trial_ansatz['3temperature'](x_int, t)[0]
  grad_T, dT_dt = jax.jacfwd(trial_ansatz['3temperature'], (0, 1))(x_int, t)
  grad_T = grad_T.flatten()
  dT_dt = dT_dt[0]
  phi = test_ansatz['3temperature'](x_int)[0]
  grad_phi = jax.jacfwd(test_ansatz['3temperature'])(x_int)
  grad_phi = grad_phi.flatten()

  # Divergence of u and v
  div_u = grad_u[0, 0] + grad_u[1, 1]
  div_v = grad_v[0, 0] + grad_v[1, 1]

  # Convective term: (u · ∇) u
  conv_u = jnp.einsum('j,ij->i', u, grad_u)

  # Temperature-dependent viscosity:
  # Base viscosity and temperature coefficient
  nu0 = 0.001
  beta = 0.05
  nu = nu0 * (1 + beta * T_val)

  # Weak form of the momentum equation
  weak_momentum = (du_dt @ v + conv_u @ v + nu * jnp.einsum('ij,ij->', grad_u, grad_v) - p * div_v)

  # Weak form of the continuity equation
  weak_continuity = -q * div_u

  # Thermal diffusivity parameter (α)
  alpha = 0.001

  # Weak form of the energy (temperature) equation
  weak_energy = (dT_dt * phi + (u @ grad_T) * phi + alpha * jnp.dot(grad_T, grad_phi))

  return weak_momentum + weak_continuity + weak_energy


### Set-up the element and time stepping procedure

In [None]:
# Ansatz functions for the fields
ansatz_fun = {
  '1velocity': spaces.fem_iso_line_tri_tet,  # quadratic (6 nodes per triangle)
  '2pressure': spaces.fem_iso_line_tri_tet,  # linear (3 nodes per triangle)
  '3temperature': spaces.fem_iso_line_tri_tet,
}

# Numerical integration
ref_int_coor, ref_int_weights = seeder.int_pts_ref_tri(order=4)

# Set-up the 'user residual' based on a weak-in-space formulation of the PDEs.
user_residual = models.mixed_reference_domain_residual_time(
    weak_form_integrand,
    ansatz_fun,
    ref_int_coor,
    ref_int_weights,
    mapping_key='1velocity',
)

# The settings are analougs to static analyses. Additionally, the 'current time' has to be initialized.
settings = {
  'connectivity': ({
    '1velocity': cells_v,
    '2pressure': cells_p,
    '3temperature': cells_T,
  },),
  'node coordinates': node_coordinates,
  'dirichlet dofs': dirichlet_dofs,
  'dirichlet conditions': dirichlet_conditions,
  'current time': 0.0,
}

# The static settings are also passed to the PDE functions. 
# When used through the time stepping manager, also the 'dae' and 'time integrators' have to be specified.
static_settings = FrozenDict({
  'assembling mode': ('user residual',),
  'solution structure': ('nodal imposition',),
  'model': (user_residual,),
  'solver type': 'newton',
  'solver backend': 'pardiso',
  'solver': 'lu',
  'verbose': 2,
  'dae': 'call pde',    # Here we need the keyword 'call pde' in order to call the assembling routines for PDE specified by the 'model'.
  'time integrators': {
    '1velocity': dae.BackwardDiffFormula(2),
    '2pressure': dae.BackwardEuler(),     # The time derivative is not used, but an integrator has to be specified, 
                                          #   that is compatible with the stages of the other fields
    '3temperature': dae.AdamsMoulton(1),
  },
})

# Define the postprocessing policy and pass the function for pre-step updates
manager = dae.TimeSteppingManager(static_settings,
                              save_policy=dae.SaveAllPolicy(),
                              pre_step_updates=pre_step_updates)

# Start with zero velocity and temperature
initial_conditions = {
  '1velocity': jnp.zeros((coords_v.shape[0], 2)),
  '2pressure': jnp.zeros((coords_p.shape[0], 1)),
  '3temperature': jnp.zeros((coords_T.shape[0], 1)),
}

### Run the time stepping loop

In [None]:
t_final = 10
num_time_steps = 500
result = manager.run(initial_conditions, t_final / num_time_steps, t_final, num_time_steps, settings)

Iteration 0, Residual norm: 0.017871636935677313
Iteration 1, Residual norm: 2.7893455422508244e-05
Iteration 2, Residual norm: 2.958197350170137e-09
 
Iteration 0, Residual norm: 0.0206150592227377
Iteration 1, Residual norm: 2.966663907177073e-05
Iteration 2, Residual norm: 3.326061805516559e-09
 
Iteration 0, Residual norm: 0.019308199656803162
Iteration 1, Residual norm: 3.102969915275309e-05
Iteration 2, Residual norm: 3.591165225103608e-09
 
Iteration 0, Residual norm: 0.019598169197754815
Iteration 1, Residual norm: 3.2099560389401873e-05
Iteration 2, Residual norm: 3.632903549363982e-09
 
Iteration 0, Residual norm: 0.019287899062661346
Iteration 1, Residual norm: 3.31972707202325e-05
Iteration 2, Residual norm: 3.6187549147290605e-09
 
Iteration 0, Residual norm: 0.019522235727412112
Iteration 1, Residual norm: 3.4250630441905934e-05
Iteration 2, Residual norm: 3.697037020571628e-09
 
Iteration 0, Residual norm: 0.019217692946893274
Iteration 1, Residual norm: 3.53379305296537

### Postprocessing

In [None]:
output_dir = "./output_navier_stokes"
os.makedirs(output_dir, exist_ok=True)

history = result.history
timesteps = history.t
velocity_history = history.q['1velocity']
pressure_history = history.q['2pressure']
temperature_history = history.q['3temperature']

# Export one VTK file for each time step
for i, t_val in enumerate(timesteps):
  # Export velocity and temperature
  point_data = {
      'velocity': jnp.pad(velocity_history[i], ((0, 0), (0, 1)), constant_values=0),
      'temperature': temperature_history[i],
  }
  meshio.Mesh(
      points=jnp.pad(coords_v, ((0, 0), (0, 1)), constant_values=0),
      cells={
          'triangle6': cells_v
      },
      point_data=point_data,
  ).write(f"{output_dir}/navier_stokes_t{i}.vtk")

  # Export pressure
  point_data = {
      'pressure': pressure_history[i],
  }
  meshio.Mesh(
      points=jnp.pad(coords_p, ((0, 0), (0, 1)), constant_values=0),
      cells={
          'triangle': cells_p
      },
      point_data=point_data,
  ).write(f"{output_dir}/navier_stokes_p_t{i}.vtk")
