In [20]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
from tqdm import tqdm
from bayes_opt import BayesianOptimization
from rebayes_mini.methods import robust_filter as rkf
from rebayes_mini import callbacks
import os
# To & from time/Fourier domain -- use reals-only fft
def fft(u): return jnp. fft. rfft(u, axis=-1)  # F
def ifft(v): return jnp.fft.irfft(v, axis=-1)  # F^{-1}

In [21]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
model_name = 'KSmc'
save_flag = True
save_dir ='../dataset/'+ model_name +'/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [23]:
def rk4_step(y, i, dt, f):
    h = dt
    t = dt * i
    k1 = h * f(y, t)
    k2 = h * f(y + k1 / 2, dt * i + h / 2)
    k3 = h * f(y + k2 / 2, t + h / 2)
    k4 = h * f(y + k3, t + h)

    y_next = y + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return y_next

@partial(jax.jit, static_argnames=("f",))
def rk4(ys, dt, N, f):
    """
    Based on
    https://colab.research.google.com
    github/google/jax/blob/master/cloud_tpu_colabs/Lorentz_ODE_Solver
    """
    @jax.jit
    def step(i, ys):
        ysi = rk4_step(ys[i - 1], i, dt, f)
        return ys.at[i].set(ysi)
    return jax.lax.fori_loop(1, N, step, ys)

In [24]:
D = 72
DL = 32
N = 200
n_mc = 100
dt = 0.01
dim_state = D
dim_obs = D

kk = jnp.arange(0, dim_state) * DL
DD = 1j * kk                   # Differentiation to compute: F[ u_x ]
L = kk**2 - kk**4             # Linear operator for KS eqn: F[ - u_xx - u_xxxx]
# Precompute ETDRK4 scalar quantities
E  = jnp.exp(dt * L)           # Integrating factor, eval at dt
E2 = jnp.exp(dt * L / 2)       # Integrating factor, eval at dt/2
# Roots of unity are used to discretize a circular contour...
nRoots = 16
roots = jnp.exp(1j * jnp.pi * (0.5 + jnp.arange(nRoots)) / nRoots)
# ... the associated integral then reduces to the mean,
# g(CL).mean(axis=-1) ~= g(L), whose computation is more stable.
CL = dt * L[:, None] + roots  # Contour for (each element of) L
# E * exact_integral of integrating factor:
Q  = dt * ((jnp.exp(CL / 2) - 1) / CL).mean(axis=-1).real
# RK4 coefficients (modified by Cox-Matthews):
f1 = dt * ((-4 - CL + jnp.exp(CL) * (4 - 3 * CL + CL**2)) / CL**3).mean(axis=-1).real
f2 = dt * ((2 + CL + jnp.exp(CL) * (-2 + CL)) / CL**3).mean(axis=-1).real
f3 = dt * ((-4 - 3 * CL - CL**2 + jnp.exp(CL) * (4 - CL)) / CL**3).mean(axis=-1).real

# NonLinear term (-u*u_x) in Fourier domain via time domain
def NL(v):
    return -0.5 * DD * fft(ifft(v).real ** 2)


key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)
    

def fcoord(x, t):
    # N1 = NL(x)
    # v1 = E2 * x + Q * N1
    # N2a = NL(v1)
    # v2a = E2 * x + Q * N2a
    # N2b = NL(v2a)
    # v2b = E2 * v1 + Q * (2 * N2b - N1)
    # N3 = NL(v2b)
    # jax.debug.print('{}-{}-{}-{}', N1, N2a, N2b, N3)
    # states_next = E * x + N1 * f1 + 2 * (N2a + N2b) * f2 + N3 * f3
    # states_next = states_next.real
    # jax.debug.print('{}', states_next)
    # return states_next
    return NL(x) + L*x

def f(x, t, D, *args):
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = fcoord(x, t) + err
    return xdot




