In [2]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os

In [3]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

In [4]:
model_name = 'lorenzUV'
save_flag = True
corrupted = False
f_inacc = False
ALPHA=2.0
BETA=2.0
save_dir ='../dataset/'+ model_name +'/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [5]:
def rk4_step(y, i, dt, f):
    h = dt
    t = dt * i
    k1 = h * f(y, t)
    k2 = h * f(y + k1 / 2, dt * i + h / 2)
    k3 = h * f(y + k2 / 2, t + h / 2)
    k4 = h * f(y + k3, t + h)

    y_next = y + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return y_next

    

@partial(jax.jit, static_argnames=("f",))
def rk4(ys, dt, N, f):
    """
    Based on
    https://colab.research.google.com
    github/google/jax/blob/master/cloud_tpu_colabs/Lorentz_ODE_Solver
    """
    @jax.jit
    def step(i, ys):
        ysi = rk4_step(ys[i - 1], 1, dt, f)
        return ys.at[i].set(ysi)
    return jax.lax.fori_loop(1, N, step, ys)

$$
    \dot{\bf x}_k = \Big({\bf x}_{k+1} - {\bf x}_{k-2}\Big) {\bf x}_{j-1} - {\bf x}(j) + {F}_k
$$


In [None]:
nU=8
J=32
nV=J*nU
F=20
h=0.01
b=10
c=10
D = (J+1)*nU 
iiU = (np.arange(J*nU)/J).astype(int)
iiV = np.arange(J*nU)
# ixs = jnp.arange(D)

key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)

key_state, key_measurement = jax.random.split(key_sim)

@partial(jax.vmap, in_axes=(None, 0, None))
def du_dt(u,iu,nU):
    udot = (u[(iu + 1) % nU] - u[iu - 2]) * u[iu - 1] - u[iu]+F
    return udot

@partial(jax.vmap, in_axes=(None, 0, None))
def dv_dt(v,iv,nV):
    v=b*jnp.flip(v)
    vdot = (v[(iv + 1) % nV] - v[iv - 2]) * v[iv - 1] - v[iv]
    vdot*=c/b
    vdot=jnp.flip(vdot)
    return vdot

def f(x, t, D, *args):
    u,v=x[..., :nU],x[..., nU:]
    ius=jnp.arange(nU)
    udot=du_dt(u,ius,nU)
    udot+=-h*c/b * v.reshape(v.shape[:-1]+(nU, J)).sum(-1)
    ivs=jnp.arange(nV)
    vdot=dv_dt(v,ivs,nV)
    vdot+= h*c/b * u[..., iiU]
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = jnp.concatenate([udot,vdot]) + err
    return xdot

def fC(x, t, D, *args):
    u,v=x[..., :nU],x[..., nU:]
    ius=jnp.arange(nU)
    udot=du_dt(u,ius,nU)
    udot+=-h*c/b * v.reshape(v.shape[:-1]+(nU, J)).sum(-1)
    ivs=jnp.arange(nV)
    vdot=dv_dt(v,ivs,nV)
    vdot+= h*c/b * u[..., iiU]
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = jnp.concatenate([udot,vdot]) + err
    key_a, key_b = jax.random.split(keyt)
    a = jax.random.uniform(key_a, x.shape, minval=-ALPHA, maxval=ALPHA)
    b = jax.random.uniform(key_b, x.shape, minval=-BETA, maxval=BETA)
    # 应用调整，注意这里的 dx_mod 是不可变的操作
    xdot = xdot.at[0].set(xdot[0] - a[0] * x[1] + b[0])
    return xdot

    # ixs = jnp.arange(D)
    # xdot = fcoord(x, ixs, D) + F + err
    # return xdot

# def make_obs_matrix(dim, rank, random_seed):
#     """
#     Generates a random low rank matrix

#     Args:
#         dim (int): dimension of matrix
#         rank (int): rank of matrix
#         random_seed (int, optional): random seed of matrix. Defaults to None.

#     Returns:
#         numpy.ndarray: matrix of rank and dim
#     """
#     # if random_seed is not None:
#     #     np.random.seed(random_seed)
#     assert dim >= rank, "rank cannot be greater than dim."
#     u = np.random.randn(rank, dim, 1)
#     return (u @ u.transpose((0, 2, 1))).sum(axis=0) / rank

# H = make_obs_matrix(D, D, key_obs)

