In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.dataloader import get_hr_file_paths
from src.model_maker import make_model
from src.sr_da_helper_2 import get_testdataset
from src.utils import read_pickle, set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
WRITE_EPS = True

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

TMP_DATA_DIR = "/workspace/all_data/notebook/paper_experiment_01/data"
if not os.path.exists(TMP_DATA_DIR):
    TMP_DATA_DIR = "./data"

CSV_DATA_DIR = "./csv"
FIG_DIR = f"{ROOT_DIR}/doc/fig_paper"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_01"
CONFIG_PATHS = sorted(glob.glob(f"{CONFIG_DIR}/*.yml"))

CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = (
        f"/workspace/all_data/data/pytorch/DL_results/{experiment_name}/{config_name}"
    )
    if not os.path.exists(_dir):
        _dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

In [None]:
CFD_DIR_NAME = "jet12"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 16

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

In [None]:
# https://matplotlib.org/stable/users/prev_whats_new/dflt_style_changes.html
DICT_COLORS = {
    "LR": "#1f77b4",
    "SRDA": "#ff7f0e",
    "EnKF": "#2ca02c",
    "SRDA (no obs.)": "#d62728",
}
DICT_LINE_STYLES = {"LR": ":", "SRDA": "-", "EnKF": "--", "SRDA (no obs.)": "-."}

# Define methods

In [None]:
def plot(
    dict_data: dict,
    obs: np.ndarray,
    obs_grid_interval: int = None,
    t: float = None,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -10,
    vmax_omega: float = 10,
    font_size: int = 22,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = font_size
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=True
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if "LR" in label:
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=HR_NX, ny=HR_NY, mode="nearest"
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY, mode="bicubic"
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "Ground truth"
        else:
            mae = np.mean(np.abs(gt - d)) / np.mean(np.abs(gt))
            ttl = label
            ttl = f"{label}\n(MAE ratio = {mae:.2f})"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY)
        else:
            assert d.shape == (LR_NX, LR_NY)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=2, c="k")
            prob = OBS_GRID_RATIO[obs_grid_interval] * 100
            ttl = f"{ttl}\n(obs. points: {prob:.2f} %)"

        ax.set_title(ttl)
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    if t is None:
        plt.suptitle(ttl_header)
    else:
        plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg", dpi=300)
        if WRITE_EPS:
            fig.savefig(f"{FIG_DIR}/{fig_file_name}.eps", dpi=300)

    plt.show()


def get_all_tmp_file_paths(config_name):
    sr_analysis_file_path = f"{TMP_DATA_DIR}/sr_analysis_{config_name}.npy"
    lr_omega_file_path = f"{TMP_DATA_DIR}/lr_omega_{config_name}.npy"
    hr_omega_file_path = f"{TMP_DATA_DIR}/hr_omega_{config_name}.npy"
    hr_obsrv_file_path = f"{TMP_DATA_DIR}/hr_obsrv_{config_name}.npy"

    return (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    )


