In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import libraries

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from src.four_dim_srda.config.config_loader import load_config
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.four_dim_srda.utils.calc_statistics import (
    calc_maer,
    calc_maer_averaging_over_selected_iz,
    calc_mssim,
)
from src.four_dim_srda.utils.ssim import MSSIM
from src.four_dim_srda.utils.torch_interpolator import (
    interpolate_2d,
    interpolate_along_z,
)
from src.qg_model.qg_model import QGModel

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
# model_names = ["ConvTransNetVer01", "UNetMaxVitVer01", "UNetVitVer02"]
model_names = ["ConvTransNetVer01", "UNetMaxVitVer01"]

CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"

# srda
dict_cfg_srda_name = {}
for m_name in model_names:
    if m_name == "ConvTransNetVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_bias1_bs12_lr1e-04"
    #
    elif m_name == "UNetMaxVitVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nmb6_bias0_bs12_lr1e-04"
    #
    elif m_name == "UNetVitVer02":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nvb4_bias0_vits0_bs12_lr1e-04"

# config srda
CFG_SRDA_PATH = f"{CFG_DIR}/perform_4D_SRDA/{model_names[0]}/{dict_cfg_srda_name[model_names[0]]}.yml"


cfg_srda = load_config(model_name=model_names[0], config_path=CFG_SRDA_PATH)

# config cfd
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE_CPU = "cpu"
DEVICE_GPU = torch.device("cuda") if torch.cuda.is_available() else None

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE_CPU

dict_cfg_srda_name

In [None]:
list_cfg_letkf_name = [
    "na3e-03_letkf_cfg_ogx08_ogy08_ne100_ch16e-04_cr6e+00_if12e-01_lr57e-01_bs6"
]
list_cfg_letkf_name

In [None]:
uhr_model = QGModel(cfg_cfd.uhr_base_config, show_input_cfg_info=False)

In [None]:
DATA_DIR = f"{ROOT_DIR}/data/four_dim_srda"

LR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/lr_pv_narrow_jet"
UHR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/uhr_pv_narrow_jet"

In [None]:
_result_dir = f"{ROOT_DIR}/python/results/four_dim_srda/{experiment_name}"
RESULT_DIR = f"{_result_dir}/analysis/use_narrow_jet"
FIG_DIR = f"{RESULT_DIR}/fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
ASSIMILATION_PERIOD = (
    cfg_cfd.da_config.segment_length - cfg_cfd.da_config.forecast_span - 1
)
FORECAST_SPAN = cfg_cfd.da_config.forecast_span

# 対象とする時間は0 <= t <= (NUM_TIMES - 1) * cfg_cfd.time_config.output_hr_dt
NUM_TIMES = (
    cfg_srda.dataset_config.max_start_time_index + ASSIMILATION_PERIOD + FORECAST_SPAN
)
NUM_TIMES_LR = cfg_cfd.time_config.end_time

# Define methods

In [None]:
def _preprocess(
    data: torch.Tensor, pv_min: float, pv_max: float, use_clipping: bool = False
) -> torch.Tensor:
    #
    # batch, time, z, y, x dims
    assert data.ndim == 5

    # normalization
    data = (data - pv_min) / (pv_max - pv_min)

    if use_clipping:
        data = torch.clamp(data, min=0.0, max=1.0)

    return data

