In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import typing

import numpy as np
import torch
import yaml
from cfd_model.cfd.periodic_channel_domain import TorchSpectralModel2D
from cfd_model.enkf.sr_enkf import (
    assimilate_with_existing_data,
    calc_localization_matrix,
)
from cfd_model.filter.low_pass_periodic_channel_domain import LowPassFilter
from cfd_model.initialization.periodic_channel_jet_initializer import calc_jet_forcing
from cfd_model.interpolator.torch_interpolator import interpolate_time_series
from src.dataloader import split_file_paths
from src.utils import set_seeds, write_pickle
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm.notebook import tqdm

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

# Define constants

In [None]:
DEVICE = "cuda:0"

In [None]:
SRDA_CONFIG_NAME = "default_neural_nets"

with open(f"{ROOT_DIR}/pytorch/config/{SRDA_CONFIG_NAME}.yml") as file:
    CONFIG = yaml.safe_load(file)

SRDA_DATA_DIR = f"{ROOT_DIR}/data/SRDA"

In [None]:
OBS_GRID_INTERVAL = 8

ENKF_PREFERENCES = {
    8: {
        "INIT_SYS_NOISE_FACTOR": 0.4,
        "LOCALIZE_DX": 0.5,
        "N_ENS": 300.0,
        "OBS_PERTURB_STD": 0.65,
        "SYS_NOISE_FACTOR": 0.2,
    },
}

ENKF_DATA_DIR = f"{ROOT_DIR}/data/EnKF"
os.makedirs(ENKF_DATA_DIR, exist_ok=True)

In [None]:
TRAIN_VALID_TEST_RATIOS = [0.7, 0.2, 0.1]

INITIAL_TIME_INDEX = 16
SEED = 42
INFLATION = 1.0

ASSIMILATION_PERIOD = 4
OBS_NOISE_STD = 0.1

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0 = np.pi / 2.0
SIGMA = 0.4
U0 = 3.0
TAU0 = 0.3
PERTUB_NOISE = 0.0025

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT

LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "assimilation_period": ASSIMILATION_PERIOD,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
    "y0": Y0,
    "sigma": SIGMA,
    "tau0": TAU0,
    "t0": INITIAL_TIME_INDEX * LR_DT * LR_NT,
}

In [None]:
N_ENS = int(ENKF_PREFERENCES[OBS_GRID_INTERVAL]["N_ENS"])
LOCALIZE_DX = ENKF_PREFERENCES[OBS_GRID_INTERVAL]["LOCALIZE_DX"]
SYS_NOISE_FACTOR = ENKF_PREFERENCES[OBS_GRID_INTERVAL]["SYS_NOISE_FACTOR"]
INIT_SYS_NOISE_FACTOR = ENKF_PREFERENCES[OBS_GRID_INTERVAL]["INIT_SYS_NOISE_FACTOR"]
OBS_PERTURB_STD = ENKF_PREFERENCES[OBS_GRID_INTERVAL]["OBS_PERTURB_STD"]

LOCALIZE_DY = LOCALIZE_DX
LR_CFD_CONFIG["ne"] = int(N_ENS)
LR_CFD_CONFIG["n_ens"] = int(N_ENS)

# Define methods

In [None]:
low_pass_filter = LowPassFilter(
    nx_lr=LR_NX, ny_lr=LR_NY, nx_hr=HR_NX, ny_hr=HR_NY, device=DEVICE
)


def initialize_lr_model(
    *,
    t0: float,
    hr_omega0: torch.Tensor,
    lr_forcing: torch.Tensor,
    lr_model: TorchSpectralModel2D,
    n_ens: int,
    hr_nx: int,
    hr_ny: int,
    lr_nx: int,
    lr_ny: int,
    **kwargs,
):
    assert hr_omega0.shape == (hr_nx, hr_ny)
    omega0 = low_pass_filter.apply(hr_omega0[None, ...])
    omega0 = torch.broadcast_to(omega0, (n_ens, lr_nx, lr_ny))

    lr_model.initialize(t0=t0, omega0=omega0, forcing=lr_forcing)
    lr_model.calc_grid_data()


