In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import random
import typing

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
from cfd_model.cfd.periodic_channel_domain import TorchSpectralModel2D
from cfd_model.enkf.sr_enkf import (
    assimilate_with_existing_data,
    calc_localization_matrix,
)
from cfd_model.filter.low_pass_periodic_channel_domain import LowPassFilter
from cfd_model.initialization.periodic_channel_jet_initializer import (
    calc_init_omega,
    calc_init_perturbation_hr_omegas,
    calc_jet_forcing,
)
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from src.dataloader import split_file_paths
from src.dataset import generate_is_obs_and_obs_matrix
from src.ssim import SSIM
from src.utils import read_pickle, set_seeds, write_pickle
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm.notebook import tqdm

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
DEVICE = "cuda"
if not torch.cuda.is_available():
    raise Exception("No GPU. CPU is used.")

In [None]:
ENKF_PREFERENCES = {
    4: {
        "INIT_SYS_NOISE_FACTOR": 0.0,
        "LOCALIZE_DX": 0.5,
        "N_ENS": 100,
        "OBS_PERTURB_STD": 0.06,
        "SYS_NOISE_FACTOR": 0.2,
    },
    6: {
        "INIT_SYS_NOISE_FACTOR": 0.1,
        "LOCALIZE_DX": 0.3,
        "N_ENS": 100,
        "OBS_PERTURB_STD": 0.16,
        "SYS_NOISE_FACTOR": 0.2,
    },
    8: {
        "INIT_SYS_NOISE_FACTOR": 0.2,
        "LOCALIZE_DX": 0.3,
        "N_ENS": 100,
        "OBS_PERTURB_STD": 0.26,
        "SYS_NOISE_FACTOR": 0.2,
    },
    10: {
        "INIT_SYS_NOISE_FACTOR": 0.2,
        "LOCALIZE_DX": 0.7,
        "N_ENS": 100,
        "OBS_PERTURB_STD": 0.41,
        "SYS_NOISE_FACTOR": 0.2,
    },
    12: {
        "INIT_SYS_NOISE_FACTOR": 0.2,
        "LOCALIZE_DX": 0.7,
        "N_ENS": 100,
        "OBS_PERTURB_STD": 0.41,
        "SYS_NOISE_FACTOR": 0.2,
    },
}

In [None]:
TMP_DATA_DIR = "./data"
CSV_DATA_DIR = "./csv"
FIG_DIR = "./fig"

In [None]:
CFD_DIR_NAME = "jet02"
TRAIN_VALID_TEST_RATIOS = [0.7, 0.2, 0.1]

INFLATION = 1.0
SEED = 777

ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 0
MAX_TIME_INDEX_FOR_INTEGRATION = 96

NUM_TIMES = MAX_TIME_INDEX_FOR_INTEGRATION

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

UHR_NX = 1024
UHR_NY = 513

Y0 = np.pi / 2.0
SIGMA = 0.4
U0 = 3.0
TAU0 = 0.3
PERTUB_NOISE = 0.0025

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT

LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "assimilation_period": ASSIMILATION_PERIOD,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
    "y0": Y0,
    "sigma": SIGMA,
    "tau0": TAU0,
    "t0": 0.0,
}

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def get_initial_hr_omega(ne: int, seed: int):
    hr_jet, _ = calc_jet_forcing(
        nx=HR_NX,
        ny=HR_NY,
        ne=ne,
        y0=Y0,
        sigma=SIGMA,
        tau0=TAU0,
    )

    hr_perturb = calc_init_perturbation_hr_omegas(
        nx=HR_NX, ny=HR_NY, ne=ne, noise_amp=PERTUB_NOISE, seed=seed
    )

    hr_omega0 = calc_init_omega(
        perturb_omega=hr_perturb,
        jet=hr_jet,
        u0=U0,
    )

    return hr_omega0

In [None]:
low_pass_filter = LowPassFilter(
    nx_lr=LR_NX, ny_lr=LR_NY, nx_hr=HR_NX, ny_hr=HR_NY, device=DEVICE
)


