In [6]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os
sin = jnp.sin
cos = jnp.cos

In [7]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
# Simulation Parameters
M = 100  # Number of Monte Carlo samples
N = 200  # Number of time steps
dt = 0.01  # Time step size
dim_state = 2  # Dimensionality of the state variables
dim_obs = 2  # Dimensionality of the observation variables
mu = 0.5  # Model parameter (not used in this specific oscillator)
omega = 1.0  # Oscillator frequency (not used directly here)
A = 1.0  # Oscillator amplitude factor (not used directly here)
s_var = 1.0  # Variance for state noise
o_var = np.sqrt(1.0)  # Standard deviation for observation noise
model_name = 'oscillator'  # Model name for data saving
save_flag = False  # Flag to control saving of results
save_dir = '../dataset/' + model_name + '/'  # Directory to save data
corrupted = False  # Flag for corrupted observations
f_inacc = False  # Flag for inaccurate dynamics
ALPHA = 2.0  # Amplitude of perturbations in inaccurate dynamics
BETA = 2.0  # Amplitude of bias in inaccurate dynamics

# Create save directory if it doesn't exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Random Key Initialization
key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)

# Initial Conditions
X0 = jnp.radians(jnp.array([0.0, -1.0]))  # Initial state (angles in radians)
mu0, cov0 = X0, jnp.eye(dim_state) * 0.01**2  # Mean and covariance for state initialization
x0 = jax.random.multivariate_normal(key_init, mean=mu0, cov=cov0)

# Nonlinear Oscillator Dynamics
def fcoord(x, t):
    """
    Computes the nonlinear state derivatives for the oscillator model.
    - x_dot: Derivative of x[0]
    - y_dot: Derivative of x[1]
    """
    x_dot = -0.1 * x[..., 0] ** 3 + 2 * x[..., 1] ** 3
    y_dot = -2 * x[..., 0] ** 3 - 0.1 * x[..., 1] ** 3
    dx = jnp.stack([x_dot, y_dot])
    return dx

# Accurate State Transition Function
def f(x, t, D, kain_var=1, *args):
    """
    Accurate state transition function with Gaussian noise.
    - x: Current state
    - t: Current time step
    - D: Dimensionality of noise
    - kain_var: Noise scaling factor
    """
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = fcoord(x, t) + kain_var * err
    return xdot

# Inaccurate State Transition Function
def fC(x, t, D, kain_var=1, *args):
    """
    Inaccurate state transition function with random perturbations.
    - Adds adjustable bias and scaling to the dynamics.
    """
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    key_a, key_b = jax.random.split(keyt)
    dx_mod = fcoord(x, t) + kain_var * err
    a = jax.random.uniform(key_a, x.shape, minval=-ALPHA, maxval=ALPHA)
    b = jax.random.uniform(key_b, x.shape, minval=-BETA, maxval=BETA)

    # Apply adjustments
    dx_mod = dx_mod.at[0].set(dx_mod[0] - a[0] * x[1] + b[0])
    return dx_mod

# Observation Matrix (Identity Mapping)
H = jnp.array([[1, 0], [0, 1]])

# Linear Observation Function
def hcoord(x, t):
    """
    Computes the linear observation from the state.
    - x: State vector
    - t: Time (unused here, for consistency)
    """
    return H @ x

# Observation Function with Noise
def h(x, t, D, kain_var=1, *args):
    """
    Adds Gaussian noise to the observation.
    - x: State vector
    - t: Current time step
    - D: Dimensionality of noise
    - kain_var: Noise scaling factor
    """
    keyt = jax.random.fold_in(key_measurement, t)
    err = jax.random.normal(keyt, shape=(D,))
    ydot = hcoord(x, t) + kain_var * err
    return ydot

# Compute Observations Over Time
def h_step(ys, dt, N, h):
    """
    Computes the observation sequence over N time steps.
    - ys: Initial observations
    - dt: Time step size
    - N: Number of time steps
    - h: Observation function
    """
    yss = jnp.zeros((N, dim_obs))

    def step(i, carry):
        yss, ys = carry
        ysi = h(ys[i - 1], dt * i)
        yss = yss.at[i].set(ysi)
        return (yss, ys)

    yss, ys = jax.lax.fori_loop(1, N, step, (yss, ys))
    return yss

# Helper Functions
def ff(x):
    """Wrapper for state dynamics."""
    return fcoord(x, dt)

def hh(x, t):
    """Wrapper for linear observation function."""
    return H @ x


In [9]:
from data_generation import generate_and_save_data, load_data
from filtering_methods import run_filtering
if save_flag:
    statev, yv, yv_corrupted = generate_and_save_data(
        M, N, dt, dim_state, dim_obs, mu0, cov0, key_init, key_obs,
        f, fC, h, h_step, save_dir, model_name, save_flag, corrupted, f_inacc
    )
else:
    statev, yv, yv_corrupted = load_data(save_dir, model_name, M, N, dim_state, dim_obs, corrupted)

In [10]:
print('System', model_name)
print('corrupted', corrupted)
print('f_inacc', f_inacc)
methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS', 'PF']
# methods = ['EnKFS']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, num_particle=1000)

System oscillator
corrupted False
f_inacc False
EKF


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:04<00:00,  4.66it/s]


EnKF


100%|██████████| 20/20 [00:10<00:00,  1.87it/s]


EnKFI


100%|██████████| 20/20 [00:11<00:00,  1.76it/s]


HubEnKF


100%|██████████| 20/20 [00:10<00:00,  1.92it/s]


EnKFS


100%|██████████| 20/20 [00:10<00:00,  1.97it/s]


PF


100%|██████████| 20/20 [00:10<00:00,  1.90it/s]

Done
RMSE
{'EKF': Array(nan, dtype=float32),
 'EnKF': Array(0.5541, dtype=float32),
 'EnKFI': Array(0.5445, dtype=float32),
 'EnKFS': Array(0.5586, dtype=float32),
 'HubEnKF': Array(0.544, dtype=float32),
 'PF': Array(0.4185, dtype=float32)}
Time
{'EKF': np.float64(1.0725721120834348),
 'EnKF': np.float64(2.671838879585266),
 'EnKFI': np.float64(2.8441840410232544),
 'EnKFS': np.float64(2.544935822486877),
 'HubEnKF': np.float64(2.611142098903656),
 'PF': np.float64(2.6319159865379333)}



