In [1]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os
sin = jnp.sin
cos = jnp.cos

In [2]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

In [3]:
M = 100
N = 200
dt = 0.01
dim_state = 3
dim_obs = 3
# parameters
sigma = 10.0
beta = 8.0 / 3.0
rho = 28.0
model_name = 'lorenz63'
save_flag = False
save_dir ='../dataset/'+ model_name +'/'
corrupted = False
f_inacc = False
ALPHA=2.0
BETA=2.0
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [4]:
def rk4_step(y, i, dt, f):
    h = dt
    t = dt * i
    k1 = h * f(y, t)
    k2 = h * f(y + k1 / 2, dt * i + h / 2)
    k3 = h * f(y + k2 / 2, t + h / 2)
    k4 = h * f(y + k3, t + h)

    y_next = y + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return y_next

    

@partial(jax.jit, static_argnames=("f",))
def rk4(ys, dt, N, f):
    """
    Based on
    https://colab.research.google.com
    github/google/jax/blob/master/cloud_tpu_colabs/Lorentz_ODE_Solver
    """
    @jax.jit
    def step(i, ys):
        ysi = rk4_step(ys[i - 1], i, dt, f)
        return ys.at[i].set(ysi)
    return jax.lax.fori_loop(1, N, step, ys)

In [5]:

key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)
    
# Initial condition: th1, w1, th2, w2
X0 = jnp.radians(jnp.array([-8.0, 7.0, 27.0]))
mu0, cov0 = X0, jnp.eye(dim_state) * 0.1**2
x0 = jax.random.multivariate_normal(key_init, mean=mu0, cov=cov0)

def fcoord(x, t):
    x_dot = sigma * (x[..., 1] - x[..., 0])
    y_dot = x[..., 0] * (rho - x[..., 2]) - x[..., 1]
    z_dot = x[..., 0] * x[..., 1] - beta * x[..., 2]

    dx = jnp.stack([x_dot, y_dot, z_dot], axis=-1)
    return dx

def f(x, t, D, *args):
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = fcoord(x, t) + err
    return xdot

def fC(x, t, D, *args):
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    key_a, key_b = jax.random.split(keyt)
    dx_mod = fcoord(x, t) + err
    a = jax.random.uniform(key_a, x.shape, minval=-ALPHA, maxval=ALPHA)
    b = jax.random.uniform(key_b, x.shape, minval=-BETA, maxval=BETA)
    
    # 应用调整，注意这里的 dx_mod 是不可变的操作
    dx_mod = dx_mod.at[0].set(dx_mod[0] - a[0] * x[1] + b[0])
    return dx_mod

H = jnp.array([[1, 0, 0],
               [0, 1, 0],
               [0, 0, 1],
                ])
def hcoord(x, t):
    return H @ x

def h(x, t, D, *args):
    keyt = jax.random.fold_in(key_measurement, t)
    err = jax.random.normal(keyt, shape=(D,))
    ydot = hcoord(x, t) + err
    return ydot

# @partial(jax.jit, static_argnames=("h",))
def h_step(ys, dt, N, h):
    yss = jnp.zeros((N, dim_obs))
    # @jax.jit
    def step(i, carry):
        yss, ys = carry
        ysi = h(ys[i - 1], dt * i)
        yss = yss.at[i].set(ysi)
        return (yss, ys)

    yss, ys = jax.lax.fori_loop(1, N, step, (yss, ys))
    return yss


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [6]:

dt = 0.01
p_err = 0.01
xs_mc = []
ys_mc = []
ys_corrupted_mc = []
key_obs, key_mc = jax.random.split(key_obs)
for i in range(M):
    key_mc, subkey = jax.random.split(key_mc)  # 生成新的子随机数生成器
    x0 = jax.random.multivariate_normal(key_init, mean=mu0, cov=cov0)
    xs1 = jnp.zeros((N,) + x0.shape)
    xs1 = xs1.at[0].set(x0)
    if f_inacc:
        fpart = partial(fC, D=dim_state)
    else:
        fpart = partial(f, D=dim_state)
    xs1 = rk4(xs1, dt, N, fpart)
    hpart = partial(h, D=dim_obs)
    ys1 = h_step(xs1, dt, N, hpart)   
    ys1_corrupted = ys1.copy()
    if corrupted:
        errs_map = jax.random.bernoulli(key_init, p=p_err, shape=ys1_corrupted.shape)
        ys1_corrupted = ys1_corrupted * (~errs_map) + 100.0 * errs_map
        err_where = np.where(errs_map)
    xs_mc.append(xs1)
    ys_mc.append(ys1)
    ys_corrupted_mc.append(ys1_corrupted)