def initialize_lr_model(
    *,
    t0: float,
    hr_omega0: torch.Tensor,
    lr_forcing: torch.Tensor,
    lr_model: TorchSpectralModel2D,
    n_ens: int,
    hr_nx: int,
    hr_ny: int,
    lr_nx: int,
    lr_ny: int,
    **kwargs,
):
    assert hr_omega0.shape == (hr_nx, hr_ny)
    omega0 = low_pass_filter.apply(hr_omega0[None, ...])
    omega0 = torch.broadcast_to(omega0, (n_ens, lr_nx, lr_ny))

    lr_model.initialize(t0=t0, omega0=omega0, forcing=lr_forcing)
    lr_model.calc_grid_data()


def get_obs_matrices(obs_grid_interval: int):
    obs_matrices = []

    for init_x in tqdm(range(obs_grid_interval)):
        for init_y in range(obs_grid_interval):
            _, obs_mat = generate_is_obs_and_obs_matrix(
                nx=HR_NX,
                ny=HR_NY,
                init_index_x=init_x,
                init_index_y=init_y,
                interval_x=obs_grid_interval,
                interval_y=obs_grid_interval,
                dtype=torch.float64,
            )
            obs_matrices.append(obs_mat)

    return obs_matrices


def load_hr_data(
    root_dir: str,
    cfd_dir_name: str,
    train_valid_test_ratios: typing.List[str],
    kind: str,
    num_hr_omega_sets: int,
    max_ens_index: int = 20,
) -> torch.Tensor:

    cfd_dir_path = f"{root_dir}/data/pytorch/CFD/{cfd_dir_name}"
    logger.info(f"CFD dir path = {cfd_dir_path}")

    data_dirs = sorted([p for p in glob.glob(f"{cfd_dir_path}/*") if os.path.isdir(p)])

    train_dirs, valid_dirs, test_dirs = split_file_paths(
        data_dirs, train_valid_test_ratios
    )

    if kind == "train":
        target_dirs = train_dirs
    elif kind == "valid":
        target_dirs = valid_dirs
    elif kind == "test":
        target_dirs = test_dirs
    else:
        raise Exception(f"{kind} is not supported.")

    logger.info(f"Kind = {kind}, Num of dirs = {len(target_dirs)}")

    all_hr_omegas = []
    for dir_path in sorted(target_dirs):
        for i in range(max_ens_index):

            hr_omegas = []
            for file_path in sorted(glob.glob(f"{dir_path}/*_hr_omega_{i:02}.npy")):
                data = np.load(file_path)

                # This is to avoid overlapping at the start/end point
                if len(hr_omegas) > 0:
                    data = data[1:]
                hr_omegas.append(data)

            # Concat along time axis
            all_hr_omegas.append(np.concatenate(hr_omegas, axis=0))

            if len(all_hr_omegas) == num_hr_omega_sets:
                # Concat along batch axis
                ret = np.stack(all_hr_omegas, axis=0)
                return torch.from_numpy(ret).to(torch.float64)

    ret = np.stack(all_hr_omegas, axis=0)
    return torch.from_numpy(ret).to(torch.float64)


def get_obs_matrix(obs: torch.Tensor) -> torch.Tensor:

    assert obs.shape == (HR_NX, HR_NY)
    is_obs = torch.where(torch.isnan(obs), torch.zeros_like(obs), torch.ones_like(obs))

    obs_indices = is_obs.reshape(-1)
    obs_indices = torch.where(obs_indices == 1.0)[0]

    num_obs = len(obs_indices)

    obs_matrix = torch.zeros(num_obs, HR_NX * HR_NY, dtype=torch.float64, device=DEVICE)

    for i, j in enumerate(obs_indices):
        obs_matrix[i, j] = 1.0

    p = 100 * torch.sum(obs_matrix).item() / (HR_NX * HR_NY)
    logger.debug(f"observatio prob = {p} [%]")

    return obs_matrix


