In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import libraries

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.four_dim_srda.utils.io_pickle import read_pickle

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
# model_names = ["ConvTransNetVer01", "UNetMaxVitVer01", "UNetVitVer02"]
model_names = ["ConvTransNetVer01", "UNetMaxVitVer01"]

CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"

# srda
dict_cfg_srda_name = {}
for m_name in model_names:
    if m_name == "ConvTransNetVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_bias1_bs12_lr1e-04"
    #
    elif m_name == "UNetMaxVitVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nmb6_bias0_bs12_lr1e-04"
    #
    elif m_name == "UNetVitVer02":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nvb4_bias0_vits0_bs12_lr1e-04"

# config cfd
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE_CPU = "cpu"

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE_CPU

dict_cfg_srda_name

In [None]:
_result_dir = f"{ROOT_DIR}/python/results/four_dim_srda/{experiment_name}"
RESULT_DIR = f"{_result_dir}/analysis/use_narrow_jet/store_only_forecast"
FIG_DIR = f"{RESULT_DIR}/fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
ASSIMILATION_PERIOD = (
    cfg_cfd.da_config.segment_length - cfg_cfd.da_config.forecast_span - 1
)
FORECAST_SPAN = cfg_cfd.da_config.forecast_span

# Define methods

In [None]:
def plot_metric(ax, x, y, label, color, linestyle, linewidth=2.2):
    """Helper function to plot data on a single axis"""
    return ax.plot(
        x, y, label=label, color=color, linestyle=linestyle, linewidth=linewidth
    )


def setup_axis(
    ax,
    x_ticks,
    y_limits,
    y_ticks,
    title,
    x_label,
    y_label,
    title_fontsize,
    label_fontsize,
    tick_label_fontsize,
):
    """Helper function to configure the axis"""
    ax.set_xticks(x_ticks)
    ax.set_ylim(*y_limits)
    ax.set_yticks(y_ticks)
    ax.set_title(title, fontsize=title_fontsize, loc="left", pad=20)
    ax.set_xlabel(x_label, fontsize=label_fontsize)
    ax.set_ylabel(y_label, fontsize=label_fontsize, labelpad=15)
    ax.tick_params(axis="both", which="major", labelsize=tick_label_fontsize)


