In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from src.four_dim_srda.config.config_loader import load_config
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.four_dim_srda.utils.io_pickle import read_pickle
from src.four_dim_srda.utils.torch_interpolator import (
    interpolate_2d,
    interpolate_along_z,
)
from src.qg_model.qg_model import QGModel

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
# model_names = ["ConvTransNetVer01", "UNetMaxVitVer01", "UNetVitVer02"]
model_names = ["ConvTransNetVer01", "UNetMaxVitVer01"]

CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"

# srda
dict_cfg_srda_name = {}
for m_name in model_names:
    if m_name == "ConvTransNetVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_bias1_bs12_lr1e-04"
    #
    elif m_name == "UNetMaxVitVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nmb6_bias0_bs12_lr1e-04"
    #
    elif m_name == "UNetVitVer02":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nvb4_bias0_vits0_bs12_lr1e-04"

# config srda
CFG_SRDA_PATH = f"{CFG_DIR}/perform_4D_SRDA/{model_names[0]}/{dict_cfg_srda_name[model_names[0]]}.yml"


cfg_srda = load_config(model_name=model_names[0], config_path=CFG_SRDA_PATH)

# config cfd
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE = "cpu"

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE

dict_cfg_srda_name

In [None]:
cfg_letkf_name = (
    "na3e-03_letkf_cfg_ogx08_ogy08_ne100_ch16e-04_cr6e+00_if12e-01_lr57e-01_bs6"
)
cfg_letkf_name

In [None]:
uhr_model = QGModel(cfg_cfd.uhr_base_config, show_input_cfg_info=False)

In [None]:
DATA_DIR = f"{ROOT_DIR}/data/four_dim_srda"

LR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/lr_pv_narrow_jet"
UHR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/uhr_pv_narrow_jet"

In [None]:
_result_dir = f"{ROOT_DIR}/python/results/four_dim_srda/{experiment_name}"
RESULT_DIR = f"{_result_dir}/analysis/use_narrow_jet"
FIG_DIR = f"{RESULT_DIR}/fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
ASSIMILATION_PERIOD = (
    cfg_cfd.da_config.segment_length - cfg_cfd.da_config.forecast_span - 1
)
FORECAST_SPAN = cfg_cfd.da_config.forecast_span

# 対象とする時間は0 <= t <= (NUM_TIMES - 1) * cfg_cfd.time_config.output_hr_dt
NUM_TIMES = (
    cfg_srda.dataset_config.max_start_time_index + ASSIMILATION_PERIOD + FORECAST_SPAN
)
NUM_TIMES_LR = cfg_cfd.time_config.end_time

# Prepare

In [None]:
# Set it to use only data at analytical time
t_slice = ASSIMILATION_PERIOD

In [None]:
all_gt = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    uhr_data_path = f"{UHR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end800_uhr_pv.npy"
    gt = np.load(f"{uhr_data_path}")
    gt = torch.from_numpy(gt)[:NUM_TIMES][::t_slice]
    all_gt.append(gt)

all_gt = torch.stack(all_gt, dim=0)
all_gt.shape, all_gt.dtype

In [None]:
all_lr_fcst = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    lr_data_path = f"{LR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end200_lr_pv.npy"
    lr_fcst = np.load(f"{lr_data_path}")
    lr_fcst = torch.from_numpy(lr_fcst)[:NUM_TIMES_LR]
    all_lr_fcst.append(lr_fcst)

all_lr_fcst = torch.stack(all_lr_fcst, dim=0)
all_lr_fcst.shape, all_lr_fcst.dtype

In [None]:
dict_srda_fcsts = {}
for m_name in dict_cfg_srda_name.keys():
    srda_hr_fcst = []
    #
    for i_seed_uhr in range(
        cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
    ):
        _path = f"{_result_dir}/srda/{m_name}/use_narrow_jet/{dict_cfg_srda_name[m_name]}/UHR_seed_{i_seed_uhr:05}.npz"
        _result_npz = np.load(_path)
        _srda_hr_fcst = torch.from_numpy(_result_npz["srda_forecast"])
        srda_hr_fcst.append(_srda_hr_fcst[::t_slice])
    #
    srda_hr_fcst = torch.stack(srda_hr_fcst, dim=0)
    dict_srda_fcsts[m_name] = srda_hr_fcst

