In [86]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os
sin = jnp.sin
cos = jnp.cos
from data_generation import generate_and_save_data, load_data
from filtering_methods import run_filtering

In [87]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [88]:
# Define simulation parameters
M = 100  # Number of Monte Carlo samples
N = 200  # Number of time steps
dt = 0.01  # Time step size
dim_state = 2  # Dimension of state variables
dim_obs = 2  # Dimension of observation variables
a = 0.08  # Parameter for the system dynamics
b = 0.6  # Parameter for the system dynamics
model_name = 'selkov'  # Model name for saving files
kain_var = np.sqrt(1)  # Standard deviation of measurement noise
save_flag = False  # Flag to enable/disable saving results
save_dir = '../dataset/' + model_name + '/'  # Directory to save results
corrupted = False  # Flag to indicate if corrupted observations should be generated
f_inacc = False  # Flag to indicate if an inaccurate state transition function is used
ALPHA = 2.0  # Range of perturbation in inaccurate dynamics
BETA = 2.0  # Range of perturbation in inaccurate dynamics

# Create the directory if it does not exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Generate random keys for reproducibility
key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)

# Initial condition for the state variables
X0 = jnp.radians(jnp.array([0.7, 1.25]))  # Initial state in radians
mu0, cov0 = X0, jnp.eye(dim_state) * 0.01**2  # Mean and covariance of initial state
x0 = jax.random.multivariate_normal(key_init, mean=mu0, cov=cov0)

# Define the system dynamics (state transition function)
def fcoord(x, t):
    """System dynamics function for Selkov model."""
    x_dot = -x[..., 0] + a * x[..., 1] + x[..., 0] ** 2 * x[..., 1]
    y_dot = b - a * x[..., 1] - x[..., 0] ** 2 * x[..., 1]
    dx = jnp.stack([x_dot, y_dot])
    return dx

def f(x, t, D, *args):
    """Accurate state transition function with added random noise."""
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = fcoord(x, t) + err
    return xdot

def fC(x, t, D, *args):
    """Inaccurate state transition function with perturbations."""
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    key_a, key_b = jax.random.split(keyt)
    dx_mod = fcoord(x, t) + err
    a = jax.random.uniform(key_a, x.shape, minval=-ALPHA, maxval=ALPHA)
    b = jax.random.uniform(key_b, x.shape, minval=-BETA, maxval=BETA)
    
    # Apply adjustments to the dynamics
    dx_mod = dx_mod.at[0].set(dx_mod[0] - a[0] * x[1] + b[0])
    return dx_mod

# Define the observation function
H = jnp.array([[1, 0], [0, 1]])  # Observation matrix

def hcoord(x, t):
    """Linear observation function."""
    return H @ x

def h(x, t, D, *args):
    """Observation function with added measurement noise."""
    keyt = jax.random.fold_in(key_measurement, t)
    err = jax.random.normal(keyt, shape=(D,))
    ydot = hcoord(x, t) + kain_var * err
    return ydot

# Time-stepping function for observations
def h_step(ys, dt, N, h):
    """Iteratively compute observations over time."""
    yss = jnp.zeros((N, dim_obs))

    def step(i, carry):
        """Single step update for observation computation."""
        yss, ys = carry
        ysi = h(ys[i - 1], dt * i)
        yss = yss.at[i].set(ysi)
        return (yss, ys)

    yss, ys = jax.lax.fori_loop(1, N, step, (yss, ys))
    return yss

# Helper functions for dynamics and observations
def ff(x):
    """Wrapper for state dynamics."""
    return fcoord(x, dt)

def hh(x, t):
    """Wrapper for linear observation function."""
    return H @ x


In [89]:
from data_generation import generate_and_save_data, load_data
from filtering_methods import run_filtering
if save_flag:
    statev, yv, yv_corrupted = generate_and_save_data(
        M, N, dt, dim_state, dim_obs, mu0, cov0, key_init, key_obs,
        f, fC, h, h_step, save_dir, model_name, save_flag, corrupted, f_inacc
    )
else:
    statev, yv, yv_corrupted = load_data(save_dir, model_name, M, N, dim_state, dim_obs, corrupted)

In [90]:
print('System', model_name)
print('corrupted', corrupted)
print('f_inacc', f_inacc)
methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS', 'PF']
# methods = ['EnKFS']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, num_particle=1000)

System selkov
corrupted False
f_inacc False
EKF


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:04<00:00,  4.32it/s]


EnKF


100%|██████████| 20/20 [00:10<00:00,  1.97it/s]


EnKFI


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


HubEnKF


100%|██████████| 20/20 [00:11<00:00,  1.77it/s]


EnKFS


100%|██████████| 20/20 [00:10<00:00,  1.98it/s]


PF


100%|██████████| 20/20 [00:10<00:00,  2.00it/s]

Done
RMSE
{'EKF': Array(0.5272, dtype=float32),
 'EnKF': Array(0.6015, dtype=float32),
 'EnKFI': Array(0.665, dtype=float32),
 'EnKFS': Array(0.6014, dtype=float32),
 'HubEnKF': Array(0.6639, dtype=float32),
 'PF': Array(0.3287, dtype=float32)}
Time
{'EKF': np.float64(1.1578254103660583),
 'EnKF': np.float64(2.5324233174324036),
 'EnKFI': np.float64(2.765438914299011),
 'EnKFS': np.float64(2.53198516368866),
 'HubEnKF': np.float64(2.824265599250794),
 'PF': np.float64(2.504407525062561)}