def load_hr_data(
    root_dir: str,
    train_valid_test_ratios: typing.List[str],
    kind: str,
    num_hr_omega_sets: int,
    max_ens_index: int = 20,
) -> torch.Tensor:

    training_data_dir_path = f"{root_dir}/data/TrainingData"
    logger.info(f"Training data dir path = {training_data_dir_path}")

    data_dirs = sorted(
        [p for p in glob.glob(f"{training_data_dir_path}/*") if os.path.isdir(p)]
    )

    train_dirs, valid_dirs, test_dirs = split_file_paths(
        data_dirs, train_valid_test_ratios
    )

    if kind == "train":
        target_dirs = train_dirs
    elif kind == "valid":
        target_dirs = valid_dirs
    elif kind == "test":
        target_dirs = test_dirs
    else:
        raise Exception(f"{kind} is not supported.")

    logger.info(f"Kind = {kind}, Num of dirs = {len(target_dirs)}")

    all_hr_omegas = []
    for dir_path in sorted(target_dirs):
        for i in range(max_ens_index):

            hr_omegas = []
            for file_path in sorted(glob.glob(f"{dir_path}/*_hr_omega_{i:02}.npy")):
                data = np.load(file_path)

                # This is to avoid overlapping at the start/end point
                if len(hr_omegas) > 0:
                    data = data[1:]
                hr_omegas.append(data)

            # Concat along time axis
            all_hr_omegas.append(np.concatenate(hr_omegas, axis=0))

            if len(all_hr_omegas) == num_hr_omega_sets:
                # Concat along batch axis
                ret = np.stack(all_hr_omegas, axis=0)
                return torch.from_numpy(ret).to(torch.float64)

    ret = np.stack(all_hr_omegas, axis=0)
    return torch.from_numpy(ret).to(torch.float64)


def get_sys_noise_generator(num_hr_omega_sets: int = 250, eps: float = 1e-12):
    hr_omegas = load_hr_data(
        root_dir=ROOT_DIR,
        train_valid_test_ratios=TRAIN_VALID_TEST_RATIOS,
        kind="train",
        num_hr_omega_sets=num_hr_omega_sets,
    )
    hr_omegas = hr_omegas[:, INITIAL_TIME_INDEX:]

    # dims = batch, time, x, and y

    lr_omegas = interpolate_time_series(hr_omegas, LR_NX, LR_NY, "bicubic")
    lr_omegas = lr_omegas - torch.mean(lr_omegas, dim=0, keepdim=True)

    lr_omegas = lr_omegas.reshape(lr_omegas.shape[:2] + (-1,))
    # dims = batch, time, and space

    del hr_omegas
    gc.collect()

    # Inner product over batch dim
    all_covs = torch.mean(lr_omegas[..., None, :] * lr_omegas[..., None], dim=0)

    # Assure conv is symmetric.
    all_covs = (all_covs + all_covs.permute(0, 2, 1)) / 2.0

    # Assure positive definiteness
    all_covs = all_covs + torch.diag(
        torch.full(size=(all_covs.shape[-1],), fill_value=eps)
    )

    loc = torch.zeros(all_covs.shape[-1], dtype=torch.float64)
    return [MultivariateNormal(loc, cov) for cov in all_covs]


def get_obs_matrix(obs: torch.Tensor) -> torch.Tensor:

    assert obs.shape == (HR_NX, HR_NY)
    is_obs = torch.where(torch.isnan(obs), torch.zeros_like(obs), torch.ones_like(obs))

    _is_obs = is_obs.reshape(-1)
    obs_indices = torch.where(_is_obs == 1.0)[0]

    num_obs = len(obs_indices)

    obs_matrix = torch.zeros(num_obs, HR_NX * HR_NY, dtype=torch.float64, device=DEVICE)

    for i, j in enumerate(obs_indices):
        obs_matrix[i, j] = 1.0

    p = 100 * torch.sum(obs_matrix).item() / (HR_NX * HR_NY)
    logger.debug(f"observatio prob = {p} [%]")

    return obs_matrix


def get_all_srda_result_paths(config_name: str = SRDA_CONFIG_NAME):
    sr_prior_file_path = f"{SRDA_DATA_DIR}/sr_prior_{config_name}.npy"
    sr_analysis_file_path = f"{SRDA_DATA_DIR}/sr_analysis_{config_name}.npy"
    lr_omega_file_path = f"{SRDA_DATA_DIR}/lr_omega_{config_name}.npy"
    hr_omega_file_path = f"{SRDA_DATA_DIR}/hr_omega_{config_name}.npy"
    hr_obsrv_file_path = f"{SRDA_DATA_DIR}/hr_obsrv_{config_name}.npy"

    return (
        sr_prior_file_path,
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    )


