In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import torch
from src.data_assimilation.enkf.enkf import EnKFPO
from src.data_assimilation.enkf.utils.enkf_config import EnKFPOConfig
from src.lorenz63_model.lorenz63_model import Lorenz63
from src.lorenz63_model.utils.lorenz63_config import Lorenz63Config
from src.utils.random_seed_helper import set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic

# Define constant

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

fig_dir = f"{ROOT_DIR}/docs/data_assimilation/fig"
os.makedirs(fig_dir, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu").type

# Define method

In [None]:
def calc_rmse(*, gt: torch.Tensor, pred: torch.Tensor,) -> torch.Tensor:
    squared_errors = torch.pow(gt - pred, 2)

    mse = torch.mean(squared_errors / gt.numel())

    rmse = torch.sqrt(mse)

    return rmse

# Prepare for simulation

In [None]:
set_seeds(seed=42, use_deterministic=True)

In [None]:
n_batch = 42

cfg_hr_lorenz = Lorenz63Config(
    n_batch=n_batch,
    noise_amplitude=1.0,
    device=DEVICE,
    precision="double"
)

cfg_uhr_lorenz = Lorenz63Config(
    n_batch=1,
    noise_amplitude=1.0,
    device=DEVICE,
    precision="double"
)

cfg_da = EnKFPOConfig(
    n_ensemble=n_batch,
    inflation_coefficient=1.2,
    obs_std=1.0,
    obs_matrix=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
    device=DEVICE,
    precision="double"
)

In [None]:
hr_model = Lorenz63(cfg_hr_lorenz, show_input_cfg_info=False)
uhr_model = Lorenz63(cfg_uhr_lorenz, show_input_cfg_info=False)
assimilator = EnKFPO(cfg_da, show_input_cfg_info=False)

# Make ground truth data

In [None]:
X0 = torch.tensor([11.2, 10.2, 33.2], dtype=uhr_model.real_dtype).to(uhr_model.device)
uhr_model.initialize(X0)

Xgt, tgt = [uhr_model.get_state()], [uhr_model.t]

uhr_dt = 0.001
output_uhr_dt = 0.001
end_time = 30

output_uhr_tsteps = torch.arange(output_uhr_dt, end_time + output_uhr_dt, output_uhr_dt)

for t in tqdm(output_uhr_tsteps):
    uhr_model.integrate_n_steps(dt_per_step=uhr_dt, n_steps=int(output_uhr_dt / uhr_dt))
    Xgt.append(uhr_model.get_state())
    tgt.append(uhr_model.t)

# Stack arrays along time dim
Xgt = torch.stack(Xgt, dim=1).squeeze()

# shape = (batch, time, (x, y, z))
logger.info(f"Shape of the result: {Xgt.shape}")

# データ同化

In [None]:
X0 = torch.tensor([11.2, 10.2, 33.2], dtype=hr_model.real_dtype).to(hr_model.device)
X0 = X0 + torch.randn(X0.shape).to(hr_model.device)

hr_model.initialize(X0)

Xa, ta = [hr_model.get_state()], [hr_model.t]
all_obs = []

hr_dt = 0.01
output_hr_dt = 0.01
end_time = 30

scale_factor = int(hr_dt // uhr_dt)
Xtrue = Xgt[::scale_factor]

output_hr_tsteps = torch.arange(output_hr_dt, end_time + output_hr_dt, output_hr_dt)

for i, t in tqdm(enumerate(output_hr_tsteps, start=1)):
    hr_model.integrate_n_steps(dt_per_step=hr_dt, n_steps=int(output_hr_dt / hr_dt))
    if i % 10 == 0:
        Xa_t, obs = assimilator.apply(Xf=hr_model.get_state(), Xtrue=Xtrue[i])
        Xa.append(Xa_t)
        all_obs.append(obs)
        hr_model.set_state(Xa_t)
    else:
        Xa.append(hr_model.get_state())
    ta.append(hr_model.t)

# Stack arrays along time dim
Xa = torch.stack(Xa, dim=1).squeeze()
all_obs = torch.stack(all_obs, dim=1).squeeze()

# shape = (batch, time, (x, y, z))
logger.info(f"Shape of the result: {Xa.shape}")

# Plot result

In [None]:
# アンサンブル予測
Xa_mean = torch.mean(Xa, dim=0)

#
obs_mean = torch.mean(all_obs, dim=0)

# rmse
rmse = calc_rmse(gt=Xtrue, pred=Xa_mean)

In [None]:
plt.rcParams["font.size"] = 24

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(15, 20))
for (i, ax), ylabel in zip(enumerate(axes), ["x", "y", "z"]):
    ax.plot(ta, Xtrue[:, i], label="ground truth")
    ax.plot(ta, Xa_mean[:, i], label="prediction")
    ax.plot(ta[1::scale_factor], obs_mean[:, i], "*", label="observation")
    ax.set_xlabel(r"$t$")
    ax.set_ylabel(rf"${ylabel}$")
    ax.legend(loc=3)

# fig.savefig(f"{fig_dir}/xyz_trajectory_with_da_plot.png")
plt.show()

In [None]:
plt.rcParams["font.size"] = 18

fig, ax = plt.subplots(figsize=(5, 4), facecolor="white")

bar = ax.bar(["EnKF-PO"], [rmse])

for b in bar:
    height = b.get_height()
    ax.text(b.get_x() + b.get_width() / 2, height, round(height, 4),
            ha='center', va='bottom')

ax.set_ylabel("RMSE")
ax.set_ylim(0, 0.03)

# fig.savefig(f"{fig_dir}/rmse_enkf_po_method.png")

plt.show()