In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
from cfd_model.interpolator.torch_interpolator import interpolate
from src.dataloader import (
    make_dataloaders_vorticity_making_observation_inside_time_series_splitted,
)
from src.model_maker import make_model
from src.sr_da_helper_2 import get_testdataset
from src.utils import read_pickle, set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

# Define constants

In [None]:
WRITE_EPS = False
DPI = 100

In [None]:
DEVICE = "cuda:2"

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

FIG_DIR = "./fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

In [None]:
ASSIMILATION_PERIOD = 4

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

UHR_NX = 1024
UHR_NY = 513

DT = LR_DT * LR_NT

In [None]:
# https://matplotlib.org/stable/users/prev_whats_new/dflt_style_changes.html
DICT_COLORS = {
    "LR": "#1f77b4",
    "EnKF": "#2ca02c",
    "EnKF(bicubic)": "#2ca02c",
    "EnKF (bicubic)": "#2ca02c",
    "SRDA": "#ff7f0e",
    "ST-SRDA": "#ff7f0e",
    "EnKF(HR)": "#1f77b4",
    "EnKF (HR)": "#1f77b4",
    "SRDA (mixup)": "#ff7f0e",
    "SRDA (no mixup)": "#d62728",
}
DICT_LINE_STYLES = {
    "LR": ":",
    "EnKF": "--",
    "EnKF(bicubic)": "--",
    "EnKF (bicubic)": "--",
    "SRDA": "-",
    "ST-SRDA": "-",
    "EnKF(HR)": ":",
    "EnKF (HR)": ":",
    "SRDA (mixup)": "-",
    "SRDA (no mixup)": "-.",
}
DICT_LEGEND = {
    "LR": "LR (no SR/DA)",
    "EnKF": "EnKF-SR",
    "EnKF(bicubic)": "EnKF-SR",
    "EnKF (bicubic)": "EnKF-SR",
    "EnKF(Bicubic)": "EnKF-SR",
    "EnKF (Bicubic)": "EnKF-SR",
    "SRDA": "ST-SRDA",
    "ST-SRDA": "ST-SRDA",
    "EnKF(HR)": "EnKF-HR",
    "EnKF (HR)": "EnKF-HR",
    "SRDA (mixup)": "ST-SRDA (mixup)",
    "SRDA (no mixup)": "ST-SRDA (no mixup)",
}

# Define methods

In [None]:
def get_uhr_and_hr_omegas(uhr_result_dir: str, num_times: int = 96):
    all_uhr_omegas = []
    for path in sorted(glob.glob(f"{uhr_result_dir}/*.npy")):
        uhr = torch.from_numpy(np.load(path)).squeeze()
        assert uhr.shape == (UHR_NX, UHR_NY)
        all_uhr_omegas.append(uhr)
    # Stack along time dim
    all_uhr_omegas = torch.stack(all_uhr_omegas)[:num_times]
    assert all_uhr_omegas.shape == (num_times, UHR_NX, UHR_NY)

    tmp = all_uhr_omegas[:, None, :, 1:]
    _omegas = F.avg_pool2d(tmp, kernel_size=8).squeeze()

    all_hr_omegas = torch.zeros((num_times, HR_NX, HR_NY), dtype=_omegas.dtype)
    all_hr_omegas[:, :, 1:] = _omegas

    return all_uhr_omegas, all_hr_omegas


def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    gt_label: str,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    vmin_omega: float = -10,
    vmax_omega: float = 10,
    font_size: int = 22,
    obs_grid_interval: int = 8,
    dot_size: float = 2,
    dpi: int = DPI,
    draw_pdf: bool = False,
    write_eps: bool = WRITE_EPS,
):

    xs = np.linspace(0, 2 * np.pi, num=UHR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=UHR_NY, endpoint=True)
    uhr_x, uhr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.size"] = font_size
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        d = np.squeeze(data)

        if d.shape == (UHR_NX, UHR_NY):
            x, y = uhr_x, uhr_y
        elif d.shape == (HR_NX, HR_NY):
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        if label == gt_label:
            gt = d
            ttl = label
        else:
            _d = interpolate(
                torch.from_numpy(d[None, ...]),
                nx=gt.shape[0],
                ny=gt.shape[1],
                mode="bicubic",
            )
            _d = _d.squeeze().numpy()
            assert _d.shape == gt.shape
            maer = np.mean(np.abs(gt - _d)) / np.mean(np.abs(gt))
            ttl = label
            ttl = f"{label}\n(MAE ratio = {maer:.2f})"

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == gt_label and obs is not None:
            assert obs.shape == (HR_NX, HR_NY)
            o = np.squeeze(obs).flatten()
            obs_x = hr_x.flatten()[~np.isnan(o)]
            obs_y = hr_y.flatten()[~np.isnan(o)]
            print(np.sum(~np.isnan(o)) / len(o) * 100)
            ax.scatter(obs_x, obs_y, marker=".", s=dot_size, c="k")
            if obs_grid_interval is not None:
                prob = OBS_GRID_RATIO[obs_grid_interval] * 100
                ttl = f"{ttl}\n(obs. points: {prob:.2f} %)"

        ax.set_title(ttl)
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    if t is None:
        plt.suptitle(ttl_header)
    else:
        plt.suptitle(f"{ttl_header}Time = {np.round(t, 2)}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg", dpi=dpi)
        if write_eps:
            fig.savefig(f"{FIG_DIR}/{fig_file_name}.eps", dpi=dpi)
        if draw_pdf:
            fig.savefig(f"{FIG_DIR}/{fig_file_name}.pdf", dpi=dpi)

    plt.show()