def read_all_tmp_files(config_name):
    (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_tmp_file_paths(config_name)

    return (
        np.load(sr_analysis_file_path),
        np.load(lr_omega_file_path),
        np.load(hr_omega_file_path),
        np.load(hr_obsrv_file_path),
    )


def calc_pred(config_name: str, i_ensemble: int, i_cycle: int):

    config = CONFIGS[config_name]["config"]

    if os.path.exists("/workspace/all_data"):
        test_dataset = get_testdataset("/workspace/all_data", config)
    else:
        test_dataset = get_testdataset(ROOT_DIR, config)

    sr_model = make_model(config).to(DEVICE)
    sr_model.load_state_dict(
        torch.load(CONFIGS[config_name]["weight_path"], map_location=DEVICE)
    )
    _ = sr_model.eval()

    lr, obs, gt = test_dataset.get_specified_item(i_ensemble, i_cycle)
    lr = lr[None, ...].to(DEVICE)  # add batch dim
    obs = obs[None, ...].to(DEVICE)

    bias = test_dataset.vorticity_bias
    scale = test_dataset.vorticity_scale
    gt = gt * scale + bias

    with torch.no_grad():
        pred = sr_model(lr, obs)
        pred = pred * scale + bias
        pred = pred.detach().cpu()

    # Extract the last time index
    # and then exchange x and y axes
    gt = gt.squeeze().numpy()[-1].transpose()
    pred = pred.squeeze().numpy()[-1].transpose()
    assert gt.shape == pred.shape == (HR_NX, HR_NY - 1)

    # The last y index has all zero values
    ret_gt = np.zeros((HR_NX, HR_NY))
    ret_gt[:, :-1] = gt

    ret_pred = np.zeros((HR_NX, HR_NY))
    ret_pred[:, :-1] = pred

    return ret_gt, ret_pred

# Results

In [None]:
obs_grid_interval = 8
config_name = (
    f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd221958"
)

## Vorticity evolution

In [None]:
# Target path = /home/yuki_yasuda/workspace_lab/spatio-temporal-sr-da/data/pytorch/CFD/jet11/seed00234/seed00234_start52_end56_hr_omega.npy
# i_path = 189, i_ensemble = 5

In [None]:
seed = 234
i_ensemble = 5
if os.path.exists("/workspace/all_data"):
    cfd_dir_path = f"/workspace/all_data/data/pytorch/CFD/{CFD_DIR_NAME}/seed{seed:05}"
else:
    cfd_dir_path = f"{ROOT_DIR}/data/pytorch/CFD/{CFD_DIR_NAME}/seed{seed:05}"
cfd_file_paths = sorted(
    [p for p in glob.glob(f"{cfd_dir_path}/*_hr_omega_{i_ensemble:02}.npy")]
)

In [None]:
hr_omegas = []
for i, p in enumerate(cfd_file_paths):
    data = np.load(p)
    if i > 0:
        data = data[1:]
    hr_omegas.append(data)
hr_omegas = np.concatenate(hr_omegas, axis=0)

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 22

fig, axes = plt.subplots(2, 4, figsize=[20, 5], sharex=True, sharey=True)

xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
xs, ys = np.meshgrid(xs, ys, indexing="ij")
vmin, vmax = -10, 10

for i_time, ax in zip([16, 20, 24, 40, 56, 68, 76, 96], np.ravel(axes)):
    t = i_time * LR_DT * LR_NT
    omega = hr_omegas[i_time]
    cnts = ax.pcolormesh(xs, ys, omega, vmin=vmin, vmax=vmax, cmap="twilight_shifted")

    fig.colorbar(
        cnts,
        ax=ax,
        ticks=[vmin, vmin / 2, 0, vmax / 2, vmax],
        extend="both",
    )

    ax.set_title(f"Time = {t:.0f}")
    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)

plt.tight_layout()

fig.savefig(f"{FIG_DIR}/vorticity_evolution.jpg", dpi=300)

if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/vorticity_evolution.eps", dpi=300)

plt.show()

## Vorticity snapshot

In [None]:
(
    all_sr_analysis,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_tmp_files(config_name)

enkf_dir = f"{TMP_DATA_DIR}/EnKF/{config_name.split('_sd')[0]}"
assert os.path.exists(enkf_dir)

In [None]:
# i_ensembles = [43, 46, 48, 70, 91, 107, 154, 156, 163, 185]

for i_ensemble in [185]:

    enkf_lr = np.load(f"{enkf_dir}/ens_all_lr_{i_ensemble:04}.npy")
    enkf_lr = np.mean(enkf_lr, axis=0)  # ensemble mean
    assert enkf_lr.shape == (81, LR_NX, LR_NY)

    enkf_hr = read_pickle(f"{enkf_dir}/ens_mean_hr_{i_ensemble:04}.pickle")

    for i_cycle in [40, 80]:
        t = (i_cycle + START_TIME_INDEX) * DT

        dict_data = OrderedDict({})

        dict_data["HR"] = all_hr_omega[i_ensemble, i_cycle]
        dict_data["LR (no SR/DA)"] = all_lr_omega[i_ensemble, i_cycle]
        # dict_data["EnKF(LR)"] = enkf_lr[i_cycle]
        dict_data["EnKF"] = enkf_hr[i_cycle]
        dict_data["SRDA"] = all_sr_analysis[i_ensemble, i_cycle]

        plot(
            dict_data,
            t=t,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            obs_grid_interval=obs_grid_interval,
            figsize=[20, 4],
            ttl_header="(a) " if i_cycle == 40 else "(b) ",
            fig_file_name=f"snapshots_enkf_srda_t{int(t):02}",
            write_out=True,
            use_hr_space=True,
            font_size=22,
        )

## Error time series

In [None]:
dict_dfs = {}
for seed in [221958, 771155, 832180, 465838, 359178]:
    config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd{seed:06}"
    csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}.csv"
    dict_dfs[seed] = pd.read_csv(csv_path)

