In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import typing

import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from cfd_model.interpolator.torch_interpolator import interpolate
from IPython.display import HTML, display
from src.dataloader import (
    _make_dataloaders_vorticity_making_observation_inside_time_series_splitted_with_mixup,
    make_dataloaders_vorticity_making_observation_inside_time_series_splitted,
    split_file_paths,
)
from src.dataset import DatasetMakingObsInsideTimeseriesSplittedWithMixupRandomSampling
from src.utils import set_seeds
from tqdm.notebook import tqdm

In [None]:
set_seeds(42)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
PYTORCH_DATA_DIR = f"{ROOT_DIR}/data/pytorch"

CFD_NAME = "jet26"
TRAIN_VALID_TEST_RATIOS = [0.7, 0.2, 0.1]
BATCH_SIZE = 32

OBS_GRID_INTERVAL = 8
OBS_TIME_INTERVAL = 4
OBS_NOISE_STD = 0.1
LR_TIME_INTERVAL = 2
BIAS = -14.5
SCALE = 29

CFD_DATA_DIR = f"{PYTORCH_DATA_DIR}/CFD/{CFD_NAME}"

# Define methods

In [None]:
def plot(
    gt,
    obs,
    lr,
    dt: float = 0.25,
    input_interval: int = 1,
    missing_value: float = np.nan,
):
    n_times = gt.shape[0]

    for j_sample, i_time in enumerate(range(0, n_times, input_interval)):
        t = dt * i_time

        _hr = gt[i_time].squeeze()
        nx = _hr.shape[0]
        ny = _hr.shape[1]

        dict_data = {
            "HR": _hr,
            "HR(obs)": obs[i_time].squeeze(),
            "LR": interpolate(
                lr[j_sample : j_sample + 1].squeeze(1), nx=nx, ny=ny
            ).squeeze(),
        }

        vmin_omega = np.quantile(_hr.numpy().flatten(), 0.01)
        vmax_omega = np.quantile(_hr.numpy().flatten(), 0.99)

        fig, axes = plt.subplots(
            1, len(dict_data), figsize=[10, 2], sharex=True, sharey=True
        )

        hr_gt = None
        for ax, (label, data) in zip(axes, dict_data.items()):
            if label.startswith("LR"):
                diff = hr_gt - data
                mae = torch.mean(torch.abs(diff)).item()
                ttl = f"MAE = {mae:.4f}\n{label} "
            elif label == "HR":
                hr_gt = data
                ttl = f"t = {t:.2f}\n{label}"
            else:
                if np.isnan(missing_value):
                    diff = torch.where(
                        torch.isnan(data), torch.zeros_like(data), data - hr_gt
                    )
                    print(
                        "non-missing grid ratio",
                        1.0 - torch.sum(torch.isnan(data)) / (nx * ny),
                    )
                else:
                    diff = torch.where(
                        data == missing_value, torch.zeros_like(data), data - hr_gt
                    )
                mae = torch.max(torch.abs(diff)).item()
                ttl = f"max diff = {mae:.4f}\n{label}"

            d = np.squeeze(data.numpy()).transpose()
            print(label, d.shape)

            xs = np.linspace(0, 2 * np.pi, num=d.shape[0], endpoint=False)
            ys = np.linspace(0, np.pi, num=d.shape[1], endpoint=False)
            xs, ys = np.meshgrid(xs, ys, indexing="ij")

            ax.pcolormesh(xs, ys, d, cmap="bwr", vmin=vmin_omega, vmax=vmax_omega)
            ax.set_title(ttl)

            ax.xaxis.set_ticklabels([])
            ax.yaxis.set_ticklabels([])
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)

        plt.tight_layout()
        plt.show()

# Get data dirs

In [None]:
data_dirs = sorted([p for p in glob.glob(f"{CFD_DATA_DIR}/*") if os.path.isdir(p)])

In [None]:
len(data_dirs)

In [None]:
len(data_dirs) * 20

# Debug dataset

In [None]:
def beta_dist_pdf(alpha, beta, num=1000):
    xs = np.linspace(0, 1, num=num, endpoint=False)
    pdfs = np.power(xs, alpha - 1) * np.power(1 - xs, beta - 1)
    dx = 1.0 / num
    normalization = np.sum(pdfs) * dx

    mode = (alpha - 1) / (alpha + beta - 2)
    var = (alpha * beta) / (alpha + beta + 1) / (alpha + beta) ** 2
    print(f"mode = {mode}, var = {var}")

    return xs, pdfs / normalization

