In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.filter.low_pass_periodic_channel_domain import LowPassFilter
from cfd_model.initialization.periodic_channel_jet_initializer import (
    calc_init_omega,
    calc_init_perturbation_hr_omegas,
    calc_jet_forcing,
)
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.sr_da_helper_2 import (
    get_observation_with_noise,
    get_testdataset,
    initialize_and_itegrate_srda_cfd_model_for_forecast,
    initialize_models,
    make_invprocessed_sr_for_forecast,
    make_models,
    make_preprocessed_lr_for_forecast,
    make_preprocessed_obs_for_forecast,
    read_all_hr_omegas_with_combining_for_forecast,
)
from src.ssim import SSIM
from src.utils import set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
TMP_DATA_DIR = "./data"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
CSV_DATA_DIR = "./csv"
os.makedirs(CSV_DATA_DIR, exist_ok=True)

In [None]:
FIG_DIR = "./tmp/fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_06"
CONFIG_PATHS = sorted(glob.glob(f"{CONFIG_DIR}/*.yml"))

In [None]:
ASSIMILATION_PERIOD = 4
FORECAST_SPAN = 4
NUM_SIMULATIONS = 500

MIN_START_TIME_INDEX = -1
MAX_START_TIME_INDEX = 88
START_TIME_INDEX = 0
NUM_TIMES = MAX_START_TIME_INDEX + ASSIMILATION_PERIOD + FORECAST_SPAN

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0 = np.pi / 2.0
SIGMA = 0.4
U0 = 3.0
TAU0 = 0.3
PERTUB_NOISE = 0.0025

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

N_ENS_PER_CHUNK = 125

In [None]:
DEVICE = "cuda:1"

if not torch.cuda.is_available():
    raise Exception("No GPU. CPU is used.")

In [None]:
LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
    "dt": LR_DT,
    "nt": LR_NT,
}

INDEX_CONFIG = {
    "assimilation_period": ASSIMILATION_PERIOD,
    "forecast_span": FORECAST_SPAN,
    "n_ens": N_ENS_PER_CHUNK,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "device": DEVICE,
}

In [None]:
CONFIGS = OrderedDict()

for num, config_path in enumerate(sorted(CONFIG_PATHS)):
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
        "log": f"{_dir}/log.txt",
        "number": num,
    }

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def get_initial_hr_omega():
    hr_jet, _ = calc_jet_forcing(
        nx=HR_NX,
        ny=HR_NY,
        ne=NUM_SIMULATIONS,
        y0=Y0,
        sigma=SIGMA,
        tau0=TAU0,
    )

    hr_perturb = calc_init_perturbation_hr_omegas(
        nx=HR_NX, ny=HR_NY, ne=NUM_SIMULATIONS, noise_amp=PERTUB_NOISE, seed=2718
    )

    hr_omega0 = calc_init_omega(
        perturb_omega=hr_perturb,
        jet=hr_jet,
        u0=U0,
    )

    return hr_omega0

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if label == "LR":
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=HR_NX, ny=HR_NY, mode="nearest"
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "HR Ground Truth"
        else:
            maer = np.mean(np.abs(gt - d)) / np.mean(np.abs(gt))
            ttl = label
            ttl = f"{label}\nMAER={maer:.2f}"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY)
        else:
            assert d.shape == (LR_NX, LR_NY)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=1, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg")

    plt.show()


def read_all_tmp_files(config_name):
    npz_file_path = f"{TMP_DATA_DIR}/{config_name}_with_init_noise.npz"
    data = np.load(npz_file_path)

    return data["sr_frcst"], data["lr_omega"], data["hr_omega"], data["hr_obsrv"]

# Plot learning curves

In [None]:
is_plotted = False

