In [17]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os

In [18]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18
XLA_PYTHON_CLIENT_ALLOCATOR=0.5

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
model_name = 'VL20mc'
save_flag = False
save_dir ='../dataset/'+ model_name +'/'
corrupted = False
f_inacc = False
kain_var = np.sqrt(1)  # Standard deviation of measurement 
ALPHA=2.0
BETA=2.0
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
M=100
N=200    
D = 72
F = 8.0
F = 10
G = 0
alpha = 1
gamma = 1
dt = 0.01
nX = 36
dim_state = 2 * nX
dim_obs = dim_state
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)
key_obs, key_mc = jax.random.split(key_obs)
X0 = jax.random.normal(key_mc, (D,)) + F
mu0, cov0 = X0, jnp.eye(dim_state)  # Mean and covariance of initial state

def shift(x, n):
    return jnp.roll(x, -n, axis=-1)

def fcoord(x, t):
    X = x[:nX]
    theta = x[nX:]
    d = jnp.zeros_like(x)
    d.at[:nX].set((shift(X, 1)-shift(X, -2))*shift(X, -1))
    d.at[:nX].set(d[:nX] - gamma*shift(X, 0)) 
    d.at[:nX].set(d[:nX] - alpha*shift(theta, 0) + F)
    d.at[nX:].set(shift(X, 1)*shift(theta, 2) \
        - shift(X, -1)*shift(theta, -2))
    d.at[nX:].set(d[nX:] - gamma*shift(theta, 0))
    d.at[nX:].set(d[nX:] + alpha*shift(X, 0) + G) 
    return d

def f(x, t, D, *args):
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = fcoord(x, t) + F + err
    return xdot

def fC(x, t, D, *args):
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    key_a, key_b = jax.random.split(keyt)
    ixs = jnp.arange(D)
    dx_mod = fcoord(x, ixs, D) + F + err
    a = jax.random.uniform(key_a, x.shape, minval=-ALPHA, maxval=ALPHA)
    b = jax.random.uniform(key_b, x.shape, minval=-BETA, maxval=BETA)
    dx_mod = dx_mod.at[0].set(dx_mod[0] - a[0] * x[1] + b[0])
    return dx_mod

def hcoord(x, t):
    """Linear observation function."""
    return x

def h(x, t, D, *args):
    """Observation function with added measurement noise."""
    keyt = jax.random.fold_in(key_measurement, t)
    err = jax.random.normal(keyt, shape=(D,))
    ydot = hcoord(x, t) + kain_var * err
    return ydot

# Time-stepping function for observations
def h_step(ys, dt, N, h):
    """Iteratively compute observations over time."""
    yss = jnp.zeros((N, dim_obs))

    def step(i, carry):
        """Single step update for observation computation."""
        yss, ys = carry
        ysi = h(ys[i - 1], dt * i)
        yss = yss.at[i].set(ysi)
        return (yss, ys)

    yss, ys = jax.lax.fori_loop(1, N, step, (yss, ys))
    return yss

# Helper functions for dynamics and observations
def ff(x):
    """Wrapper for state dynamics."""
    return fcoord(x, dt)

def hh(x, t):
    """Wrapper for linear observation function."""
    return hcoord(x, t)

In [20]:
from data_generation import generate_and_save_data, load_data
from filtering_methods import run_filtering
if save_flag:
    statev, yv, yv_corrupted = generate_and_save_data(
        M, N, dt, dim_state, dim_obs, mu0, cov0, key_init, key_obs,
        ff, fC, hh, h_step, save_dir, model_name, save_flag, corrupted, f_inacc
    )
else:
    statev, yv, yv_corrupted = load_data(save_dir, model_name, M, N, dim_state, dim_obs, corrupted)

In [21]:
print('System', model_name)
print('corrupted', corrupted)
print('f_inacc', f_inacc)
methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS', 'PF']
# methods = ['EnKFS']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, num_particle=1000)

System VL20mc
corrupted False
f_inacc False
EKF


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:13<00:00,  1.51it/s]


EnKF


100%|██████████| 20/20 [01:04<00:00,  3.25s/it]


EnKFI


100%|██████████| 20/20 [00:54<00:00,  2.74s/it]


HubEnKF


100%|██████████| 20/20 [01:12<00:00,  3.63s/it]


EnKFS


100%|██████████| 20/20 [01:02<00:00,  3.13s/it]


PF


100%|██████████| 20/20 [00:32<00:00,  1.60s/it]