resolution = "HR"
grid_ratio = OBS_GRID_RATIO[obs_grid_interval] * 100

print(config_name)
csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}.csv"
df = pd.read_csv(csv_path)

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 22
fig, axes = plt.subplots(1, 2, sharex=True, figsize=[15, 4.5])

for ax, ycol in zip(axes, ["MAER", "SSIM_Gauss"]):
    xs = None
    for label in ["LR", "EnKF", "SRDA"]:
        all_ys = []
        for df in dict_dfs.values():
            all_ys.append(df[f"{ycol}_{resolution}_{label}"].values)
            if xs is None:
                xs = df["Time"].values
            else:
                np.testing.assert_array_equal(xs, df["Time"].values)
        all_ys = np.stack(all_ys)
        assert all_ys.shape == (5, 81)  # axis = batch, time

        ys = np.mean(all_ys, axis=0)
        diffs = all_ys - ys
        min_diffs = -np.min(diffs, axis=0)
        max_diffs = np.max(diffs, axis=0)
        errs = np.stack([min_diffs, max_diffs])
        assert errs.shape == (2, 81)
        assert np.max(errs) >= 0

        ls = DICT_LINE_STYLES[label]
        c = DICT_COLORS[label]
        if label == "LR":
            label = "LR (no SR/DA)"

        if ycol == "SSIM_Gauss":
            ys = 1 - ys
        ax.plot(xs, ys, label=label, ls=ls, color=c, lw=2.0)
        # ax.errorbar(xs, ys, yerr=errs, label=label, ls=ls, capsize=2)
    ax.set_xlabel("Time")

    ax.set_xticks(np.linspace(4, 24, 6))

    if ycol == "MAER":
        ax.set_ylim(0, 1)
        ax.set_yticks(np.linspace(0, 1, 6))
        ax.set_title(f"(a) MAE ratio in {resolution} space")
        ax.set_ylabel("MAE ratio")
    else:
        ax.set_ylim(0, 0.4)
        ax.set_yticks(np.linspace(0, 0.4, 5))
        ax.set_title(f"(b) MSSIM loss in {resolution} space")
        ax.set_ylabel("MSSIM loss")


lg = axes[-1].legend(
    bbox_to_anchor=(1.05, 1.0),
    loc="upper left",
    ncol=1,
    fontsize=20,
    framealpha=1,
    edgecolor="k",
)
plt.suptitle(f"Observation point ratio: {grid_ratio:.2f} %")
plt.tight_layout()


fig.savefig(
    f"{FIG_DIR}/err_time_series_enkf_srda.jpg",
    bbox_extra_artists=(lg,),
    # bbox_inches="tight",
    dpi=300,
)

if WRITE_EPS:
    fig.savefig(
        f"{FIG_DIR}/err_time_series_enkf_srda.eps",
        bbox_extra_artists=(lg,),
        # bbox_inches="tight",
        dpi=300,
    )

plt.show()

## Error dependency on obs grid ratio

In [None]:
resolution = "HR"

xs = []
ys = {
    "MAER": {"SRDA": [], "EnKF": []},
    "SSIM_Gauss": {"SRDA": [], "EnKF": []},
}
y_min_errs = {
    "MAER": {"SRDA": [], "EnKF": []},
    "SSIM_Gauss": {"SRDA": [], "EnKF": []},
}
y_max_errs = {
    "MAER": {"SRDA": [], "EnKF": []},
    "SSIM_Gauss": {"SRDA": [], "EnKF": []},
}

intervals = np.arange(4, 14, 2)