for config_name, config_info in CONFIGS.items():
    try:
        config_info["is_not_finished"] = True

        if not os.path.exists(config_info["learning_history_path"]):
            logger.info(f"Train is not started yet: {config_name}")
            continue

        df = pd.read_csv(config_info["learning_history_path"])

        if len(df) != config_info["config"]["train"]["num_epochs"]:
            print(f"Not finished at {len(df)} epoch: {config_name}")
            continue

        config_info["is_not_finished"] = False

        if (not is_plotted) or config_info["config"]["train"]["seed"] != 832180:
            continue

        plt.rcParams["font.size"] = 15
        fig = plt.figure(figsize=[7, 5])
        ax = plt.subplot(111)

        df.plot(
            ax=ax,
            xlabel="Epochs",
            ylabel=config_info["config"]["train"]["loss"]["name"],
        )
        ax.set_title(config_name)
        plt.yscale("log")

        plt.show()
    except Exception as e:
        print(config_name)
        print(e)

# Perform SR-DA

In [None]:
all_init_hr_omegas = None

In [None]:
for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    # if CONFIGS[config_name]["number"] % 2 != 0:
    #     continue

    if "og00" in config_name or "noLR" in config_name:
        continue

    if "_bT_muT_a02_b02_" in config_name or "_bT_muF_a02_b02_" in config_name:
        pass
    else:
        continue

    if "test" in config_name:
        continue

    set_seeds(42, use_deterministic=True)

    output_npz_file_path = f"{TMP_DATA_DIR}/{config_name}_with_init_noise.npz"
    if os.path.exists(output_npz_file_path):
        logger.info(f"Results already exist. So skip {config_name}")
        continue

    config_info = CONFIGS[config_name]
    weight_path = config_info["weight_path"]

    if config_info["is_not_finished"]:
        logger.info(f"Training is not finished yet: {config_name}")
        continue

    if not os.path.exists(weight_path):
        logger.info(f"No weight file: {config_name}")
        continue

    logger.info(f"{config_name} is being evaluated")

    config = config_info["config"]

    if all_init_hr_omegas is None:
        all_init_hr_omegas = get_initial_hr_omega()
    assert all_init_hr_omegas.shape == (NUM_SIMULATIONS, HR_NX, HR_NY)

    test_dataset = get_testdataset(
        ROOT_DIR,
        config,
        min_start_time_index=MIN_START_TIME_INDEX,
        max_start_time_index=MAX_START_TIME_INDEX + 1,
    )

    assert test_dataset.obs_time_interval == ASSIMILATION_PERIOD
    assert test_dataset.is_output_only_last == False
    assert test_dataset.is_last_obs_missing == True

    all_hr_omegas = read_all_hr_omegas_with_combining_for_forecast(
        test_dataset.hr_file_paths,
        assim_period=ASSIMILATION_PERIOD,
        forecast_span=FORECAST_SPAN,
    )

    assert all_hr_omegas.shape == (NUM_SIMULATIONS, NUM_TIMES, HR_NX, HR_NY)
    assert all_hr_omegas.shape[0] % N_ENS_PER_CHUNK == 0

    _, lr_forcing = calc_jet_forcing(
        nx=LR_NX,
        ny=LR_NY,
        ne=1,
        y0=Y0,
        sigma=SIGMA,
        tau0=TAU0,
    )

    low_pass_filter = LowPassFilter(
        nx_lr=LR_NX, ny_lr=LR_NY, nx_hr=HR_NX, ny_hr=HR_NY, device=DEVICE
    )

    all_hr_obsrv, all_sr_forecast, all_lr_omega = [], [], []

    for hr_omegas, init_hr_omegas in tqdm(
        zip(
            torch.split(all_hr_omegas, N_ENS_PER_CHUNK),
            torch.split(all_init_hr_omegas, N_ENS_PER_CHUNK),
        ),
        total=(all_hr_omegas.shape[0] // N_ENS_PER_CHUNK),
    ):
        sr_model, lr_model, _ = make_models(config, weight_path, LR_CFD_CONFIG)
        sr_model.eval()

        assert init_hr_omegas.shape == (N_ENS_PER_CHUNK, HR_NX, HR_NY)

        initialize_models(
            t0=T0,
            hr_omega0=init_hr_omegas,  # first time index
            lr_forcing=lr_forcing,
            lr_model=lr_model,
            srda_model=None,
            **INDEX_CONFIG,
        )

        hr_obsrv = get_observation_with_noise(hr_omegas, test_dataset, **INDEX_CONFIG)
        assert hr_obsrv.shape == hr_omegas.shape

        ts, hr_obs, lr_omega, sr_forecast = [], [], [], []

        last_t0 = T0
        last_hr_omega0 = init_hr_omegas
        lr_ens_forcing = torch.broadcast_to(lr_forcing, (N_ENS_PER_CHUNK, LR_NX, LR_NY))

        max_time_index = hr_omegas.shape[1]

        for i_cycle in tqdm(range(max_time_index)):
            logger.debug(f"Start: i_cycle = {i_cycle}, t = {lr_model.t:.2f}")

            ts.append(lr_model.t)
            lr_omega.append(lr_model.omega.cpu().clone())

            if config["data"]["use_observation"] and i_cycle % ASSIMILATION_PERIOD == 0:
                hr_obs.append(hr_obsrv[:, i_cycle])
            else:
                hr_obs.append(torch.full_like(hr_obsrv[:, i_cycle], torch.nan))

            if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
                lr_forecast = []
                initialize_and_itegrate_srda_cfd_model_for_forecast(
                    lr_forecast=lr_forecast,
                    num_integrate_steps=ASSIMILATION_PERIOD + FORECAST_SPAN,
                    last_t0=last_t0,
                    last_hr_omega0=last_hr_omega0,
                    lr_ens_forcing=lr_ens_forcing,
                    cfd_config=LR_CFD_CONFIG,
                    low_pass_filter=low_pass_filter,
                )
                assert len(lr_forecast) == ASSIMILATION_PERIOD + FORECAST_SPAN + 1

                x = make_preprocessed_lr_for_forecast(
                    lr_forecast,
                    test_dataset,
                    **INDEX_CONFIG,
                )
                o = make_preprocessed_obs_for_forecast(
                    hr_obs,
                    test_dataset,
                    **INDEX_CONFIG,
                )

                # Check num of time dims
                _sum = ASSIMILATION_PERIOD + FORECAST_SPAN
                _nt = int(_sum / config["data"]["lr_time_interval"] + 1)
                assert x.shape[1] == _nt
                assert o.shape[1] == _sum + 1

                if not config["data"].get("use_lr_forecast", True):
                    x = torch.full_like(x, test_dataset.missing_value)

                with torch.no_grad():
                    sr = sr_model(x, o).detach().cpu().clone()
                sr = make_invprocessed_sr_for_forecast(
                    sr,
                    test_dataset,
                    **INDEX_CONFIG,
                )

                last_hr_omega0 = sr[ASSIMILATION_PERIOD + 1].clone()
                last_t0 += ASSIMILATION_PERIOD * LR_DT * LR_NT

                # The indices between 0 to ASSIMILATION_PERIOD are past
                # So NaN values are substituted for the forecast.
                if len(sr_forecast) == 0:
                    dummy = torch.full(
                        size=(ASSIMILATION_PERIOD,) + sr.shape[1:],
                        fill_value=torch.nan,
                        dtype=sr.dtype,
                    )
                    sr_forecast += dummy

                i_start = ASSIMILATION_PERIOD
                i_end = ASSIMILATION_PERIOD + FORECAST_SPAN
                sr_forecast += sr[i_start:i_end]

                logger.debug(f"Assimilation at i = {i_cycle}")

            lr_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
            lr_model.calc_grid_data()

        # Stack along time dim
        all_hr_obsrv.append(torch.stack(hr_obs, dim=1))
        all_lr_omega.append(torch.stack(lr_omega, dim=1))
        all_sr_forecast.append(torch.stack(sr_forecast, dim=1))

        del sr_model, lr_model
        torch.cuda.empty_cache()
        _ = gc.collect()

    # Stack along ensemble (batch) dim
    all_hr_obsrv = torch.cat(all_hr_obsrv, dim=0).to(torch.float32).numpy()
    all_lr_omega = torch.cat(all_lr_omega, dim=0).to(torch.float32).numpy()
    all_sr_forecast = torch.cat(all_sr_forecast, dim=0).to(torch.float32).numpy()
    all_hr_omegas = all_hr_omegas.to(torch.float32).numpy()

    assert all_sr_forecast.shape == all_hr_omegas.shape == all_hr_obsrv.shape
    assert all_sr_forecast.shape[:2] == all_lr_omega.shape[:2]

    np.savez(
        output_npz_file_path,
        hr_omega=all_hr_omegas,
        hr_obsrv=all_hr_obsrv,
        lr_omega=all_lr_omega,
        sr_frcst=all_sr_forecast,
    )

# Anlyze errors

## Make csv

In [None]:
ssim_func_gauss = SSIM(size_average=False, use_gauss=True)

In [None]:
for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    try:
        csv_file = f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio_with_init_noise.csv"

        if os.path.exists(csv_file):
            continue

        (
            all_sr_analysis,
            all_lr_omega,
            all_hr_omega,
            _,
        ) = read_all_tmp_files(config_name)

        max_time_index = (
            MAX_START_TIME_INDEX + FORECAST_SPAN + ASSIMILATION_PERIOD
        ) - START_TIME_INDEX

        all_sr_analysis = all_sr_analysis[:, :max_time_index]
        all_lr_omega = all_lr_omega[:, :max_time_index]
        all_hr_omega = all_hr_omega[:, :max_time_index]

        ts = [LR_DT * LR_NT * i_cycle + T0 for i_cycle in range(max_time_index)]

        _lr = interpolate_time_series(
            torch.from_numpy(all_lr_omega), HR_NX, HR_NY
        ).numpy()
        assert (
            _lr.shape
            == all_hr_omega.shape
            == (NUM_SIMULATIONS, max_time_index, HR_NX, HR_NY)
        )

        # mean over x and y dims
        _lr_err1 = np.mean(np.abs(_lr - all_hr_omega), axis=(-2, -1))
        _lr_err2 = np.mean(np.abs(all_hr_omega), axis=(-2, -1))
        # mean over ensemble dim
        lr = np.mean(_lr_err1, axis=(0,))
        lr_ratio = np.mean(_lr_err1 / _lr_err2, axis=(0,))

        assert (
            all_sr_analysis.shape
            == all_hr_omega.shape
            == (NUM_SIMULATIONS, max_time_index, HR_NX, HR_NY)
        )

        _sr_err1 = np.mean(np.abs(all_sr_analysis - all_hr_omega), axis=(-2, -1))
        _sr_err2 = np.mean(np.abs(all_hr_omega), axis=(-2, -1))
        sr = np.mean(_sr_err1, axis=(0,))
        sr_ratio = np.mean(_sr_err1 / _sr_err2, axis=(0,))

        ssim = ssim_func_gauss(
            img1=torch.from_numpy(all_sr_analysis).to(DEVICE),
            img2=torch.from_numpy(all_hr_omega).to(DEVICE),
        )
        ssim = torch.mean(ssim, dim=(0, 2, 3))  # mean over batch, x, and y
        ssim = ssim.detach().cpu()

        assert ssim.shape == (len(ts),) == (max_time_index,)

        df = pd.DataFrame()
        df["Time"] = ts
        df["ErrLR(bicubic)"] = lr
        df["ErrRatioLR(bicubic)"] = lr_ratio
        df["ErrSR"] = sr
        df["ErrRatioSR"] = sr_ratio
        df["SsimLoss"] = ssim
        df.to_csv(csv_file, index=False)

    except Exception as e:
        logger.error(config_name)
        logger.error(e)

## Analyze csv

In [None]:
df_results = pd.DataFrame(index=CONFIGS.keys())

for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    csv_file = f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio_with_init_noise.csv"
    if not os.path.exists(csv_file):
        logger.error(f"{csv_file} does not exist!")
        continue

    df = pd.read_csv(csv_file)

    config = CONFIGS[config_name]["config"]

    df_results.loc[config_name, "AveErrSR"] = df["ErrSR"].mean()
    df_results.loc[config_name, "MaxErrSR"] = df["ErrSR"].max()
    df_results.loc[config_name, "MinErrSR"] = df["ErrSR"].min()
    df_results.loc[config_name, "StdErrSR"] = df["ErrSR"].std()
    df_results.loc[config_name, "Max-MinErrSR"] = (
        df_results.loc[config_name, "MaxErrSR"]
        - df_results.loc[config_name, "MinErrSR"]
    )

    df_results.loc[config_name, "UseObs"] = config["data"]["use_observation"]
    df_results.loc[config_name, "ObsGridInterval"] = config["data"]["obs_grid_interval"]
    df_results.loc[config_name, "ObsGridRatio"] = (
        OBS_GRID_RATIO[config["data"]["obs_grid_interval"]] * 100
    )
    df_results.loc[config_name, "ObsNoiseStd"] = config["data"]["obs_noise_std"]

    df_results.loc[config_name, "LrTimeInterval"] = config["data"]["lr_time_interval"]
    df_results.loc[config_name, "UseSkipConn"] = config["model"][
        "use_global_skip_connection"
    ]
    df_results.loc[config_name, "UseMixup"] = config["data"]["use_mixup"]
    df_results.loc[config_name, "alpha"] = config["data"]["beta_dist_alpha"]
    df_results.loc[config_name, "beta"] = config["data"]["beta_dist_beta"]
    df_results.loc[config_name, "UseLrForecast"] = config["data"].get(
        "use_lr_forecast", True
    )
    df_results.loc[config_name, "Seed"] = config["train"]["seed"]
    df_results.loc[config_name, "Bias"] = CONFIGS[config_name]["config"]["model"][
        "bias"
    ]

In [None]:
# for grid_ratio, grp in df_results[(df_results["LrTimeInterval"] == 4)].groupby(
#     "ObsGridRatio"
# ):
#     display(f"Grid Ratio = {grid_ratio*100}")
#     display(grp.sort_values("MaxErrSR"))

# Make figures

## Obs grid ratio, groupby use mixup

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20
ycol = "AveErrSR"

fig = plt.figure()
ax = plt.subplot(111)

df = df_results.sort_values("ObsGridRatio")
df = df[
    (df["UseObs"] == True)
    & (df["Bias"] == True)
    & (df["UseLrForecast"] == True)
    & (df_results["LrTimeInterval"] == 4)
]

data = df[(df["UseMixup"] == False)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {5}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="No Mixup", capsize=5)


data = df[
    (df["UseMixup"] == True) & (df["beta"] == 2) & (df_results["LrTimeInterval"] == 4)
]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {5}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="Use Mixup", capsize=5)
ax.set_ylim([0.2, 3.2])
# ax.axhline(0.5, color="k", ls="-")