Done
RMSE
{'EKF': Array(10.4698, dtype=float32),
 'EnKF': Array(9.8552, dtype=float32),
 'EnKFI': Array(9.6644, dtype=float32),
 'EnKFS': Array(17.752, dtype=float32),
 'HubEnKF': Array(8.2927, dtype=float32),
 'PF': Array(20.9228, dtype=float32)}
Time
{'EKF': np.float64(3.31871634721756),
 'EnKF': np.float64(16.233397603034973),
 'EnKFI': np.float64(13.690877616405489),
 'EnKFS': np.float64(15.6501869559288),
 'HubEnKF': np.float64(18.13212239742279),
 'PF': np.float64(8.008211016654968)}


In [22]:
n_mc = 100
N = 200
dt = 0.01
p_err = 0.01
xs_mc = []
ys_mc = []
ys_corrupted_mc = []
key_obs, key_mc = jax.random.split(key_obs)
for i in range(n_mc):
    key_mc, subkey = jax.random.split(key_mc)  # 生成新的子随机数生成器
    x0 = jax.random.normal(key_mc, (D,)) + F
    xs1 = jnp.zeros((N,) + x0.shape)
    xs1 = xs1.at[0].set(x0)
    fpart = partial(f, D=D)
    xs1 = rk4(xs1, dt, N, fpart)
    ys1 = xs1 + jax.random.normal(key_measurement, xs1.shape)
    ys1_corrupted = ys1.copy()
    if corrupted:
        errs_map = jax.random.bernoulli(key_init, p=p_err, shape=ys1_corrupted.shape)
        ys1_corrupted = ys1_corrupted * (~errs_map) + 100.0 * errs_map
        err_where = np.where(errs_map)
    xs_mc.append(xs1)
    ys_mc.append(ys1)
    ys_corrupted_mc.append(ys1_corrupted)
statev = jnp.array(xs_mc)
yv = jnp.array(ys_mc)
yv_corrupted = jnp.array(ys_corrupted_mc)
print(statev.shape)
print(yv.shape)
print(yv_corrupted.shape)

NameError: name 'rk4' is not defined

In [7]:
ys = yv.reshape(-1, D)
ys_corrupted = yv_corrupted.reshape(-1, D)
xs = statev.reshape(-1, D)
print(xs.shape,ys.shape)

(20000, 72) (20000, 72)


In [8]:
# 生成时间序列
start_time = datetime(2024, 6, 22, 13, 0, 0)
time_interval = timedelta(seconds=1)
time_series = [start_time + i * time_interval for i in range(xs.shape[0])]

# 创建 DataFrame
column_names = [f'x_{i}' for i in range(1, xs.shape[1] + 1)]
df = pd.DataFrame(xs, columns=column_names)
df.insert(0, 'date', time_series)

print(df.head())

                 date        x_1        x_2       x_3        x_4        x_5  \
0 2024-06-22 13:00:00   9.930324  10.413463  9.601624  11.190229  10.475810   
1 2024-06-22 13:00:01  10.026490  10.512879  9.690468  11.293329  10.577833   
2 2024-06-22 13:00:02  10.122657  10.612296  9.779312  11.396429  10.679856   
3 2024-06-22 13:00:03  10.218823  10.711713  9.868156  11.499529  10.781879   
4 2024-06-22 13:00:04  10.314990  10.811130  9.957001  11.602629  10.883903   

        x_6        x_7        x_8       x_9  ...       x_63       x_64  \
0  9.100482   9.611787  10.599230  7.906620  ...  10.527282  10.186784   
1  9.216859   9.716647  10.709743  8.002557  ...  10.633910  10.273329   
2  9.333236   9.821507  10.820255  8.098494  ...  10.740539  10.359874   
3  9.449613   9.926368  10.930768  8.194430  ...  10.847167  10.446419   
4  9.565989  10.031228  11.041281  8.290367  ...  10.953795  10.532964   

        x_65       x_66      x_67      x_68      x_69       x_70       x_71  \
0