def calc_pred(
    config: dict,
    weight_path: str,
    i_ensemble: int,
    i_cycle: int,
    min_start_time_index: int = 12,
    max_start_time_index: int = 88,
    start_time_index: int = 12,
):

    test_dataset = get_testdataset(
        ROOT_DIR,
        config,
        min_start_time_index=min_start_time_index,
        max_start_time_index=max_start_time_index,
    )

    sr_model = make_model(config).to(DEVICE)
    sr_model.load_state_dict(torch.load(weight_path, map_location=DEVICE))
    _ = sr_model.eval()

    lr, obs, gt = test_dataset.get_specified_item(
        i_ensemble, i_cycle, start_time_index=start_time_index
    )
    lr = lr[None, ...].to(DEVICE)  # add batch dim
    obs = obs[None, ...].to(DEVICE)

    bias = test_dataset.vorticity_bias
    scale = test_dataset.vorticity_scale
    gt = gt * scale + bias

    with torch.no_grad():
        pred = sr_model(lr, obs)
        pred = pred * scale + bias
        pred = pred.detach().cpu()

    # Extract the last time index
    # and then exchange x and y axes
    gt = gt.squeeze().numpy()[ASSIMILATION_PERIOD].transpose()
    pred = pred.squeeze().numpy()[ASSIMILATION_PERIOD].transpose()
    assert gt.shape == pred.shape == (HR_NX, HR_NY - 1)

    # The last y index has all zero values
    ret_gt = np.zeros((HR_NX, HR_NY))
    ret_gt[:, :-1] = gt

    ret_pred = np.zeros((HR_NX, HR_NY))
    ret_pred[:, :-1] = pred

    return ret_gt, ret_pred

# Vorticity inference: UHR

In [None]:
GRID_INTERVAL = 10
I_SEED_UHR = 9996
# 221958, 771155, 832180, 465838, 359178
OBS_SRDA_SEED = 771155
DATA_DIR = f"{ROOT_DIR}/pytorch/notebook/paper_experiment_05/data"
CONFIG_NAME = f"lt4og{GRID_INTERVAL:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{OBS_SRDA_SEED}"
UHR_RESULT_DIR = f"{ROOT_DIR}/data/pytorch/CFD/jet27/seed{I_SEED_UHR:05}"

In [None]:
uhr_omegas, _ = get_uhr_and_hr_omegas(UHR_RESULT_DIR)
uhr_omegas.shape

In [None]:
output_obs_npz_file_path = (
    f"{DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_{CONFIG_NAME}.npz"
)
all_data = np.load(output_obs_npz_file_path)
hr_obs = torch.from_numpy(all_data["hr_obs"])
sr_forecast = torch.from_numpy(all_data["sr_forecast"])
hr_obs.shape, sr_forecast.shape

In [None]:
output_lr_file_path = f"{DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}_ens_bicubic_lr.npy"
output_hr_file_path = f"{DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}_ens_bicubic_mean_hr.pickle"
enkf_bicubic_lrs = np.load(output_lr_file_path)
dict_enkf_bicubic_hr_analysis = read_pickle(output_hr_file_path)

enkf_bicubic_lrs.shape

In [None]:
output_hr_enkf_file_path = f"{DATA_DIR}/UHR_seed_{I_SEED_UHR:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}_enkf_hr_mean.npy"
enkf_hrs = np.load(output_hr_enkf_file_path)
enkf_hrs.shape

In [None]:
for i in [24, 60, 80]:
    dict_data = {
        "Ground truth (UHR)": uhr_omegas[i].numpy(),
        "EnKF-SR": dict_enkf_bicubic_hr_analysis[i],
        "EnKF-HR": enkf_hrs[i],
        "ST-SRDA": sr_forecast[i].numpy(),
    }

    ttl_header = None
    if i == 24:
        ttl_header = "(a) "
    if i == 60:
        ttl_header = "(b) "
    if i == 80:
        ttl_header = "(c) "

    t = i * LR_DT * LR_NT

    plot(
        dict_data=dict_data,
        t=t,
        obs=hr_obs[i].numpy(),
        figsize=[20, 4],
        gt_label="Ground truth (UHR)",
        dot_size=3,
        obs_grid_interval=None,
        ttl_header=ttl_header,
        fig_file_name=f"snapshots_uhr_enkf_srda_t{int(t):02}",
        write_out=True,
        font_size=22,
        dpi=25.0,
        draw_pdf=False,
        write_eps=False,
    )

