In [1]:
from time import time
import jax
import numpy as np
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from datetime import datetime, timedelta
import ensemble_kalman_filter as enkf
import os
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
plt.style.use("default")
plt.rcParams["font.size"] = 18

In [2]:
model_name = 'VL20'
save_flag = True
save_dir ='../dataset/'+ model_name +'/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [3]:
def rk4_step(y, i, dt, f):
    h = dt
    t = dt * i
    k1 = h * f(y, t)
    k2 = h * f(y + k1 / 2, dt * i + h / 2)
    k3 = h * f(y + k2 / 2, t + h / 2)
    k4 = h * f(y + k3, t + h)

    y_next = y + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return y_next

    

@partial(jax.jit, static_argnames=("f",))
def rk4(ys, dt, N, f):
    """
    Based on
    https://colab.research.google.com
    github/google/jax/blob/master/cloud_tpu_colabs/Lorentz_ODE_Solver
    """
    @jax.jit
    def step(i, ys):
        ysi = rk4_step(ys[i - 1], i, dt, f)
        return ys.at[i].set(ysi)
    return jax.lax.fori_loop(1, N, step, ys)

In [4]:
F = 10
G = 0
alpha = 1
gamma = 1
N = 20000
dt = 0.01
nX = 36
dim_state = 2 * nX
dim_obs = dim_state
model_name = 'VL20'
key = jax.random.PRNGKey(31415)
key_init, key_sim, key_eval, key_obs = jax.random.split(key, 4)
key_state, key_measurement = jax.random.split(key_sim)

def shift(x, n):
    return jnp.roll(x, -n, axis=-1)

def fcoord(x, t):
    X = x[:nX]
    theta = x[nX:]
    d = jnp.zeros_like(x)
    d.at[:nX].set((shift(X, 1)-shift(X, -2))*shift(X, -1))
    d.at[:nX].set(d[:nX] - gamma*shift(X, 0)) 
    d.at[:nX].set(d[:nX] - alpha*shift(theta, 0) + F)
    d.at[nX:].set(shift(X, 1)*shift(theta, 2) \
        - shift(X, -1)*shift(theta, -2))
    d.at[nX:].set(d[nX:] - gamma*shift(theta, 0))
    d.at[nX:].set(d[nX:] + alpha*shift(X, 0) + G) 
    return d

def f(x, t, D, *args):
    keyt = jax.random.fold_in(key_state, t)
    err = jax.random.normal(keyt, shape=(D,))
    xdot = fcoord(x, t) + F + err
    return xdot

x0 = jax.random.normal(key_init, (dim_state,)) + F
xs = jnp.zeros((N,) + x0.shape)
xs = xs.at[0].set(x0)
fpart = partial(f, D=dim_state)
xs = rk4(xs, dt, N, fpart)
ys = xs + jax.random.normal(key_measurement, xs.shape)
# hpart = partial(h, D=dim_state)
# ys = h_step(xs, dt, N, hpart)

print(xs.shape)
print(ys.shape)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


(20000, 72)
(20000, 72)


In [5]:
corrupted = True
ys_corrupted = ys.copy()
sigma = 5
value_corrupted = ys = xs + sigma * jax.random.normal(key_measurement, xs.shape)
if corrupted:
    p_err = 0.01
    errs_map = jax.random.bernoulli(key_init, p=p_err, shape=ys_corrupted.shape)
    ys_corrupted = ys_corrupted * (~errs_map) + value_corrupted * errs_map
    err_where = np.where(errs_map)
# print(ys_corrupted.shape)
# print(ys_corrupted[0, 0:3])
# print(errs_map)

In [6]:
# 生成时间序列
D = dim_state
start_time = datetime(2024, 6, 14, 13, 0, 0)
time_interval = timedelta(seconds=1)
time_series = [start_time + i * time_interval for i in range(xs.shape[0])]