In [9]:
if save_flag:
    # # 保存为 CSV 文件
    csv_path = save_dir + model_name + '.csv'
    df.to_csv(csv_path, index=False)
    # 创建 DataFrame
    column_names = [f'y_{i}' for i in range(1, ys.shape[1] + 1)]
    dfy = pd.DataFrame(ys, columns=column_names)
    dfy.insert(0, 'date', time_series)

    print(dfy.head())
    csv_path = save_dir + model_name + '_obs.csv'
    dfy.to_csv(csv_path, index=False)

    if corrupted:
        column_names = [f'y_{i}' for i in range(1, ys_corrupted.shape[1] + 1)]
        dfy = pd.DataFrame(ys_corrupted, columns=column_names)
        dfy.insert(0, 'date', time_series)

        print(dfy.head())
        csv_path = save_dir + model_name + '_obs_corrupted.csv'
        dfy.to_csv(csv_path, index=False)

In [10]:
range_time = np.arange(N) * dt
np.set_printoptions(precision=4)

def callback_fn(particles, particles_pred, y, i):
    return jnp.sqrt(jnp.power(particles.mean(axis=0) - xs[i], 2).mean()), particles.mean(axis=0)
def latent_fn(x, key, i):
    """
    State function
    """
    err = jax.random.normal(key, (D,))
    @jax.jit
    def ff(x, t):
        return fcoord(x, D) + F + err
    
    return rk4_step(x, i, dt, ff)

def obs_fn(x, key, i):
    """
    Measurement function
    """
    err = jax.random.normal(key, (D,))
    return x + err

def f_step(x):
    return fcoord(x, D) + F

def h_step(x, _):
    return x

def obs_fn(x, key, i):
    """
    Measurement function
    """
    err = jax.random.normal(key, (D,))
    return x + err

def calculate_mse(errs):
    return jnp.mean(jnp.square(errs))
def calculate_mae(errs):
    return jnp.mean(jnp.abs(errs))

def calculate_rmse(errs):
    return jnp.sqrt(jnp.mean(jnp.square(errs)))

def calculate_test_error(errs):
    mse = calculate_mse(errs)
    mae = calculate_mae(errs)
    rmse = calculate_rmse(errs) 
    print(f'mse:{mse}, mae:{mae}, rmse:{rmse}')
    return mse, mae, rmse
M = n_mc
ff = f_step
hh = h_step

In [11]:
from filtering_methods import run_filtering
methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS', 'PF']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
parameter_range = (1, 20)
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, parameter_range, num_particle=100)

EKF


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:04<00:00,  4.21it/s]


EnKF


100%|██████████| 20/20 [00:13<00:00,  1.52it/s]


EnKFI


100%|██████████| 20/20 [00:13<00:00,  1.52it/s]


HubEnKF


100%|██████████| 20/20 [00:12<00:00,  1.57it/s]


EnKFS


100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


PF


100%|██████████| 20/20 [00:17<00:00,  1.16it/s]


Done
RMSE
{'EKF': Array(5.8633, dtype=float32),
 'EnKF': Array(nan, dtype=float32),
 'EnKFI': Array(1.4246, dtype=float32),
 'EnKFS': Array(10.4608, dtype=float32),
 'HubEnKF': Array(0.8825, dtype=float32),
 'PF': Array(11.6927, dtype=float32)}
Time
{'EKF': np.float64(1.1888734102249146),
 'EnKF': np.float64(3.2847465872764587),
 'EnKFI': np.float64(3.2846200466156006),
 'EnKFS': np.float64(3.4689825177192692),
 'HubEnKF': np.float64(3.1928635835647583),
 'PF': np.float64(4.322352588176727)}


In [12]:
# y_index = int(0.8 * ys.shape[0])
# print(ys.shape)
# # 截取最后 20% 的数据
# ys = ys[y_index:]
# ys_corrupted = ys_corrupted[y_index:]
# print(ys.shape)
time_methods = {}
hist_methods = {}
errs_methods = {}

In [13]:
def filter_enkf(x0, n_particles, measurements, state, key_eval):
    agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
    key_init_particles, key_scan = jax.random.split(key_eval, 2)
    X0 = agent.init_bel(key_init_particles, D)
    particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, measurements, callback_fn=callback_fn)
    return errs, particles_hist_mean

def filter_enkfi(x0, inflation_factor, n_particles, measurements, state, key_eval):
    agent = enkf.EnsembleKalmanFilterInflation(latent_fn, obs_fn, n_particles, inflation_factor=inflation_factor)
    key_init_particles, key_scan = jax.random.split(key_eval, 2)
    X0 = agent.init_bel(key_init_particles, D)
    particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, measurements, callback_fn=callback_fn)
    return errs, particles_hist_mean

