In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import libraries

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from src.four_dim_srda.config.config_loader import load_config
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.four_dim_srda.utils.calc_statistics import (
    calc_maer,
    calc_maer_averaging_over_selected_iz,
    calc_mssim,
)
from src.four_dim_srda.utils.ssim import MSSIM
from src.four_dim_srda.utils.torch_interpolator import (
    interpolate_2d,
    interpolate_along_z,
)
from src.qg_model.low_pass_filter import LowPassFilter
from src.qg_model.qg_model import QGModel

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
# model_names = ["ConvTransNetVer01", "UNetMaxVitVer01", "UNetVitVer02"]
model_names = ["ConvTransNetVer01", "UNetMaxVitVer01"]

CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"

# srda
dict_cfg_srda_name = {}
for m_name in model_names:
    if m_name == "ConvTransNetVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_bias1_bs12_lr1e-04"
    #
    elif m_name == "UNetMaxVitVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nmb6_bias0_bs12_lr1e-04"
    #
    elif m_name == "UNetVitVer02":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nvb4_bias0_vits0_bs12_lr1e-04"

# config srda
CFG_SRDA_PATH = f"{CFG_DIR}/perform_4D_SRDA/{model_names[0]}/{dict_cfg_srda_name[model_names[0]]}.yml"


cfg_srda = load_config(model_name=model_names[0], config_path=CFG_SRDA_PATH)

# config cfd
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE_CPU = "cpu"
DEVICE_GPU = torch.device("cuda") if torch.cuda.is_available() else None

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE_CPU

dict_cfg_srda_name

In [None]:
list_cfg_letkf_name = [
    "na3e-03_letkf_cfg_ogx08_ogy08_ne100_ch16e-04_cr6e+00_if12e-01_lr57e-01_bs6"
]
list_cfg_letkf_name

In [None]:
uhr_model = QGModel(cfg_cfd.uhr_base_config, show_input_cfg_info=False)

In [None]:
DATA_DIR = f"{ROOT_DIR}/data/four_dim_srda"

LR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/lr_pv_narrow_jet"
UHR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/uhr_pv_narrow_jet"

In [None]:
_result_dir = f"{ROOT_DIR}/python/results/four_dim_srda/{experiment_name}"
RESULT_DIR = f"{_result_dir}/analysis/use_narrow_jet"
FIG_DIR = f"{RESULT_DIR}/fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
ASSIMILATION_PERIOD = (
    cfg_cfd.da_config.segment_length - cfg_cfd.da_config.forecast_span - 1
)
FORECAST_SPAN = cfg_cfd.da_config.forecast_span

NUM_TIMES = (
    cfg_srda.dataset_config.max_start_time_index + ASSIMILATION_PERIOD + FORECAST_SPAN
)
NUM_TIMES_LR = cfg_cfd.time_config.end_time

# Define methods

In [None]:
def _preprocess(
    data: torch.Tensor, pv_min: float, pv_max: float, use_clipping: bool = False
) -> torch.Tensor:
    #
    # batch, time, z, y, x dims
    assert data.ndim == 5

    # normalization
    data = (data - pv_min) / (pv_max - pv_min)

    if use_clipping:
        data = torch.clamp(data, min=0.0, max=1.0)

    return data