def calc_errors(all_gt: torch.Tensor, all_enkf: torch.Tensor) -> pd.DataFrame:
    ssim_func = SSIM(size_average=False, use_gauss=True)
    df_errors = pd.DataFrame()

    for kind in ["LR", "HR"]:

        if kind == "HR":
            gt = all_gt
        else:
            gt = interpolate_time_series(all_gt, LR_NX, LR_NY, "bicubic")

        if kind == "LR":
            pred = all_enkf
        else:
            pred = interpolate_time_series(all_enkf, HR_NX, HR_NY, "bicubic")

        # batch, time, x, and y dims
        assert gt.ndim == 4
        assert gt.shape == pred.shape

        mae = torch.mean(torch.abs(gt - pred), dim=(-2, -1))  # mean over x and y
        nrms = torch.mean(torch.abs(gt), dim=(-2, -1))
        maer = torch.mean(mae / nrms, dim=0)  # mean over batch dim

        ssim = ssim_func(
            img1=gt.to(DEVICE),
            img2=pred.to(DEVICE),
        )
        ssim = torch.mean(ssim, dim=(0, -2, -1))  # mean over batch, x, and y
        ssim = 1.0 - ssim

        assert ssim.shape == maer.shape

        df_errors[f"{kind}_MAER"] = maer.numpy()
        df_errors[f"{kind}_SSIMLoss"] = ssim.cpu().numpy()

    return df_errors


def get_sys_noise_generator(num_hr_omega_sets: int = 250, eps: float = 1e-12):
    hr_omegas = load_hr_data(
        root_dir=ROOT_DIR,
        cfd_dir_name=CFD_DIR_NAME,
        train_valid_test_ratios=TRAIN_VALID_TEST_RATIOS,
        kind="train",
        num_hr_omega_sets=num_hr_omega_sets,
    )
    # dims = batch, time, x, and y
    logger.info(f"hr_omega shape = {hr_omegas.shape}")

    lr_omegas = interpolate_time_series(hr_omegas, LR_NX, LR_NY, "bicubic")
    lr_omegas = lr_omegas - torch.mean(lr_omegas, dim=0, keepdim=True)

    lr_omegas = lr_omegas.reshape(lr_omegas.shape[:2] + (-1,))
    # dims = batch, time, and space

    del hr_omegas
    gc.collect()

    # Inner product over batch dim
    all_covs = torch.mean(lr_omegas[..., None, :] * lr_omegas[..., None], dim=0)

    # Assure conv is symmetric.
    all_covs = (all_covs + all_covs.permute(0, 2, 1)) / 2.0

    # Assure positive definiteness
    all_covs = all_covs + torch.diag(
        torch.full(size=(all_covs.shape[-1],), fill_value=eps)
    )

    loc = torch.zeros(all_covs.shape[-1], dtype=torch.float64)
    return [MultivariateNormal(loc, cov) for cov in all_covs]

# Perform enkf

In [None]:
sys_noise_generators = get_sys_noise_generator()
torch_rand_generator = torch.Generator().manual_seed(SEED)
_ = gc.collect()

In [None]:
OBS_SRDA_SEED = 221958

In [None]:
i_simulation_cuda = 0

