In [None]:
import os

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import jax_cfd.base as cfd
from jax_cfd.base.grids import AlignedArray
import numpy as np
import seaborn
import xarray
import pickle
from tqdm import tqdm
from glob import glob
from pathlib import Path

In [None]:
config = dict(
    re = 4000,
    size = 256,
    density = 1,
    burn_in = 60*25,
    max_velocity = 10,
    peak_wavenumber = 4,
    cfl_safety_factor = 0.5,
    constant_magnitude = 4,
    constant_wavenumber = 4,
    linear_coefficient = -0.1,
)
config['viscosity'] = 1 / config['re']
config['grid'] = cfd.grids.Grid((config['size'], config['size']), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))  
config['dt'] = cfd.equations.stable_time_step(
    config['max_velocity'], config['cfl_safety_factor'], config['viscosity'], config['grid']
)

def vorticity(ds):
    return (ds.v.differentiate('x') - ds.u.differentiate('y')).rename('vorticity')

In [None]:
def run_and_save_simulation(seed, config, folder=None):
    key = jax.random.PRNGKey(seed)
    
    # SET PARAMETERS
    re = config['re']
    size = config['size']
    density = config['density']
    viscosity = config['viscosity']
    inner_steps = config['inner_steps']
    outer_steps = config['outer_steps']
    burn_in = config['burn_in']
    max_velocity = config['max_velocity']
    peak_wavenumber = config['peak_wavenumber']
    cfl_safety_factor = config['cfl_safety_factor']
    grid = config['grid']   
    dt = config['dt']
    
    # INIT/LOAD VELOCITY
    filename = f"inputs/v0_{size}_{max_velocity}_{peak_wavenumber}_{seed}.pkl"
    try:
        with open(filename, "rb") as f:
            v0 = pickle.load(f)
        print("Loaded from file")
    except:
        if USE_DP:
            v0 = cfd.initial_conditions.filtered_velocity_field(
                key, 
                grid, 
                max_velocity, 
                peak_wavenumber
            )
            with open(filename, "wb") as f:
                pickle.dump(v0, f)
            print("Initialized and saved to file")

        else:
            raise ValueError

    # DEFINE STEP FUNCTION
    forcing_fn = cfd.forcings.simple_turbulence_forcing(
        grid=grid,
        constant_magnitude = config['constant_magnitude'],
        constant_wavenumber = config['constant_wavenumber'],
        linear_coefficient = config['linear_coefficient'],
        forcing_type='kolmogorov'
    )
    step_fn_single = cfd.equations.semi_implicit_navier_stokes(density=density, viscosity=viscosity, dt=dt, grid=grid, forcing=forcing_fn)
    step_fn = cfd.funcutils.repeated(step_fn_single, steps=inner_steps)
    step_fn_single = jax.jit(step_fn_single)
    step_fn = jax.jit(step_fn)
    

    # BURN IN
    for i in range(burn_in):
        v0 = step_fn_single(v0)
        
    if not USE_DP:
        for i in range(len(v0)):
            v0[i].data = v0[i].data.astype('bfloat16')
            # v0[i].data = v0[i].data.astype('float32')

        
    # RUN
    saved_u = []
    saved_v = []
    output = v0
    for i in tqdm(range(outer_steps)):
        output = step_fn(output)
        saved_u.append(output[0].data.astype('float32'))
        saved_v.append(output[1].data.astype('float32'))
            

    # SAVE
    saved_u = jax.device_get(saved_u)
    saved_v = jax.device_get(saved_v)
    saved_sim = [jnp.stack([u, v], axis=-1) for u, v in zip(saved_u, saved_v)]
    
    savedir = os.path.join(f'/data/akash/decay_turbulence_T{T}_N{N}', 
                           'DP' if USE_DP else 'SP', 
                           folder if folder else 'temp',
                           f'sim_{seed:03d}'
                          )
    
    Path(savedir).mkdir(parents=True, exist_ok=True)        
    for i, snapshot in enumerate(saved_sim):
        savepath = os.path.join(savedir, f'{i:04d}.npy')
        jnp.save(savepath, snapshot)

In [None]:
USE_DP = True

In [None]:
T = 6
N = 50

In [None]:
config['inner_steps'] = T
config['outer_steps'] = N+1
for seed in range(200):
    run_and_save_simulation(seed, config, folder='train')
for seed in range(200, 210):
    run_and_save_simulation(seed, config, folder='test')
    
config['inner_steps'] = T
config['outer_steps'] = 1500
for seed in range(211, 213):
    run_and_save_simulation(seed, config, folder='val')

config['outer_steps'] = 1000
for seed in range(213, 218):
    run_and_save_simulation(seed, config, folder='longtest')