In [None]:
def plot_maer_and_mssim_loss(
    *,
    dict_maer: dict[str, torch.Tensor],
    dict_mssim_loss: dict[str, torch.Tensor],
    time: torch.Tensor,
    list_ylim_maer: tuple[int, int],
    list_ylim_mssim_loss: list[int, int],
    base_font_size: int,
    list_fig_size_xy: tuple[float, float],
    num_xticks: int = 5,
    num_yticks: int = 5,
    save_fig: bool = False,
    fig_name: str = "maer_and_mssim_loss_lr_space_plots_for_paper",
):
    #
    plt.rcParams["font.size"] = base_font_size

    title_fs_scale = 1.4
    label_fs_scale = 1.2
    legend_fs_scale = 1.1
    tick_label_fs_scale = 1.0

    # grid_alpha = 0.8
    lw = 2.2

    #
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=list_fig_size_xy)

    # SRDA's results have NaN in i_s = 0, so skip it
    i_s = 1

    keys = dict_maer.keys()
    for key in keys:
        if key == "lr_fcst":
            c = "tab:red"
            ls = "-"
            label = "LR-Forecast"
        #
        elif key == "letkf":
            c = "tab:green"
            ls = "--"
            label = "HR-LETKF"
        else:
            continue

        # maer
        ax1.plot(
            time[i_s:],
            dict_maer[key][i_s:],
            c=c,
            lw=lw,
            linestyle=ls,
            label=label,
        )
        ax1.set_xticks(np.linspace(0, 200, num_xticks))
        ax1.set_ylim(list_ylim_maer[0], list_ylim_maer[1])
        ax1.set_yticks(np.linspace(list_ylim_maer[0], list_ylim_maer[1], num_yticks))

        ax1.set_title(
            "(a) MAE Ratio in LR space",
            fontsize=title_fs_scale * base_font_size,
            loc="left",
            pad=20,
        )
        ax1.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
        ax1.set_ylabel(
            "MAE Ratio", fontsize=label_fs_scale * base_font_size, labelpad=15
        )

        # ax1.grid(True, alpha=grid_alpha)

        ax1.tick_params(
            axis="both", which="major", labelsize=tick_label_fs_scale * base_font_size
        )

        # mssim loss

        # NaN in the SRDA results has already been removed.
        # so i_s isn't used here
        ax2.plot(
            time[i_s:],
            dict_mssim_loss[key],
            c=c,
            lw=lw,
            linestyle=ls,
            label=label,
        )
        ax2.set_xticks(np.linspace(0, 200, num_xticks))
        ax2.set_ylim(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1])
        ax2.set_yticks(
            np.linspace(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1], num_yticks)
        )

        ax2.set_title(
            "(b) MSSIM Loss in LR space",
            fontsize=title_fs_scale * base_font_size,
            loc="left",
            pad=20,
        )
        ax2.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
        ax2.set_ylabel(
            "MSSIM Loss", fontsize=label_fs_scale * base_font_size, labelpad=15
        )

        # ax2.grid(True, alpha=grid_alpha)

        ax2.tick_params(
            axis="both", which="major", labelsize=tick_label_fs_scale * base_font_size
        )

        legend = ax2.legend(
            fontsize=legend_fs_scale * base_font_size, edgecolor="black"
        )
        legend.get_frame().set_alpha(1.0)  # 背景の透明度
        legend.get_frame().set_edgecolor("black")  # 凡例の枠線の色
    #
    plt.tight_layout()

    #
    if save_fig:
        plt.savefig(
            f"{FIG_DIR}/{fig_name}.jpg",
            dpi=300,
            bbox_inches="tight",
        )

    plt.show()

# Prepare data

In [None]:
# Set it to use only data at analytical time
t_slice = ASSIMILATION_PERIOD

In [None]:
time = torch.arange(
    cfg_cfd.time_config.start_time,
    cfg_cfd.time_config.end_time,
    cfg_cfd.time_config.output_uhr_dt,
)[::t_slice]

In [None]:
all_gt = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    uhr_data_path = f"{UHR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end800_uhr_pv.npy"
    gt = np.load(f"{uhr_data_path}")
    gt = torch.from_numpy(gt)[:NUM_TIMES][::t_slice]
    all_gt.append(gt)

all_gt = torch.stack(all_gt, dim=0)

# Downscale to low resolution
lowpass_filter = LowPassFilter(
    nx_lr=cfg_cfd.lr_base_config.nx,
    ny_lr=cfg_cfd.lr_base_config.ny,
    nz_lr=cfg_cfd.lr_base_config.nz,
    nx_hr=cfg_cfd.uhr_base_config.nx,
    ny_hr=cfg_cfd.uhr_base_config.ny,
    nz_hr=cfg_cfd.uhr_base_config.nz,
    dtype=torch.complex128,
    device="cpu",
)

lr_all_gt = []
for it in range(all_gt.shape[1]):
    #
    gt = all_gt[:, it, :, :, :]
    lr_gt = lowpass_filter.apply(hr_grid_data=gt)
    lr_all_gt.append(lr_gt)

lr_all_gt = torch.stack(lr_all_gt, dim=1)

lr_all_gt.dtype, lr_all_gt.shape

In [None]:
all_lr_fcst = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    lr_data_path = f"{LR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end200_lr_pv.npy"
    lr_fcst = np.load(f"{lr_data_path}")
    lr_fcst = torch.from_numpy(lr_fcst)[:NUM_TIMES_LR]
    all_lr_fcst.append(lr_fcst)

all_lr_fcst = torch.stack(all_lr_fcst, dim=0)
all_lr_fcst.shape, all_lr_fcst.dtype

In [None]:
dict_letkf_fcsts = {}
for cfg_name in list_cfg_letkf_name:
    _path = f"{_result_dir}/letkf/perform_letkf_hr_using_uhr/use_narrow_jet/{cfg_name}/all_letkf_fcst.npy"
    _letkf_hr_fcst = np.load(_path)
    _letkf_hr_fcst = torch.from_numpy(_letkf_hr_fcst)
    #
    _key = cfg_name
    dict_letkf_fcsts[_key] = _letkf_hr_fcst[:, ::t_slice, ...]

# Downscale to low resolution
lowpass_filter = LowPassFilter(
    nx_lr=cfg_cfd.lr_base_config.nx,
    ny_lr=cfg_cfd.lr_base_config.ny,
    nz_lr=cfg_cfd.lr_base_config.nz,
    nx_hr=cfg_cfd.hr_base_config.nx,
    ny_hr=cfg_cfd.hr_base_config.ny,
    nz_hr=cfg_cfd.hr_base_config.nz,
    dtype=torch.complex128,
    device="cpu",
)