statev = jnp.array(xs_mc)
yv = jnp.array(ys_mc)
yv_corrupted = jnp.array(ys_corrupted_mc)
print(statev.shape)
print(yv.shape)
print(yv_corrupted.shape)

(100, 200, 3)
(100, 200, 3)
(100, 200, 3)


In [7]:
ys = yv.reshape(-1, dim_obs)
ys_corrupted = yv_corrupted.reshape(-1, dim_obs)
xs = statev.reshape(-1, dim_state)
print(xs.shape,ys.shape)

(20000, 3) (20000, 3)


In [8]:

# 生成时间序列
start_time = datetime(2024, 7, 2, 13, 0, 0)
time_interval = timedelta(seconds=1)
time_series = [start_time + i * time_interval for i in range(xs.shape[0])]

# 创建 DataFrame
column_names = [f'x_{i}' for i in range(1, xs.shape[1] + 1)]
df = pd.DataFrame(xs, columns=column_names)
df.insert(0, 'date', time_series)

print(df.head())

                 date       x_1       x_2       x_3
0 2024-07-02 13:00:00 -0.171118  0.106694  0.384818
1 2024-07-02 13:00:01 -0.140380  0.078005  0.397646
2 2024-07-02 13:00:02 -0.114898  0.057307  0.410193
3 2024-07-02 13:00:03 -0.093480  0.043242  0.422442
4 2024-07-02 13:00:04 -0.075161  0.034759  0.434389


In [9]:
if save_flag:
    # # 保存为 CSV 文件
    csv_path = save_dir + model_name + '.csv'
    df.to_csv(csv_path, index=False)
    # 创建 DataFrame
    column_names = [f'y_{i}' for i in range(1, ys.shape[1] + 1)]
    dfy = pd.DataFrame(ys, columns=column_names)
    dfy.insert(0, 'date', time_series)

    print(dfy.head())
    csv_path = save_dir + model_name + '_obs.csv'
    dfy.to_csv(csv_path, index=False)

    if corrupted:
        column_names = [f'y_{i}' for i in range(1, ys_corrupted.shape[1] + 1)]
        dfy = pd.DataFrame(ys_corrupted, columns=column_names)
        dfy.insert(0, 'date', time_series)

        print(dfy.head())
        csv_path = save_dir + model_name + '_obs_corrupted.csv'
        dfy.to_csv(csv_path, index=False)

In [10]:
range_time = np.arange(N) * dt
np.set_printoptions(precision=4)

def callback_fn(particles, particles_pred, y, i):
    return jnp.sqrt(jnp.power(particles.mean(axis=0) - xs[i], 2).mean()), particles.mean(axis=0)

def latent_fn(x, key, i):
    """
    State function
    """
    err = jax.random.normal(key, (dim_state,))
    @jax.jit
    def f(x, t):
        return fcoord(x, t)  + err
    
    return rk4_step(x, i, dt, f)

def ff(x):
    return fcoord(x, dt)

def hh(x, t):
    return H @ x

def obs_fn(x, key, i):
    """
    Measurement function
    """
    err = jax.random.normal(key, (dim_obs,))
    return hh(x, dt) + err

def calculate_mse(errs):
    return jnp.mean(jnp.square(errs))
def calculate_mae(errs):
    return jnp.mean(jnp.abs(errs))

def calculate_rmse(errs):
    return jnp.sqrt(jnp.mean(jnp.square(errs)))

def calculate_test_error(errs):
    mse = calculate_mse(errs)
    mae = calculate_mae(errs)
    rmse = calculate_rmse(errs) 
    print(f'mse:{mse}, mae:{mae}, rmse:{rmse}')
    return mse, mae, rmse


In [11]:
from filtering_methods import run_filtering
print('System', model_name)
print('corrupted', corrupted)
print('f_inacc', f_inacc)
# methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS']
methods = ['PF']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
parameter_range = (10,30)
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, parameter_range, num_particle=1000)

System lorenz63
corrupted False
f_inacc False
PF


100%|██████████| 20/20 [06:04<00:00, 18.20s/it]


Done
RMSE
{'PF': Array(nan, dtype=float32)}
Time
{'PF': np.float64(91.0041116476059)}