In [None]:
for alpha in [2.0]:
    for beta in np.arange(2.0, 25.0, 2):
        xs, pdfs = beta_dist_pdf(alpha, beta)
        plt.plot(xs, pdfs)

        samples = np.random.beta(a=alpha, b=beta, size=10000)
        plt.hist(samples, density=True, range=(0, 1), bins=51)

        plt.title(f"alpha = {alpha}, beta = {beta}")
        plt.show()

In [None]:
lst_lr_names = [
    "lr_omega_no-noise",
]

In [None]:
data_dirs[0]

In [None]:
logger.setLevel(DEBUG)
dataset = DatasetMakingObsInsideTimeseriesSplittedWithMixupRandomSampling(
    data_dirs=data_dirs,
    lr_kind_names=lst_lr_names,
    lr_time_interval=LR_TIME_INTERVAL,
    obs_grid_interval=OBS_GRID_INTERVAL,
    obs_time_interval=OBS_TIME_INTERVAL,
    obs_noise_std=OBS_NOISE_STD,
    use_observation=True,
    vorticity_bias=BIAS,
    vorticity_scale=SCALE,
    use_ground_truth_clamping=True,
    use_mixup=True,
    use_mixup_init_time=False,
    beta_dist_alpha=2.0,
    beta_dist_beta=2.0,
    is_last_obs_missing=True,
    min_start_time_index=-1,
    max_start_time_index=92,
)
logger.setLevel(INFO)

In [None]:
len(dataset)

In [None]:
dataset.hr_file_paths

In [None]:
logger.setLevel(DEBUG)
target_lr, source_lr, hr = dataset._load_np_data(777)
target_lr.shape, source_lr.shape, hr.shape

In [None]:
logger.setLevel(DEBUG)
lr, obs, gt = dataset.__getitem__(3900)
logger.setLevel(INFO)
lr.shape, obs.shape, gt.shape

In [None]:
plot(
    gt,
    obs,
    lr,
    input_interval=LR_TIME_INTERVAL,
    dt=0.25,
    missing_value=0.0,
)

# Debug dataloader

In [None]:
data_dirs = sorted([p for p in glob.glob(f"{CFD_DATA_DIR}/*") if os.path.isdir(p)])
train_dirs, valid_dirs, test_dirs = split_file_paths(data_dirs, TRAIN_VALID_TEST_RATIOS)
len(train_dirs), len(valid_dirs), len(test_dirs)

In [None]:
dict_data_dirs = {"train": train_dirs, "valid": valid_dirs, "test": test_dirs}

In [None]:
lst_lr_names = [
    "lr_omega_no-noise",
]

In [None]:
(
    dict_dataloaders,
    _,
) = _make_dataloaders_vorticity_making_observation_inside_time_series_splitted_with_mixup(
    dict_dir_paths=dict_data_dirs,
    lr_kind_names=lst_lr_names,
    lr_time_interval=LR_TIME_INTERVAL,
    obs_time_interval=OBS_TIME_INTERVAL,
    obs_grid_interval=OBS_GRID_INTERVAL,
    obs_noise_std=OBS_NOISE_STD,
    use_observation=True,
    vorticity_bias=BIAS,
    vorticity_scale=SCALE,
    batch_size=BATCH_SIZE,
    use_mixup=False,
    use_mixup_init_time=False,
    beta_dist_alpha=2.0,
    beta_dist_beta=30.0,
    is_last_obs_missing=True,
    min_start_time_index=-1,
    max_start_time_index=92,
)

In [None]:
i_batch = 10
for kind in ["train", "valid", "test"]:
    display(HTML(f"<h2>{kind}</h2>"))
    lr, obs, hr = next(iter(dict_dataloaders[kind]))
    print(lr.shape, obs.shape, hr.shape)
    plot(
        hr[i_batch],
        obs[i_batch],
        lr[i_batch],
        dt=0.25,
        missing_value=0,
        input_interval=LR_TIME_INTERVAL,
    )

# Check histograms

In [None]:
data_dirs = sorted([p for p in glob.glob(f"{CFD_DATA_DIR}/*") if os.path.isdir(p)])
train_dirs, valid_dirs, test_dirs = split_file_paths(data_dirs, TRAIN_VALID_TEST_RATIOS)
len(train_dirs), len(valid_dirs), len(test_dirs)