# 创建 DataFrame
column_names = [f'x_{i}' for i in range(1, xs.shape[1] + 1)]
df = pd.DataFrame(xs, columns=column_names)
df.insert(0, 'date', time_series)

# print(df.head())

In [7]:
if save_flag:
    # # 保存为 CSV 文件
    csv_path = save_dir + model_name + '.csv'
    df.to_csv(csv_path, index=False)
    # 创建 DataFrame
    column_names = [f'y_{i}' for i in range(1, ys.shape[1] + 1)]
    dfy = pd.DataFrame(ys, columns=column_names)
    dfy.insert(0, 'date', time_series)

    print(dfy.head())
    csv_path = save_dir + model_name + '_obs.csv'
    dfy.to_csv(csv_path, index=False)

    column_names = [f'y_{i}' for i in range(1, ys_corrupted.shape[1] + 1)]
    dfy = pd.DataFrame(ys_corrupted, columns=column_names)
    dfy.insert(0, 'date', time_series)

    print(dfy.head())
    csv_path = save_dir + model_name + '_obs_corrupted.csv'
    dfy.to_csv(csv_path, index=False)

                 date        y_1        y_2        y_3        y_4        y_5  \
0 2024-06-14 13:00:00  13.531202  12.347681  14.767914   0.894815   1.283921   
1 2024-06-14 13:00:01  17.423916  13.602861  11.208827  15.064735  13.979038   
2 2024-06-14 13:00:02   5.704052  11.692755  12.551895  15.514514   3.720289   
3 2024-06-14 13:00:03   4.390932   2.249312   8.807728   8.446161  17.210573   
4 2024-06-14 13:00:04  15.862354  26.063387   7.261002   9.383172   6.261335   

         y_6        y_7        y_8        y_9  ...       y_63       y_64  \
0   8.614499  11.102380   9.019107   9.132580  ...   7.562762   6.653304   
1  16.530704  11.876959   3.987224  12.963837  ...   2.782481   7.516932   
2  16.369982   9.797428  10.435306  12.086037  ...   7.882241   9.108672   
3   4.871646  12.198980  11.974627  10.558963  ...  11.322978  16.244171   
4  12.507937  15.747322   2.190125  14.218163  ...  13.422709  10.404652   

        y_65       y_66       y_67       y_68       y_69      

In [8]:
range_time = np.arange(N) * dt
np.set_printoptions(precision=4)
def callback_fn(particles, particles_pred, y, i):
    return jnp.sqrt(jnp.power(particles.mean(axis=0) - xs[i], 2).mean()), particles.mean(axis=0)
def latent_fn(x, key, i):
    """
    State function
    """
    err = jax.random.normal(key, (D,))
    @jax.jit
    def f(x, t):
        return fcoord(x, D) + F + err
    
    return rk4_step(x, i, dt, f)


def obs_fn(x, key, i):
    """
    Measurement function
    """
    err = jax.random.normal(key, (D,))
    return x + err

def calculate_mse(errs):
    return jnp.mean(jnp.square(errs))
def calculate_mae(errs):
    return jnp.mean(jnp.abs(errs))

def calculate_rmse(errs):
    return jnp.sqrt(jnp.mean(jnp.square(errs)))

def calculate_test_error(errs):
    mse = calculate_mse(errs)
    mae = calculate_mae(errs)
    rmse = calculate_rmse(errs) 
    print(f'mse:{mse}, mae:{mae}, rmse:{rmse}')
    return mse, mae, rmse

In [9]:
# y_index = int(0.8 * ys.shape[0])
# print(ys.shape)
# # 截取最后 20% 的数据
# ys = ys[y_index:]
# ys_corrupted = ys_corrupted[y_index:]
# print(ys.shape)

In [10]:
method = 'EnKF_20'
n_particles = 20
agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
key_init_particles, key_scan = jax.random.split(key_eval, 2)
X0 = agent.init_bel(key_init_particles, D)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
errs_index = int(0.8 * errs.shape[0])
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys_corrupted, callback_fn=callback_fn)
tend = time()
time_cost = (tend - tinit) / 0.2 / N * 1000
print('corrupted')
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)