In [None]:
def plot_maer_and_mssim_loss(
    *,
    dict_maer: dict[str, torch.Tensor],
    dict_mssim_loss: dict[str, torch.Tensor],
    time: torch.Tensor,
    num_batch: int,
    first_peak_it: int,
    second_peak_it: int,
    list_ylim_maer: tuple[int, int],
    list_ylim_mssim_loss: list[int, int],
    base_font_size: int,
    list_fig_size_xy: tuple[float, float],
    num_xticks: int = 5,
    num_yticks: int = 5,
    save_fig: bool = False,
):
    #
    plt.rcParams["font.size"] = base_font_size

    title_fs_scale = 1.4
    label_fs_scale = 1.2
    legend_fs_scale = 1.1
    tick_label_fs_scale = 1.0

    # grid_alpha = 0.8
    lw = 2.2
    lw_scale = 0.8

    #
    for ib in range(num_batch):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=list_fig_size_xy)
        fig_name = f"maer_and_mssim_loss_ib{ib}_plot_for_paper"
        #
        keys = dict_maer.keys()
        for key in keys:
            if key == "ConvTransNetVer01":
                c = "tab:blue"
                ls = "-."
                label = "YO23"
            #
            elif key == "UNetMaxVitVer01":
                c = "tab:red"
                ls = "-"
                label = "4D-SRDA"
            #
            elif key == "letkf":
                c = "tab:green"
                ls = "--"
                label = "HR-LETKF"
            else:
                continue

            # SRDA's results have NaN in i_s = 0, so skip it
            i_s = 1

            # maer
            ax1.plot(
                time[i_s:],
                dict_maer[key][ib, i_s:],
                c=c,
                lw=lw,
                linestyle=ls,
                label=label,
            )
            ax1.set_xticks(np.linspace(0, 200, num_xticks))
            ax1.set_ylim(list_ylim_maer[0], list_ylim_maer[1])
            ax1.set_yticks(
                np.linspace(list_ylim_maer[0], list_ylim_maer[1], num_yticks)
            )

            ax1.set_title(
                f"(a) MAE Ratio in UHR space (ib={ib})",
                fontsize=title_fs_scale * base_font_size,
                loc="left",
                pad=20,
            )
            ax1.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
            ax1.set_ylabel(
                "MAE Ratio", fontsize=label_fs_scale * base_font_size, labelpad=15
            )

            # ax1.grid(True, alpha=grid_alpha)

            ax1.tick_params(
                axis="both",
                which="major",
                labelsize=tick_label_fs_scale * base_font_size,
            )

            # mssim loss

            # NaN in the SRDA results has already been removed.
            # so i_s isn't used here
            ax2.plot(
                time[i_s:],
                dict_mssim_loss[key][ib],
                c=c,
                lw=lw,
                linestyle=ls,
                label=label,
            )
            ax2.set_xticks(np.linspace(0, 200, num_xticks))
            ax2.set_ylim(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1])
            ax2.set_yticks(
                np.linspace(
                    list_ylim_mssim_loss[0], list_ylim_mssim_loss[1], num_yticks
                )
            )

            ax2.set_title(
                f"(b) MSSIM Loss in UHR space (ib={ib})",
                fontsize=title_fs_scale * base_font_size,
                loc="left",
                pad=20,
            )
            ax2.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
            ax2.set_ylabel(
                "MSSIM Loss", fontsize=label_fs_scale * base_font_size, labelpad=15
            )

            # ax2.grid(True, alpha=grid_alpha)

            ax2.tick_params(
                axis="both",
                which="major",
                labelsize=tick_label_fs_scale * base_font_size,
            )

            legend = ax2.legend(
                fontsize=legend_fs_scale * base_font_size, edgecolor="black"
            )
            legend.get_frame().set_alpha(1.0)  # 背景の透明度
            legend.get_frame().set_edgecolor("black")  # 凡例の枠線の色

        # Add vertical lines at first_peak_it and second_peak_it
        ax1.axvline(
            x=time[first_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
        )
        ax2.axvline(
            x=time[first_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
        )
        ax1.axvline(
            x=time[second_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
        )
        ax2.axvline(
            x=time[second_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
        )

        #
        plt.tight_layout()

        #
        if save_fig:
            plt.savefig(
                f"{FIG_DIR}/{fig_name}.jpg",
                dpi=300,
                bbox_inches="tight",
            )

        plt.show()

# Prepare data

In [None]:
# Set it to use only data at analytical time
t_slice = ASSIMILATION_PERIOD

In [None]:
time = torch.arange(
    cfg_cfd.time_config.start_time,
    cfg_cfd.time_config.end_time,
    cfg_cfd.time_config.output_uhr_dt,
)[::t_slice]

In [None]:
all_gt = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    uhr_data_path = f"{UHR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end800_uhr_pv.npy"
    gt = np.load(f"{uhr_data_path}")
    gt = torch.from_numpy(gt)[:NUM_TIMES][::t_slice]
    all_gt.append(gt)

all_gt = torch.stack(all_gt, dim=0)
all_gt.shape, all_gt.dtype

In [None]:
all_lr_fcst = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    lr_data_path = f"{LR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end200_lr_pv.npy"
    lr_fcst = np.load(f"{lr_data_path}")
    lr_fcst = torch.from_numpy(lr_fcst)[:NUM_TIMES_LR]
    all_lr_fcst.append(lr_fcst)

all_lr_fcst = torch.stack(all_lr_fcst, dim=0)
all_lr_fcst.shape, all_lr_fcst.dtype

In [None]:
dict_srda_fcsts = {}
for m_name in dict_cfg_srda_name.keys():
    srda_hr_fcst = []
    #
    for i_seed_uhr in range(
        cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
    ):
        _path = f"{_result_dir}/srda/{m_name}/use_narrow_jet/{dict_cfg_srda_name[m_name]}/UHR_seed_{i_seed_uhr:05}.npz"
        _result_npz = np.load(_path)
        _srda_hr_fcst = torch.from_numpy(_result_npz["srda_forecast"])
        srda_hr_fcst.append(_srda_hr_fcst[::t_slice])
    #
    srda_hr_fcst = torch.stack(srda_hr_fcst, dim=0)
    dict_srda_fcsts[m_name] = srda_hr_fcst

dict_srda_fcsts[m_name].shape, dict_srda_fcsts[m_name].dtype, dict_srda_fcsts.keys()

In [None]:
dict_letkf_fcsts = {}
for cfg_name in list_cfg_letkf_name:
    _path = f"{_result_dir}/letkf/perform_letkf_hr_using_uhr/use_narrow_jet/{cfg_name}/all_letkf_fcst.npy"
    _letkf_hr_fcst = np.load(_path)
    _letkf_hr_fcst = torch.from_numpy(_letkf_hr_fcst)
    #
    _key = cfg_name
    dict_letkf_fcsts[_key] = _letkf_hr_fcst[:, ::t_slice, ...]

dict_letkf_fcsts[_key].shape, dict_letkf_fcsts[_key].dtype, dict_letkf_fcsts.keys()

In [None]:
# interpolation

all_lr_fcst_uhr = []

for ib in range(len(all_lr_fcst)):
    _tmp = interpolate_2d(
        data=all_lr_fcst[ib],
        nx=cfg_cfd.uhr_base_config.nx,
        ny=cfg_cfd.uhr_base_config.ny,
        mode="nearest-exact",
    )
    _lr_fcst_uhr = interpolate_along_z(
        data=_tmp,
        nz=cfg_cfd.uhr_base_config.nz,
        mode="nearest-exact",
    )
    all_lr_fcst_uhr.append(_lr_fcst_uhr)
#
all_lr_fcst_uhr = torch.stack(all_lr_fcst_uhr, dim=0)
assert all_gt.shape == all_lr_fcst_uhr.shape

all_lr_fcst_uhr.shape

In [None]:
# interpolation

dict_srda_uhr_fcsts = {}
for m_model in dict_srda_fcsts.keys():
    srda_uhr_fcsts = []
    for ib in range(len(dict_srda_fcsts[m_model])):
        _tmp = interpolate_2d(
            data=dict_srda_fcsts[m_model][ib],
            nx=cfg_cfd.uhr_base_config.nx,
            ny=cfg_cfd.uhr_base_config.ny,
            mode="nearest-exact",
        )
        _srda_uhr_fcst = interpolate_along_z(
            data=_tmp,
            nz=cfg_cfd.uhr_base_config.nz,
            mode="nearest-exact",
        )
        srda_uhr_fcsts.append(_srda_uhr_fcst)
    #
    srda_uhr_fcsts = torch.stack(srda_uhr_fcsts, dim=0)
    assert all_gt.shape == srda_uhr_fcsts.shape

    dict_srda_uhr_fcsts[m_model] = srda_uhr_fcsts

dict_srda_uhr_fcsts[m_model].shape

In [None]:
# interpolation

dict_letkf_uhr_fcsts = {}
for key in dict_letkf_fcsts.keys():
    letkf_uhr_fcsts = []
    for ib in range(len(dict_letkf_fcsts[key])):
        _tmp = interpolate_2d(
            data=dict_letkf_fcsts[key][ib],
            nx=cfg_cfd.uhr_base_config.nx,
            ny=cfg_cfd.uhr_base_config.ny,
            mode="nearest-exact",
        )
        _letkf_uhr_fcsts = interpolate_along_z(
            data=_tmp,
            nz=cfg_cfd.uhr_base_config.nz,
            mode="nearest-exact",
        )
        letkf_uhr_fcsts.append(_letkf_uhr_fcsts)
    #
    letkf_uhr_fcsts = torch.stack(letkf_uhr_fcsts, dim=0)
    assert all_gt.shape == letkf_uhr_fcsts.shape

    dict_letkf_uhr_fcsts[key] = letkf_uhr_fcsts

dict_letkf_uhr_fcsts[key].shape

# Calc mae ratio

In [None]:
selected_iz = [18, 19, 20, 21, 22]
all_maer = {}

In [None]:
lr_maer = []
for ib in range(len(all_gt)):
    _maer = calc_maer_averaging_over_selected_iz(
        all_gt=all_gt[ib].unsqueeze(dim=0),
        all_fcst=all_lr_fcst_uhr[ib].unsqueeze(dim=0),
        selected_iz=selected_iz,
    )
    lr_maer.append(_maer)

all_maer["lr"] = torch.stack(lr_maer, dim=0)
all_maer["lr"].shape

In [None]:
for m_model in dict_srda_uhr_fcsts.keys():
    #
    assert all_gt.shape == dict_srda_uhr_fcsts[m_model].shape

    srda_maer = []
    for ib in range(len(all_gt)):
        _maer = calc_maer_averaging_over_selected_iz(
            all_gt=all_gt[ib].unsqueeze(dim=0),
            all_fcst=dict_srda_uhr_fcsts[m_model][ib].unsqueeze(dim=0),
            selected_iz=selected_iz,
        )
        srda_maer.append(_maer)
    srda_maer = torch.stack(srda_maer, dim=0)
    all_maer[m_model] = srda_maer

all_maer.keys(), all_maer[m_model].shape

In [None]:
for key in dict_letkf_uhr_fcsts.keys():
    #
    key_letkf = key
    assert all_gt.shape == dict_letkf_uhr_fcsts[key].shape

    letkf_maer = []
    for ib in range(len(all_gt)):
        _maer = calc_maer_averaging_over_selected_iz(
            all_gt=all_gt[ib].unsqueeze(dim=0),
            all_fcst=dict_letkf_uhr_fcsts[key][ib].unsqueeze(dim=0),
            selected_iz=selected_iz,
        )
        letkf_maer.append(_maer)
    letkf_maer = torch.stack(letkf_maer, dim=0)
    all_maer["letkf"] = letkf_maer

all_maer.keys(), all_maer["letkf"].shape

# Calc mssim loss

## Normalization

In [None]:
pv_miv = cfg_srda.dataset_config.pv_min
pv_max = cfg_srda.dataset_config.pv_max

all_gt_norm = _preprocess(data=all_gt, pv_min=pv_miv, pv_max=pv_max, use_clipping=False)
assert all_gt_norm.min() >= 0 and all_gt_norm.max() <= 1

all_lr_fcst_norm = _preprocess(
    data=all_lr_fcst_uhr, pv_min=pv_miv, pv_max=pv_max, use_clipping=False
)
assert all_lr_fcst_norm.min() >= 0 and all_lr_fcst_norm.max() <= 1

dict_srda_uhr_fcsts_norm = {}

for m_model in dict_srda_uhr_fcsts.keys():
    dict_srda_uhr_fcsts_norm[m_model] = _preprocess(
        dict_srda_uhr_fcsts[m_model], pv_min=pv_miv, pv_max=pv_max, use_clipping=False
    )

    # NaN を除外した上で、min と max を確認
    valid_values = dict_srda_uhr_fcsts_norm[m_model][
        ~torch.isnan(dict_srda_uhr_fcsts_norm[m_model])
    ]

    assert valid_values.min() >= 0 and valid_values.max() <= 1

letkf_uhr_fcsts_norm = _preprocess(
    data=dict_letkf_uhr_fcsts[key_letkf],
    pv_min=pv_miv,
    pv_max=pv_max,
    use_clipping=False,
)
assert letkf_uhr_fcsts_norm.min() >= 0 and letkf_uhr_fcsts_norm.max() <= 1

## Uniform window

In [None]:
num_batch = len(all_gt_norm)
all_mssim_loss = {}

In [None]:
#
mssim_loss_params = {
    "window_3d_size": (5, 11, 11),
    "sigma_3d": (0.7, 1.5, 1.5),
    "value_magnitude": 1.0,
    "use_gaussian": False,
}

# lr

lr_mssim = []
for ib in range(num_batch):
    ssim = MSSIM(**mssim_loss_params)
    _r = calc_mssim(
        all_gt=all_gt_norm[ib, 1:].unsqueeze(0).to(DEVICE_GPU),
        all_fcst=all_lr_fcst_norm[ib, 1:].unsqueeze(0).to(DEVICE_GPU),
        mssim=ssim,
    )
    lr_mssim.append(_r.cpu())
    #
    del _r
    torch.cuda.empty_cache()

lr_mssim_loss = 1.0 - torch.stack(lr_mssim)
all_mssim_loss["lr"] = torch.mean(lr_mssim_loss[:, :, selected_iz], dim=2)

# srda

for m_model in dict_srda_uhr_fcsts_norm.keys():
    #
    srda_mssim = []
    #
    for ib in range(num_batch):
        ssim = MSSIM(**mssim_loss_params)
        _r = calc_mssim(
            # fcst val is NaN at it = 0, so skip it by [ib, 1:]
            all_gt=all_gt_norm[ib, 1:].unsqueeze(0).to(DEVICE_GPU),
            all_fcst=dict_srda_uhr_fcsts_norm[m_model][ib, 1:]
            .unsqueeze(0)
            .to(DEVICE_GPU),
            mssim=ssim,
        )
        srda_mssim.append(_r.cpu())
        #
        del _r
        torch.cuda.empty_cache()

    srda_mssim_loss = 1.0 - torch.stack(srda_mssim)
    srda_mssim_loss = torch.mean(srda_mssim_loss[:, :, selected_iz], dim=2)
    all_mssim_loss[m_model] = srda_mssim_loss

# letkf

letkf_mssim = []
for ib in range(num_batch):
    ssim = MSSIM(**mssim_loss_params)
    _r = calc_mssim(
        all_gt=all_gt_norm[ib, 1:].unsqueeze(0).to(DEVICE_GPU),
        all_fcst=letkf_uhr_fcsts_norm[ib, 1:].unsqueeze(0).to(DEVICE_GPU),
        mssim=ssim,
    )
    letkf_mssim.append(_r.cpu())
    #
    del _r
    torch.cuda.empty_cache()

letkf_mssim_loss = 1.0 - torch.stack(letkf_mssim)
all_mssim_loss["letkf"] = torch.mean(letkf_mssim_loss[:, :, selected_iz], dim=2)

assert (
    all_mssim_loss["lr"].shape
    == all_mssim_loss[m_model].shape
    == all_mssim_loss["letkf"].shape
)

all_mssim_loss["lr"].shape

# Plot

In [None]:
plot_maer_and_mssim_loss(
    dict_maer=all_maer,
    dict_mssim_loss=all_mssim_loss,
    time=time,
    num_batch=num_batch,
    first_peak_it=130,
    second_peak_it=70,
    list_fig_size_xy=[16, 6],
    list_ylim_maer=[0.04, 0.20],
    list_ylim_mssim_loss=[0.01, 0.10],
    base_font_size=18,
    num_xticks=6,
    num_yticks=6,
    save_fig=False,
)