In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import random
import time
import typing

import numpy as np
import torch
from src.dataloader import split_file_paths
from src.utils import read_pickle, set_seeds, write_pickle

if "ipykernel" in sys.modules:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
TMP_DATA_DIR = f"{ROOT_DIR}/data/pytorch"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
DEVICE = "cpu"

In [None]:
CFD_DIR_NAME = "jet02"
TRAIN_VALID_TEST_RATIOS = [0.7, 0.2, 0.1]
USED_DATA_KIND = "train"
ASSIMILATION_PERIOD = 4

# Define methods

In [None]:
def load_hr_data(
    root_dir: str,
    cfd_dir_name: str,
    train_valid_test_ratios: typing.List[str],
    kind: str,
    num_hr_omega_sets: int,
    max_ens_index: int = 20,
) -> torch.Tensor:

    cfd_dir_path = f"{root_dir}/data/pytorch/CFD/{cfd_dir_name}"
    logger.info(f"CFD dir path = {cfd_dir_path}")

    data_dirs = sorted([p for p in glob.glob(f"{cfd_dir_path}/*") if os.path.isdir(p)])

    train_dirs, valid_dirs, test_dirs = split_file_paths(
        data_dirs, train_valid_test_ratios
    )

    if kind == "train":
        target_dirs = train_dirs
    elif kind == "valid":
        target_dirs = valid_dirs
    elif kind == "test":
        target_dirs = test_dirs
    else:
        raise Exception(f"{kind} is not supported.")

    logger.info(f"Kind = {kind}, Num of dirs = {len(target_dirs)}")

    all_hr_omegas = []
    for dir_path in sorted(target_dirs):
        for i in range(max_ens_index):

            hr_omegas = []
            for file_path in sorted(glob.glob(f"{dir_path}/*_hr_omega_{i:02}.npy")):
                data = np.load(file_path)

                # This is to avoid overlapping at the start/end point
                if len(hr_omegas) > 0:
                    data = data[1:]
                hr_omegas.append(data)

            # Concat along time axis
            all_hr_omegas.append(np.concatenate(hr_omegas, axis=0))

            if len(all_hr_omegas) == num_hr_omega_sets:
                # Concat along batch axis
                ret = np.stack(all_hr_omegas, axis=0)
                return torch.from_numpy(ret).to(torch.float64)

    ret = np.stack(all_hr_omegas, axis=0)
    return torch.from_numpy(ret).to(torch.float64)


# Set `num_hr_omega_sets` = 50 for the paper
def get_cov_for_sys_noise_generator(num_hr_omega_sets: int = 10, eps: float = 1e-10):
    hr_omegas = load_hr_data(
        root_dir=ROOT_DIR,
        cfd_dir_name=CFD_DIR_NAME,
        train_valid_test_ratios=TRAIN_VALID_TEST_RATIOS,
        kind="train",
        num_hr_omega_sets=num_hr_omega_sets,
    )
    # dims = batch, time, x, and y

    logger.info(hr_omegas.shape)

    hr_omegas = hr_omegas.reshape(hr_omegas.shape[:2] + (-1,))
    hr_omegas = hr_omegas[:, ::ASSIMILATION_PERIOD]

    # Inner product over batch dim
    all_covs = torch.mean(hr_omegas[..., None, :] * hr_omegas[..., None], dim=0)

    # Assure conv is symmetric.
    all_covs = (all_covs + all_covs.permute(0, 2, 1)) / 2.0

    # Assure positive definiteness
    all_covs = all_covs + torch.diag(
        torch.full(size=(all_covs.shape[-1],), fill_value=eps)
    )

    return all_covs

# Generate system noise in advance

In [None]:
cov_file_path = f"{TMP_DATA_DIR}/sys_noise_covs.pickle"

if os.path.exists(cov_file_path):
    all_covs = get_cov_for_sys_noise_generator()
    write_pickle(all_covs, cov_file_path)
else:
    all_covs = get_cov_for_sys_noise_generator()
    write_pickle(all_covs, cov_file_path)

In [None]:
all_covs.shape