ax.set_ylabel("Ave MAE")
ax.set_xlabel("Observation point ratio [%]")
ax.legend()

plt.show()

## Using bias

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20
ycol = "AveErrSR"

fig = plt.figure()
ax = plt.subplot(111)

df = df_results.sort_values("ObsGridRatio")
df = df[
    (df["UseObs"] == True)
    & (df["UseMixup"] == True)
    & (df["UseLrForecast"] == True)
    & (df_results["LrTimeInterval"] == 4)
    & (df_results["Seed"].isin({359178.0, 465838.0, 832180.0}))
]

data = df[(df["Bias"] == False)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {3}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="No Bias", capsize=5)


data = df[(df["Bias"] == True)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {3}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="Use Bias", capsize=5)


ax.set_ylabel("Ave MAE")
ax.set_xlabel("Observation point ratio [%]")
ax.legend()

plt.tight_layout()
fig.savefig(f"{FIG_DIR}/effect_bias.jpg")

plt.show()

## Using only Observation or not (no LR or not)

In [None]:
ycol = "AveErrSR"

for seed in [221958]:
    xs, ys, ys_no_lr = [], [], []
    for obs_grid_interval in range(4, 14, 2):
        try:
            config_name_using_lr = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"
            err = df_results.loc[config_name_using_lr, ycol]

            config_name_no_lr = config_name_using_lr.replace("_muT_", "_muF_") + "_noLR"
            err_no_lr = df_results.loc[config_name_no_lr, ycol]

            xs.append(OBS_GRID_RATIO[obs_grid_interval])
            ys.append(err)
            ys_no_lr.append(err_no_lr)
        except Exception as e:
            print(e)

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20

    fig = plt.figure()
    ax = plt.subplot(111)

    ax.plot(xs, ys, "o-", label="Using Obs and LR")
    ax.plot(xs, ys_no_lr, "o-", label="Using Obs but No LR")

    ax.set_ylabel("MAE (time ave.)")
    ax.set_xlabel("Observation grid ratio [%]")
    ax.set_title(f"Seed = {seed}")

    ax.legend()
    plt.show()

