# 2D & 3D Kuramoto-Sivashinsky in JAX using Exponax

In 1D, the KS equation is popular example for spatio-temporal chaos, producing
rich states depending on the domain length $L$

$$
\frac{\partial u}{\partial t}
+
\frac{1}{2} \frac{\partial u^2}{\partial x}
+
\frac{\partial^2 u}{\partial x^2}
+
\frac{\partial^4 u}{\partial x^4}
=
0
$$

It balances energy "production" from a destabilizing second-order term, energy
transfer using a convective nonlinearity and the dissipation using a
fourth-order term.

Extending it to higher dimensions typically uses a different nonlinearity

$$
\frac{1}{2} \frac{\partial u^2}{\partial x}
\to
\frac{1}{2} \left(\frac{\partial u}{\partial x}\right)^2
$$

which yields

$$
\frac{\partial u}{\partial t}
+
\frac{1}{2} \lVert \nabla u \rVert_2^2
+ \Delta u + \Delta \Delta u
=
0
$$

We will solve it on either a 2D domain $\Omega = (0, L)^2$ or a 3D domain
$\Omega = (0, L)^3$ using the Fourier-pseudo-spectral solver `Exponax`. By this,
we will use *periodic boundary conditions in each dimension. We will fix $L=50$
and integrate in steps of $\Delta t = 0.1$. The first $500$ time steps starting
from Gaussian white noise are a transient phase and we will only record a
trajectory on the chaotic manifold.

In [2]:
import jax
import jax.numpy as jnp
import exponax as ex
from IPython.display import HTML

In [3]:
ks_stepper_2d = ex.stepper.KuramotoSivashinsky(
    2,
    50.0,
    256,
    0.1,
)

In [4]:
ic_2d = jax.random.normal(
    jax.random.key(0),
    (1, 256, 256)
)

In [5]:
warmed_ic_2d = ex.repeat(ks_stepper_2d, 500)(ic_2d)

In [6]:
trj_2d = ex.rollout(ks_stepper_2d, 100, include_init=True)(warmed_ic_2d)

In [7]:
trj_2d;

In [8]:
ani_2d = ex.viz.animate_state_2d(trj_2d, vlim=(-10, 10))

In [None]:
HTML(ani_2d.to_jshtml())

In [10]:
ics_2d = jax.random.normal(
    jax.random.key(1),
    (20, 1, 256, 256)
)

In [11]:
warmed_ics_2d = jax.vmap(ex.repeat(ks_stepper_2d, 500))(ics_2d)

In [12]:
trjs_2d = jax.vmap(ex.rollout(ks_stepper_2d, 100, include_init=True))(warmed_ics_2d)

In [13]:
trjs_2d;

In [14]:
ani_2d_facet = ex.viz.animate_state_2d_facet(
    trjs_2d,
    facet_over_channels=False,
    vlim=(-10, 10),
    grid=(2, 2),
    figsize=(5, 5),
)

In [None]:
HTML(ani_2d_facet.to_jshtml())

In [16]:
jnp.save(
    "ks_trjs_2d.npy",
    trjs_2d.astype(jnp.float16),
)

In [18]:
ks_stepper_3d = ex.stepper.KuramotoSivashinsky(
    3,
    50.0,
    64,
    0.1,
)

In [19]:
ics_3d = jax.random.normal(
    jax.random.key(2),
    (20, 1, 64, 64, 64)
)

In [20]:
warmed_ics_3d = jax.vmap(ex.repeat(ks_stepper_3d, 500))(ics_3d)

In [None]:
trjs_3d = jax.vmap(ex.rollout(ks_stepper_3d, 100, include_init=True))(warmed_ics_3d)

In [24]:
trjs_3d

Array([[[[[[ 7.54035592e-01,  3.89586163e+00,  6.33405924e+00, ...,
            -1.19389117e+00, -1.35544682e+00, -9.97703254e-01],
           [-4.29681778e-01,  2.93012047e+00,  7.17192936e+00, ...,
            -3.16991711e+00, -2.90139771e+00, -2.16130662e+00],
           [-6.43073320e-01,  2.93236518e+00,  7.53497028e+00, ...,
            -5.04572201e+00, -4.28391266e+00, -2.87305403e+00],
           ...,
           [ 3.52588940e+00,  2.22089410e-01, -3.56735468e+00, ...,
             2.72751904e+00,  4.20207405e+00,  4.72851324e+00],
           [ 3.62371802e+00,  3.10905790e+00,  1.92281246e-01, ...,
             1.51930821e+00,  1.93805528e+00,  2.61576080e+00],
           [ 2.41005516e+00,  4.42043591e+00,  3.98759604e+00, ...,
             3.18311453e-01,  1.49392843e-01,  5.72623014e-01]],

          [[ 1.98160040e+00,  4.76488876e+00,  5.98668146e+00, ...,
            -1.10317075e+00, -5.59282899e-01,  6.62160516e-02],
           [ 6.61220312e-01,  4.07032013e+00,  6.71462154e

In [25]:
64**3 * 20

5242880

In [26]:
jnp.save("ks_trjs_3d.npy", trjs_3d.astype(jnp.float16))

In [28]:
ani_3d_facet = ex.viz.animate_state_3d_facet(
    trjs_3d[:4, :40],
    facet_over_channels=False,
    vlim=(-10, 10),
    grid=(2, 2),
    figsize=(5, 5),
)

# Internally uses vape4d (can also be separately installed)

In [None]:
HTML(ani_3d_facet.to_jshtml())