for interval in tqdm(intervals, total=len(intervals)):
    grid_ratio = OBS_GRID_RATIO[interval] * 100
    xs.append(grid_ratio)

    dict_dfs = {}
    for seed in [221958, 771155, 832180, 465838, 359178]:
        config_name = (
            f"lt4og{interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd{seed:06}"
        )
        csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}.csv"
        dict_dfs[seed] = pd.read_csv(csv_path)

    for ycol in ["MAER", "SSIM_Gauss"]:
        for label in ["SRDA", "EnKF"]:
            _ys = []
            for df in dict_dfs.values():
                _ys.append(df[f"{ycol}_{resolution}_{label}"].mean())
            _ys = np.array(_ys)
            assert _ys.shape == (5,)  # number of seeds

            y = np.mean(_ys)
            ys[ycol][label].append(y)
            y_min_errs[ycol][label].append(-np.min(_ys - y))
            y_max_errs[ycol][label].append(np.max(_ys - y))

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 22

fig, axes = plt.subplots(1, 2, sharex=True, figsize=[14, 4])

for ax, ycol in zip(axes, ["MAER", "SSIM_Gauss"]):
    for label, _ys in ys[ycol].items():
        ls = DICT_LINE_STYLES[label]
        c = DICT_COLORS[label]

        if ycol == "SSIM_Gauss":
            _ys = 1.0 - np.array(_ys)
            print(_ys)

        if label == "SRDA":
            errs = np.array([y_min_errs[ycol][label], y_max_errs[ycol][label]])

            # ax.errorbar(xs, _ys, yerr=errs, fmt="o-", label=label, capsize=5)
            ax.errorbar(
                xs, _ys, marker="o", label=label, ls=ls, color=c, lw=2, capsize=7
            )
        else:
            ax.errorbar(
                xs, _ys, marker="o", label=label, ls=ls, color=c, lw=2, capsize=7
            )

    ax.set_xlabel("Observation point ratio [%]")
    ax.set_xlim(0.5, 6.5)
    ax.set_xticks(np.linspace(0.5, 6.5, 7))

    if ycol == "MAER":
        ax.set_ylim(0.08, 0.4)
        ax.set_yticks(np.linspace(0.08, 0.4, 5))
        ax.set_title(f"(a) MAE ratio in {resolution} space")
        ax.set_ylabel("MAE ratio")
        ax.legend(loc="upper right", ncol=2, fontsize=20)

    else:
        ax.set_ylim(0, 0.16)
        ax.set_yticks(np.linspace(0, 0.16, 5))
        ax.set_title(f"(b) MSSIM loss in {resolution} space")
        ax.set_ylabel("MSSIM loss")
        ax.legend(loc="upper right", ncol=2, fontsize=20)


plt.tight_layout()

fig.savefig(f"{FIG_DIR}/err_obs_grid_dependency_enkf_srda.jpg", dpi=300)

if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/err_obs_grid_dependency_enkf_srda.eps", dpi=300)

plt.show()

# Discussion

In [None]:
df_agg_results = pd.read_csv(f"{CSV_DATA_DIR}/mae_scores_using_testdataset.csv")
df_agg_results.set_index("Unnamed: 0", inplace=True)
len(df_agg_results)

In [None]:
n_loops = 100

for c_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    try:
        config = CONFIGS[c_name]["config"]
        csv_file = f"{CSV_DATA_DIR}/hr_err_time_series_{c_name}_with_mae_ratio.csv"
        if not os.path.exists(csv_file):
            logger.error(f"{csv_file} does not exist!")
            continue

        df = pd.read_csv(csv_file)

        df_agg_results.loc[c_name, "AveErrRatioSR"] = df["ErrRatioSR"].mean()
        df_agg_results.loc[c_name, "MaxErrRatioSR"] = df["ErrRatioSR"].max()
        df_agg_results.loc[c_name, "MinErrRatioSR"] = df["ErrRatioSR"].min()
        df_agg_results.loc[c_name, "StdErrRatioSR"] = df["ErrRatioSR"].std()

        df_agg_results.loc[c_name, "UseObs"] = config["data"]["use_observation"]
        df_agg_results.loc[c_name, "ObsGridInterval"] = config["data"][
            "obs_grid_interval"
        ]
        df_agg_results.loc[c_name, "ObsGridRatio"] = (
            OBS_GRID_RATIO[config["data"]["obs_grid_interval"]] * 100
        )
        df_agg_results.loc[c_name, "ObsNoiseStd"] = config["data"]["obs_noise_std"]

        df_agg_results.loc[c_name, "LrTimeInterval"] = config["data"][
            "lr_time_interval"
        ]
        df_agg_results.loc[c_name, "UseSkipConn"] = config["model"][
            "use_global_skip_connection"
        ]
        df_agg_results.loc[c_name, "UseMixup"] = config["data"]["use_mixup"]
        df_agg_results.loc[c_name, "alpha"] = config["data"]["beta_dist_alpha"]
        df_agg_results.loc[c_name, "beta"] = config["data"]["beta_dist_beta"]
        df_agg_results.loc[c_name, "UseLrForecast"] = config["data"].get(
            "use_lr_forecast", True
        )
    except Exception as e:
        print(e)