## Error time series (observation)

In [None]:
obs_grid_interval = 8

for seed in [832180, 465838, 359178]:
    config_name_without_obs = (
        f"lt4og00_on1e-01_ep1000_lr1e-04_scT_bF_muT_a02_b02_sd{seed:06}"
    )
    config_name_using_obs = config_name_without_obs.replace(
        "og00", f"og{obs_grid_interval:02}"
    )
    config_name_only_obs = config_name_using_obs.replace("muT", "muF") + "_noLR"

    df_without_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_without_obs}_with_mae_ratio_with_init_noise.csv"
    )
    df_using_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_using_obs}_with_mae_ratio_with_init_noise.csv"
    )
    df_only_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_only_obs}_with_mae_ratio_with_init_noise.csv"
    )

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20

    fig = plt.figure(figsize=[8, 4])
    ax = plt.subplot(111)

    ax.plot(df_using_obs["Time"], df_using_obs["ErrSR"], "-", label="SRDA")
    ax.plot(df_without_obs["Time"], df_without_obs["ErrSR"], "--", label="Only SR")
    ax.plot(df_only_obs["Time"], df_only_obs["ErrSR"], "--", label="Only Obs")

    ax.plot(
        df_without_obs["Time"],
        df_without_obs["ErrLR(bicubic)"],
        "-.",
        label="No SR or DA",
    )

    ax.set_title(
        f"Observation point ratio = {OBS_GRID_RATIO[obs_grid_interval]*100:.2f} %"
    )
    ax.set_xlabel("Time")
    ax.set_ylabel("MAE")
    ax.set_title(f"Seed = {seed}")

    lg = ax.legend(
        bbox_to_anchor=(1.05, 1.0),
        loc="upper left",
        ncol=1,
        fontsize=16,
        framealpha=1,
        edgecolor="k",
    )
    plt.tight_layout()
    plt.show()

