<a href="https://colab.research.google.com/github/JA4S/JAX-AMR/blob/main/examples/jax_eb_basic_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Install JAX-AMR and import relevant libraries

In [None]:
# Copyright © 2025 Haocheng Wen
# SPDX-License-Identifier: MIT

!pip install git+https://github.com/JA4S/JAX-AMR.git
!wget https://raw.githubusercontent.com/JA4S/JAX-AMR/main/examples/simple_solver_eb.py

In [2]:
from jaxeb import eb
import simple_solver_eb as solver

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

#Set computation parameters

In [3]:
x_min, y_min = 0.0, 0.0
Lx = 2.0
Ly = 1.0

nx = 400
ny = 200

dx = Lx / nx
dy = Ly / ny

# Generate vertices of a circle
center_x = 1.25
center_y = 0.5
radius = 0.2
num_points = 72
angles = jnp.linspace(2 * jnp.pi, 0, num_points, endpoint=False)
x_points = center_x + radius * jnp.cos(angles)
y_points = center_y + radius * jnp.sin(angles)

polygon_vertices = jnp.stack([x_points, y_points], axis=1)

if jnp.allclose(polygon_vertices[0], polygon_vertices[-1]):
    polygon_vertices = polygon_vertices[:-1]
    print('Warning: The last vertice is coincided with the fisrt one. The last vertice is deleted.')

#Initialize EB

In [None]:
cell_info = eb.initialize(polygon_vertices, x_min, y_min, nx, ny, dx, dy, visual=True)

#initialize flow field and compute

In [None]:
X, Y, U = solver.initialize(x_min, y_min, Lx, Ly, nx, ny, cell_info['cell_type'])

cfl = 0.3
nt = 300 # computation step

for step in tqdm(range(nt), desc="Progress", unit="step"):

    U = solver.rk2(U, dx, dy, cfl, cell_info)

#Result Visualization

In [None]:
fig, ax = plt.subplots(figsize=(16, 8))
ax.set_aspect('equal', adjustable='box')
ax.set_xlabel("X")
ax.set_ylabel("Y")
plt.pcolormesh(X, Y, U[0,:,:], cmap='jet')
poly_x, poly_y = jnp.array(polygon_vertices)[:,0], jnp.array(polygon_vertices)[:,1]
plt.fill(poly_x, poly_y, color='black', alpha=1)
plt.colorbar()
plt.show()