In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import libraries

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.four_dim_srda.utils.io_pickle import read_pickle

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
# model_names = ["ConvTransNetVer01", "UNetMaxVitVer01", "UNetVitVer02"]
model_names = ["ConvTransNetVer01", "UNetMaxVitVer01"]

CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"

# srda
dict_cfg_srda_name = {}
for m_name in model_names:
    if m_name == "ConvTransNetVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_bias1_bs12_lr1e-04"
    #
    elif m_name == "UNetMaxVitVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nmb6_bias0_bs12_lr1e-04"
    #
    elif m_name == "UNetVitVer02":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nvb4_bias0_vits0_bs12_lr1e-04"

# config cfd
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE_CPU = "cpu"

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE_CPU

dict_cfg_srda_name

In [None]:
_result_dir = f"{ROOT_DIR}/python/results/four_dim_srda/{experiment_name}"
RESULT_DIR = f"{_result_dir}/analysis/use_narrow_jet"
FIG_DIR = f"{RESULT_DIR}/fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
ASSIMILATION_PERIOD = (
    cfg_cfd.da_config.segment_length - cfg_cfd.da_config.forecast_span - 1
)

# Define methods

In [None]:
def plot_maer_and_mssim_loss(
    *,
    dict_maer: dict[str, torch.Tensor],
    dict_mssim_loss: dict[str, torch.Tensor],
    time: torch.Tensor,
    first_peak_it: int,
    second_peak_it: int,
    list_ylim_maer: tuple[int, int],
    list_ylim_mssim_loss: list[int, int],
    base_font_size: int,
    list_fig_size_xy: tuple[float, float],
    num_xticks: int = 5,
    num_yticks: int = 5,
    save_fig: bool = False,
    fig_name: str = "maer_and_mssim_loss_plots_for_paper",
):
    #
    plt.rcParams["font.size"] = base_font_size

    title_fs_scale = 1.4
    label_fs_scale = 1.2
    legend_fs_scale = 1.1
    tick_label_fs_scale = 1.0

    # grid_alpha = 0.8
    lw = 2.2
    lw_scale = 0.8

    #
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=list_fig_size_xy)

    # SRDA's results have NaN in i_s = 0, so skip it
    i_s = 1

    keys = dict_maer.keys()
    for key in keys:
        if key == "ConvTransNetVer01":
            c = "tab:blue"
            ls = "-."
            label = "YO23"
        #
        elif key == "UNetMaxVitVer01":
            c = "tab:red"
            ls = "-"
            label = "4D-SRDA"
        #
        elif key == "letkf":
            c = "tab:green"
            ls = "--"
            label = "HR-LETKF"
        else:
            continue

        # maer
        ax1.plot(
            time[i_s:],
            dict_maer[key][i_s:],
            c=c,
            lw=lw,
            linestyle=ls,
            label=label,
        )
        ax1.set_xticks(np.linspace(0, 200, num_xticks))
        ax1.set_ylim(list_ylim_maer[0], list_ylim_maer[1])
        ax1.set_yticks(np.linspace(list_ylim_maer[0], list_ylim_maer[1], num_yticks))

        ax1.set_title(
            "(a) MAE Ratio in UHR space",
            fontsize=title_fs_scale * base_font_size,
            loc="left",
            pad=20,
        )
        ax1.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
        ax1.set_ylabel(
            "MAE Ratio", fontsize=label_fs_scale * base_font_size, labelpad=15
        )

        # ax1.grid(True, alpha=grid_alpha)

        ax1.tick_params(
            axis="both", which="major", labelsize=tick_label_fs_scale * base_font_size
        )

        # mssim loss

        # NaN in the SRDA results has already been removed.
        # so i_s isn't used here
        ax2.plot(
            time[i_s:],
            dict_mssim_loss[key],
            c=c,
            lw=lw,
            linestyle=ls,
            label=label,
        )
        ax2.set_xticks(np.linspace(0, 200, num_xticks))
        ax2.set_ylim(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1])
        ax2.set_yticks(
            np.linspace(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1], num_yticks)
        )

        ax2.set_title(
            "(b) MSSIM Loss in UHR space",
            fontsize=title_fs_scale * base_font_size,
            loc="left",
            pad=20,
        )
        ax2.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
        ax2.set_ylabel(
            "MSSIM Loss", fontsize=label_fs_scale * base_font_size, labelpad=15
        )

        # ax2.grid(True, alpha=grid_alpha)

        ax2.tick_params(
            axis="both", which="major", labelsize=tick_label_fs_scale * base_font_size
        )

        legend = ax2.legend(
            fontsize=legend_fs_scale * base_font_size, edgecolor="black"
        )
        legend.get_frame().set_alpha(1.0)
        legend.get_frame().set_edgecolor("black")

    # Add vertical lines at first_peak_it and second_peak_it
    ax1.axvline(
        x=time[first_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
    )
    ax2.axvline(
        x=time[first_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
    )
    ax1.axvline(
        x=time[second_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
    )
    ax2.axvline(
        x=time[second_peak_it], color="gray", linestyle="-", linewidth=lw_scale * lw
    )

    #
    plt.tight_layout()

    #
    if save_fig:
        plt.savefig(
            f"{FIG_DIR}/{fig_name}.jpg",
            dpi=300,
            bbox_inches="tight",
        )

    plt.show()

In [None]:
def plot_maer_and_mssim_loss_only_lr_fcst_and_letkf(
    *,
    dict_maer: dict[str, torch.Tensor],
    dict_mssim_loss: dict[str, torch.Tensor],
    time: torch.Tensor,
    list_ylim_maer: tuple[int, int],
    list_ylim_mssim_loss: list[int, int],
    base_font_size: int,
    list_fig_size_xy: tuple[float, float],
    num_xticks: int = 5,
    num_yticks: int = 5,
    save_fig: bool = False,
):
    #
    plt.rcParams["font.size"] = base_font_size

    title_fs_scale = 1.4
    label_fs_scale = 1.2
    legend_fs_scale = 1.1
    tick_label_fs_scale = 1.0

    # grid_alpha = 0.8
    lw = 2.2

    #
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=list_fig_size_xy)

    # SRDA's results have NaN in i_s = 0, so skip it
    i_s = 1

    keys = dict_maer.keys()
    for key in keys:
        if key == "lr_fcst":
            c = "tab:red"
            ls = "-"
            label = "LR-Forecast"
        #
        elif key == "letkf":
            c = "tab:green"
            ls = "--"
            label = "HR-LETKF"
        else:
            continue

        # maer
        ax1.plot(
            time[i_s:],
            dict_maer[key][i_s:],
            c=c,
            lw=lw,
            linestyle=ls,
            label=label,
        )
        ax1.set_xticks(np.linspace(0, 200, num_xticks))
        ax1.set_ylim(list_ylim_maer[0], list_ylim_maer[1])
        ax1.set_yticks(np.linspace(list_ylim_maer[0], list_ylim_maer[1], num_yticks))

        ax1.set_title(
            "(a) MAE Ratio in UHR space",
            fontsize=title_fs_scale * base_font_size,
            loc="left",
            pad=20,
        )
        ax1.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
        ax1.set_ylabel(
            "MAE Ratio", fontsize=label_fs_scale * base_font_size, labelpad=15
        )

        # ax1.grid(True, alpha=grid_alpha)

        ax1.tick_params(
            axis="both", which="major", labelsize=tick_label_fs_scale * base_font_size
        )

        # mssim loss

        # NaN in the SRDA results has already been removed.
        # so i_s isn't used here
        ax2.plot(
            time[i_s:],
            dict_mssim_loss[key],
            c=c,
            lw=lw,
            linestyle=ls,
            label=label,
        )
        ax2.set_xticks(np.linspace(0, 200, num_xticks))
        ax2.set_ylim(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1])
        ax2.set_yticks(
            np.linspace(list_ylim_mssim_loss[0], list_ylim_mssim_loss[1], num_yticks)
        )

        ax2.set_title(
            "(b) MSSIM Loss in UHR space",
            fontsize=title_fs_scale * base_font_size,
            loc="left",
            pad=20,
        )
        ax2.set_xlabel("Time", fontsize=label_fs_scale * base_font_size)
        ax2.set_ylabel(
            "MSSIM Loss", fontsize=label_fs_scale * base_font_size, labelpad=15
        )

        # ax2.grid(True, alpha=grid_alpha)

        ax2.tick_params(
            axis="both", which="major", labelsize=tick_label_fs_scale * base_font_size
        )

        legend = ax2.legend(
            fontsize=legend_fs_scale * base_font_size, edgecolor="black"
        )
        legend.get_frame().set_alpha(1.0)
        legend.get_frame().set_edgecolor("black")

    #
    plt.tight_layout()

    #
    if save_fig:
        plt.savefig(
            f"{FIG_DIR}/maer_and_mssim_loss_only_lr_fcst_and_letkf_plots_for_paper.jpg",
            dpi=300,
            bbox_inches="tight",
        )

    plt.show()

# Prepare data

In [None]:
# Set it to use only data at analytical time
t_slice = ASSIMILATION_PERIOD

In [None]:
time = torch.arange(
    cfg_cfd.time_config.start_time,
    cfg_cfd.time_config.end_time,
    cfg_cfd.time_config.output_uhr_dt,
)[::t_slice]

In [None]:
all_maer = read_pickle(f"{RESULT_DIR}/all_maer_result.pkl")
all_maer.keys(), all_maer["maer_selected_iz"].keys()

In [None]:
# Uniform window (window_size of z is 5)
all_mssim_loss_uniform_wsz5 = read_pickle(
    f"{RESULT_DIR}/all_mssim_loss_uniform_wsz5_result.pkl"
)

all_mssim_loss_uniform_wsz5.keys(), all_mssim_loss_uniform_wsz5[
    "mssim_loss_selected_iz"
].keys()

# Plot

In [None]:
plot_maer_and_mssim_loss(
    dict_maer=all_maer["maer_selected_iz"],
    dict_mssim_loss=all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"],
    time=time,
    first_peak_it=130,
    second_peak_it=70,
    list_fig_size_xy=[16, 6],
    list_ylim_maer=[0.04, 0.18],
    list_ylim_mssim_loss=[0.01, 0.08],
    base_font_size=20,
    num_xticks=6,
    num_yticks=6,
    save_fig=False,
)

In [None]:
plot_maer_and_mssim_loss(
    dict_maer=all_maer["maer_selected_iz_subset"],
    dict_mssim_loss=all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz_subset"],
    time=time,
    first_peak_it=130,
    second_peak_it=70,
    list_fig_size_xy=[16, 6],
    list_ylim_maer=[0.04, 0.18],
    list_ylim_mssim_loss=[0.01, 0.08],
    base_font_size=20,
    num_xticks=6,
    num_yticks=6,
    save_fig=False,
    fig_name="maer_and_mssim_loss_subset_plots_for_paper",
)

In [None]:
ordered_keys = ["lr_fcst", "letkf"]
dict_maer_ordered = {
    k: all_maer["maer_selected_iz"][k]
    for k in ordered_keys
    if k in all_maer["maer_selected_iz"]
}
dict_mssim_loss_ordered = {
    k: all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"][k]
    for k in ordered_keys
    if k in all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]
}

plot_maer_and_mssim_loss_only_lr_fcst_and_letkf(
    dict_maer=dict_maer_ordered,
    dict_mssim_loss=dict_mssim_loss_ordered,
    time=time,
    list_fig_size_xy=[16, 6],
    list_ylim_maer=[0.03, 1.2],
    list_ylim_mssim_loss=[0.01, 1.0],
    base_font_size=20,
    num_xticks=6,
    num_yticks=7,
    save_fig=False,
)