def read_all_srda_result_files(config_name: str = SRDA_CONFIG_NAME):
    (
        sr_prior_file_path,
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_srda_result_paths(config_name)

    return (
        np.load(sr_prior_file_path),
        np.load(sr_analysis_file_path),
        np.load(lr_omega_file_path),
        np.load(hr_omega_file_path),
        np.load(hr_obsrv_file_path),
    )

# Read ground truth and observations

In [None]:
(_, _, _, all_hr_omegas, all_hr_obsrv) = read_all_srda_result_files(SRDA_CONFIG_NAME)
all_hr_omegas = torch.from_numpy(all_hr_omegas)
all_hr_obsrv = torch.from_numpy(all_hr_obsrv)

In [None]:
assert all_hr_omegas.shape == all_hr_obsrv.shape == (500, 81, 128, 65)

# Perform enkf

In [None]:
set_seeds(SEED, use_deterministic=True)

sys_noise_generators = get_sys_noise_generator()
_ = gc.collect()

In [None]:
assert len(sys_noise_generators) == all_hr_omegas.shape[1]

In [None]:
torch_rand_generator = torch.Generator().manual_seed(SEED)

localization_matrix = calc_localization_matrix(
    nx=HR_NX, ny=HR_NY, d_x=LOCALIZE_DX, d_y=LOCALIZE_DY
).to(DEVICE)

max_time_index = all_hr_omegas.shape[1]
assert max_time_index == 81

# Initialize with the hr simulation results at the initial time without noise
init_hr_omegas = all_hr_omegas[:, 0]
assert init_hr_omegas.shape == (500, HR_NX, HR_NY)

In [None]:
all_hr_obsrv.shape

In [None]:
for i_ens in tqdm(range(all_hr_omegas.shape[0])):
    output_lr_file_path = (
        f"{ENKF_DATA_DIR}/ens_all_lr_forecast_og{OBS_GRID_INTERVAL:02}_{i_ens:04}.npy"
    )
    output_hr_file_path = (
        f"{ENKF_DATA_DIR}/ens_mean_hr_og{OBS_GRID_INTERVAL:02}_{i_ens:04}.pickle"
    )
    if os.path.exists(output_lr_file_path) or os.path.exists(output_hr_file_path):
        continue

    logger.setLevel(WARNING)
    lr_model = TorchSpectralModel2D(**LR_CFD_CONFIG)
    _, lr_forcing = calc_jet_forcing(**LR_CFD_CONFIG)
    logger.setLevel(INFO)

    initialize_lr_model(
        hr_omega0=init_hr_omegas[i_ens].clone(),  # dims = x and y only
        lr_forcing=lr_forcing,
        lr_model=lr_model,
        **LR_CFD_CONFIG,
    )

    lr_enkfs = []
    dict_hr_analysis = {}

    for i_cycle in tqdm(range(max_time_index)):

        lr_enkfs.append(lr_model.omega.cpu().clone())

        # Data assimilation
        if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
            obs = all_hr_obsrv[i_ens, i_cycle].to(torch.float64)
            obs_matrix = get_obs_matrix(obs)

            # This is to avoid nan when observation operator acts.
            obs = torch.nan_to_num(obs, nan=1e10)

            # This method returns forecast conv
            all_hr_analysis = assimilate_with_existing_data(
                hr_omega=obs.to(DEVICE),
                lr_ens_model=lr_model,
                obs_matrix=obs_matrix,
                obs_noise_std=OBS_PERTURB_STD,
                inflation=INFLATION,
                rand_generator=torch_rand_generator,
                localization_matrix=localization_matrix,
                return_hr_analysis=True,
            )

            # Mean over batch (ensemble dim)
            assert all_hr_analysis.shape == (N_ENS, HR_NX, HR_NY)
            hr_analysis = torch.mean(all_hr_analysis, axis=0)
            dict_hr_analysis[i_cycle] = hr_analysis.cpu().to(torch.float32).numpy()

        # Add additive system noise
        if i_cycle == 0 or (INFLATION == 1.0 and i_cycle % ASSIMILATION_PERIOD == 0):
            noise = sys_noise_generators[i_cycle].sample([N_ENS])
            noise = noise.reshape(N_ENS, LR_NX, LR_NY)
            noise = noise - torch.mean(noise, dim=0, keepdims=True)

            factor = INIT_SYS_NOISE_FACTOR if i_cycle == 0 else SYS_NOISE_FACTOR
            omega = lr_model.omega + factor * noise.to(DEVICE)

            lr_model.initialize(t0=lr_model.t, omega0=omega)
            lr_model.calc_grid_data()

        lr_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
        lr_model.calc_grid_data()

    write_pickle(dict_hr_analysis, output_hr_file_path)

    # Stack along time dim
    lr_enkfs = torch.stack(lr_enkfs, dim=1).to(torch.float32).numpy()
    assert lr_enkfs.shape == (N_ENS, 81, LR_NX, LR_NY)
    np.save(output_lr_file_path, lr_enkfs)