In [None]:
# Generate SP dataset forked from every DP point
# First generate DP, then reload with SP

config['inner_steps'] = T
config['timespan'] = N

seed=0
key = jax.random.PRNGKey(seed)

v0 = cfd.initial_conditions.filtered_velocity_field(
    key, 
    config['grid'], 
    config['max_velocity'], 
    config['peak_wavenumber']
)

forcing_fn = cfd.forcings.simple_turbulence_forcing(
        grid=config['grid'],
        constant_magnitude = config['constant_magnitude'],
        constant_wavenumber = config['constant_wavenumber'],
        linear_coefficient = config['linear_coefficient'],
        forcing_type='kolmogorov'
    )

step_fn = cfd.funcutils.repeated(
    cfd.equations.semi_implicit_navier_stokes(
        density=config['density'], viscosity=config['viscosity'], dt=config['dt'], grid=config['grid'], forcing=forcing_fn), 
    steps=config['inner_steps']
)
step_fn = jax.jit(step_fn)


dp_sims = sorted(glob(f'/data/akash/decay_turbulence_T{T}_N{N}/DP/val/sim_*')) + \
          sorted(glob(f'/data/akash/decay_turbulence_T{T}_N{N}/DP/test/sim_*'))

for i, sim_folder in enumerate(dp_sims):
    dp_files = sorted(glob(os.path.join(sim_folder, '*.npy')))

    for f in tqdm(dp_files[:(len(dp_files) - config['timespan'])]):
        dp = np.load(f)
        for i in range(len(v0)):
            v0[i].data = dp[:,:,i].astype('bfloat16')


        outer_steps = config['timespan']
        output = v0
        saved_u = []
        saved_v = []

        for i in range(outer_steps):
            output = step_fn(output)
            saved_u.append(output[0].data.astype('float32'))
            saved_v.append(output[1].data.astype('float32'))


        saved_sim = [jnp.stack([u, v], axis=-1) for u, v in zip(saved_u, saved_v)]
        snapshots = jnp.concatenate(saved_sim, axis=-1)
        Path(sim_folder.replace('DP', 'SP')).mkdir(parents=True, exist_ok=True)   
        savepath = f.replace('DP', 'SP')
        jnp.save(savepath, snapshots)

# Long run

In [None]:
config['inner_steps'] = T

forcing_fn = cfd.forcings.simple_turbulence_forcing(
        grid=config['grid'],
        constant_magnitude = config['constant_magnitude'],
        constant_wavenumber = config['constant_wavenumber'],
        linear_coefficient = config['linear_coefficient'],
        forcing_type='kolmogorov'
    )

step_fn = cfd.funcutils.repeated(
    cfd.equations.semi_implicit_navier_stokes(
        density=config['density'], viscosity=config['viscosity'], dt=config['dt'], grid=config['grid'], forcing=forcing_fn), 
    steps=config['inner_steps']
)
step_fn = jax.jit(step_fn)

sim_folders = glob(f'/data/akash/decay_turbulence_T{T}_N{N}/DP/longtest/sim_*/')
for sim_folder in sim_folders:
    Path(sim_folder.replace('DP', 'SP')).mkdir(parents=True, exist_ok=True)
    files = sorted(glob(sim_folder + '*.npy'))
    
    v0_arr = jnp.load(files[0]).astype(jnp.bfloat16)
    assert(v0_arr.dtype==jnp.bfloat16)
    
    jnp.save(files[0].replace('DP', 'SP'), v0_arr.astype('float32'))
    v0 = (AlignedArray(data=v0_arr[:, :, 0], offset=(1.0, 0.5)), 
            AlignedArray(data=v0_arr[:, :, 1], offset=(0.5, 1.0))) 

    outer_steps = len(files) - 1
    output = v0

    for i in tqdm(range(outer_steps)):
        output = step_fn(output)
        u = output[0].data.astype('float32')
        v = output[1].data.astype('float32')
        snapshot = jnp.stack([u, v], axis=-1)
        jnp.save(files[i+1].replace('DP', 'SP'), snapshot)

# Visualize

In [None]:
saved_u, saved_v = run_and_save_simulation(0, config)
saved_u = jnp.stack(saved_u)
saved_v = jnp.stack(saved_v)

In [None]:
# load into xarray for visualization and analysis
ds = xarray.Dataset(
    {
        'u': (('time', 'x', 'y'), saved_u),
        'v': (('time', 'x', 'y'), saved_v),
    },
)

In [None]:
# SP

(ds.tail(time=2000).thin(time=2000//10).pipe(vorticity)
 .plot.imshow(col='time', cmap=seaborn.cm.icefire, robust=True, col_wrap=5));