In [None]:
obs_grid_interval = 8

for seed in [221958]:
    config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"
    config_name_only_obs = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muF_a02_b02_sd{seed:06}_noLR"

    df = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio_with_init_noise.csv"
    )
    df_only_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_only_obs}_with_mae_ratio_with_init_noise.csv"
    )

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20

    fig = plt.figure(figsize=[12, 4])
    ax = plt.subplot(111)

    ax.plot(df["Time"], df["ErrSR"], "-", label="SRDA")
    ax.plot(df_only_obs["Time"], df_only_obs["ErrSR"], "--", label="Only Obs")

    ax.set_title(
        f"Observation point ratio = {OBS_GRID_RATIO[obs_grid_interval]*100:.2f} %"
    )
    ax.set_xlabel("Time")
    ax.set_ylabel("MAE")
    ax.set_title(f"Seed = {seed}")

    lg = ax.legend(
        bbox_to_anchor=(1.05, 1.0),
        loc="upper left",
        ncol=1,
        fontsize=16,
        framealpha=1,
        edgecolor="k",
    )
    plt.tight_layout()
    plt.show()

## Error time series (mixup)

In [None]:
obs_grid_interval = 8

for seed in [221958, 771155, 832180, 465838, 359178]:
    config_name_without_mixup = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muF_a02_b02_sd{seed:06}"

    config_name_with_mixup = config_name_without_mixup.replace("_muF_", "_muT_")

    df_without_mixup = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_without_mixup}_with_mae_ratio_with_init_noise.csv"
    )

    df_with_mixup = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_with_mixup}_with_mae_ratio_with_init_noise.csv"
    )

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20

    fig = plt.figure(figsize=[8, 4])
    ax = plt.subplot(111)

    ax.plot(
        df_without_mixup["Time"],
        df_without_mixup["ErrSR"],
        "-",
        label="SRDA (no mixup)",
    )
    ax.plot(df_with_mixup["Time"], df_with_mixup["ErrSR"], "--", label="SRDA (mixup)")

    ax.plot(
        df_without_mixup["Time"],
        df_without_mixup["ErrLR(bicubic)"],
        "-.",
        label="No SR or DA",
    )

    ax.set_title(
        f"Observation point ratio = {OBS_GRID_RATIO[obs_grid_interval]*100:.2f} %"
    )
    ax.set_xlabel("Time")
    ax.set_ylabel("MAE")
    ax.set_title(f"Seed = {seed}")

    lg = ax.legend(
        bbox_to_anchor=(1.05, 1.0),
        loc="upper left",
        ncol=1,
        fontsize=16,
        framealpha=1,
        edgecolor="k",
    )
    plt.tight_layout()
    plt.show()