# Error time series: UHR

In [None]:
GRID_INTERVAL = 8
I_SEEDS_UHR = np.arange(9950, 10000)
OBS_SRDA_SEED = 771155
# 221958, 771155, 832180, 465838, 359178

cols = None
sum_arrays = None
df_errs = None

for seed in I_SEEDS_UHR:
    csv_path = f"{ROOT_DIR}/pytorch/notebook/paper_experiment_05/csv/error_time_series_UHR_seed_{seed:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}.csv"
    df_errs = pd.read_csv(csv_path)

    if cols is None:
        cols = list(df_errs.columns)
        sum_arrays = df_errs.values
    else:
        assert cols == list(df_errs.columns)
        assert sum_arrays.shape == df_errs.values.shape
        sum_arrays += df_errs.values

if len(I_SEEDS_UHR) > 1:
    df_errs = pd.DataFrame(sum_arrays / len(I_SEEDS_UHR), columns=cols)

In [None]:
for win in [11, 21, 31, 51, 71]:

    plt.rcParams["font.size"] = 22
    fig, axes = plt.subplots(1, 2, sharex=True, figsize=[15, 4.5])

    resolution = "UHR"

    for ax, ycol in zip(axes, ["MAER", "SSIMLoss"]):
        for label in ["EnKF", "EnKF(HR)", "SRDA"]:
            xs = df_errs["Time"].values[1:]
            if "SSIM" not in ycol:
                ys = df_errs[f"{resolution}_{ycol}_{label}"].values[1:]
            else:
                ys = df_errs[f"{resolution}_{ycol}_{label}_Gauss_{win:02}"].values[1:]

            ls = DICT_LINE_STYLES[label]
            c = DICT_COLORS[label]
            legend_label = DICT_LEGEND[label]
            ax.plot(xs, ys, ls=ls, c=c, label=legend_label, lw=2)

        ax.set_xlabel("Time")
        ax.set_xticks(np.linspace(0, 24, 7))
        ax.axvline(6, ls="-", c="k", lw=0.5)
        ax.axvline(15, ls="-", c="k", lw=0.5)
        ax.axvline(20, ls="-", c="k", lw=0.5)

        if ycol == "MAER":
            ax.set_ylim(0, 0.6)
            ax.set_yticks(np.linspace(0, 0.6, 6))
            ax.set_title(f"(a) MAE ratio in {resolution} space")
            ax.set_ylabel("MAE ratio")
        else:
            ax.set_ylim(0, 0.3)
            ax.set_yticks(np.linspace(0, 0.3, 6))
            ax.set_title(
                f"(b) MSSIM loss in {resolution} space\nGaussian window size = {win:02}"
            )
            ax.set_ylabel("MSSIM loss")

    lg = axes[-1].legend(
        bbox_to_anchor=(1.05, 1.0),
        loc="upper left",
        ncol=1,
        fontsize=20,
        framealpha=1,
        edgecolor="k",
    )

    grid_ratio = OBS_GRID_RATIO[GRID_INTERVAL] * 100
    plt.tight_layout()

    plt.show()

In [None]:
for win in [11, 21, 31, 51, 71]:

    plt.rcParams["font.size"] = 22
    fig, axes = plt.subplots(1, 2, sharex=True, figsize=[15, 4.5])

    resolution = "UHR"

    for ax, ycol in zip(axes, ["MAER", "SSIMLoss"]):
        for label in ["EnKF", "EnKF(HR)", "SRDA"]:
            xs = df_errs["Time"].values[1:]
            if "SSIM" not in ycol:
                ys = df_errs[f"{resolution}_{ycol}_{label}"].values[1:]
            else:
                ys = df_errs[f"{resolution}_{ycol}_{label}_Rect_{win:02}"].values[1:]

            ls = DICT_LINE_STYLES[label]
            c = DICT_COLORS[label]
            legend_label = DICT_LEGEND[label]
            ax.plot(xs, ys, ls=ls, c=c, label=legend_label, lw=2)

        ax.set_xlabel("Time")
        ax.set_xticks(np.linspace(0, 24, 7))
        ax.axvline(6, ls="-", c="k", lw=0.5)
        ax.axvline(15, ls="-", c="k", lw=0.5)
        ax.axvline(20, ls="-", c="k", lw=0.5)

        if ycol == "MAER":
            ax.set_ylim(0, 0.6)
            ax.set_yticks(np.linspace(0, 0.6, 6))
            ax.set_title(f"(a) MAE ratio in {resolution} space")
            ax.set_ylabel("MAE ratio")
        else:
            ax.set_ylim(0, 0.3)
            ax.set_yticks(np.linspace(0, 0.3, 6))
            ax.set_title(
                f"(b) MSSIM loss in {resolution} space\nRect window size = {win:02}"
            )
            ax.set_ylabel("MSSIM loss")

    lg = axes[-1].legend(
        bbox_to_anchor=(1.05, 1.0),
        loc="upper left",
        ncol=1,
        fontsize=20,
        framealpha=1,
        edgecolor="k",
    )

    grid_ratio = OBS_GRID_RATIO[GRID_INTERVAL] * 100
    plt.tight_layout()

    plt.show()