def plot_maer_and_mssim_loss(
    *,
    dict_maer: dict[str, torch.Tensor],
    dict_mssim_loss: dict[str, torch.Tensor],
    time: torch.Tensor,
    list_ylim_maer: tuple[int, int],
    list_ylim_mssim_loss: list[int, int],
    base_font_size: int,
    list_fig_size_xy: tuple[float, float],
    num_xticks: int = 5,
    num_yticks: int = 5,
    save_fig: bool = False,
    fig_name: str = "maer_and_mssim_loss_only_forecast_plots_for_paper",
):
    """Plot MAE ratio and MSSIM loss with configured subplots"""
    plt.rcParams["font.size"] = base_font_size

    # Font size scales
    title_fs_scale = 1.4
    label_fs_scale = 1.2
    legend_fs_scale = 1.0
    tick_label_fs_scale = 1.0

    lw = 2.2  # Line width

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=list_fig_size_xy)

    # Define styles for each dataset
    styles = {
        "srda_lr": {
            "color": "tab:blue",
            "linestyle": "-.",
            "label": "LR fluid simulation",
        },
        "srda": {"color": "tab:red", "linestyle": "-", "label": "4D-SRDA"},
        "letkf": {"color": "tab:green", "linestyle": "--", "label": "HR-LETKF"},
    }

    handles, labels = [], []

    # Plot data for each dataset
    for key1, style in styles.items():
        if key1 not in dict_maer:
            continue

        if key1 == "letkf":
            # Handle the "letkf" case directly
            key = "letkf"
            line1 = plot_metric(ax1, time, dict_maer[key], **style, linewidth=lw)
            line2 = plot_metric(ax2, time, dict_mssim_loss[key], **style, linewidth=lw)
        else:
            # Handle "srda_lr" and "srda" cases with nested keys
            for key in dict_maer[key1]:
                if key != "UNetMaxVitVer01":
                    continue
                line1 = plot_metric(
                    ax1, time, dict_maer[key1][key], **style, linewidth=lw
                )
                line2 = plot_metric(
                    ax2, time, dict_mssim_loss[key1][key], **style, linewidth=lw
                )

        # Append handles and labels for the legend
        handles.append(line2[0])
        labels.append(style["label"])

    # Configure the x-ticks for both axes
    x_ticks = np.linspace(0, 200, num_xticks)

    # Configure ax1 (MAE Ratio)
    setup_axis(
        ax1,
        x_ticks,
        list_ylim_maer,
        np.linspace(*list_ylim_maer, num_yticks),
        "(a) MAE Ratio in UHR space",
        "Time",
        "MAE Ratio",
        title_fs_scale * base_font_size,
        label_fs_scale * base_font_size,
        tick_label_fs_scale * base_font_size,
    )

    # Configure ax2 (MSSIM Loss)
    setup_axis(
        ax2,
        x_ticks,
        list_ylim_mssim_loss,
        np.linspace(*list_ylim_mssim_loss, num_yticks),
        "(b) MSSIM Loss in UHR space",
        "Time",
        "MSSIM Loss",
        title_fs_scale * base_font_size,
        label_fs_scale * base_font_size,
        tick_label_fs_scale * base_font_size,
    )

    # Add legend to the second axis
    legend = ax2.legend(
        handles=handles,
        labels=labels,
        fontsize=legend_fs_scale * base_font_size,
        edgecolor="black",
    )
    legend.get_frame().set_alpha(1.0)
    legend.get_frame().set_edgecolor("black")

    # Adjust layout
    plt.tight_layout()

    # Save figure if requested
    if save_fig:
        plt.savefig(
            f"{FIG_DIR}/{fig_name}.jpg",
            dpi=300,
            bbox_inches="tight",
        )

    # Display the plot
    plt.show()

# Prepare data

In [None]:
# Set it to use only data at last forecast time
t_slice = FORECAST_SPAN

In [None]:
# SRDA's results have NaN at it = 0, 1, ..., FORECAST_SPAN, so skip them
# And only the last forecast in one cycle is needed
# So, we set 2 * FORECAST_SPAN
time = np.arange(
    cfg_cfd.time_config.start_time,
    cfg_cfd.time_config.end_time,
    cfg_cfd.time_config.output_uhr_dt,
)[2 * FORECAST_SPAN :: t_slice]

In [None]:
all_maer = read_pickle(f"{RESULT_DIR}/all_maer_only_forecast_result.pkl")

logger.info(
    f"Keys in all_maer:\n {all_maer.keys()}\n"
    f"Keys in maer_selected_iz:\n {all_maer['maer_selected_iz'].keys()}\n"
    f"Shape: {all_maer['maer_selected_iz']['letkf'].shape}"
)

In [None]:
# Uniform window (window_size of z is 5)
all_mssim_loss_uniform_wsz5 = read_pickle(
    f"{RESULT_DIR}/all_mssim_loss_uniform_wsz5_only_forecast_result.pkl"
)

logger.info(
    f"Keys in all_mssim_loss_uniform_wsz5:\n {all_mssim_loss_uniform_wsz5.keys()}\n"
    f"mssim_loss_selected_iz keys:\n {all_mssim_loss_uniform_wsz5['mssim_loss_selected_iz'].keys()}\n"
    f"Shape: {all_mssim_loss_uniform_wsz5['mssim_loss_selected_iz']['letkf'].shape}"
)

# Plot

In [None]:
plot_maer_and_mssim_loss(
    dict_maer=all_maer["maer_selected_iz"],
    dict_mssim_loss=all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"],
    time=time,
    list_fig_size_xy=[16, 8],
    list_ylim_maer=[0.00, 0.40],
    list_ylim_mssim_loss=[0.00, 0.40],
    base_font_size=20,
    num_xticks=6,
    num_yticks=6,
    save_fig=False,
    fig_name="fig5",
)