In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import glob
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from cfd_model.cfd.periodic_channel_domain import TorchSpectralModel2D
from cfd_model.enkf.sr_enkf import (
    assimilate_with_existing_data,
    calc_localization_matrix,
    get_multivariate_normal_sampler,
)
from cfd_model.initialization.periodic_channel_jet_initializer import calc_jet_forcing
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from src.sr_da_helper import get_testdataset
from src.utils import read_pickle, set_seeds, write_pickle
from tqdm.notebook import tqdm

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
if DEVICE.startswith("cuda"):
    logger.info("GPU is used.")
else:
    logger.error("No GPU. CPU is used.")
    raise Exception("No GPU. CPU is used.")

In [None]:
ENKF_PREFERENCES = {
    4.0: {
        "N_ENS": 100,
        "SYS_NOISE_STD": 0.5,
        "SYS_NOISE_DX": 0.5,
        "LOCALIZE_DX": 0.3,
    },
    6.0: {
        "N_ENS": 100,
        "SYS_NOISE_STD": 0.5,
        "SYS_NOISE_DX": 0.5,
        "LOCALIZE_DX": 0.3,
    },
    8.0: {
        "N_ENS": 100,
        "SYS_NOISE_STD": 0.5,
        "SYS_NOISE_DX": 0.5,
        "LOCALIZE_DX": 0.3,
    },
    10.0: {
        "N_ENS": 150,
        "SYS_NOISE_STD": 0.7,
        "SYS_NOISE_DX": 0.4,
        "LOCALIZE_DX": 0.7,
    },
    12.0: {
        "N_ENS": 150,
        "SYS_NOISE_STD": 0.5,
        "SYS_NOISE_DX": 0.4,
        "LOCALIZE_DX": 0.7,
    },
}

In [None]:
OBS_GRID_INTEVAL = 12
CONFIG_NAME = f"lt4og{OBS_GRID_INTEVAL:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02"
CONFIG_PATH = f"{ROOT_DIR}/pytorch/config/prototype_experiment_14/{CONFIG_NAME}.yml"

In [None]:
with open(CONFIG_PATH) as file:
    CONFIG = yaml.safe_load(file)

TMP_DATA_DIR = "./data"
CSV_DATA_DIR = "./csv"
FIG_DIR = "./fig"

ENKF_DIR = f"{TMP_DATA_DIR}/EnKF/{CONFIG_NAME}"
os.makedirs(ENKF_DIR, exist_ok=True)

In [None]:
CFD_DIR_NAME = "jet11"

N_ENS = ENKF_PREFERENCES[OBS_GRID_INTEVAL]["N_ENS"]
SYS_NOISE_STD = ENKF_PREFERENCES[OBS_GRID_INTEVAL]["SYS_NOISE_STD"]
SYS_NOISE_DX = ENKF_PREFERENCES[OBS_GRID_INTEVAL]["SYS_NOISE_DX"]
LOCALIZE_DX = ENKF_PREFERENCES[OBS_GRID_INTEVAL]["LOCALIZE_DX"]

SYS_NOISE_DY = SYS_NOISE_DX
LOCALIZE_DY = LOCALIZE_DX
ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 16
SEED = 42

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0 = np.pi / 2.0
SIGMA = 0.4
TAU0 = 0.3

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "ne": N_ENS,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "n_ens": N_ENS,
    "assimilation_period": ASSIMILATION_PERIOD,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
    "y0": Y0,
    "sigma": SIGMA,
    "tau0": TAU0,
    "t0": T0,
}

In [None]:
logger.info(
    f"N_ENS = {N_ENS}, SYS_NOISE_STD = {SYS_NOISE_STD}, SYS_NOISE_DX = {SYS_NOISE_DX}, LOCALIZE_DX = {LOCALIZE_DX}"
)

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if "LR" in label:
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=HR_NX, ny=HR_NY, mode="nearest"
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "HR Ground Truth"
        else:
            mae = np.mean(np.abs(gt - d))
            ttl = label
            ttl = f"{label}\nMAE={mae:.2f}"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY)
        else:
            assert d.shape == (LR_NX, LR_NY)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=1, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg", dpi=300)
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.eps", dpi=300)

    plt.show()


def get_all_tmp_file_paths(config_name):
    sr_analysis_file_path = f"{TMP_DATA_DIR}/sr_analysis_{config_name}.npy"
    lr_omega_file_path = f"{TMP_DATA_DIR}/lr_omega_{config_name}.npy"
    hr_omega_file_path = f"{TMP_DATA_DIR}/hr_omega_{config_name}.npy"
    hr_obsrv_file_path = f"{TMP_DATA_DIR}/hr_obsrv_{config_name}.npy"

    return (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    )


