In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
log_handler = StreamHandler(sys.stdout)
logger.addHandler(log_handler)
logger.setLevel(INFO)

# Import libraries

In [None]:
import os
from glob import glob
from logging import INFO, StreamHandler, getLogger

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import HTML, display
from tqdm.notebook import tqdm

ROOT_DIR = "/workspace"

sys.path.append(f"{ROOT_DIR}/pytorch/src")
from io_fortran_data import NX_T10, NX_T21, NX_T42, NY_T10, NY_T21, NY_T42, read_simulation_results

# Define constants

In [None]:
INPUT_DIR = f"{ROOT_DIR}/data/fortran/barotropic_instability_spectral_nudging"
OUTPUT_DIR = f"{ROOT_DIR}/data/pytorch/barotropic_instability_spectral_nudging/DL_data/vorticity"

DICT_DL_DATA_INFO = {
    "train_valid": {
        "fortran_configs": ["default_positive", "default_negative"],
        "time_start": 250,
        "time_end": 510,
        "time_interval": 20,
        "train_size_ratio": 0.7,
    },
    "test": {
        "fortran_configs": ["shear_with_0p40_positive", "shear_with_0p40_negative"],
        "time_start": 250,
        "time_end": 510,
        "time_interval": 20,
    },
}

# Split train-valid-test data

In [None]:
for data_kind, data_info in DICT_DL_DATA_INFO.items():
    for config_name in data_info["fortran_configs"]:
        _seeds = set()

        all_file_paths = pd.Series(glob(f"{INPUT_DIR}/{config_name}/*.dat"))

        all_seeds = (
            all_file_paths.apply(
                lambda x: os.path.basename(x).split("_")[-1].replace("seed", "").replace(".dat", "")
            )
            .drop_duplicates()
            .astype(np.uint64)
            .sort_values()
            .to_list()
        )
        print(f"{config_name} has {len(all_seeds)} simulation sets.")

        lst_data_kind = []
        lst_used_seeds = []
        if data_kind == "train_valid":
            _n = int(len(all_seeds) * data_info["train_size_ratio"])
            lst_data_kind.append("train")
            lst_used_seeds.append(all_seeds[:_n])
            lst_data_kind.append("valid")
            lst_used_seeds.append(all_seeds[_n:])
        else:
            lst_data_kind.append(data_kind)
            lst_used_seeds.append(all_seeds)

        for kind, seeds in zip(lst_data_kind, lst_used_seeds):
            print(f"{config_name} {kind}: num simulations = {len(seeds)}")

            for seed in tqdm(seeds):
                assert seed not in _seeds
                _seeds.add(seed)

                (
                    vortex_field_T10,
                    vortex_field_T21,
                    vortex_field_T42,
                    _,
                ) = read_simulation_results(INPUT_DIR, config_name, seed, use_T85_data=False)

                dict_vortex_fields = {
                    "T10": vortex_field_T10,
                    "T21": vortex_field_T21,
                    "T42": vortex_field_T42,
                }

                for resolution, Z in dict_vortex_fields.items():
                    dir_path = os.path.join(OUTPUT_DIR, kind, resolution)
                    os.makedirs(dir_path, exist_ok=True)

                    for it in range(
                        data_info["time_start"],
                        data_info["time_end"] + data_info["time_interval"],
                        data_info["time_interval"],
                    ):
                        z = Z[it, :, :]
                        file_path = os.path.join(
                            dir_path, f"seed_{seed}_time_{it}_{config_name}.npy"
                        )
                        np.save(file_path, z)
        print(f"Num _seeds = {len(_seeds)}")
        assert len(_seeds) == len(all_seeds)
        print("")

# Check train-valid-test data

In [None]:
dict_grid_fields = {}

xs = np.linspace(0, 2 * np.pi, NX_T10)
ys = np.linspace(0, np.pi, NY_T10 + 1)
dict_grid_fields["T10"] = np.meshgrid(xs, ys, indexing="ij")

xs = np.linspace(0, 2 * np.pi, NX_T21)
ys = np.linspace(0, np.pi, NY_T21 + 1)
dict_grid_fields["T21"] = np.meshgrid(xs, ys, indexing="ij")

xs = np.linspace(0, 2 * np.pi, NX_T42)
ys = np.linspace(0, np.pi, NY_T42 + 1)
dict_grid_fields["T42"] = np.meshgrid(xs, ys, indexing="ij")

In [None]:
resolutions = ["T10", "T21", "T42"]

ncols = 10
nrows = len(resolutions)

for data_kind in ["train", "valid", "test"]:
    data_info = (
        DICT_DL_DATA_INFO["test"] if data_kind == "test" else DICT_DL_DATA_INFO["train_valid"]
    )
    display(HTML(f"<h2>{data_kind}</h2>"))
    print(
        f'all file num = {len(glob(os.path.join(OUTPUT_DIR, data_kind, resolutions[0], "*.npy")))}'
    )

    all_seeds = (
        pd.Series(glob(os.path.join(OUTPUT_DIR, data_kind, resolutions[0], "*.npy")))
        .apply(lambda x: os.path.basename(x).split("_")[1])
        .drop_duplicates()
        .sort_values()
    )
    seed = all_seeds[0]

    for config_name in data_info["fortran_configs"]:
        display(HTML(f"<h3>{config_name}</h3>"))

        for it_start in range(
            data_info["time_start"],
            data_info["time_end"] + data_info["time_interval"],
            data_info["time_interval"] * ncols,
        ):
            fig, axes = plt.subplots(nrows, ncols, figsize=[25, 7])
            for ax in np.ravel(axes):
                ax.axis("off")
            for it, _axes in zip(
                range(
                    it_start,
                    it_start + data_info["time_interval"] * ncols,
                    data_info["time_interval"],
                ),
                axes.transpose(),
            ):
                if it > data_info["time_end"]:
                    break
                for resolution, ax in zip(resolutions, _axes):
                    X, Y = dict_grid_fields[resolution]
                    Z = np.load(
                        os.path.join(
                            OUTPUT_DIR,
                            data_kind,
                            resolution,
                            f"seed_{seed}_time_{it}_{config_name}.npy",
                        )
                    )
                    # ax.contourf(X, Y, Z, cmap="rainbow")
                    ax.imshow(Z.transpose(), cmap="rainbow", vmin=-2, vmax=2, interpolation=None)
                    ax.set_title(f"{resolution}, it = {it}")
            plt.tight_layout()
            plt.show()