# Experiment for Shearing Box Simulation

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import seaborn as sns
import xarray
import matplotlib.pyplot as plt
from scipy import stats

import sys
sys.path.insert(1, '/home/alam/Desktop/Code/jax-cfd/jax_cfd/')
import jax_cfd.sb as cfd
import utils

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
size = 256
density = 1.
viscosity = 1e-3
total_time = 20.0
outer_steps = 200

max_velocity = 7.0
cfl_safety_factor = 0.5

angular_velocity = 1.0
shear_rate = 3/2 * angular_velocity

In [3]:
# Define the physical dimensions of the simulation.
grid = cfd.grids.Grid((size, size), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))

# Construct an initial velocity field with constant x velocity
x_velocity_fn = lambda x, y: jnp.zeros_like(x)
y_velocity_fn = lambda x, y: -shear_rate * x
v0 = cfd.initial_conditions.initial_velocity_field((x_velocity_fn, y_velocity_fn), grid, 
        velocity_bc = (cfd.boundaries.shearingbox_boundary_conditions(grid.ndim, shear_rate, 0.0, 'vx'),
                       cfd.boundaries.shearingbox_boundary_conditions(grid.ndim, shear_rate, 0.0, 'vy')))

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,5))
v0x_plot = ax[0].imshow(v0[0].data.transpose(), origin='lower', cmap=sns.cm.icefire, vmin=-10, vmax=10, extent=(0,2*jnp.pi,0,2*jnp.pi))
v0y_plot = ax[1].imshow(v0[1].data.transpose(), origin='lower', cmap=sns.cm.icefire, vmin=-10, vmax=10, extent=(0,2*jnp.pi,0,2*jnp.pi))
plt.colorbar(v0x_plot, ax=ax[0])
plt.colorbar(v0y_plot, ax=ax[1])
ax[0].set_title(r'$v_{0,x}$')
ax[1].set_title(r'$v_{0,y}$')

In [5]:
# Choose a time step.
dt = cfd.equations.stable_time_step(
    max_velocity, cfl_safety_factor, viscosity, grid)

# Add in forcing
forcing = cfd.forcings.shearingbox_turbulence_forcing(grid, 
                                                 constant_magnitude=1.0,
                                                 constant_wavenumber=4.0,
                                                 linear_coefficient=-0.1,
                                                 angular_velocity=angular_velocity,
                                                 shear_rate=shear_rate,
                                                 dt=dt)

# In case for a change in viscosity term
diffuse = cfd.diffusion.diffuse

step_fn = cfd.equations.semi_implicit_navier_stokes(
        density=density, viscosity=viscosity, dt=dt, grid=grid,
        diffuse=diffuse, forcing=forcing)

In [None]:
v = step_fn(v0)