# 2D Incompressible Forced HIT

This initial demonstration shows how to use JAX-CFD to simulate forced turbulence in 2D.

Note: Remember to switch on pressure projection

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import seaborn
import xarray
import matplotlib.pyplot as plt
from scipy import stats

import sys
sys.path.insert(1, '/home/alam/Desktop/Code/jax-cfd/jax_cfd/')
import jax_cfd.sb as cfd
import utils

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
size = 256
density = 1.
viscosity = 1e-3
hyperviscosity = None #viscosity * ((size/2)**2/(12/128*size/2)**4) #viscosity * (10**2/10**4) 
seed = 0
total_time = 20.0
outer_steps = 200

max_velocity = 7.0
cfl_safety_factor = 0.5

burn_in = 40

constant_magnitude = 1.0
constant_wavenumber = 4.0
linear_coefficient = -0.1

hyperviscosity

In [None]:
# Define the physical dimensions of the simulation.
grid = cfd.grids.Grid((size, size), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))

# Construct a random initial velocity. The `filtered_velocity_field` function
# ensures that the initial velocity is divergence free and it filters out
# high frequency fluctuations.
v0 = cfd.initial_conditions.filtered_velocity_field(
    jax.random.PRNGKey(seed), grid, max_velocity)

# Choose a time step.
dt = cfd.equations.stable_time_step(
    max_velocity, cfl_safety_factor, viscosity, grid)
inner_steps = total_time // (dt * outer_steps)

# Add in forcing
forcing = cfd.forcings.simple_turbulence_forcing(grid, 
                                                 constant_magnitude=constant_magnitude,
                                                 constant_wavenumber=constant_wavenumber,
                                                 linear_coefficient=linear_coefficient,)

# In case for a change in viscosity term
diffuse = cfd.diffusion.diffuse

step_fn = cfd.equations.semi_implicit_navier_stokes(
        density=density, viscosity=viscosity, dt=dt, grid=grid,
        diffuse=diffuse, forcing=forcing)

v0 = cfd.funcutils.repeated(step_fn, steps=int(burn_in/dt))(v0)

# Define a step function and use it to compute a trajectory.
step_fn = cfd.funcutils.repeated(step_fn,steps=inner_steps)
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
%time _, trajectory = jax.device_get(rollout_fn(v0))

In [None]:
v0

In [5]:
# load into xarray for visualization and analysis
ds = xarray.Dataset(
    {
        'u': (('time', 'x', 'y'), trajectory[0].data),
        'v': (('time', 'x', 'y'), trajectory[1].data),
    },
    coords={
        'x': grid.axes()[0],
        'y': grid.axes()[1],
        'time': dt * inner_steps * np.arange(outer_steps)
    }
)

In [None]:
ds

In [None]:
def vorticity(ds):
  return (ds.v.differentiate('x') - ds.u.differentiate('y')).rename('vorticity')

(ds.pipe(vorticity).thin(time=20).transpose()
 .plot.imshow(col='time', cmap=seaborn.cm.icefire, robust=True, col_wrap=5));

In [8]:
v = np.transpose(np.array([trajectory[0].data,trajectory[1].data]),(1,0,2,3))

In [None]:
energy = utils.compute_average_ke_over_time(v)
plt.plot(energy)

## Testing out hyperviscosity

In [10]:
if hyperviscosity is not None:
    # Construct a random initial velocity. The `filtered_velocity_field` function
    # ensures that the initial velocity is divergence free and it filters out
    # high frequency fluctuations.
    v0 = cfd.initial_conditions.filtered_velocity_field(
        jax.random.PRNGKey(seed), grid, max_velocity)

    hyperdiffuse = cfd.diffusion.hyperdiffuse

    step_fn = cfd.equations.semi_implicit_navier_stokes(
            density=density, viscosity=hyperviscosity, dt=dt, grid=grid,
            diffuse=hyperdiffuse, forcing=forcing)

    v0 = cfd.funcutils.repeated(step_fn, steps=int(burn_in/dt))(v0)

    # Define a step function and use it to compute a trajectory.
    step_fn = cfd.funcutils.repeated(step_fn,steps=inner_steps)
    rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
    _, hp_trajectory = jax.device_get(rollout_fn(v0))

    hp_v = np.transpose(np.array([hp_trajectory[0].data,hp_trajectory[1].data]),(1,0,2,3))

In [11]:
if hyperviscosity is not None:
    hp_energy = utils.compute_average_ke_over_time(hp_v)
    plt.plot(hp_energy)

In [12]:
if hyperviscosity is not None:
    # load into xarray for visualization and analysis
    ds = xarray.Dataset(
        {
            'u': (('time', 'x', 'y'), hp_trajectory[0].data),
            'v': (('time', 'x', 'y'), hp_trajectory[1].data),
        },
        coords={
            'x': grid.axes()[0],
            'y': grid.axes()[1],
            'time': dt * inner_steps * np.arange(outer_steps)
        }
    )

    def vorticity(ds):
        return (ds.v.differentiate('x') - ds.u.differentiate('y')).rename('vorticity')

    (ds.pipe(vorticity).thin(time=20).transpose()
    .plot.imshow(col='time', cmap=seaborn.cm.icefire, robust=True, col_wrap=5));

In [None]:
spectrum = utils.compute_energy_spectrum(v)
slope, intercept, _, _, _ = stats.linregress(np.log(np.arange(4,9)),np.log(spectrum[4:9]))
fit = 10 * np.exp(intercept)*np.arange(size)**slope
plt.plot(spectrum[:int(1/2*size)],'0.8')
plt.plot(spectrum[:int(1/3*size)],'b',label='viscosity')
plt.title(fr'$\nu={viscosity}$')
plt.plot(np.arange(4,9), fit[4:9], 'b--', label=f'k^{slope:.2f}')

if hyperviscosity is not None:
    hp_spectrum = utils.compute_energy_spectrum(hp_v)
    hp_slope, hp_intercept, _, _, _ = stats.linregress(np.log(np.arange(4,9)),np.log(hp_spectrum[4:9]))
    hp_fit = 20 * np.exp(hp_intercept)*np.arange(size)**hp_slope
    plt.plot(hp_spectrum[:int(1/2*size)],'0.8')
    plt.plot(hp_spectrum[:int(1/3*size)],'r',label='hyperviscosity')
    plt.plot(np.arange(4,9), hp_fit[4:9], 'r--', label=f'k^{hp_slope:.2f}')
    plt.title(fr'$\nu={viscosity}$, $\nu_h={hyperviscosity:.4f}$')

plt.yscale('log')
plt.xscale('log')
plt.ylim(1e-12)
plt.legend()
#plt.savefig(f'spectrum{size}.png', dpi=300)