In [1]:
import h5py
import pickle
import numpy as np
from tqdm import tqdm
from typing import List, Optional, Tuple
from math import floor, log10

import jax
import jax.numpy as jnp
import jax_cfd.base as cfd
from jax_cfd.base.forcings import ForcingFn


def generate_data(
    density: float,
    viscosity: float,
    seeds: List[int],
    domain: Tuple[Tuple[float, float], Tuple[float, float]] = (
        (0, 2 * jnp.pi), (0, 2 * jnp.pi)),
    size: Tuple[int, int] = (64, 64),
    time_steps: int = 100,
    max_time: int = 25,
    max_velocity: float = 2.5,
    cfl_safety_factor: float = 0.5,
    force: Optional[str] = None,
) -> None:

    # Define the physical dimensions of the simulation
    grid = cfd.grids.Grid(size, domain=domain)

    # Choose a stable time step
    stable_dt = cfd.equations.stable_time_step(
        max_velocity, cfl_safety_factor, viscosity, grid)

    # Round down the time step to the nearest decimal to get cleaner
    # evaluation times while maintaining stability.
    decimals = -int(floor(log10(abs(stable_dt))))
    dt = floor(stable_dt * (10 ** decimals))
    dt = dt - ((max_time & 1) != (dt & 1))
    dt = dt * (10 ** -decimals)

    assert dt < stable_dt, "Chosen time step is larger than stable time step."

    # Compute the correct number of inner steps given the stable time step.
    # This code should ensure that the final evaluation time is max_time.
    outer_steps = time_steps + 1
    end_time = max_time + (max_time / (outer_steps - 1))
    inner_steps = (end_time / dt) // outer_steps

    assert float(max_time) == dt * inner_steps * time_steps, \
        "Computed time steps do not match the expected maximum time."
    
    # Determine forcing function
    if force == "kolmogorov":
        forcing_fn = cfd.forcings.kolmogorov_forcing(grid, k=4)
    elif force is None:
        force = "nonforced"
    else:
        raise AssertionError("Invalid forcing function.")

    # Save simulation details
    file_name = f"dataset/{force}_max_velocity_{max_velocity}_viscosity_{viscosity}_density_{density}"

    simulation_details = dict(
        details={
            "force": force,
            "density": density,
            "viscosity": viscosity,
            "min_seed": 0,
            "max_seed": len(seeds),
            "max_velocity": max_velocity,
            "inner_steps": inner_steps,
            "outer_steps": outer_steps,
            "dt": dt,
            "dx": grid.step[0],
            "dy": grid.step[1],
        },
        coords={
            "x": grid.axes()[0],
            "y": grid.axes()[1],
            "time": dt * inner_steps * np.arange(outer_steps)
        }
    )

    with open(f"{file_name}.pkl", "wb") as handle:
        pickle.dump(simulation_details, handle,
                    protocol=pickle.HIGHEST_PROTOCOL)

    # Run simulations
    with h5py.File(f"{file_name}.h5", "w") as hf:

        dataset = hf.create_dataset(
            "dataset", (0, 2, outer_steps, size[0], size[1]),
            maxshape=(None, 2, outer_steps, size[0], size[1]),
            chunks=(1, 2, outer_steps, size[0], size[1]),
            dtype=np.float32,
        )

        for seed in tqdm(seeds):

            # Construct a random initial velocity using `filtered_velocity_field`
            # to ensure the initial velocity is divergence free and high frequency
            # fluctuations are filtered
            v0 = cfd.initial_conditions.filtered_velocity_field(
                jax.random.PRNGKey(seed), grid, max_velocity
            )

            # Define a step function and use it to compute a trajectory
            step_fn = cfd.funcutils.repeated(
                cfd.equations.semi_implicit_navier_stokes(
                    density=density, viscosity=viscosity, dt=dt, grid=grid, forcing=forcing_fn
                ),
                steps=inner_steps
            )

            rollout_fn = jax.jit(
                cfd.funcutils.trajectory(step_fn, outer_steps)
            )
            _, trajectory = jax.device_get(rollout_fn(v0))

            # Save the trajectory data
            data = np.stack([trajectory[0].data, trajectory[1].data])
            dataset.resize(dataset.shape[0]+1, axis=0)
            dataset[-1] = data

In [2]:
density = 1.
viscosity = 1e-3
seeds = list(range(2000))
force = "kolmogorov"

generate_data(density, viscosity, seeds, force=force) 

100%|██████████| 2000/2000 [28:11<00:00,  1.18it/s]
