# JAX-CFD demo

This initial demonstration shows how to use JAX-CFD to simulate forced turbulence in 2D.

In [12]:
import jax
import jax.numpy as jnp
import numpy as np
import seaborn
import xarray
import matplotlib.pyplot as plt
from scipy import stats

import sys
sys.path.insert(1, '/home/alam/Desktop/Code/jax-cfd/jax_cfd/')
import jax_cfd.sb as cfd
import utils

In [13]:
size = 256
density = 1.
viscosity = 2e-3
hyperviscosity = 8e-3 #viscosity * (80**2/40**4)
seed = 0
inner_steps = 25
outer_steps = 200

max_velocity = 7.0
cfl_safety_factor = 0.5

burn_in = 40

constant_magnitude = 1.0
constant_wavenumber = 4.0
linear_coefficient = -0.1

In [14]:
# Define the physical dimensions of the simulation.
grid = cfd.grids.Grid((size, size), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))

# Construct a random initial velocity. The `filtered_velocity_field` function
# ensures that the initial velocity is divergence free and it filters out
# high frequency fluctuations.
v0 = cfd.initial_conditions.filtered_velocity_field(
    jax.random.PRNGKey(seed), grid, max_velocity)

# Choose a time step.
dt = cfd.equations.stable_time_step(
    max_velocity, cfl_safety_factor, viscosity, grid)

# Add in forcing
forcing = cfd.forcings.simple_turbulence_forcing(grid, 
                                                 constant_magnitude=1.0,
                                                 constant_wavenumber=4.0,
                                                 linear_coefficient=-0.1,)

# In case for a change in viscosity term
diffuse = cfd.diffusion.diffuse

step_fn = cfd.equations.semi_implicit_navier_stokes(
        density=density, viscosity=viscosity, dt=dt, grid=grid,
        diffuse=diffuse, forcing=forcing)

v0 = cfd.funcutils.repeated(step_fn, steps=int(burn_in/dt))(v0)

# Define a step function and use it to compute a trajectory.
step_fn = cfd.funcutils.repeated(step_fn,steps=inner_steps)
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
_, trajectory = jax.device_get(rollout_fn(v0))

In [15]:
# load into xarray for visualization and analysis
ds = xarray.Dataset(
    {
        'u': (('time', 'x', 'y'), trajectory[0].data),
        'v': (('time', 'x', 'y'), trajectory[1].data),
    },
    coords={
        'x': grid.axes()[0],
        'y': grid.axes()[1],
        'time': dt * inner_steps * np.arange(outer_steps)
    }
)

In [None]:
ds

In [None]:
def vorticity(ds):
  return (ds.v.differentiate('x') - ds.u.differentiate('y')).rename('vorticity')

(ds.pipe(vorticity).thin(time=20)
 .plot.imshow(col='time', cmap=seaborn.cm.icefire, robust=True, col_wrap=5));

In [18]:
v = np.transpose(np.array([trajectory[0].data,trajectory[1].data]),(1,0,2,3))

In [None]:
energy = utils.compute_average_ke_over_time(v)
plt.plot(energy)

## Testing out hyperviscosity

In [20]:
# Construct a random initial velocity. The `filtered_velocity_field` function
# ensures that the initial velocity is divergence free and it filters out
# high frequency fluctuations.
v0 = cfd.initial_conditions.filtered_velocity_field(
    jax.random.PRNGKey(seed), grid, max_velocity)

hyperdiffuse = cfd.diffusion.hyperdiffuse

step_fn = cfd.equations.semi_implicit_navier_stokes(
        density=density, viscosity=hyperviscosity, dt=dt, grid=grid,
        diffuse=hyperdiffuse, forcing=forcing)

v0 = cfd.funcutils.repeated(step_fn, steps=int(burn_in/dt))(v0)

# Define a step function and use it to compute a trajectory.
step_fn = cfd.funcutils.repeated(step_fn,steps=inner_steps)
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
_, hp_trajectory = jax.device_get(rollout_fn(v0))

hp_v = np.transpose(np.array([hp_trajectory[0].data,hp_trajectory[1].data]),(1,0,2,3))

In [None]:
# load into xarray for visualization and analysis
ds = xarray.Dataset(
    {
        'u': (('time', 'x', 'y'), hp_trajectory[0].data),
        'v': (('time', 'x', 'y'), hp_trajectory[1].data),
    },
    coords={
        'x': grid.axes()[0],
        'y': grid.axes()[1],
        'time': dt * inner_steps * np.arange(outer_steps)
    }
)

def vorticity(ds):
  return (ds.v.differentiate('x') - ds.u.differentiate('y')).rename('vorticity')

(ds.pipe(vorticity).thin(time=20)
 .plot.imshow(col='time', cmap=seaborn.cm.icefire, robust=True, col_wrap=5));

In [None]:
spectrum = utils.compute_energy_spectrum(v, -1)
hp_spectrum = utils.compute_energy_spectrum(hp_v, -1)
slope, intercept, _, _, _ = stats.linregress(np.log(np.arange(4,11)),np.log(spectrum[4:11]))
fit = 10 * np.exp(intercept)*np.arange(size)**slope
plt.plot(spectrum[:int(1/3*size)],label='viscosity')
plt.plot(hp_spectrum[:int(1/3*size)],label='hyperviscosity')
#plt.plot(spectrum)
plt.plot(np.arange(4,11), fit[4:11], 'k--', label=f'k^{slope:.2f}')
plt.yscale('log')
plt.xscale('log')
plt.ylim(1e-10)
plt.legend()