In [25]:
p_err = 0.01
corrupted = True
xs_mc = []
ys_mc = []
ys_corrupted_mc = []
key_obs, key_mc = jax.random.split(key_obs)
for i in range(n_mc):
    key_mc, subkey = jax.random.split(key_mc)  # 生成新的子随机数生成器
    x0 = jax.random.normal(key_mc, (D,))
    xs1 = jnp.zeros((N,) + x0.shape)
    xs1 = xs1.at[0].set(x0)
    fpart = partial(f, D=dim_state)
    xs1 = rk4(xs1, dt, N, fpart)
    ys1 = xs1 + jax.random.normal(key_measurement, xs1.shape)  
    ys1_corrupted = ys1.copy()
    if corrupted:
        errs_map = jax.random.bernoulli(key_init, p=p_err, shape=ys1_corrupted.shape)
        ys1_corrupted = ys1_corrupted * (~errs_map) + 100.0 * errs_map
        err_where = np.where(errs_map)
    xs_mc.append(xs1)
    ys_mc.append(ys1)
    ys_corrupted_mc.append(ys1_corrupted)
statev = jnp.array(xs_mc)
yv = jnp.array(ys_mc)
yv_corrupted = jnp.array(ys_corrupted_mc)
print(statev.shape, statev[0][1][-5:])
print(yv.shape)
print(yv_corrupted.shape)

  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_type)
  return lax_internal._convert_element_type(out, dtype, weak_t

(100, 200, 72) [nan nan nan nan nan]
(100, 200, 72)
(100, 200, 72)


In [26]:
ys = yv.reshape(-1, dim_obs)
ys_corrupted = yv_corrupted.reshape(-1, dim_obs)
xs = statev.reshape(-1, dim_state)
print(xs.shape,ys.shape)

(20000, 72) (20000, 72)


In [14]:
# 生成时间序列
start_time = datetime(2024, 7, 2, 13, 0, 0)
time_interval = timedelta(seconds=1)
time_series = [start_time + i * time_interval for i in range(xs.shape[0])]

# 创建 DataFrame
column_names = [f'x_{i}' for i in range(1, xs.shape[1] + 1)]
df = pd.DataFrame(xs, columns=column_names)
df.insert(0, 'date', time_series)

print(df.head())

                 date       x_1       x_2       x_3       x_4       x_5  \
0 2024-07-02 13:00:00 -0.770176 -1.762404  1.182988 -0.440464 -0.250118   
1 2024-07-02 13:00:01       NaN       NaN       NaN       NaN       NaN   
2 2024-07-02 13:00:02       NaN       NaN       NaN       NaN       NaN   
3 2024-07-02 13:00:03       NaN       NaN       NaN       NaN       NaN   
4 2024-07-02 13:00:04       NaN       NaN       NaN       NaN       NaN   

        x_6      x_7       x_8       x_9  ...      x_23      x_24      x_25  \
0  0.875375  3.20674 -0.712694  0.565648  ...  0.397486 -1.181172  0.197003   
1       NaN      NaN       NaN       NaN  ...       NaN       NaN       NaN   
2       NaN      NaN       NaN       NaN  ...       NaN       NaN       NaN   
3       NaN      NaN       NaN       NaN  ...       NaN       NaN       NaN   
4       NaN      NaN       NaN       NaN  ...       NaN       NaN       NaN   

       x_26      x_27      x_28      x_29      x_30      x_31      x_32  


In [15]:
# # # 保存为 CSV 文件
# csv_path = save_dir + 'KS.csv'
# df.to_csv(csv_path, index=False)
# # 创建 DataFrame
# column_names = [f'y_{i}' for i in range(1, ys.shape[1] + 1)]
# dfy = pd.DataFrame(ys, columns=column_names)
# dfy.insert(0, 'date', time_series)

# print(dfy.head())
# csv_path = save_dir + 'KS_obs.csv'
# dfy.to_csv(csv_path, index=False)

# column_names = [f'y_{i}' for i in range(1, ys_corrupted.shape[1] + 1)]
# dfy = pd.DataFrame(ys_corrupted, columns=column_names)
# dfy.insert(0, 'date', time_series)

# print(dfy.head())
# csv_path = save_dir + 'KS_obs_corrupted.csv'
# dfy.to_csv(csv_path, index=False)

In [16]:
range_time = np.arange(N) * dt
np.set_printoptions(precision=4)

def callback_fn(particles, particles_pred, y, i):
    return jnp.sqrt(jnp.power(particles.mean(axis=0) - xs[i], 2).mean()), particles.mean(axis=0)

def latent_fn(x, key, i):
    """
    State function
    """
    err = jax.random.normal(key, (dim_state,))
    @jax.jit
    def f(x, t):
        return fcoord(x, t)  + err
    
    return rk4_step(x, i, dt, f)

def ff(x):
    return fcoord(x, dt)

def hh(x, t):
    return H @ x

def obs_fn(x, key, i):
    """
    Measurement function
    """
    err = jax.random.normal(key, (dim_obs,))
    return hh(x, dt) + err

def calculate_mse(errs):
    return jnp.mean(jnp.square(errs))
def calculate_mae(errs):
    return jnp.mean(jnp.abs(errs))

def calculate_rmse(errs):
    return jnp.sqrt(jnp.mean(jnp.square(errs)))

def calculate_test_error(errs):
    mse = calculate_mse(errs)
    mae = calculate_mae(errs)
    rmse = calculate_rmse(errs) 
    print(f'mse:{mse}, mae:{mae}, rmse:{rmse}')
    return mse, mae, rmse


In [17]:
# y_index = int(0.8 * ys.shape[0])
# print(ys.shape)
# # 截取最后 20% 的数据
# ys = ys[y_index:]
# ys_corrupted = ys_corrupted[y_index:]
# print(ys.shape)
time_methods = {}
hist_methods = {}
errs_methods = {}
configs_methods = {}

In [18]:
def filter_enkf(x0, n_particles, measurements, state, key_eval):
    agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
    key_init_particles, key_scan = jax.random.split(key_eval, 2)
    X0 = agent.init_bel(key_init_particles, dim_state)
    particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, measurements, callback_fn=callback_fn)
    return errs, particles_hist_mean