## Necessity of Observations

In [None]:
obs_grid_interval = 8
grid_ratio = OBS_GRID_RATIO[obs_grid_interval] * 100
assert grid_ratio == 1.5625000116415322

In [None]:
dict_dfs = {}
dict_dfs_no_obs = {}

for seed in [221958, 771155, 832180, 465838, 359178]:
    config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd{seed:06}"
    config_name_no_obs = config_name.replace("og08", "og00")
    assert "og00" in config_name_no_obs

    dict_dfs[seed] = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio.csv"
    )
    dict_dfs_no_obs[seed] = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_no_obs}_with_mae_ratio.csv"
    )

### Error time series

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 22

fig = plt.figure(figsize=[10, 3.5])
ax = plt.subplot(111)

for label in ["LR", "SRDA (no obs.)", "SRDA"]:
    _dict = dict_dfs_no_obs if label == "SRDA (no obs.)" else dict_dfs
    xs = None
    all_ys = []
    for df in _dict.values():
        key = "ErrRatioLR(bicubic)" if label == "LR" else "ErrRatioSR"
        all_ys.append(df[key].values)
        if xs is None:
            xs = df["Time"].values
        else:
            np.testing.assert_array_equal(xs, df["Time"].values)
    all_ys = np.stack(all_ys)
    assert all_ys.shape == (5, 81)  # axis = batch, time

    ys = np.mean(all_ys, axis=0)
    diffs = all_ys - ys
    min_diffs = -np.min(diffs, axis=0)
    max_diffs = np.max(diffs, axis=0)
    errs = np.stack([min_diffs, max_diffs])
    assert errs.shape == (2, 81)
    assert np.max(errs) >= 0

    c = DICT_COLORS[label]
    ls = DICT_LINE_STYLES[label]

    legend_label = label
    if label == "LR":
        legend_label = "LR (no SR/DA)"

    ax.plot(xs, ys, label=legend_label, color=c, ls=ls, lw=2)
    # ax.errorbar(xs, ys, yerr=errs, label=label, ls=ls, capsize=2)

ax.set_xlabel("Time")
ax.set_xticks(np.linspace(4, 24, 6))

ax.set_ylim(0, 1)
ax.set_yticks(np.linspace(0, 1, 6))
ax.set_title("MAE ratio in HR space")
ax.set_ylabel("MAE ratio")

lg = ax.legend(
    bbox_to_anchor=(1.05, 1.0),
    loc="upper left",
    ncol=1,
    fontsize=20,
    framealpha=1,
    edgecolor="k",
)

# plt.suptitle(f"Observation grid ratio: {grid_ratio:.2f} %")
plt.tight_layout()


fig.savefig(
    f"{FIG_DIR}/err_time_series_srda_no_obs.jpg",
    bbox_extra_artists=(lg,),
    bbox_inches="tight",
    dpi=300,
)

if WRITE_EPS:
    fig.savefig(
        f"{FIG_DIR}/err_time_series_srda_no_obs.eps",
        bbox_extra_artists=(lg,),
        bbox_inches="tight",
        dpi=300,
    )
plt.show()

## Necessity of SR mixup

### Vorticity snapshots (SRDA cycle)

In [None]:
obs_grid_interval = 10
config_name = (
    f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd221958"
)

In [None]:
grid_ratio = OBS_GRID_RATIO[obs_grid_interval] * 100
assert grid_ratio == 1.0000000149011612