mse:nan, mae:nan, rmse:nan
time cost (ms): 0.5698678493499756
corrupted
mse:nan, mae:nan, rmse:nan
time cost (ms): 0.13200724124908447


In [11]:
method = 'EnKF_1000'
n_particles = 1000
agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
key_init_particles, key_scan = jax.random.split(key_eval, 2)
X0 = agent.init_bel(key_init_particles, D)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys_corrupted, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
print('corrupted')
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)

mse:3.9054126739501953, mae:1.9699300527572632, rmse:1.976211667060852
time cost (ms): 0.446089506149292
corrupted
mse:2.726518154144287, mae:1.6457977294921875, rmse:1.6512172222137451
time cost (ms): 0.7171604633331299


In [12]:
method = 'EnKFI_20_3'
n_particles = 20
inflation_factor = 3.0
agent = enkf.EnsembleKalmanFilterInflation(latent_fn, obs_fn, n_particles, inflation_factor=inflation_factor)
# agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
key_init_particles, key_scan = jax.random.split(key_eval, 2)
X0 = agent.init_bel(key_init_particles, D)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys_corrupted, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
print('corrupted')
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)

mse:10.849430084228516, mae:3.275829315185547, rmse:3.2938473224639893
time cost (ms): 0.6817166805267334
corrupted
mse:0.5596656203269958, mae:0.7391964197158813, rmse:0.7481080293655396
time cost (ms): 0.6094875335693359


In [13]:
method = 'EnKFI_1000_3'
n_particles = 1000
inflation_factor = 3.0
agent = enkf.EnsembleKalmanFilterInflation(latent_fn, obs_fn, n_particles, inflation_factor=inflation_factor)
# agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
key_init_particles, key_scan = jax.random.split(key_eval, 2)
X0 = agent.init_bel(key_init_particles, D)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)
tinit = time()
particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys_corrupted, callback_fn=callback_fn)
tend = time()
time_cost = tend - tinit
print('corrupted')
calculate_test_error(errs[errs_index:])
print('time cost (ms):', time_cost)

mse:8.013690948486328, mae:2.820892810821533, rmse:2.8308463096618652
time cost (ms): 0.7259926795959473
corrupted
mse:0.3973667323589325, mae:0.6230413317680359, rmse:0.6303703188896179
time cost (ms): 0.7065515518188477


In [14]:
print(errs[:20])

[2.6445 1.1461 0.6749 0.6959 0.6041 0.6148 0.5518 0.5083 0.5671 0.7428
 0.7839 0.6182 0.593  0.663  0.6123 0.5986 0.5182 0.7085 0.567  0.5363]


In [15]:
# n_particles = 20
# inflation_factor = 3.0
# agent = enkf.EnsembleKalmanFilterInflation(latent_fn, obs_fn, n_particles, inflation_factor=inflation_factor)
# # agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
# key_init_particles, key_scan = jax.random.split(key_eval, 2)
# X0 = agent.init_bel(key_init_particles, D)
# tinit = time()
# particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys, callback_fn=callback_fn)
# tend = time()
# time_cost = tend - tinit
# calculate_test_error(errs)

In [16]:
# n_particles = 1000
# inflation_factor = 3.0
# agent = enkf.EnsembleKalmanFilterInflation(latent_fn, obs_fn, n_particles, inflation_factor=inflation_factor)
# # agent = enkf.EnsembleKalmanFilter(latent_fn, obs_fn, n_particles)
# key_init_particles, key_scan = jax.random.split(key_eval, 2)
# X0 = agent.init_bel(key_init_particles, D)
# tinit = time()
# particles_end, (errs, particles_hist_mean) = agent.scan(X0, key_scan, ys, callback_fn=callback_fn)
# tend = time()
# time_cost = tend - tinit
# calculate_test_error(errs)