# def hcoord(x, t):
#     return H @ x

# def h(x, t, D, *args):
#     keyt = jax.random.fold_in(key_measurement, t)
#     err = jax.random.normal(keyt, shape=(D,))
#     ydot = hcoord(x, t) + err
#     return ydot

# @partial(jax.jit, static_argnames=("h",))
# def h_step(ys, dt, N, h):
#     @jax.jit
#     def step(i, ys):
#         ysi = h(ys[i - 1], dt * i)
#         return ys.at[i].set(ysi)
#     return jax.lax.fori_loop(1, N, step, ys)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [None]:
n_mc = 100
N = 200
dt = 0.01
p_err = 0.01
xs_mc = []
ys_mc = []
ys_corrupted_mc = []
key_obs, key_mc = jax.random.split(key_obs)
for i in range(n_mc):
    key_mc, subkey = jax.random.split(key_mc)  # 生成新的子随机数生成器
    x0 = np.eye(D)[0]
    u,v=x0[..., :nU],x0[..., nU:]
    xs1 = jnp.zeros((N,) + x0.shape)
    xs1 = xs1.at[0].set(x0)
    if f_inacc:
        fpart = partial(fC, D=D)
    else:
        fpart = partial(f, D=D)
    xs1 = rk4(xs1, dt, N, fpart)
    ys1 = xs1 + jax.random.normal(key_measurement, xs1.shape)
    ys1_corrupted = ys1.copy()
    if corrupted:
        errs_map = jax.random.bernoulli(key_init, p=p_err, shape=ys1_corrupted.shape)
        ys1_corrupted = ys1_corrupted * (~errs_map) + 100.0 * errs_map
        err_where = np.where(errs_map)
    xs_mc.append(xs1)
    ys_mc.append(ys1)
    ys_corrupted_mc.append(ys1_corrupted)
statev = jnp.array(xs_mc)
yv = jnp.array(ys_mc)
yv_corrupted = jnp.array(ys_corrupted_mc)
print(statev.shape)
print(yv.shape)
print(yv_corrupted.shape)

(100, 200, 264)
(100, 200, 264)
(100, 200, 264)


In [8]:
ys = yv.reshape(-1, D)
ys_corrupted = yv_corrupted.reshape(-1, D)
xs = statev.reshape(-1, D)
print(xs.shape,ys.shape)

(20000, 264) (20000, 264)


In [9]:
# 生成时间序列
start_time = datetime(2024, 6, 22, 13, 0, 0)
time_interval = timedelta(seconds=1)
time_series = [start_time + i * time_interval for i in range(xs.shape[0])]

# 创建 DataFrame
column_names = [f'x_{i}' for i in range(1, xs.shape[1] + 1)]
df = pd.DataFrame(xs, columns=column_names)
df.insert(0, 'date', time_series)

print(df.head())

                 date       x_1       x_2       x_3       x_4       x_5  \
0 2024-06-22 13:00:00 -0.508842 -0.210080  0.165664 -0.781829  0.151676   
1 2024-06-22 13:00:01  0.497520  2.089623  1.454283  0.150064  0.274227   
2 2024-06-22 13:00:02 -0.431483  0.207251  0.291909 -0.617165  1.073409   
3 2024-06-22 13:00:03  1.641699  1.697483  0.374305  0.683262 -0.283825   
4 2024-06-22 13:00:04  3.825752  3.042590  0.009009 -0.853737  0.726204   

        x_6       x_7       x_8       x_9  ...     x_255     x_256     x_257  \
0  0.282983 -0.211611  1.484356  0.351793  ...  1.475593  0.076225 -1.395732   
1  1.860394 -0.778327  1.001215 -1.619408  ... -0.984250  0.725734  0.748291   
2  1.405412 -0.633494  0.197610  0.832576  ... -0.200857  0.207517  1.013408   
3 -1.340437  1.048515  0.441048  0.796797  ... -0.668731 -0.583319 -0.445682   
4  1.359205  0.252431  0.579456 -1.066996  ...  0.745839 -0.612405  0.881337   

      x_258     x_259     x_260     x_261     x_262     x_263     x_