def read_all_tmp_files(config_name):
    (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_tmp_file_paths(config_name)

    return (
        np.load(sr_analysis_file_path),
        np.load(lr_omega_file_path),
        np.load(hr_omega_file_path),
        np.load(hr_obsrv_file_path),
    )


def initialize_lr_model(
    *,
    t0: float,
    hr_omega0: torch.Tensor,
    lr_forcing: torch.Tensor,
    lr_model: TorchSpectralModel2D,
    n_ens: int,
    hr_nx: int,
    hr_ny: int,
    lr_nx: int,
    lr_ny: int,
    **kwargs,
):

    assert hr_omega0.shape == (hr_nx, hr_ny)
    omega0 = interpolate(hr_omega0[None, ...], lr_nx, lr_ny, "bicubic")
    omega0 = torch.broadcast_to(omega0, (n_ens, lr_nx, lr_ny))

    lr_model.initialize(t0=t0, omega0=omega0, forcing=lr_forcing)
    lr_model.calc_grid_data()


def make_lr_model(hr_omega0: torch.Tensor) -> TorchSpectralModel2D:
    logger.setLevel(WARNING)

    lr_model = TorchSpectralModel2D(**LR_CFD_CONFIG)
    _, lr_forcing = calc_jet_forcing(**LR_CFD_CONFIG)

    initialize_lr_model(
        hr_omega0=hr_omega0, lr_forcing=lr_forcing, lr_model=lr_model, **LR_CFD_CONFIG
    )

    logger.setLevel(INFO)

    return lr_model


def get_obs_matrix(obs: torch.Tensor) -> torch.Tensor:

    assert obs.shape == (HR_NX, HR_NY)
    is_obs = torch.where(torch.isnan(obs), torch.zeros_like(obs), torch.ones_like(obs))

    obs_indices = is_obs.reshape(-1)
    obs_indices = torch.where(obs_indices == 1.0)[0]

    num_obs = len(obs_indices)

    obs_matrix = torch.zeros(num_obs, HR_NX * HR_NY, dtype=torch.float64, device=DEVICE)

    for i, j in enumerate(obs_indices):
        obs_matrix[i, j] = 1.0

    p = 100 * torch.sum(obs_matrix).item() / (HR_NX * HR_NY)
    logger.debug(f"observatio prob = {p} [%]")

    return obs_matrix

# Perform enkf

In [None]:
(_, _, all_hr_omega, all_hr_obsrv) = read_all_tmp_files(CONFIG_NAME)

all_hr_omega = torch.from_numpy(all_hr_omega)
all_hr_obsrv = torch.from_numpy(all_hr_obsrv)

test_dataset = get_testdataset(ROOT_DIR, CONFIG)

In [None]:
torch_rand_generator = torch.Generator().manual_seed(SEED)

sys_noise_generator = get_multivariate_normal_sampler(
    nx=LR_NX, ny=LR_NY, sigma=SYS_NOISE_STD, d_x=SYS_NOISE_DX, d_y=SYS_NOISE_DY
)

localization_matrix = calc_localization_matrix(
    nx=HR_NX, ny=HR_NY, d_x=LOCALIZE_DX, d_y=LOCALIZE_DY
).to(DEVICE)

In [None]:
for i_ens in tqdm(range(all_hr_omega.shape[0])):
    output_lr_file_path = f"{ENKF_DIR}/ens_all_lr_{i_ens:04}.npy"
    output_hr_file_path = f"{ENKF_DIR}/ens_mean_hr_{i_ens:04}.pickle"
    if os.path.exists(output_lr_file_path):
        continue

    hr_omega = all_hr_omega[i_ens]
    hr_obs = all_hr_obsrv[i_ens]

    lr_model = make_lr_model(hr_omega[0])  # 0 means initial time

    lr_enkfs = []
    dict_hr_analysis = {}

    for i_cycle in tqdm(range(all_hr_omega.shape[1])):
        if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
            obs = hr_obs[i_cycle]
            obs_matrix = get_obs_matrix(obs)

            all_hr_analysis = assimilate_with_existing_data(
                hr_omega=hr_omega[i_cycle].to(torch.float64).to(DEVICE),
                lr_ens_model=lr_model,
                obs_matrix=obs_matrix,
                obs_noise_std=test_dataset.obs_noise_std,
                inflation=1.0,
                rand_generator=torch_rand_generator,
                localization_matrix=localization_matrix,
                return_hr_analysis=True,
            )
            # Mean over batch (ensemble dim)
            hr_analysis = torch.mean(all_hr_analysis, axis=0)
            assert hr_analysis.shape == (HR_NX, HR_NY)
            dict_hr_analysis[i_cycle] = hr_analysis.cpu().numpy()

        lr_enkfs.append(lr_model.omega.cpu().clone())

        if i_cycle % ASSIMILATION_PERIOD == 0:
            noise = sys_noise_generator.sample([N_ENS]).reshape(N_ENS, LR_NX, LR_NY)
            noise = noise - torch.mean(noise, dim=0, keepdims=True)
            omega = lr_model.omega + noise.to(DEVICE)
            lr_model.initialize(t0=lr_model.t, omega0=omega)
            lr_model.calc_grid_data()

        lr_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
        lr_model.calc_grid_data()

    write_pickle(dict_hr_analysis, output_hr_file_path)

    # Stack along time dim
    lr_enkfs = torch.stack(lr_enkfs, dim=1).to(torch.float32).numpy()
    np.save(output_lr_file_path, lr_enkfs)

# Analyze results

In [None]:
(
    all_sr_analysis,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_tmp_files(CONFIG_NAME)

In [None]:
all_enkfs = np.stack(
    [np.load(p) for p in sorted(glob.glob(f"{ENKF_DIR}/*.npy"))], axis=0
)
all_enkf_means = np.mean(all_enkfs, axis=1)
all_enkf_means.shape

## Time series of error

In [None]:
all_hr_omega = all_hr_omega[: all_enkf_means.shape[0]]
all_sr_analysis = all_sr_analysis[: all_enkf_means.shape[0]]

In [None]:
max_time_index = all_hr_omega.shape[1]
ts = [LR_DT * LR_NT * i_cycle + T0 for i_cycle in range(max_time_index)]

_lr = interpolate_time_series(torch.from_numpy(all_enkf_means), HR_NX, HR_NY).numpy()
assert _lr.shape == all_hr_omega.shape

# mean over x and y dims
_lr_err1 = np.mean(np.abs(_lr - all_hr_omega), axis=(-2, -1))
_lr_err2 = np.mean(np.abs(all_hr_omega), axis=(-2, -1))
# mean over ensemble dim
lr = np.mean(_lr_err1, axis=(0,))
lr_ratio = np.mean(_lr_err1 / _lr_err2, axis=(0,))

assert all_sr_analysis.shape == all_hr_omega.shape

_sr_err1 = np.mean(np.abs(all_sr_analysis - all_hr_omega), axis=(-2, -1))
_sr_err2 = np.mean(np.abs(all_hr_omega), axis=(-2, -1))
sr = np.mean(_sr_err1, axis=(0,))
sr_ratio = np.mean(_sr_err1 / _sr_err2, axis=(0,))

In [None]:
plt.plot(ts, lr, "o-", label="EnKF")
plt.plot(ts, sr, "o-", label="SRDA")
plt.legend()
plt.show()

## Vorticity evolution

In [None]:
i_ensemble = 43

dict_hr_enkf_analysis = read_pickle(f"{ENKF_DIR}/ens_mean_hr_{i_ensemble:04}.pickle")

for i_cycle in range(0, all_sr_analysis.shape[1], 4):
    if i_cycle == 0:
        continue

    t = (i_cycle + START_TIME_INDEX) * DT
    hr_analysis = dict_hr_enkf_analysis.get(i_cycle, None)

    if hr_analysis is None:
        dict_data = {
            "HR": all_hr_omega[i_ensemble, i_cycle],
            "LR(no-DA)": all_lr_omega[i_ensemble, i_cycle],
            "LR(EnKF)": all_enkf_means[i_ensemble, i_cycle],
            "SRDA(no noise)": all_sr_analysis[i_ensemble, i_cycle],
        }
    else:
        dict_data = {
            "HR": all_hr_omega[i_ensemble, i_cycle],
            "LR(no-DA)": all_lr_omega[i_ensemble, i_cycle],
            "LR(EnKF)": all_enkf_means[i_ensemble, i_cycle],
            "HR(EnKF)": hr_analysis,
            "SRDA(no noise)": all_sr_analysis[i_ensemble, i_cycle],
        }

    plot(
        dict_data,
        t,
        obs=all_hr_obsrv[i_ensemble, i_cycle],
        figsize=[16, 4.0],
        ttl_header=f"i_ens = {i_ensemble}, ",
        fig_file_name=f"snapshot_use_obs_{i_ensemble:03}_t_{int(t):03}",
        write_out=False,
        use_hr_space=True,
        vmin_omega=-10,
        vmax_omega=10,
    )

In [None]:
for i_ensemble in [43, 46, 48, 70, 91, 107]:
    for i_cycle in [40, 80]:

        t = (i_cycle + START_TIME_INDEX) * DT
        dict_hr_enkf_analysis = read_pickle(
            f"{ENKF_DIR}/ens_mean_hr_{i_ensemble:04}.pickle"
        )
        hr_analysis = dict_hr_enkf_analysis.get(i_cycle, None)

        if hr_analysis is None:
            dict_data = {
                "HR": all_hr_omega[i_ensemble, i_cycle],
                "LR(no-DA)": all_lr_omega[i_ensemble, i_cycle],
                "LR(EnKF)": all_enkf_means[i_ensemble, i_cycle],
                "SRDA(no noise)": all_sr_analysis[i_ensemble, i_cycle],
            }
        else:
            dict_data = {
                "HR": all_hr_omega[i_ensemble, i_cycle],
                "LR(no-DA)": all_lr_omega[i_ensemble, i_cycle],
                "LR(EnKF)": all_enkf_means[i_ensemble, i_cycle],
                "HR(EnKF)": hr_analysis,
                "SRDA(no noise)": all_sr_analysis[i_ensemble, i_cycle],
            }

        plot(
            dict_data,
            t,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            figsize=[16, 4.0],
            ttl_header=f"i_ens = {i_ensemble}, ",
            fig_file_name=f"snapshot_use_obs_{i_ensemble:03}_t_{int(t):03}",
            write_out=False,
            use_hr_space=True,
            vmin_omega=-10,
            vmax_omega=10,
        )