<a href="https://colab.research.google.com/github/JA4S/JANC/blob/main/examples/janc_basic_example4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install JANC and import relevant libraries

In [None]:
# Copyright © 2025 Haocheng Wen, Faxuan Luo
# SPDX-License-Identifier: MIT

!pip install git+https://github.com/JA4S/JANC.git

In [None]:
from janc.thermodynamics import thermo
from janc.solver import solver
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# set JAX to use GPU
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'gpu')

# Example : 1D Shock-tube Test

In [None]:
Lx = 20.0
Ly = 1.0

nx = 400
ny = 5

dx = Lx/nx
dy = Ly/ny
dt = 0.2*dx/800

thermo_config = {'is_detailed_chemistry':False,
          'species':['O2','N2'],
          'thermo_model':'constant_gamma',
          'gamma':1.4}

boundary_config = {'left_boundary':'neumann',
           'right_boundary':'neumann',
           'bottom_boundary':'periodic',
           'up_boundary':'periodic'}

source_config = {'self_defined_source_terms':None}

advance_one_step, _ = solver.set_solver(thermo_config,boundary_config,source_config)

In [None]:
def initial_conditions(UL,UR):
    rhoL, rhouL, pL = UL
    rhoR, rhouR, pR = UR

    rhoE_L = pL/(1.4-1)
    rhoE_R = pR/(1.4-1)

    rho_init = rhoL*jnp.ones((nx,ny))
    rhou_init = rhouL*jnp.ones((nx,ny))
    rhov_init = jnp.zeros((nx,ny))
    E_init = rhoE_L*jnp.ones((nx,ny))

    rho_init = rho_init.at[round(nx/2):,:].set(rhoR)
    rhou_init = rhou_init.at[round(nx/2):,:].set(rhouR)
    E_init = E_init.at[round(nx/2):,:].set(rhoE_R)
    rhoY_init = rho_init[None,:,:]*(jnp.array([0.232])[:,None,None])

    U_init = jnp.concatenate([rho_init[None,:,:],rhou_init[None,:,:],rhov_init[None,:,:],E_init[None,:,:],rhoY_init],axis=0)

    R = thermo.get_R(jnp.tile(jnp.array([0.232])[:,None,None],(1,nx,ny)))
    T = pL/(rhoL*R)
    T = T.at[:,round(nx/2):,:].set(pR/(rhoR*R[:,round(nx/2):,:]))
    gamma = jnp.full_like(T,1.4)
    aux_init = jnp.concatenate([gamma,T],axis=0)
    return U_init,aux_init

UL = (1,0,1e5)
UR = (0.125,0,1e4)
U, aux = initial_conditions(UL,UR)
field = jnp.concatenate([U,aux],axis=0)

In [None]:
nt = 200*round(nx/100)
for step in tqdm(range(nt), desc="progress", unit="step"):
      field = advance_one_step(field,dx,dy,dt)

In [None]:
U, aux = field[0:-2],field[-2:]
rho = U[0,:,0]
u = U[1,:,0]/rho
p = (U[3,:,0]-0.5*rho*(u**2))*(1.4-1)
x = jnp.linspace(-10,10,nx)

plt.figure()
plt.plot(x, rho, '-o', markersize=4)
plt.xlabel('x')
plt.ylabel('rho')
plt.show()

plt.figure()
plt.plot(x, u, '-o', markersize=4)
plt.xlabel('x')
plt.ylabel('rho')
plt.show()

plt.figure()
plt.plot(x, p, '-o', markersize=4)
plt.xlabel('x')
plt.ylabel('p')
plt.show()