# Working with Jaxfluids

(Introductory) notebook on working with [Jaxfluids](https://github.com/adopt-opt/jaxfluids) in the context of design optimization. For installation instructions of the underlying JAX package, please see their [GitHub repository](https://github.com/google/jax).

## Manipulation of JSON

For the manipulation of the JSON files we want to use [PyJSON](https://github.com/niyoh120/pyjson), with which we can then alter the entries of our JSON files.

In [4]:
import json

Storing the path to our two most important JSON files into variables

In [27]:
case_setup = './cylinderflow.json'
num_setup = './numerical_setup.json'

Altering the JSON file

In [26]:
jsonFile = open(case_setup, "r") # Open the JSON file for reading
setup = json.load(jsonFile) # Read the JSON into the buffer
jsonFile.close() # Close the JSON file

## Working with buffered content
setup["domain"]["x"]["cells"] = 75
setup["domain"]["x"]["range"] = [-2, 4]
setup["domain"]["y"]["cells"] = 50
setup["domain"]["y"]["range"] = [-2, 2]

## Save our changes to JSON file
jsonFile = open(case_setup, "w+")
jsonFile.write(json.dumps(setup))
jsonFile.close()

The same approach can be replicated for other tasks such as changing the lambda function determining the shape of the body inside of the domain for design optimization purposes

```json
{
"initial_condition": {
        "primes":{
            "rho": 1.0,
            "u": 0.0,
            "v": 0.0,
            "w": 0.0,
            "p": 1.0
        },
        "levelset": "lambda x,y: - 0.1 + jnp.sqrt(x**2 + y**2)"
    }
}
```

## Ahead-of-Time (AOT) Compilation with JAX



In [28]:
from jaxfluids import InputReader, Initializer, SimulationManager

Setting up the simulation

In [30]:
input_reader = InputReader("cylinderflow.json", "numerical_setup.json")
initializer  = Initializer(input_reader)
sim_manager  = SimulationManager(input_reader)

Ahead-of-Time Compiling the Simulation

In [None]:
from jaxfluids import InputReader, Initializer, SimulationManager
import jax  # For Jit'ing, and AOT Compilation

In [None]:
# Simulation Setup
input_reader = InputReader("cylinderflow.json", "numerical_setup.json")
initializer  = Initializer(input_reader)
sim_manager  = SimulationManager(input_reader)

Helper function to allow for the templating of the main setup while we are only changing the shape of the body in the flow

In [None]:
# Pre-Shock Conditions
gamma_L, gamma_R = 1.4
rho_R = p_R = 1.0
a_R   = np.sqrt(gamma_R * p_R / rho_R)
u_R   = 0.0
M_R   = u_R / a_R

def wrapper_fun(M_S: float = 2.0):  # function needs to be specialized to the purpose here
    traj_length = 5
    time_step   = 1e-2
    res = case_dict["nx"]

    dx = 1.0 / res
    x_cf   = jnp.linspace(0, 1, num=res+1)
    x_cc = 0.5 * (x_cf[1:] + x_cf[:-1])

    # POST SHOCK RANKINE HUGONIOT CONDITIONS
    p_L   = p_R * ( 1/(gamma_L + 1) * (gamma_R * (M_R - M_S)**2 + 1) + jnp.sqrt( (1/(gamma_L + 1) * (gamma_R * (M_R - M_S)**2 + 1))**2 - (gamma_L-1)/(gamma_L+1) * ((M_R-M_S)**2 * 2 * gamma_R/(gamma_R - 1) - 1) ))
    rho_L = rho_R *  (gamma_R - 1)/(gamma_L - 1) * ( p_L / p_R + (gamma_L - 1)/ (gamma_L + 1) ) / ( p_L / p_R * (gamma_R - 1) / (gamma_L + 1) + (gamma_R + 1) / (gamma_L + 1) )
    u_L   = a_R * ( rho_R/rho_L * (M_R - M_S) + M_S )

    # INTIAL BUFFER
    prime_init      = jnp.zeros((1, 5, res, 1, 1))
    prime_init      = prime_init.at[0,0,:,0,0].set(jnp.where(x_cc > 0.5, rho_R, rho_L))
    prime_init      = prime_init.at[0,1,:,0,0].set(jnp.where(x_cc > 0.5, u_R, u_L))
    prime_init      = prime_init.at[0,4,:,0,0].set(jnp.where(x_cc > 0.5, p_R, p_L))
    levelset_init   = None

    # FORWARD SIMULATION
    data_series, _ = sim_manager.feed_forward(
        prime_init,
        levelset_init,
        traj_length,
        time_step,
        0.0, 1, None, None)
    data_series = data_series[0]

    # COMPUTE SCALAR OUTPUT QUANTITY
    entropy = data_series[:,4] / data_series[:,0]**gamma_L
    total_entropy = jnp.mean(data_series[-1,0] * entropy[-1] - data_series[0,0] * entropy[0])
    return total_entropy

In [None]:
# Lower the function to its IR representation
lowered_wrapper = jax.jit(wrapper_fun).lower(M_S)

# Compile the function itself
compiled_wrapper = lowered_wrapper.compile()

After which we can call our ahead-of-time compiled wrapper function with its arguments


In [None]:
compiled_wrapper(inputs) # derives its args from the specialized wrapper function from above.

## Batch Evaluation with AOT'd wrapper

We can now generate a large batch of potential inputs, and just vmap over the entire input_axis. As we have ahead-of-time compiler the function evaluation, we will save a very large swathe of computation due to not having to recompile, fully reevaluate every single execution

In [None]:
# Batch execution to go here

To quantify the computational savings we achieve from this we can now evaluate this with varying batch-size (vector lengths), and see just how much faster we can get while cramming the entire computation onto the GPU.

In [None]:
# Plotting code which takes the above cell as a function, evaluates it for different batch sizes, and then plots them

In this context we can now also look at the difference between the just-in-time compilation, and the ahead-of-time compilation

In [None]:
# Plotting code which takes the above cell as a function, evaluates it for different batch sizes, and then plots them JUST FOR THE JIT