lr_letkf_fcst = []
for it in range(dict_letkf_fcsts[_key].shape[1]):
    #
    fcst = dict_letkf_fcsts[_key][:, it, :, :, :]
    lr_fcst = lowpass_filter.apply(hr_grid_data=fcst)
    lr_letkf_fcst.append(lr_fcst)

lr_letkf_fcst = torch.stack(lr_letkf_fcst, dim=1)

dict_letkf_fcsts[_key].shape, dict_letkf_fcsts[
    _key
].dtype, dict_letkf_fcsts.keys(), lr_letkf_fcst.shape

# Calc mae ratio

In [None]:
# lr

assert lr_all_gt.shape == all_lr_fcst.shape
lr_maer = calc_maer(all_gt=lr_all_gt, all_fcst=all_lr_fcst)

lr_maer.shape

In [None]:
# letkf

assert lr_all_gt.shape == lr_letkf_fcst.shape
letkf_maer = calc_maer(all_gt=lr_all_gt, all_fcst=lr_letkf_fcst)

letkf_maer.shape

# Calc mssim loss

## Normalization

In [None]:
pv_miv = cfg_srda.dataset_config.pv_min
pv_max = cfg_srda.dataset_config.pv_max

lr_all_gt_norm = _preprocess(
    data=lr_all_gt, pv_min=pv_miv, pv_max=pv_max, use_clipping=False
)
assert lr_all_gt_norm.min() >= 0 and lr_all_gt_norm.max() <= 1

all_lr_fcst_norm = _preprocess(
    data=all_lr_fcst, pv_min=pv_miv, pv_max=pv_max, use_clipping=False
)
assert all_lr_fcst_norm.min() >= 0 and all_lr_fcst_norm.max() <= 1


lr_letkf_fcst_norm = _preprocess(
    data=lr_letkf_fcst,
    pv_min=pv_miv,
    pv_max=pv_max,
    use_clipping=False,
)
assert lr_letkf_fcst_norm.min() >= 0 and lr_letkf_fcst_norm.max() <= 1

## Uniform window

In [None]:
num_batch = len(lr_all_gt_norm)

In [None]:
#
mssim_loss_params = {
    "window_3d_size": (5, 11, 11),
    "sigma_3d": (0.7, 1.5, 1.5),
    "value_magnitude": 1.0,
    "use_gaussian": False,
}

# lr
lr_mssim = []

for ib in range(num_batch):
    ssim = MSSIM(**mssim_loss_params)
    _r = calc_mssim(
        all_gt=lr_all_gt_norm[ib, 1:].float().unsqueeze(0).to(DEVICE_GPU),
        all_fcst=all_lr_fcst_norm[ib, 1:].float().unsqueeze(0).to(DEVICE_GPU),
        mssim=ssim,
    )
    lr_mssim.append(_r.cpu())
    #
    del _r
    torch.cuda.empty_cache()

# mean over batch dim and calc loss
lr_mssim_loss = 1.0 - torch.mean(torch.stack(lr_mssim), dim=0)
lr_mssim_loss = torch.mean(lr_mssim_loss, dim=1)

# letkf
letkf_mssim = []

for ib in range(num_batch):
    ssim = MSSIM(**mssim_loss_params)
    _r = calc_mssim(
        all_gt=lr_all_gt_norm[ib, 1:].float().unsqueeze(0).to(DEVICE_GPU),
        all_fcst=lr_letkf_fcst_norm[ib, 1:].float().unsqueeze(0).to(DEVICE_GPU),
        mssim=ssim,
    )
    letkf_mssim.append(_r.cpu())
    #
    del _r
    torch.cuda.empty_cache()

# mean over batch dim and calc loss
letkf_mssim_loss = 1.0 - torch.mean(torch.stack(letkf_mssim), dim=0)
letkf_mssim_loss = torch.mean(letkf_mssim_loss, dim=1)

assert lr_mssim_loss.shape == letkf_mssim_loss.shape
lr_mssim_loss.shape

# Plot

In [None]:
all_maer = {"lr_fcst": lr_maer, "letkf": letkf_maer}
all_mssim_loss = {"lr_fcst": lr_mssim_loss, "letkf": letkf_mssim_loss}

plot_maer_and_mssim_loss(
    dict_maer=all_maer,
    dict_mssim_loss=all_mssim_loss,
    time=time,
    list_fig_size_xy=[16, 6],
    list_ylim_maer=[0.0, 1.2],
    list_ylim_mssim_loss=[0.0, 0.1],
    base_font_size=20,
    num_xticks=6,
    num_yticks=6,
    save_fig=True,
)