def filter_enkfi(x0, inflation_factor, n_particles, measurements, state, key_eval):
    agent = enkf.EnsembleKalmanFilterInflation(latent_fn, obs_fn, n_particles, inflation_factor=inflation_factor)
    key_init_particles, key_scan = jax.random.split(key_eval, 2)
    X0 = agent.init_bel(key_init_particles, dim_state)
    particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, measurements, callback_fn=callback_fn)
    return errs, particles_hist_mean

def filter_ekf(x0, measurements, state, key_eval):
    nsteps = len(measurements)
    agent_imq = rkf.ExtendedKalmanFilterIMQ(
        ff, hh,
        dynamics_covariance=jnp.eye(dim_state),
        observation_covariance=jnp.eye(dim_obs),
        soft_threshold=1e8,
    )
    init_bel = agent_imq.init_bel(x0, cov=1.0)
    filterfn = partial(agent_imq.scan, callback_fn=callbacks.get_updated_mean)
    _, hist = filterfn(init_bel, measurements, jnp.ones(N))

    err = jnp.sqrt(jnp.power(hist - state, 2).sum(axis=0))
    return err, hist

@jax.jit
def filter_ekf(x0, measurements, state, key_eval):
    nsteps = len(measurements)
    agent_imq = rkf.ExtendedKalmanFilterIMQ(
        ff, hh,
        dynamics_covariance=jnp.eye(dim_state),
        observation_covariance=jnp.eye(dim_obs),
        soft_threshold=1e8,
    )
    init_bel = agent_imq.init_bel(x0, cov=1.0)
    filterfn = partial(agent_imq.scan, callback_fn=callbacks.get_updated_mean)
    _, hist = filterfn(init_bel, measurements, jnp.ones(N))

    err = jnp.sqrt(jnp.power(hist - state, 2).sum(axis=0))
    return err, hist

@jax.jit
def filter_wlfmd(x0, threshold, measurements, state, key_eval):
    agent = rkf.ExtendedKalmanFilterMD(
        ff, hh,
        dynamics_covariance=jnp.eye(dim_state),
        observation_covariance=jnp.eye(dim_obs),
        threshold=threshold
    )
    
    init_bel = agent.init_bel(x0, cov=1e-8)
    
    _, hist = agent.scan(
        init_bel, measurements, jnp.ones(N), callback_fn=callbacks.get_updated_mean
    )

    err = jnp.sqrt(jnp.power(hist - state, 2).sum(axis=0))
    return err, hist

In [19]:
method = 'EKF'
hist_bel = []
times = []
errs = 0
for y, state in tqdm(zip(yv, statev), total=n_mc): 
    tinit = time()
    key_eval, key_scan = jax.random.split(key_eval)
    errs, _ = filter_ekf(state[0], y, state, key_eval)
    tend = time()
    
    hist_bel.append(errs)
    times.append(tend - tinit)