## Dependency on beta

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20
ycol = "AveErrSR"

df = df_results.sort_values("ObsGridRatio")
df = df[
    (df["UseObs"] == True)
    & (df["Bias"] == True)
    & (df["UseMixup"] == True)
    & (df["UseLrForecast"] == True)
    & (df_results["LrTimeInterval"] == 4)
    & (df_results["Seed"] == 221958)
]

fig = plt.figure(figsize=[10, 6])
ax = plt.subplot(111)

for grid_interval, grp in df.groupby("ObsGridInterval"):
    assert len(grp) == 6, display(grp)
    grp = grp.sort_values("beta")

    ax.plot(grp["beta"], grp[ycol], "o-", label=f"interval:{grid_interval:02}")

ax.set_xlabel("beta")
ax.set_ylabel(ycol)
ax.legend()
plt.show()

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20
ycol = "AveErrSR"

df = df_results.sort_values("ObsGridRatio")
df = df[
    (df["UseObs"] == True)
    & (df["beta"] == 2)
    & (df["Bias"] == True)
    & (df_results["LrTimeInterval"] == 4)
    & (df_results["Seed"] == 221958)
]

fig = plt.figure(figsize=[10, 6])
ax = plt.subplot(111)

data = df[df["UseMixup"] == True].sort_values("ObsGridInterval")
assert len(data) == 5
ax.plot(data["ObsGridInterval"], data[ycol], "o-", label="UseLrFrcst")