dict_srda_fcsts[m_name].shape, dict_srda_fcsts[m_name].dtype, dict_srda_fcsts.keys()

In [None]:
path = f"{_result_dir}/letkf/perform_letkf_hr_using_uhr/use_narrow_jet/{cfg_letkf_name}/all_letkf_fcst.npy"
letkf_hr_fcsts = np.load(path)
letkf_hr_fcsts = torch.from_numpy(letkf_hr_fcsts[:, ::t_slice, ...])
letkf_hr_fcsts.shape, letkf_hr_fcsts.dtype

In [None]:
# interpolation

all_lr_fcst_uhr = []

for ib in range(len(all_lr_fcst)):
    _tmp = interpolate_2d(
        data=all_lr_fcst[ib],
        nx=cfg_cfd.uhr_base_config.nx,
        ny=cfg_cfd.uhr_base_config.ny,
        mode="nearest-exact",
    )
    _lr_fcst_uhr = interpolate_along_z(
        data=_tmp,
        nz=cfg_cfd.uhr_base_config.nz,
        mode="nearest-exact",
    )
    all_lr_fcst_uhr.append(_lr_fcst_uhr)
#
all_lr_fcst_uhr = torch.stack(all_lr_fcst_uhr, dim=0)
assert all_gt.shape == all_lr_fcst_uhr.shape

all_lr_fcst_uhr.shape

In [None]:
# interpolation

dict_srda_uhr_fcsts = {}
for m_model in dict_srda_fcsts.keys():
    srda_uhr_fcsts = []
    for ib in range(len(dict_srda_fcsts[m_model])):
        _tmp = interpolate_2d(
            data=dict_srda_fcsts[m_model][ib],
            nx=cfg_cfd.uhr_base_config.nx,
            ny=cfg_cfd.uhr_base_config.ny,
            mode="nearest-exact",
        )
        _srda_uhr_fcst = interpolate_along_z(
            data=_tmp,
            nz=cfg_cfd.uhr_base_config.nz,
            mode="nearest-exact",
        )
        srda_uhr_fcsts.append(_srda_uhr_fcst)
    #
    srda_uhr_fcsts = torch.stack(srda_uhr_fcsts, dim=0)
    assert all_gt.shape == srda_uhr_fcsts.shape

    dict_srda_uhr_fcsts[m_model] = srda_uhr_fcsts

dict_srda_uhr_fcsts[m_model].shape

In [None]:
# interpolation

letkf_uhr_fcsts = []
for ib in range(len(letkf_hr_fcsts)):
    _tmp = interpolate_2d(
        data=letkf_hr_fcsts[ib],
        nx=cfg_cfd.uhr_base_config.nx,
        ny=cfg_cfd.uhr_base_config.ny,
        mode="nearest-exact",
    )
    _letkf_uhr_fcsts = interpolate_along_z(
        data=_tmp,
        nz=cfg_cfd.uhr_base_config.nz,
        mode="nearest-exact",
    )
    letkf_uhr_fcsts.append(_letkf_uhr_fcsts)

letkf_uhr_fcsts = torch.stack(letkf_uhr_fcsts, dim=0)
assert all_gt.shape == letkf_uhr_fcsts.shape

letkf_uhr_fcsts.shape

In [None]:
all_maer = read_pickle(f"{RESULT_DIR}/all_maer_result.pkl")
all_maer.keys(), all_maer["maer"].keys(), all_maer["maer_selected_iz"][
    "ConvTransNetVer01"
].shape

# Plot

In [None]:
time = torch.arange(
    cfg_cfd.time_config.start_time,
    cfg_cfd.time_config.end_time,
    cfg_cfd.time_config.output_uhr_dt,
)
time = time[::t_slice]

uhr_grids = uhr_model.get_grids()

In [None]:
# 2行プロット

key_mapping = {
    "YO23": "ConvTransNetVer01",
    "HR-LETKF": "letkf",
    "4D-SRDA": "UNetMaxVitVer01",
}

