In [17]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os
sin = jnp.sin
cos = jnp.cos

In [18]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
# Constants for the double pendulum system
G  = 9.8  # Acceleration due to gravity (m/s^2)
L1 = 1.0  # Length of pendulum 1 (m)
L2 = 1.0  # Length of pendulum 2 (m)
M1 = 1.0  # Mass of pendulum 1 (kg)
M2 = 1.0  # Mass of pendulum 2 (kg)

# Simulation parameters
M = 100         # Number of Monte Carlo samples
N = 200         # Number of time steps
dt = 0.01       # Time step size
dim_state = 4   # State dimension: [θ1, ω1, θ2, ω2]
dim_obs = 2     # Observation dimension: [θ1, θ2]
model_name = 'pendulum'
save_flag = False
corrupted = False
f_inacc = False
save_dir = '../dataset/pendulum/'  # Directory to save data

# Ensure the directory exists
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# JAX random keys for reproducibility
key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)

# Initial condition: [θ1, ω1, θ2, ω2]
X0 = jnp.radians(jnp.array([130, 0, -10, 0]))  # Convert angles to radians
mu0, cov0 = X0, jnp.eye(4) * 0.01**2           # Mean and covariance for initialization
x0 = jax.random.multivariate_normal(key_init, mean=mu0, cov=cov0)

# Dynamics function for the double pendulum
def fcoord(x, t):
    """
    Computes the derivatives of the state variables.
    
    Parameters:
        x: State vector [θ1, ω1, θ2, ω2].
        t: Time (not used explicitly in this example but can be used for time-dependent systems).
    
    Returns:
        dydx: Derivatives of the state vector.
    """
    th1, w1, th2, w2 = x  # Unpack state variables

    dydx = jnp.zeros_like(x)
    dydx = dydx.at[0].set(w1)  # d(θ1)/dt = ω1
    dydx = dydx.at[2].set(w2)  # d(θ2)/dt = ω2

    D = th2 - th1  # Angle difference between pendulums

    # Compute angular accelerations using equations of motion
    Denom1 = (M1 + M2)*L1 - M2*L1*jnp.cos(D)*jnp.cos(D)
    dd1 = ((M2*L1*w1*w1*jnp.sin(D)*jnp.cos(D)
            + M2*G*jnp.sin(th2)*jnp.cos(D)
            + M2*L2*w2*w2*jnp.sin(D)
            - (M1 + M2)*G*jnp.sin(th1)) / Denom1)
    dydx = dydx.at[1].set(dd1)  # d(ω1)/dt = dd1

    Denom2 = (L2 / L1) * Denom1
    dd2 = ((-M2*L2*w2*w2*jnp.sin(D)*jnp.cos(D)
            + (M1 + M2)*G*jnp.sin(th1)*jnp.cos(D)
            - (M1 + M2)*L1*w1*w1*jnp.sin(D)
            - (M1 + M2)*G*jnp.sin(th2)) / Denom2)
    dydx = dydx.at[3].set(dd2)  # d(ω2)/dt = dd2

    return dydx

# State dynamics with noise
def f(x, t, D, *args):
    """
    Computes the state transition with noise.

    Parameters:
        x: State vector [θ1, ω1, θ2, ω2].
        t: Time step.
        D: Dimensionality of the state.
        *args: Additional arguments for flexibility.
    
    Returns:
        xdot: State transition with noise.
    """
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))  # Gaussian noise
    xdot = fcoord(x, t) + err  # Add noise to the state dynamics
    return xdot

fC = f  # For compatibility, fC is the same as f here

# Linear observation model
H = jnp.array([[1, 0, 0, 0],  # Extract θ1
               [0, 0, 1, 0]]) # Extract θ2

def hcoord(x, t):
    """
    Computes the observation from the state vector.
    
    Parameters:
        x: State vector [θ1, ω1, θ2, ω2].
        t: Time (not used explicitly).
    
    Returns:
        y: Observation vector [θ1, θ2].
    """
    return H @ x

# Observation function with noise
def h(x, t, D, *args):
    """
    Computes the observation with noise.

    Parameters:
        x: State vector [θ1, ω1, θ2, ω2].
        t: Time step.
        D: Dimensionality of the observation space.
        *args: Additional arguments for flexibility.
    
    Returns:
        ydot: Observation with noise.
    """
    keyt = jax.random.fold_in(key_measurement, t)
    err = jax.random.normal(keyt, shape=(D,))  # Gaussian noise
    ydot = hcoord(x, t) + err  # Add noise to the observations
    return ydot

# Step function to compute observations over time
def h_step(ys, dt, N, h):
    """
    Computes a sequence of observations over time.
    
    Parameters:
        ys: Initial observations.
        dt: Time step size.
        N: Number of time steps.
        h: Observation function.
    
    Returns:
        yss: Sequence of observations.
    """
    yss = jnp.zeros((N, dim_obs))  # Preallocate observation array

    def step(i, carry):
        yss, ys = carry
        ysi = h(ys[i - 1], dt * i)  # Compute observation at step i
        yss = yss.at[i].set(ysi)
        return (yss, ys)

    yss, ys = jax.lax.fori_loop(1, N, step, (yss, ys))
    return yss

# Helper functions
def ff(x):
    """Wrapper for state dynamics."""
    return fcoord(x, dt)

def hh(x, t):
    """Wrapper for linear observation function."""
    return H @ x


In [22]:
from data_generation import generate_and_save_data, load_data
from filtering_methods import run_filtering
if save_flag:
    statev, yv, yv_corrupted = generate_and_save_data(
        M, N, dt, dim_state, dim_obs, mu0, cov0, key_init, key_obs,
        f, fC, h, h_step, save_dir, model_name, save_flag, corrupted, f_inacc
    )
else:
    statev, yv, yv_corrupted = load_data(save_dir, model_name, M, N, dim_state, dim_obs, corrupted)

In [25]:
print('System', model_name)
print('corrupted', corrupted)
print('f_inacc', f_inacc)
methods = ['EKF', 'WLFIMQ', 'WLFSQ', 'WLFMD', 'PF']
# methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS', 'PF']
# methods = ['EnKFS']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
parameter_range = (1, 5)
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, num_particle=1000)

System pendulum
corrupted False
f_inacc False
EKF


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:13<00:00,  1.48it/s]


WLFIMQ


100%|██████████| 20/20 [00:23<00:00,  1.16s/it]


WLFSQ


100%|██████████| 20/20 [00:14<00:00,  1.40it/s]


WLFMD


100%|██████████| 20/20 [00:13<00:00,  1.49it/s]


PF


100%|██████████| 20/20 [01:40<00:00,  5.02s/it]


Done
RMSE
{'EKF': Array(nan, dtype=float32),
 'PF': Array(nan, dtype=float32),
 'WLFIMQ': Array(nan, dtype=float32),
 'WLFMD': Array(nan, dtype=float32),
 'WLFSQ': Array(nan, dtype=float32)}
Time
{'EKF': np.float64(3.389719605445862),
 'PF': np.float64(25.11091560125351),
 'WLFIMQ': np.float64(5.805581152439117),
 'WLFMD': np.float64(3.3533756732940674),
 'WLFSQ': np.float64(3.566981673240662)}
