<a href="https://colab.research.google.com/github/JA4S/JAX-AMR/blob/main/examples/jax_amr_basic_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install JAX-AMR and import relevant libraries

In [None]:
# Copyright © 2025 Haocheng Wen, Faxuan Luo
# SPDX-License-Identifier: MIT

!pip install git+https://github.com/JA4S/JAX-AMR.git
!wget https://raw.githubusercontent.com/JA4S/JAX-AMR/main/examples/simple_solver.py

In [None]:
from jaxamr import amr, amraux
import simple_solver as solver

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

# Set computation parameters

In [None]:
Lx = 1.0
Ly = 1.0

nx = 200
ny = 200

dx = Lx/nx
dy = Ly/ny

base_grid = {'Lx':Lx,'Ly':Ly,'Nx':nx,'Ny':ny}

n_block = [
    [1, 1],  # Level 0
    [20, 20], # Level 1
    [2, 2],  # Level 2
    [2, 2],  # Level 3
    [2, 2]   # Level 4
] # x-direction, y-direction

template_node_num = 1

buffer_num = 2

refinement_tolerance = {
    'density': 5.0,
    'velocity': 0.5
}

amr_config = {'base_grid':base_grid,
        'n_block':n_block,
        'template_node_num':template_node_num,
        'buffer_num':buffer_num,
        'refinement_tolerance':refinement_tolerance
}

amr.set_amr(amr_config)

dx = [dx] # Grid size in refinement levels
dy = [dy]
for i, (bx, by) in enumerate(n_block[1:], 1):
    dx.append(Lx/nx / (2.0**i))
    dy.append(Ly/ny / (2.0**i))

# Initilization

In [None]:
X, Y, U = solver.initialize(nx, ny)

blk_data0 = jnp.array([U])

blk_info0 = {
      'number': 1,
      'index': jnp.array([0, 0, 0]),
      'glob_index': jnp.array([[0, 0]]),
      'neighbor_index': jnp.array([[-1, -1, -1, -1]])
        }

# AMR main loop

In [None]:
dt = 0.00006 * 8  # time step

nt = 30 # computation step

amr_update_step = 2 # AMR update step

amr_initialized = False

for step in tqdm(range(nt), desc="Progress", unit="step"):

    if amr_initialized == False :

        blk_data1, blk_info1, max_blk_num1 = amr.initialize(1, blk_data0, blk_info0, 'density', dx[1], dy[1])
        blk_data2, blk_info2, max_blk_num2 = amr.initialize(2, blk_data1, blk_info1, 'density', dx[2], dy[2])
        blk_data3, blk_info3, max_blk_num3 = amr.initialize(3, blk_data2, blk_info2, 'density', dx[3], dy[3])

        amr_initialized = True

    elif (step % amr_update_step == 0):
        blk_data1, blk_info1, max_blk_num1 = amr.update(1, blk_data0, blk_info0, 'density', dx[1], dy[1], blk_data1, blk_info1, max_blk_num1)
        blk_data2, blk_info2, max_blk_num2 = amr.update(2, blk_data1, blk_info1, 'density', dx[2], dy[2], blk_data2, blk_info2, max_blk_num2)
        blk_data3, blk_info3, max_blk_num3 = amr.update(3, blk_data2, blk_info2, 'density', dx[3], dy[3], blk_data3, blk_info3, max_blk_num3)

    '''Crossover advance'''
    for _ in range(2):
        for _ in range(2):
            for _ in range(2):
                blk_data3 = solver.rk2(3, blk_data2, dx[3], dy[3], dt/8.0, blk_data3, blk_info3)
            blk_data2 = solver.rk2(2, blk_data1, dx[2], dy[2], dt/4.0, blk_data2, blk_info2)
        blk_data1 = solver.rk2(1, blk_data0, dx[1], dy[1], dt/2.0, blk_data1, blk_info1)
    blk_data0 = solver.rk2_L0(blk_data0, dx[0], dy[0], dt)


    '''Synchronous advance'''
    #blk_data3 = solver.rk2(3, blk_data2, dx[3], dy[3], dt/8.0, blk_data3, blk_info3)
    #blk_data2 = solver.rk2(2, blk_data1, dx[2], dy[2], dt/8.0, blk_data2, blk_info2)
    #blk_data1 = solver.rk2(1, blk_data0, dx[1], dy[1], dt/8.0, blk_data1, blk_info1)
    #blk_data0 = solver.rk2_L0(blk_data0, dx[0], dy[0], dt/8.0)


    blk_data2 = amr.interpolate_fine_to_coarse(3, blk_data2, blk_data3, blk_info3)
    blk_data1 = amr.interpolate_fine_to_coarse(2, blk_data1, blk_data2, blk_info2)
    blk_data0 = amr.interpolate_fine_to_coarse(1, blk_data0, blk_data1, blk_info1)

# Result Visualization

In [None]:
# Density Contour
plt.figure(figsize=(10, 8))
ax = plt.gca()

component = 0
vrange = (0, 1)
fig = amraux.plot_block_data(blk_data0[:, component], blk_info0, ax, vrange) # Level 0
fig = amraux.plot_block_data(blk_data1[:, component], blk_info1, ax, vrange) # Level 1
fig = amraux.plot_block_data(blk_data2[:, component], blk_info2, ax, vrange) # Level 2
fig = amraux.plot_block_data(blk_data3[:, component], blk_info3, ax, vrange) # Level 3

plt.colorbar(fig, ax=ax, label='Density')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.axis('equal')
plt.show()

# Refinement Level
plt.figure(figsize=(10, 8))
ax = plt.gca()

component = 0
vrange = (0, 3)
fig = amraux.plot_block_data(0*jnp.ones_like(blk_data0[:, component]), blk_info0, ax, vrange) # Level 0
fig = amraux.plot_block_data(1*jnp.ones_like(blk_data1[:, component]), blk_info1, ax, vrange) # Level 1
fig = amraux.plot_block_data(2*jnp.ones_like(blk_data2[:, component]), blk_info2, ax, vrange) # Level 2
fig = amraux.plot_block_data(3*jnp.ones_like(blk_data3[:, component]), blk_info3, ax, vrange) # Level 3

plt.colorbar(fig, ax=ax, label='Refinement Level')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.axis('equal')
plt.show()