for GRID_INTERVAL in tqdm([4, 6, 8, 12]):
    N_ENS = int(ENKF_PREFERENCES[GRID_INTERVAL]["N_ENS"])
    LOCALIZE_DX = ENKF_PREFERENCES[GRID_INTERVAL]["LOCALIZE_DX"]
    SYS_NOISE_FACTOR = ENKF_PREFERENCES[GRID_INTERVAL]["SYS_NOISE_FACTOR"]
    INIT_SYS_NOISE_FACTOR = ENKF_PREFERENCES[GRID_INTERVAL]["INIT_SYS_NOISE_FACTOR"]
    OBS_PERTURB_STD = ENKF_PREFERENCES[GRID_INTERVAL]["OBS_PERTURB_STD"]

    LOCALIZE_DY = LOCALIZE_DX
    LR_CFD_CONFIG["ne"] = int(N_ENS)
    LR_CFD_CONFIG["n_ens"] = int(N_ENS)

    localization_matrix = calc_localization_matrix(
        nx=HR_NX, ny=HR_NY, d_x=LOCALIZE_DX, d_y=LOCALIZE_DY
    ).to(DEVICE)

    for I_SEED_UHR in tqdm(range(9999, 9949, -1)):
        i_simulation_cuda += 1
        # if i_simulation_cuda % 4 != 0:
        #     continue

        logger.info(f"grid interval = {GRID_INTERVAL}, SEED = {I_SEED_UHR}")

        UHR_RESULT_DIR = f"{ROOT_DIR}/data/pytorch/CFD/jet09/seed{I_SEED_UHR:05}"

        CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/perform_ST_SRDA"
        CONFIG_NAME = f"lt4og{GRID_INTERVAL:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{OBS_SRDA_SEED}"
        CONFIG_PATH = f"{CONFIG_DIR}/{CONFIG_NAME}.yml"

        output_lr_file_path = f"{TMP_DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}_ens_bicubic_lr.npy"
        output_hr_file_path = f"{TMP_DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}_ens_bicubic_mean_hr.pickle"
        if os.path.exists(output_lr_file_path) and os.path.exists(output_hr_file_path):
            logger.info("Results already exist. So skip.")
            continue

        output_obs_npz_file_path = f"{TMP_DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_{CONFIG_NAME}.npz"
        if not os.path.exists(output_obs_npz_file_path):
            logger.info(f"SRDA result does not exist.")
            continue

        all_data = np.load(output_obs_npz_file_path)
        hr_obs = torch.from_numpy(all_data["hr_obs"])

        set_seeds(SEED, use_deterministic=True)
        init_hr_omega = get_initial_hr_omega(
            ne=1, seed=I_SEED_UHR + GRID_INTERVAL * 100
        ).squeeze()

        assert len(sys_noise_generators) == (
            MAX_TIME_INDEX_FOR_INTEGRATION + START_TIME_INDEX + 1
        )

        logger.setLevel(WARNING)
        lr_model = TorchSpectralModel2D(**LR_CFD_CONFIG)
        _, lr_forcing = calc_jet_forcing(**LR_CFD_CONFIG)
        logger.setLevel(INFO)

        assert lr_forcing.shape == (100, LR_NX, LR_NY)

        initialize_lr_model(
            hr_omega0=init_hr_omega,
            lr_forcing=lr_forcing,
            lr_model=lr_model,
            **LR_CFD_CONFIG,
        )

        lr_enkfs = []
        dict_hr_analysis = {}

        for i_cycle in tqdm(range(MAX_TIME_INDEX_FOR_INTEGRATION)):

            # Data assimilation
            if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
                obs = hr_obs[i_cycle].to(torch.float64)
                obs_matrix = get_obs_matrix(obs)

                # This is to avoid nan when observation operator acts.
                obs = torch.nan_to_num(obs, nan=1e10)

                all_hr_analysis, _ = assimilate_with_existing_data(
                    hr_omega=obs.to(DEVICE),
                    lr_ens_model=lr_model,
                    obs_matrix=obs_matrix,
                    obs_noise_std=OBS_PERTURB_STD,
                    inflation=INFLATION,
                    rand_generator=torch_rand_generator,
                    localization_matrix=localization_matrix,
                    return_hr_analysis=True,
                )

                # Mean over batch (ensemble dim)
                hr_analysis = torch.mean(all_hr_analysis, axis=0)
                assert hr_analysis.shape == (HR_NX, HR_NY)
                dict_hr_analysis[i_cycle] = hr_analysis.cpu().to(torch.float32).numpy()

            lr_enkfs.append(lr_model.omega.cpu().clone())

            # Add additive system noise
            if i_cycle == 0 or (
                INFLATION == 1.0 and i_cycle % ASSIMILATION_PERIOD == 0
            ):
                noise = sys_noise_generators[START_TIME_INDEX + i_cycle].sample([N_ENS])
                noise = noise.reshape(N_ENS, LR_NX, LR_NY)
                noise = noise - torch.mean(noise, dim=0, keepdims=True)

                factor = INIT_SYS_NOISE_FACTOR if i_cycle == 0 else SYS_NOISE_FACTOR
                omega = lr_model.omega + factor * noise.to(DEVICE)

                lr_model.initialize(t0=lr_model.t, omega0=omega)
                lr_model.calc_grid_data()

            lr_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
            lr_model.calc_grid_data()

        write_pickle(dict_hr_analysis, output_hr_file_path)

        # Stack along time dim
        lr_enkfs = torch.stack(lr_enkfs, dim=1).to(torch.float32).numpy()
        np.save(output_lr_file_path, lr_enkfs)