In [None]:
dict_data_dirs = {"train": train_dirs, "valid": valid_dirs, "test": test_dirs}

In [None]:
lst_lr_names = []
for i_amp in [
    9.0,
]:
    i = int(i_amp)
    lst_lr_names.append(f"lr_omega_gaussian_0p{i}")
    lst_lr_names.append(f"lr_omega_sobel_y_p0p{i}")
    lst_lr_names.append(f"lr_omega_sobel_y_n0p{i}")
    lst_lr_names.append(f"lr_omega_sobel_x_p0p{i}")
    lst_lr_names.append(f"lr_omega_sobel_x_n0p{i}")
lst_lr_names

In [None]:
(
    dict_dataloaders,
    _,
) = _make_dataloaders_vorticity_making_observation_inside_time_series_splitted(
    dict_dir_paths=dict_data_dirs,
    lr_kind_names=lst_lr_names,
    lr_time_interval=LR_TIME_INTERVAL,
    obs_time_interval=OBS_TIME_INTERVAL,
    obs_grid_interval=OBS_GRID_INTERVAL,
    obs_noise_std=OBS_NOISE_STD,
    use_observation=True,
    vorticity_bias=BIAS,
    vorticity_scale=SCALE,
    batch_size=BATCH_SIZE,
)

In [None]:
lr_data, obs_data, hr_data = (
    {"ω": []},
    {"ω": []},
    {"ω": []},
)

for kind in ["train", "valid"]:
    for lr, obs, hr in tqdm(dict_dataloaders[kind], total=len(dict_dataloaders[kind])):
        for idx_var, name_var in enumerate(["ω"]):
            lr_data[name_var] += list(lr[:, :, idx_var, ::2, ::2].numpy().flatten())
            hr_data[name_var] += list(hr[:, :, idx_var, ::8, ::8].numpy().flatten())
            obs_data[name_var] += list(obs[:, :, idx_var, ::8, ::8].numpy().flatten())

In [None]:
fig, axes = plt.subplots(1, 3, figsize=[15, 5])
name_var = "ω"

for i, (name_data, data) in enumerate(
    zip(["LR", "OBS", "HR"], [lr_data, obs_data, hr_data])
):

    ax = axes[i]
    xs = data[name_var]

    vmin = np.nanquantile(xs, 0.00001)
    vmax = np.nanquantile(xs, 0.99999)
    mean = np.nanmean(np.abs(xs))

    print(
        f"{name_data}:{name_var}, vmin={vmin:.3f}, vmax={vmax:.3f}, mean={mean}, scale={vmax-vmin:.3f}, len={len(xs)}"
    )
    ax.set_title(
        f"{name_data}:{name_var}\nvmin={vmin:.3f},vmax={vmax:.3f}\nmean={mean},len={len(xs)}"
    )
    ax.hist(xs, range=(vmin, vmax), bins=21)

plt.tight_layout()
plt.show()

In [None]:
# import gc

# del lr_data, obs_data, hr_data
# _ = gc.collect()

# Make dataloader from config

In [None]:
CONFIG_PATH = f"{ROOT_DIR}/pytorch/config/paper_experiment_06/lt4og12_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd221958.yml"
with open(CONFIG_PATH) as file:
    config = yaml.safe_load(file)

In [None]:
config

In [None]:
(
    dict_dataloaders,
    _,
) = make_dataloaders_vorticity_making_observation_inside_time_series_splitted(
    ROOT_DIR, config
)

In [None]:
dict_dataloaders["test"].dataset.hr_file_paths

In [None]:
i_batch = 16
for kind in ["train", "valid", "test"]:
    display(HTML(f"<h2>{kind}</h2>"))
    lr, obs, hr = next(iter(dict_dataloaders[kind]))
    print(lr.shape, obs.shape, hr.shape)
    plot(
        hr[i_batch],
        obs[i_batch],
        lr[i_batch],
        dt=0.25,
        missing_value=0.0,
        input_interval=config["data"]["lr_time_interval"],
    )

In [None]:
lr, obs, hr = next(iter(dict_dataloaders["train"]))

In [None]:
lr.shape, obs.shape, hr.shape

In [None]:
for data in obs[0]:
    plt.imshow(data[0])
    plt.show()