all_iz = [-1, 0]
i_data = 0
# 0 for t=130
# 5 for t=70


base_font_size = 20
plt.rcParams["font.size"] = base_font_size
title_fs_scale = 1.6
label_fs_scale = 1.6
bar_fs_scale = 1.2

all_data = [
    (all_gt[i_data], uhr_grids, "UHR\nGround Truth"),
    (all_lr_fcst_uhr[i_data], uhr_grids, "LR-Forecast\n(Without DA)"),
]
for title, key in key_mapping.items():
    if title == "HR-LETKF":
        #
        all_data.append((letkf_uhr_fcsts[i_data], uhr_grids, title))
    else:
        #
        all_data.append((dict_srda_uhr_fcsts[key][i_data], uhr_grids, title))

for its in range(130, 131):
    # for its in range(70, 71):
    fig, axes = plt.subplots(
        len(all_iz),  # 行数：レイヤー数
        len(all_data),  # 列数：データセット数
        sharex=True,
        sharey=True,
        figsize=(5.4 * len(all_data), 3.85 * len(all_iz)),
    )
    # plt.suptitle(f"Potential Vorticity, $t$ = {time[its]:.2f}", fontsize=20)
    plt.subplots_adjust(wspace=0.005)

    for row_idx, iz in enumerate(all_iz):
        layer_label = (
            "Top\nLayer"
            if iz == -1
            else "Bottom\nLayer"
            if iz == 0
            else f"iz = {iz:02}"
        )
        for col_idx, (data, grids, title) in enumerate(all_data):
            ax = axes[row_idx, col_idx] if len(all_iz) > 1 else axes[col_idx]

            #
            vmin = torch.min(all_gt[-1, its, iz]).item()
            vmax = torch.max(all_gt[-1, its, iz]).item()
            _max = max(abs(vmin), abs(vmax))
            vmin, vmax = -_max, _max

            #
            cnt = ax.pcolormesh(
                grids.x[iz],
                grids.y[iz],
                data[its, iz],
                vmin=vmin,
                vmax=vmax,
                cmap="twilight_shifted",
            )

            if title in key_mapping and row_idx == 0:
                ax.set_title(
                    f"{title}", fontsize=title_fs_scale * base_font_size, pad=45
                )
                key = key_mapping[title]
                maer_value = all_maer["maer_selected_iz"][key][its]
                ax.text(
                    0.5,
                    1.15,
                    f"(MAE Ratio = {maer_value:.2f})",
                    fontsize=0.85 * title_fs_scale * base_font_size,
                    ha="center",
                    va="top",
                    transform=ax.transAxes,
                )
            elif row_idx == 0:
                ax.set_title(
                    f"{title}", fontsize=title_fs_scale * base_font_size, pad=10
                )

            if col_idx == 0:
                ax.set_ylabel(
                    f"{layer_label}",
                    fontsize=label_fs_scale * base_font_size,
                )
                ax.yaxis.set_label_coords(-0.05, 0.5)

            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            cbar = fig.colorbar(cnt, ax=ax, extend="both")
            if col_idx == len(all_data) - 1:
                # 右端のカラーバー
                cbar.ax.tick_params(labelsize=bar_fs_scale * base_font_size)
            else:
                # それ以外のカラーバー
                cbar.ax.tick_params(labelsize=0.5 * base_font_size)

            if title != "LR-Forecast\n(Without DA)" and row_idx == 1:
                r = patches.Rectangle(
                    xy=(0.1, 0.1),
                    width=2.9,
                    height=2.2,
                    edgecolor="g",
                    fill=False,
                    linewidth=6,
                )  # for t=130
                # r = patches.Rectangle(
                #     xy=(3.0, 0.1),
                #     width=2.9,
                #     height=2.5,
                #     edgecolor="g",
                #     fill=False,
                #     linewidth=6,
                # )  # for t=70
                ax.add_patch(r)

    plt.savefig(
        f"{FIG_DIR}/result_snapshot_two_rows_ts{its}_for_paper.jpg",
        bbox_inches="tight",
        dpi=300,
    )

    plt.show()