In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import jax_cfd.base as cfd
import jax_cfd.base.grids as grids
import jax_cfd.spectral as spectral
from jax_cfd.spectral import utils as spectral_utils

from jax import vmap, jit

# Decaying Turbulence

In [2]:
%%time 
# physical parameters
viscosity = 1e-2
max_velocity = 3
grid = grids.Grid((1024, 1024), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
# dt = cfd.equations.stable_time_step(max_velocity, .5, viscosity, grid)
dt = 5e-4

# setup step function using crank-nicolson runge-kutta order 4
smooth = True # use anti-aliasing 
step_fn = spectral.time_stepping.crank_nicolson_rk4(
    spectral.equations.NavierStokes2D(viscosity, grid, smooth=smooth), dt)

# run the simulation up until time 25.0 but only save 10 frames for visualization
# final_time = 10.0
outer_steps = 201
inner_steps = 100

trajectory_fn = cfd.funcutils.trajectory(
    cfd.funcutils.repeated(step_fn, inner_steps), outer_steps)

# create an initial velocity field and compute the fft of the vorticity.
# the spectral code assumes an fft'd vorticity for an initial state
v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 2)
vorticity0 = cfd.finite_differences.curl_2d(v0).data
vorticity_hat0 = jnp.fft.rfftn(vorticity0)

_, trajectory = trajectory_fn(vorticity_hat0)


CPU times: user 28.9 s, sys: 21.2 s, total: 50 s
Wall time: 51.4 s


In [3]:
w = jnp.fft.irfftn(trajectory, axes=(1,2))

velocity_solve = spectral_utils.vorticity_to_velocity(grid)

u_hat, v_hat = vmap(velocity_solve)(trajectory)
u = vmap(jnp.fft.irfftn)(u_hat)
v = vmap(jnp.fft.irfftn)(v_hat)

x = jnp.arange(grid.shape[0]) * 2 * jnp.pi / grid.shape[0]
y = jnp.arange(grid.shape[0]) * 2 * jnp.pi / grid.shape[0]
t = dt * jnp.arange(outer_steps) * inner_steps

u0 = u[0, :, :]
v0 = v[0, :, :]
w0 = w[0, :, :]

data = {'w': w, 'u':u, 'v':v, 'u0':u0, 'v0':v0, 'w0':w0, 'x':x, 'y': y, 't':t, 'viscosity':viscosity}
# jnp.save('ns_tori_high_res.npy', data)

In [4]:
# Downsampling
res = 8
data = {'u': u[:, ::res, ::res], 
        'v': v[:, ::res, ::res], 
        'w': w[:, ::res, ::res], 
        'x': x[::res], 
        'y': y[::res], 
        't':t, 
        'viscosity':viscosity}

jnp.save('ns_tori.npy', data)