In [4]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

# Global Sensitivity Analysis

We will again work with the Duffing Oscillator. As a reminder it is given by the following equation:

$$
\ddot{x} + \delta \dot{x} + \alpha x + \beta x^3 = \gamma \cos(\omega t)
$$

Written down as a vector field, we have:

$$
\begin{align*}
\dot{x} = v \\
\dot{v} = -\delta v - \alpha x - \beta x^3 + \gamma \cos(\omega t)
\end{align*}
$$

Using the same code as before:

In [7]:
import numpy as np
import jax.numpy as jnp
from diffrax import diffeqsolve, Tsit5, ODETerm, SaveAt
from SALib.sample import sobol_sequence
from SALib.analyze import sobol

def vector_field(t, y, theta):
    alpha, beta, gamma, delta, omega = theta[:5]
    x = y[0]
    v = y[1]
    return jnp.array(
        [
            v,
            - alpha * x - beta * x ** 3 - delta * v + gamma * jnp.cos(omega * t)
        ]
    )

theta = jnp.array([
    1.0,  # alpha
    5.0,  # beta
    0.37, # gamma
    0.1,  # delta
    1.0,  # omega
])

# The numerical solver to use.
solver = Tsit5()
# At which timesteps to store the solution.
saveat = SaveAt(ts=jnp.linspace(0, 50, 2000))
# The differential equation term.
term = ODETerm(vector_field)
# The Solution for one theta.
sol = diffeqsolve(
    term,
    solver,
    t0=0,                       # Initial time
    t1=50,                      # Terminal time
    dt0=0.1,                    # Initial timestep - it will be adjusted
    y0=jnp.array([0.0, 0.0]),   # Initial value
    args=theta,
    saveat=saveat
)