config_name_no_mixup = config_name.replace("muT", "muF")
assert "muF" in config_name_no_mixup and "muT" not in config_name_no_mixup

In [None]:
# obs_grid_interval = 4
# config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02"
# config_name_no_mixup = config_name.replace("muT", "muF")

In [None]:
(
    all_sr_analysis,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_tmp_files(config_name)

(all_sr_analysis_no_mixup, _, _, _) = read_all_tmp_files(config_name_no_mixup)

In [None]:
for i_ensemble in [185]:

    enkf_lr = np.load(f"{enkf_dir}/ens_all_lr_{i_ensemble:04}.npy")
    enkf_lr = np.mean(enkf_lr, axis=0)  # ensemble mean
    assert enkf_lr.shape == (81, LR_NX, LR_NY)

    enkf_hr = read_pickle(f"{enkf_dir}/ens_mean_hr_{i_ensemble:04}.pickle")

    for i_cycle in [40]:
        t = (i_cycle + START_TIME_INDEX) * DT

        dict_data = OrderedDict({})

        dict_data["HR"] = all_hr_omega[i_ensemble, i_cycle]
        dict_data["SRDA (no mixup)"] = all_sr_analysis_no_mixup[i_ensemble, i_cycle]
        dict_data["SRDA (mixup)"] = all_sr_analysis[i_ensemble, i_cycle]

        plot(
            dict_data,
            t=None,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            obs_grid_interval=obs_grid_interval,
            figsize=[15, 4],
            ttl_header="(b) Case with feedback cycles\n(domain shift occurs)",
            fig_file_name=f"snapshots_srda_no_mixup_t{int(t):02}",
            write_out=True,
            use_hr_space=True,
            font_size=22,
        )

### Error dependency on obs grid ratio

In [None]:
df = df_agg_results[df_agg_results["UseLrForecast"] == True].sort_values("ObsGridRatio")

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 22

fig, axes = plt.subplots(1, 2, figsize=[14, 4], sharex=True)

for ax, ycol in zip(axes, ["MAER_n100", "AveErrRatioSR"]):

    data = df[(df["UseMixup"] == True) & (df["UseObs"] == True) & (df["beta"] == 2)]
    assert len(data) == 25
    data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])

    ys = data["mean"]
    errs = np.array([ys - data["min"], data["max"] - ys])
    assert len(ys) == 5 and set(data["count"]) == {5}

    c = DICT_COLORS["SRDA"]
    ax.errorbar(
        data.index,
        ys,
        yerr=errs,
        color=c,
        fmt="o-",
        label="SRDA (mixup)",
        capsize=5,
        lw=2.0,
    )

    data = df[(df["UseMixup"] == False) & (df["UseObs"] == True)]
    assert len(data) == 25
    data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])

    ys = data["mean"]
    errs = np.array([ys - data["min"], data["max"] - ys])
    assert len(ys) == 5 and set(data["count"]) == {5}

    c = "#9467bd"
    ax.errorbar(
        data.index,
        ys,
        yerr=errs,
        color=c,
        fmt="o--",
        label="SRDA (no mixup)",
        capsize=5,
        lw=2.0,
    )

    # data = df[(df["UseMixup"] == False) & (df["UseObs"] == False)]
    # if ycol == "MAER_n100":
    #     m = data.groupby("ObsGridRatio")[ycol].mean()
    #     assert len(m) == 1
    #     ax.axhline(m.values[0], ls="-", color="k", lw=1)

    ax.set_ylim(0.06, 0.45)
    ax.set_yticks(np.linspace(0.08, 0.68, 5))
    ax.set_ylabel("MAE ratio")

    ax.set_xlabel("Observation point ratio [%]")
    ax.set_xlim([0.5, 6.5])
    ax.set_xticks(np.linspace(0.5, 6.5, 4))

    if ycol == "MAER_n100":
        ax.set_title("(a) Case without feedback cycles\n(no domain shift occurs)")
        ax.legend(loc="upper right", fontsize=20)
    else:
        ax.set_title("(b) Case with feedback cycles\n(domain shift occurs)")
        ax.legend(loc="upper right", fontsize=20)

plt.tight_layout()

fig.savefig(f"{FIG_DIR}/err_obs_grid_dependency_mixup.jpg", dpi=300)