In [10]:
if save_flag:
    # # 保存为 CSV 文件
    csv_path = save_dir + model_name + '.csv'
    df.to_csv(csv_path, index=False)
    # 创建 DataFrame
    column_names = [f'y_{i}' for i in range(1, ys.shape[1] + 1)]
    dfy = pd.DataFrame(ys, columns=column_names)
    dfy.insert(0, 'date', time_series)

    print(dfy.head())
    csv_path = save_dir + model_name + '_obs.csv'
    dfy.to_csv(csv_path, index=False)

    if corrupted:
        column_names = [f'y_{i}' for i in range(1, ys_corrupted.shape[1] + 1)]
        dfy = pd.DataFrame(ys_corrupted, columns=column_names)
        dfy.insert(0, 'date', time_series)

        print(dfy.head())
        csv_path = save_dir + model_name + '_obs_corrupted.csv'
        dfy.to_csv(csv_path, index=False)

                 date       y_1       y_2       y_3       y_4       y_5  \
0 2024-06-22 13:00:00  0.298052  0.136882 -2.437877 -0.532799  2.394890   
1 2024-06-22 13:00:01  0.741885  2.168654  0.897008 -0.355968  0.097677   
2 2024-06-22 13:00:02 -0.060917  2.127455  0.123571 -1.701928  0.373233   
3 2024-06-22 13:00:03  0.137413  2.321258  0.567308  2.484061 -1.039487   
4 2024-06-22 13:00:04  4.114027  4.554846  0.547699 -0.989031  1.660715   

        y_6       y_7       y_8       y_9  ...     y_255     y_256     y_257  \
0  0.305981  0.591586  1.393537  0.514840  ...  2.847121  1.264062 -0.989506   
1  2.379422 -2.312613  1.100919 -0.947449  ... -0.077087  0.715438  1.846685   
2  1.640064 -1.177835 -1.130167  1.089051  ... -1.087193  1.008759  1.834841   
3 -1.841616  0.704126  1.980245  0.658309  ... -2.665695 -1.934356 -0.648119   
4  0.565213  0.241693  0.006319 -1.315167  ... -0.293478 -0.144189  0.515352   

      y_258     y_259     y_260     y_261     y_262     y_263     y_

In [11]:
range_time = np.arange(N) * dt
np.set_printoptions(precision=4)

def callback_fn(particles, particles_pred, y, i):
    return jnp.sqrt(jnp.power(particles.mean(axis=0) - xs[i], 2).mean()), particles.mean(axis=0)


def f_step(x):
    @jax.jit
    def f_x(x, t):
        return f(x)
    return rk4_step(x, 1, dt, f_x)

def h_step(x, _):
    return x

def obs_fn(x, key, i):
    """
    Measurement function
    """
    err = jax.random.normal(key, (D,))
    return x + err

def calculate_mse(errs):
    return jnp.mean(jnp.square(errs))
def calculate_mae(errs):
    return jnp.mean(jnp.abs(errs))

def calculate_rmse(errs):
    return jnp.sqrt(jnp.mean(jnp.square(errs)))

def calculate_test_error(errs):
    mse = calculate_mse(errs)
    mae = calculate_mae(errs)
    rmse = calculate_rmse(errs) 
    print(f'mse:{mse}, mae:{mae}, rmse:{rmse}')
    return mse, mae, rmse
M = n_mc
ff = f_step
hh = h_step

In [12]:
print(statev.shape)
print(yv.shape)
print(yv_corrupted.shape)

(100, 200, 264)
(100, 200, 264)
(100, 200, 264)


In [None]:
from filtering_methods import run_filtering
print('System', model_name)
print('corrupted', corrupted)
print('f_inacc', f_inacc)
methods = ['EKF', 'EnKF', 'EnKFI', 'HubEnKF', 'EnKFS', 'PF']
# methods = ['PF']
if corrupted:
    yv_test = yv_corrupted[int(M * 0.8):]
else:
    yv_test = yv[int(M * 0.8):]
statev_test = statev[int(M * 0.8):]
parameter_range = (1, 20)
_, _, _ = run_filtering(methods, yv_test, statev_test, key_eval, ff, hh, parameter_range, num_particle=1000)

System lorenzUV
corrupted False
f_inacc False
EKF


100%|██████████| 20/20 [19:02<00:00, 57.15s/it]


EnKF


100%|██████████| 20/20 [17:40<00:00, 53.00s/it]


EnKFI


100%|██████████| 20/20 [20:22<00:00, 61.14s/it]


HubEnKF