# Average errors: UHR

In [None]:
OBS_SRDA_SEED = 771155
resolution = "UHR"

In [None]:
ALL_SEEDS = np.arange(9950, 10000)

In [None]:
dict_means = {
    "MAER": {"EnKF": [], "EnKF(HR)": [], "SRDA": []},
    "SSIMLoss": {"EnKF": [], "EnKF(HR)": [], "SRDA": []},
}
dict_stds = {
    "MAER": {"EnKF": [], "EnKF(HR)": [], "SRDA": []},
    "SSIMLoss": {"EnKF": [], "EnKF(HR)": [], "SRDA": []},
}


for GRID_INTERVAL in [4, 6, 8, 12]:
    for ycol in ["MAER", "SSIMLoss"]:

        for label in ["EnKF", "EnKF(HR)", "SRDA"]:
            all_ys = []
            for seed in ALL_SEEDS:
                # if GRID_INTERVAL == 10 and seed == 9996:
                #     continue
                csv_path = f"{ROOT_DIR}/pytorch/notebook/paper_experiment_05/csv/error_time_series_UHR_seed_{seed:05}_og{GRID_INTERVAL:02}_SRDA_seed_{OBS_SRDA_SEED}.csv"
                df_errs = pd.read_csv(csv_path)
                if "SSIM" not in ycol:
                    all_ys.append(df_errs[f"{resolution}_{ycol}_{label}"].values)
                else:
                    all_ys.append(
                        df_errs[f"{resolution}_{ycol}_{label}_Gauss_31"].values
                    )
            all_ys = np.stack(all_ys, axis=0)

            means = np.nanmean(all_ys, axis=1)
            assert means.shape == (len(ALL_SEEDS),)

            mean = np.nanmean(means)
            std = np.nanstd(means, ddof=1)

            dict_means[ycol][label].append(mean)
            dict_stds[ycol][label].append(std)

            # print(
            #     f"GI = {GRID_INTERVAL}, {resolution}_{ycol}_{label}: {np.round(mean, 4)} \pm {np.round(std, 4)}"
            # )

In [None]:
grids = [4, 6, 8, 12]


plt.rcParams["font.size"] = 22

fig, axes = plt.subplots(1, 2, sharex=True, figsize=[15, 4])

xs = [OBS_GRID_RATIO[g] * 100 for g in grids if g != 10]

for ax, ycol in zip(axes, ["MAER", "SSIMLoss"]):
    for label in ["EnKF", "EnKF(HR)", "SRDA"]:

        means = dict_means[ycol][label]
        stds = dict_stds[ycol][label]

        ls = DICT_LINE_STYLES[label]
        c = DICT_COLORS[label]
        legend = DICT_LEGEND[label]
        ax.errorbar(
            xs,
            means,
            yerr=stds,
            marker="o",
            ls=ls,
            color=c,
            lw=2,
            label=legend,
            capsize=7,
        )

    ax.set_xlabel("Observation point ratio [%]")
    ax.set_xlim(0.5, 6.5)
    ax.set_xticks(np.linspace(0.5, 6.5, 7))
    ax.axvline(OBS_GRID_RATIO[8] * 100, ls="-", c="k", lw=0.5)

    if ycol == "MAER":
        ax.set_ylim(0.1, 0.4)
        ax.set_yticks(np.linspace(0.1, 0.4, 6))
        ax.set_title(f"(a) MAE ratio in {resolution} space")
        ax.set_ylabel("MAE ratio")

    else:
        ax.set_ylim(0.03, 0.1)
        ax.set_yticks(np.linspace(0.03, 0.1, 6))
        ax.set_title(f"(b) MSSIM loss in {resolution} space")
        ax.set_ylabel("MSSIM loss")


lg = axes[-1].legend(
    bbox_to_anchor=(1.05, 1.0),
    loc="upper left",
    ncol=1,
    fontsize=20,
    framealpha=1,
    edgecolor="k",
)


plt.tight_layout()