data = df[df["UseLrForecast"] == False].sort_values("ObsGridInterval")
assert len(data) == 5
ax.plot(data["ObsGridInterval"], data[ycol], "o-", label="NoUseLrFrcst")

ax.legend()

## Vorticity snapshots

In [None]:
target_config_name = "lt4og04_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd832180"

(
    all_sr_analysis,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_tmp_files(target_config_name)

In [None]:
all_sr_analysis.shape

In [None]:
# 43, 46, 48, 70, 91, 107, 154, 156, 163, 185
for i_ensemble in [185]:
    for i_cycle in np.arange(0, 96, 4):

        t = (i_cycle + START_TIME_INDEX) * DT

        dict_data = {
            "HR": all_hr_omega[i_ensemble, i_cycle],
            "LR": all_lr_omega[i_ensemble, i_cycle],
            "SRDA": all_sr_analysis[i_ensemble, i_cycle],
        }

        plot(
            dict_data,
            t,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            figsize=[12, 3.5],
            ttl_header=f"i_ens = {i_ensemble}, ",
            fig_file_name=f"snapshot_without_mixup_t{int(t):03}",
            write_out=False,
            use_hr_space=True,
            vmin_omega=-10,
            vmax_omega=10,
        )
    break

### Error time series

In [None]:
ssim_func_gauss = SSIM(size_average=False, use_gauss=True)

In [None]:
num_batch_indices = all_lr_omega.shape[0]
num_time_indices = all_lr_omega.shape[1]

df_results = pd.DataFrame()
ts = [LR_DT * LR_NT * i_cycle + T0 for i_cycle in range(num_time_indices)]
assert len(ts) == num_time_indices

df_results["Time"] = ts

for use_mixup in [True, False]:

    target_config_name = "lt4og04_on1e-01_ep1000_lr1e-04_scT_bT_muF_a02_b02_sd832180"

    if use_mixup:
        target_config_name = target_config_name.replace("_muF_", "_muT_")

    (
        all_sr_analysis,
        all_lr_omega,
        all_hr_omega,
        all_hr_obsrv,
    ) = read_all_tmp_files(target_config_name)

    for kind in ["HR", "LR"]:
        gt = all_hr_omega
        lr = all_lr_omega
        srda = all_sr_analysis

        if kind == "HR":
            lr = interpolate_time_series(torch.from_numpy(lr), HR_NX, HR_NY).numpy()
        elif kind == "LR":
            gt = interpolate_time_series(torch.from_numpy(gt), LR_NX, LR_NY).numpy()
            srda = interpolate_time_series(torch.from_numpy(srda), LR_NX, LR_NY).numpy()
        else:
            raise NotImplementedError()

        for label, data in zip(["LR", "SRDA"], [lr, srda]):
            assert gt.shape == data.shape

            errs = np.mean(np.abs(gt - data), axis=(-2, -1))  # mean over x and y
            nrms = np.mean(np.abs(gt), axis=(-2, -1))

            mae = np.mean(errs, axis=0)  # mean over batch
            maer = np.mean(errs / nrms, axis=0)

            # Only time dim remains.
            assert mae.shape == maer.shape == (num_time_indices,)

            suffix = "with_mixup" if use_mixup else "without_mixup"
            df_results[f"MAE_{kind}_{label}_{suffix}"] = mae
            df_results[f"MAER_{kind}_{label}_{suffix}"] = maer

            for ssim_kind, ssim_func in zip(["Gauss"], [ssim_func_gauss]):

                ssim = ssim_func(
                    img1=torch.from_numpy(gt).cuda(),
                    img2=torch.from_numpy(data).cuda(),
                )
                ssim = torch.mean(ssim, dim=(0, 2, 3))  # mean over batch, x, and y
                ssim = ssim.detach().cpu()

                # Only time dim remains.
                assert ssim.shape == (num_time_indices,)

                df_results[f"SSIM_{ssim_kind}_{kind}_{label}_{suffix}"] = 1 - ssim

In [None]:
for resolution in ["HR", "LR"]:
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 14
    fig, axes = plt.subplots(1, 2, sharex=True, figsize=[16, 3])

    for ax, ycol in zip(axes, ["MAER", "SSIM_Gauss"]):
        xs = None

        for ls, label in zip(["-.", "-", "--"], ["LR", "SRDA"]):
            for suffix in ["with_mixup", "without_mixup"]:
                if label == "LR" and suffix == "without_mixup":
                    continue

                xs = df_results["Time"].values
                ys = df_results[f"{ycol}_{resolution}_{label}_{suffix}"].values

                legend = label
                if label == "SRDA" and suffix == "with_mixup":
                    legend = "SRDA (mixup)"
                if label == "SRDA" and suffix == "without_mixup":
                    legend = "SRDA (not mixup)"

                ax.plot(xs, ys, label=legend, ls=ls)

        ax.set_xlabel("Time")
        ax.set_ylabel(f"{ycol} ({resolution})")
        ax.set_title(f"{ycol} ({resolution})")

        ax.set_xticks(np.linspace(4, 24, 6))

        if ycol == "MAER":
            # ax.axhline(0.2, ls="--", color="k")
            ax.set_ylim(0, 1.5)
        else:
            # ax.axhline(0.1, ls="--", color="k")
            ax.set_ylim(0.0, 0.5)

    lg = ax.legend(
        bbox_to_anchor=(1.8, 0.0),
        loc="lower right",
        ncol=1,
        fontsize=16,
        framealpha=1,
        edgecolor="k",
    )
    plt.tight_layout()

    fig.savefig(f"{FIG_DIR}/err_time_series_{resolution}_mixup_vs_no_mixup.jpg")
    plt.show()