if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/err_obs_grid_dependency_mixup.eps", dpi=300)

plt.show()

### Vorticity snapshots (no SRDA cycle)

In [None]:
assert config_name == "lt4og10_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd221958"
grid_ratio = OBS_GRID_RATIO[obs_grid_interval] * 100
assert grid_ratio == 1.0000000149011612

config_name_no_mixup = config_name.replace("muT", "muF")
assert "muF" in config_name_no_mixup and "muT" not in config_name_no_mixup

In [None]:
for i_ensemble in [185]:
    for i_cycle in [40]:
        t = (i_cycle + START_TIME_INDEX) * DT

        set_seeds(42, use_deterministic=True)
        gt, pred = calc_pred(config_name, i_ensemble, i_cycle)
        _, pred_no_mixup = calc_pred(config_name_no_mixup, i_ensemble, i_cycle)

        dict_data = OrderedDict({})

        dict_data["HR"] = gt
        dict_data["SRDA (no mixup)"] = pred_no_mixup
        dict_data["SRDA (mixup)"] = pred

        plot(
            dict_data,
            t=None,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            obs_grid_interval=obs_grid_interval,
            figsize=[15, 4],
            ttl_header="(a) Case without feedback cycles\n(no domain shift occurs)",
            fig_file_name=f"snapshots_srda_no_mixup_no_srda_cycle_t{int(t):02}",
            write_out=True,
            use_hr_space=True,
            font_size=22,
        )

## Necessity of LR forecast

### Error dependency on obs grid ratio

In [None]:
df = df_agg_results[
    (df_agg_results["UseObs"] == True) & (df_agg_results["beta"] == 2)
].sort_values("ObsGridRatio")

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 22

fig = plt.figure(figsize=[7, 3.8])
ax = plt.subplot(111)

ycol = "AveErrRatioSR"

data = df[(df["UseMixup"] == True) & (df["UseLrForecast"] == True)]
xs, all_ys = [], []
for ratio, grp in data.groupby("ObsGridRatio"):
    xs.append(ratio)
    assert len(grp[ycol]) == 5  # number of seeds
    all_ys.append(grp[ycol])
all_ys = np.stack(all_ys, axis=-1)  # batch, ratio

ys = np.mean(all_ys, axis=0)
diffs = all_ys - ys
min_diffs = -np.min(diffs, axis=0)
max_diffs = np.max(diffs, axis=0)
errs = np.stack([min_diffs, max_diffs])

assert errs.shape == (2, 5)
assert np.max(errs) >= 0

c = DICT_COLORS["SRDA"]
ax.errorbar(xs, ys, yerr=errs, fmt="o-", label="SRDA", color=c, lw=2, capsize=2)

data = df[(df["UseMixup"] == False) & (df["UseLrForecast"] == False)]
xs, all_ys = [], []
for ratio, grp in data.groupby("ObsGridRatio"):
    xs.append(ratio)
    assert len(grp[ycol]) == 5  # number of seeds
    all_ys.append(grp[ycol])
all_ys = np.stack(all_ys, axis=-1)  # batch, ratio

ys = np.mean(all_ys, axis=0)
diffs = all_ys - ys
min_diffs = -np.min(diffs, axis=0)
max_diffs = np.max(diffs, axis=0)
errs = np.stack([min_diffs, max_diffs])

assert errs.shape == (2, 5)
assert np.max(errs) >= 0

c = "#8c564b"
ax.errorbar(
    xs, ys, yerr=errs, fmt="o-", label="SRDA (no LR input)", color=c, lw=2, capsize=2
)

ax.set_ylim(0.07, 0.21)
ax.set_yticks(np.linspace(0.08, 0.20, 5))
ax.set_ylabel("MAE ratio")

ax.set_xlabel("Observation grid ratio [%]")
ax.set_xlim([0.5, 6.5])
ax.set_xticks(np.linspace(0.5, 6.5, 4))

ax.legend(loc="upper right", fontsize=20)

plt.tight_layout()

fig.savefig(f"{FIG_DIR}/err_obs_grid_dependency_lr_forecast.jpg", dpi=300)

if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/err_obs_grid_dependency_lr_forecast.eps", dpi=300)

plt.show()