# Approximating the **Lyapunov Spectrum** of the Lorenz system using JAX's autodiff

The Lorenz equations are a prototypical example of **deterministic chaos**. They
are a system of three **nonlinear** ODEs

$$
\begin{aligned}
\frac{dx}{dt} &= \sigma(y - x), \\
\frac{dy}{dt} &= x(\rho - z) - y, \\
\frac{dz}{dt} &= xy - \beta z.
\end{aligned}
$$

The three variables can be combined into the state vector $u = (x, y, z) \in
\mathbb{R}^3$. A system with N degrees of freedom has N Lyapunov exponents,
forming the **Lyapunov Spectrum**. To assess chaotic perturbuation growth,
typically only the **largest** Lyapunov exponent $\lambda$ is relevant

$$
\| \delta u(t) \| \approx \| \delta u(0) \| \exp(\max(\lambda_i) t).
$$

The entire spectrum $\{\lambda_i\}$ can, e.g., be used to compute the
Kaplan-Yorke dimension and analyze other properties of the system.

In this notebook, we will approximate the spectrum for the Lorenz system under
the original configuration $\sigma = 10$, $\rho = 28$ and $\beta = 8/3$ (1)
using a [Runge-Kutta 4
simulator](https://github.com/Ceyron/machine-learning-and-simulation/blob/main/english/simulation_scripts/lorenz_simulator_numpy.ipynb)
of time step size $\Delta t = 0.01$. Let's call the discrete time stepper
$\mathcal{P}$ that advances from one time level $u^{[t]}$ to the next
$u^{[t+1]}$.

Then, we can approximate the spectrum $\{\lambda_0, \lambda_1, \lambda_2\}$ by the following strategy:

1. Draw a reasonable initial condition, e.g. $u^{[0]} = (1, 1, 1)$.
2. Evolve the initial condition until it enters the chaotic attractor, e.g., by
   using $5000$ time steps to get $u^{[5000]}$. Use the last state $u^{[5000]}$
   as the "warmed-up" initial state $u^{[0]} \leftarrow u^{[5000]}$.
4. Introduce a perturbation **matrix** $Y^{[0]} \in \mathbb{R}^{3 \times
   3}$:
   1. For example, draw it randomly from a normal distribution, $Y^{[0]}_{ij}
      \sim \mathcal{N}(0, 1)$.
   2. Orthonormalize it using a **QR decomposition**:
      1. $Q, R = \text{QR}((Y^{[0]}))$
      2. $Y^{[0]} \leftarrow Q$
5. Evolve $u^{[t]}$ via the Runge-Kutta 4 stepper $u^{[t+1]} =
   \mathcal{P}(u^{[t]})$. At the same time, at each time step $t$:
   1. Compute the Jacobian of the time stepper evaluated at the current state
      $J_\mathcal{P}(u^{[t]})$. 
   2. Then, evolve the perturbation $Y^{[t]}$ via the $Y^{[t+1]} =
      J_\mathcal{P}(u^{[t]}) Y^{[t]}$. (*)
   3. Re-orthonormalize $Y^{[t+1]}$ the perturbation matrix using a QR
      decomposition and record the diagonal of
      $R$:
      1. $Q, R = \text{QR}((Y^{[t+1]}))$
      2. $Y^{[t+1]} \leftarrow Q$
      3. $\epsilon^{[t+1]} = \text{diag}(R)$
6. Do this for a certain number of time steps, e.g. $50000$, and record the
   growth factors $\epsilon^{[t+1]} \in \mathbb{R}^{3}$.
7. Approximate the Lyapunov spectrum via

$$
\lambda_i = \frac{1}{\Delta t}\frac{1}{T} \sum_{t=0}^{T} \log |\epsilon^{[t+1]}_i|.
$$

(*) Instead of instantiating the full (and oftentimes dense) Jacobian matrix
$J_\mathcal{P}(u^{[t]})$ at each time step, we can also use `jax.vmap` on
`jax.linearize`.

---

(1) E. N. Lorenz, "Deterministic Nonperiodic Flow", Journal of the Atmospheric
Sciences, 1963,
https://journals.ametsoc.org/view/journals/atsc/20/2/1520-0469_1963_020_0130_dnf_2_0_co_2.xml

In [1]:
import jax
import jax.numpy as jnp

In [2]:
def lorenz_rhs(u, *, sigma, rho, beta):
    x, y, z = u
    x_dot = sigma * (y - x)
    y_dot = x * (rho - z) - y
    z_dot = x * y - beta * z
    u_dot = jnp.array([x_dot, y_dot, z_dot])
    return u_dot

In [3]:
class LorenzStepperRK4:
    def __init__(self, dt=0.01, *, sigma=10, rho=28, beta=8/3):
        self.dt = dt
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
    
    def __call__(self, u_prev):
        lorenz_rhs_fixed = lambda u: lorenz_rhs(
            u,
            sigma=self.sigma,
            rho=self.rho,
            beta=self.beta,
        )
        k_1 = lorenz_rhs_fixed(u_prev)
        k_2 = lorenz_rhs_fixed(u_prev + 0.5 * self.dt * k_1)
        k_3 = lorenz_rhs_fixed(u_prev + 0.5 * self.dt * k_2)
        k_4 = lorenz_rhs_fixed(u_prev + self.dt * k_3)
        u_next = u_prev + self.dt * (k_1 + 2*k_2 + 2*k_3 + k_4)/6
        return u_next

In [4]:
lorenz_stepper = LorenzStepperRK4()

In [5]:
u_0 = jnp.array([1.0, 1.0, 1.0])

In [6]:
lorenz_stepper(u_0)

Array([1.0125672, 1.2599177, 0.984891 ], dtype=float32)

In [7]:
def rollout(stepper, n, *, include_init: bool = False):
    def scan_fn(u, _):
        u_next = stepper(u)
        return u_next, u_next

    def rollout_fn(u_0):
        _, trj = jax.lax.scan(scan_fn, u_0, None, length=n)

        if include_init:
            return jnp.concatenate([jnp.expand_dims(u_0, axis=0), trj], axis=0)

        return trj

    return rollout_fn

In [8]:
trj = rollout(lorenz_stepper, 5000, include_init=True)(u_0)

In [9]:
u_warmed = trj[-1]

In [19]:
def push_orthonormal_matrix_variation(stepper, u_0, Y_0, n):
    def scan_fn(carry, _):
        u, Y = carry

        # Jacobian instantiation
        # u_next = stepper(u)
        # jac = jax.jacfwd(stepper)(u)
        # Y_next = jac @ Y

        # More efficient approach
        u_next, jvp_fn = jax.linearize(stepper, u)
        Y_next = jax.vmap(jvp_fn, in_axes=-1, out_axes=-1)(Y)

        Q, R = jnp.linalg.qr(Y_next)
        Y_next = Q
        growth = jnp.diag(R)

        carry_next = (u_next, Y_next)

        return carry_next, growth
    
    Q, _ = jnp.linalg.qr(Y_0)

    initial_carry = (u_0, Q)

    _, growth_trj = jax.lax.scan(
        scan_fn,
        initial_carry,
        None,
        length=n,
    )

    return growth_trj

In [20]:
Y_0 = jax.random.normal(jax.random.key(0), (3, 3))
Y_0

Array([[ 1.6226422 ,  2.0252647 , -0.43359444],
       [-0.07861735,  0.1760909 , -0.97208923],
       [-0.49529874,  0.4943786 ,  0.6643493 ]], dtype=float32)

In [21]:
growth_trj = push_orthonormal_matrix_variation(
    lorenz_stepper,
    u_warmed,
    Y_0,
    50_000
)

In [22]:
growth_trj.shape

(50000, 3)

In [23]:
unscaled_lyapunov_spectrum = jnp.mean(
    jnp.log(jnp.abs(growth_trj)),
    axis=0,
)
unscaled_lyapunov_spectrum.shape

(3,)

In [24]:
unscaled_lyapunov_spectrum

Array([ 9.039051e-03, -3.169071e-05, -1.456730e-01], dtype=float32)

In [25]:
lyapunov_spectrum = unscaled_lyapunov_spectrum / lorenz_stepper.dt

In [26]:
lyapunov_spectrum.round(3)

Array([ 9.0400004e-01, -3.0000000e-03, -1.4567000e+01], dtype=float32)