def filter_ekf(x0, measurements, state, key_eval):
    nsteps = len(measurements)
    agent_imq = rkf.ExtendedKalmanFilterIMQ(
        f_step, h_step,
        dynamics_covariance=jnp.eye(D),
        observation_covariance=jnp.eye(D),
        soft_threshold=1e8,
    )
    init_bel = agent_imq.init_bel(x0, cov=1.0)
    filterfn = partial(agent_imq.scan, callback_fn=callbacks.get_updated_mean)
    _, hist = filterfn(init_bel, measurements, jnp.ones(N))

    err = jnp.sqrt(jnp.power(hist - state, 2).sum(axis=0))
    return err, hist

In [14]:
method = 'EnKF_1000'
n_particles = 1000
hist_bel = []
times = []

for y, state in tqdm(zip(yv, statev), total=n_mc): 
    tinit = time()
    key_eval, key_scan = jax.random.split(key_eval)
    # print(xs[0])
    errs, _ = filter_enkf(state[0], n_particles, y, state, key_eval)
    tend = time()
    
    hist_bel.append(errs)
    times.append(tend - tinit)

errs = np.stack(hist_bel)
time_methods[method] = times
print(method)
errs_test = errs[int(errs.shape[0] * 0.8):]
# _, _, errs_methods[method] = calculate_test_error(errs_test)
errs_methods[method] = errs_test


 13%|█▎        | 13/100 [00:13<01:32,  1.06s/it]

100%|██████████| 100/100 [01:41<00:00,  1.01s/it]


EnKF_1000


In [15]:
method = 'EnKFI_1000'
n_particles = 1000
inflation_factor = 3.0
hist_bel = []
times = []

for y, state in tqdm(zip(yv, statev), total=n_mc): 
    tinit = time()
    key_eval, key_scan = jax.random.split(key_eval)
    errs, _ = filter_enkfi(state[0], inflation_factor, n_particles, y, state, key_eval)
    tend = time()
    
    hist_bel.append(errs)
    times.append(tend - tinit)

errs = np.stack(hist_bel)
time_methods[method] = times
print(method)
errs_test = errs[int(errs.shape[0] * 0.8):]
_, _, errs_methods[method] = calculate_test_error(errs_test)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [01:31<00:00,  1.10it/s]


EnKFI_1000
mse:1.9698152542114258, mae:1.3963477611541748, rmse:1.403501033782959


In [16]:
print(errs.shape)
print(errs_test.shape)

(100, 200)
(20, 200)


In [17]:
method = 'EKF'
hist_bel = []
times = []
errs = 0
for y, state in tqdm(zip(yv, statev), total=n_mc): 
    tinit = time()
    key_eval, key_scan = jax.random.split(key_eval)
    errs, _ = filter_ekf(state[0], y, state, key_eval)
    tend = time()
    
    hist_bel.append(errs)
    times.append(tend - tinit)

errs = np.stack(hist_bel)
time_methods[method] = times
print(method)
errs_test = errs[int(errs.shape[0] * 0.8):]
_, _, errs_methods[method] = calculate_test_error(errs_test)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:23<00:00,  4.22it/s]

EKF
mse:6875.650390625, mae:82.53736114501953, rmse:82.91954040527344





In [18]:
print(errs.shape)
print(errs_test.shape)

(100, 72)
(20, 72)


In [19]:
print(errs_methods)

{'EnKF_1000': array([[5.2713, 3.8535, 3.0845, ..., 1.7313, 1.7358, 1.7356],
       [5.4065, 3.9168, 3.2378, ..., 1.7031, 1.7095, 1.7097],
       [5.3494, 3.8693, 3.1638, ..., 1.8859, 1.8882, 1.8908],
       ...,
       [5.7314, 4.0663, 3.2847, ..., 1.7853, 1.7857, 1.7888],
       [5.4255, 3.7086, 2.9593, ..., 1.7975, 1.8038, 1.8037],
       [5.2083, 3.801 , 3.117 , ..., 1.5761, 1.5803, 1.5796]],
      dtype=float32), 'EnKFI_1000': Array(1.4035, dtype=float32), 'EKF': Array(82.9195, dtype=float32)}


In [20]:
# # 将 JAX 数组转换为 Python 的 float 类型
# def convert_jax_to_python(dictionary):
#     return {key: float(value) for key, value in dictionary.items()}

# # 转换字典
# converted_errs_methods = {method: convert_jax_to_python(metrics) for method, metrics in errs_methods.items()}

# for keys, values in converted_errs_methods.items():
#     print(keys, values)