errs = np.stack(hist_bel)
time_methods[method] = np.sum(times)
print(method)
errs_test = errs[int(errs.shape[0] * 0.8):]
_, _, errs_methods[method] = calculate_test_error(errs_test)

  0%|          | 0/1 [00:00<?, ?it/s]


TypeError: jacrev requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex64. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

In [None]:
@jax.jit
def bo_filter_wlfmd(threshold):
    err, _ = filter_wlfmd(statev[1][0], threshold, yv[1], statev[1], key_eval)
    return -err.max()

bo = BayesianOptimization(
    bo_filter_wlfmd,
    pbounds={
        "threshold": (1e-6, 50)
    },
    random_state=314,
    verbose=1
)
bo.maximize(init_points=10, n_iter=20)
print(bo.max)

|   iter    |  target   | threshold |
-------------------------------------


TypeError: jacrev requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex64. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

In [None]:
method = 'WLFMD'
threshold = bo.max["params"]["threshold"]
configs_methods[method] = bo.max["params"]
hist_bel = []
times = []
errs = 0
for y, state in tqdm(zip(yv, statev), total=n_mc): 
    tinit = time()
    key_eval, key_scan = jax.random.split(key_eval)
    errs, _ = filter_wlfmd(state[0], threshold, y, state, key_eval)
    tend = time()
    
    hist_bel.append(errs)
    times.append(tend - tinit)

errs = np.stack(hist_bel)
time_methods[method] = np.sum(times)
print(method)
errs_test = errs[int(errs.shape[0] * 0.8):]
_, _, errs_methods[method] = calculate_test_error(errs_test)

  0%|          | 0/128 [00:00<?, ?it/s]

100%|██████████| 128/128 [00:03<00:00, 39.66it/s]

WLFMD
mse:1324.28369140625, mae:33.05164337158203, rmse:36.390708923339844





In [None]:
print('rmse', errs_methods)
print('time', time_methods)

rmse {'EKF': Array(36.3907, dtype=float32), 'WLFMD': Array(36.3907, dtype=float32)}
time {'EKF': 3.208562135696411, 'WLFMD': 3.20003080368042}


In [None]:
# method = 'EnKF_1000'
# n_particles = 1000
# hist_bel = []
# times = []

# for y, state in tqdm(zip(yv, statev), total=n_mc): 
#     print(y.shape)
#     print(state.shape)
#     tinit = time()
#     key_eval, key_scan = jax.random.split(key_eval)
#     # print(xs[0])
#     errs, _ = filter_enkf(state[0], n_particles, y, state, key_eval)
#     tend = time()
    
#     hist_bel.append(errs)
#     times.append(tend - tinit)

# errs = np.stack(hist_bel)
# time_methods[method] = times
# print(method)
# errs_test = errs[int(errs.shape[0] * 0.8):]
# _, _, errs_methods[method] = calculate_test_error(errs_test)


In [None]:
# method = 'EnKFI_1000'
# n_particles = 1000
# inflation_factor = 3.0
# hist_bel = []
# times = []

# for y, state in tqdm(zip(yv, statev), total=n_mc): 
#     tinit = time()
#     key_eval, key_scan = jax.random.split(key_eval)
#     errs, _ = filter_enkfi(state[0], inflation_factor, n_particles, y, state, key_eval)
#     tend = time()
    
#     hist_bel.append(errs)
#     times.append(tend - tinit)

# errs = np.stack(hist_bel)
# time_methods[method] = times
# print(method)
# errs_test = errs[int(errs.shape[0] * 0.8):]
# _, _, errs_methods[method] = calculate_test_error(errs_test)

In [None]:
print(errs.shape)
print(errs_test.shape)

(128, 4)
(26, 4)


OSError: model/Qwen2-7B-Instruct is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

{'EKF': Array(36.3907, dtype=float32), 'WLFMD': Array(36.3907, dtype=float32)}


In [None]:
# # 将 JAX 数组转换为 Python 的 float 类型
# def convert_jax_to_python(dictionary):
#     return {key: float(value) for key, value in dictionary.items()}

# # 转换字典
# converted_errs_methods = {method: convert_jax_to_python(metrics) for method, metrics in errs_methods.items()}

# for keys, values in converted_errs_methods.items():
#     print(keys, values)
