In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import random
from collections import OrderedDict

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.model_maker import make_model
from src.sr_da_helper import get_testdataloader
from src.utils import AverageMeter, set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
TMP_DATA_DIR = "/workspace/all_data/notebook/paper_experiment_01/data"
if not os.path.exists(TMP_DATA_DIR):
    TMP_DATA_DIR = "./data"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
CSV_DATA_DIR = "./csv"
os.makedirs(CSV_DATA_DIR, exist_ok=True)

In [None]:
FIG_DIR = "./fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_01"
CONFIG_PATHS = sorted(glob.glob(f"{CONFIG_DIR}/*.yml"))

In [None]:
CFD_DIR_NAME = "jet12"

ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 16

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0_MEAN = np.pi / 2.0
SIGMA_MEAN = 0.4
TAU0_MEAN = 0.3

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

N_ENS_PER_CHUNK = 125

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    logger.info("GPU is used.")
else:
    logger.error("No GPU. CPU is used.")
    raise Exception("No GPU. CPU is used.")

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = (
        f"/workspace/all_data/data/pytorch/DL_results/{experiment_name}/{config_name}"
    )
    if not os.path.exists(_dir):
        _dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY - 1, endpoint=False)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY - 1, endpoint=False)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if label == "LR":
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]),
                    nx=HR_NX,
                    ny=HR_NY - 1,
                    mode="nearest",
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY - 1
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "HR Ground Truth"
        else:
            mae = np.mean(np.abs(gt - d))
            ttl = label
            ttl = f"{label}\nMAE={mae:.2f}"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY - 1)
        else:
            assert d.shape == (LR_NX, LR_NY - 1)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=1, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg", dpi=300)
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.eps", dpi=300)

    plt.show()

# Plot learning curves

In [None]:
for config_name, config_info in CONFIGS.items():
    try:
        df = pd.read_csv(config_info["learning_history_path"])
        assert len(df) == config_info["config"]["train"]["num_epochs"]

        continue

        plt.rcParams["font.size"] = 15
        fig = plt.figure(figsize=[7, 5])
        ax = plt.subplot(111)

        df.plot(
            ax=ax,
            xlabel="Epochs",
            ylabel=config_info["config"]["train"]["loss"]["name"],
        )
        ax.set_title(config_name)
        plt.yscale("log")

        # fig.savefig(f"{FIG_DIR}/{config_name}_learning_curve.jpg")
        plt.show()
    except Exception as e:
        print(config_name)
        print(e)

# Evaluate models

In [None]:
csv_result = f"{CSV_DATA_DIR}/mae_scores_using_testdataset.csv"

if os.path.exists(csv_result):
    df_results = pd.read_csv(csv_result).set_index("Unnamed: 0")
    print("DF is read from csv.")
else:
    df_results = pd.DataFrame()
    print("DF is created.")

In [None]:
n_loops = 100

for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    if config_name in df_results.index:
        logger.info(f"Result of {config_name} already exists. So skip it.")
        continue

    set_seeds(42, use_deterministic=True)

    config_info = CONFIGS[config_name]
    config = config_info["config"]
    weight_path = config_info["weight_path"]

    if not os.path.exists(weight_path):
        print(f"Weight of {config_name} does not exist. So skip it.")
        continue
    logger.info(f"{config_name} is being evaluated")

    sr_model = make_model(config).to(DEVICE)
    sr_model.load_state_dict(torch.load(weight_path, map_location=DEVICE))
    _ = sr_model.eval()

    test_dataloader = get_testdataloader(ROOT_DIR, config)
    bias = test_dataloader.dataset.vorticity_bias
    scale = test_dataloader.dataset.vorticity_scale

    maes, maers = AverageMeter(), AverageMeter()
    for n in tqdm(range(n_loops)):
        random.seed(n)
        np.random.seed(n)
        for lr, obs, gt in test_dataloader:
            with torch.no_grad():
                pred = sr_model(lr.to(DEVICE), obs.to(DEVICE)).detach()
                pred = pred * scale + bias
                gt = gt * scale + bias

                diffs = pred - gt.to(DEVICE)
                mae = torch.mean(torch.abs(diffs)).item()
                maes.update(mae, n=lr.shape[0])  # n == batch size

                # mean over channel, y, x
                tmp1 = torch.mean(torch.abs(diffs), dim=(-3, -2, -1))
                tmp2 = torch.mean(torch.abs(gt.to(DEVICE)), dim=(-3, -2, -1))
                tmp3 = tmp1 / tmp2
                # mean over batch, time
                maer = torch.mean(tmp3).item()
                maers.update(maer, n=lr.shape[0])  # n == batch size

            logger.debug(f"n = {n}, mae = {mae:.10f}")

        df_results.loc[config_name, f"MAE_n{n+1:02}"] = maes.avg
        df_results.loc[config_name, f"MAER_n{n+1:02}"] = maers.avg

    del test_dataloader, sr_model
    gc.collect()
    torch.cuda.empty_cache()

    df_results.to_csv(csv_result, index=True)

In [None]:
for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    try:
        config = CONFIGS[config_name]["config"]

        df_results.loc[config_name, "UseObs"] = config["data"]["use_observation"]
        df_results.loc[config_name, "ObsGridInterval"] = config["data"][
            "obs_grid_interval"
        ]
        df_results.loc[config_name, "ObsGridRatio"] = (
            OBS_GRID_RATIO[config["data"]["obs_grid_interval"]] * 100
        )
        df_results.loc[config_name, "ObsNoiseStd"] = config["data"]["obs_noise_std"]

        df_results.loc[config_name, "LrTimeInterval"] = config["data"][
            "lr_time_interval"
        ]
        df_results.loc[config_name, "UseSkipConn"] = config["model"][
            "use_global_skip_connection"
        ]
        df_results.loc[config_name, "UseMixup"] = config["data"]["use_mixup"]
        df_results.loc[config_name, "alpha"] = config["data"]["beta_dist_alpha"]
        df_results.loc[config_name, "beta"] = config["data"]["beta_dist_beta"]
        df_results.loc[config_name, "UseLrForecast"] = config["data"].get(
            "use_lr_forecast", True
        )
        df_results.loc[config_name, "Seed"] = config["train"]["seed"]
    except Exception as e:
        print(e)

# Anlyze errors

## Convergence check

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20
n_loops = 100

for config_name, items in df_results.iterrows():

    fig = plt.figure()
    ax = plt.subplot(111)
    xs, ys = [], []
    for n in range(n_loops):
        xs.append(n + 1)
        ys.append(items[f"MAE_n{n+1:02}"])

    xs = np.array(xs)
    ys = np.array(ys)
    ys = ys / np.mean(ys) * 100

    ax.plot(xs, ys, "o-")
    ax.set_xlabel("loop count")
    ax.set_ylabel("MAE variation [%]")
    ax.set_title(config_name)
    plt.show()



## Compoare with/without observations

In [None]:
df = df_results[df_results["UseLrForecast"] == True].sort_values("ObsGridRatio")
ycol = "MAE_n100"

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20

fig = plt.figure()
ax = plt.subplot(111)

data = df[(df["UseMixup"] == False) & (df["UseObs"] == True)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {5}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="No Mixup (with obs)", capsize=5)


data = df[(df["UseMixup"] == False) & (df["UseObs"] == False) & (df["beta"] == 2)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 1 and set(data["count"]) == {5}

ax.axhline(ys[0], color="k", ls="--", label="No Mixup (without obs)")


data = df[(df["UseMixup"] == True) & (df["UseObs"] == True) & (df["beta"] == 2)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {5}

ax.errorbar(
    data.index, ys, yerr=errs, fmt="o-", label="Use Mixup (with obs)", capsize=5
)

lg = ax.legend(
    bbox_to_anchor=(1.05, 1.0),
    loc="upper left",
    ncol=1,
    fontsize=16,
    framealpha=1,
    edgecolor="k",
)

ax.set_xlabel("Observation point ratio [%]")
ax.set_ylabel("MAE")

plt.show()

# Merge SRDA scores

In [None]:
for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    csv_file = f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio.csv"
    if not os.path.exists(csv_file):
        logger.error(f"{csv_file} does not exist!")
        continue

    df = pd.read_csv(csv_file)

    df_results.loc[config_name, "AveErrSR"] = df["ErrSR"].mean()
    df_results.loc[config_name, "MaxErrSR"] = df["ErrSR"].max()
    df_results.loc[config_name, "MinErrSR"] = df["ErrSR"].min()
    df_results.loc[config_name, "StdErrSR"] = df["ErrSR"].std()
    df_results.loc[config_name, "Max-MinErrSR"] = (
        df_results.loc[config_name, "MaxErrSR"]
        - df_results.loc[config_name, "MinErrSR"]
    )

## Compare errors

In [None]:
df = df_results[df_results["UseLrForecast"] == True].sort_values("ObsGridRatio")

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20

fig, axes = plt.subplots(1, 2, figsize=[13, 4], sharex=True, sharey=True)

for ax, ycol in zip(axes, ["MAE_n100", "AveErrSR"]):
    data = df[(df["UseMixup"] == False) & (df["UseObs"] == True)]
    data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
    ys = data["mean"]
    errs = np.array([ys - data["min"], data["max"] - ys])
    assert len(ys) == 5 and set(data["count"]) == {5}
    ax.errorbar(data.index, ys, yerr=errs, fmt="o--", label="No Mixup", capsize=5)

    # data = df[(df["UseMixup"] == False) & (df["UseObs"] == False) & (df["beta"] == 2)]
    # assert len(data) == 1
    # ax.axhline(data[ycol].values[0], color="k", ls="--", label="No Mixup (without obs)")

    data = df[(df["UseMixup"] == True) & (df["UseObs"] == True) & (df["beta"] == 2)]
    data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
    ys = data["mean"]
    errs = np.array([ys - data["min"], data["max"] - ys])
    assert len(ys) == 5 and set(data["count"]) == {5}
    ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="Use Mixup", capsize=5)

    ax.set_ylabel("MAE")
    ax.set_xlabel("Observation point ratio [%]")
    ax.set_xlim([0.5, 6.5])
    ax.set_xticks(np.linspace(0.5, 6.5, 4))
    ax.set_ylim(0.15, 1.45)
    ax.set_yticks(np.linspace(0.2, 1.4, 7))

    if ycol == "MAE_n100":
        ax.set_title("No feedback cycles")
    else:
        ax.set_title("Repeating feedback cycles")


lg = axes[-1].legend(
    bbox_to_anchor=(1.05, 1.0),
    loc="upper left",
    ncol=1,
    fontsize=16,
    framealpha=1,
    edgecolor="k",
)

plt.tight_layout()
plt.show()

# Vorticity snapshots

In [None]:
target_config_name_muT = "lt4og12_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd221958"
target_config_name_muF = target_config_name_muT.replace("muT", "muF")

dict_models = {}
dataset = None
set_seeds(42, use_deterministic=True)

for key, config_name in zip(
    ["use_mixup", "no_mixup"], [target_config_name_muT, target_config_name_muF]
):
    config_info = CONFIGS[config_name]
    config = config_info["config"]
    weight_path = config_info["weight_path"]

    sr_model = make_model(config).to(DEVICE)
    sr_model.load_state_dict(torch.load(weight_path, map_location=DEVICE))
    _ = sr_model.eval()

    dict_models[key] = sr_model

    if key == "no_mixup":
        if os.path.exists("/workspace/all_data"):
            dataset = get_testdataloader("/workspace/all_data", config).dataset
        else:
            dataset = get_testdataloader(ROOT_DIR, config).dataset

In [None]:
set_seeds(42, use_deterministic=True)
lr, obs, gt = dataset.__getitem__(8)

gt = gt[-1, 0].permute(1, 0).numpy()
gt = gt * dataset.vorticity_scale + dataset.vorticity_bias
dict_data = {"HR": gt}

for key, model in dict_models.items():
    pred = model(lr[None, ...].to(DEVICE), obs[None, ...].to(DEVICE))
    pred = pred.squeeze().detach().cpu()
    pred = pred * dataset.vorticity_scale + dataset.vorticity_bias
    dict_data[key] = pred[-1].permute(1, 0).numpy()

obs = torch.where(obs == dataset.missing_value, torch.full_like(obs, torch.nan), obs)
obs = obs[-1, 0] * dataset.vorticity_scale + dataset.vorticity_bias
obs = obs.numpy()

plot(
    dict_data,
    t=0,
    obs=obs,
    figsize=[12, 4.0],
    ttl_header="",
    fig_file_name="",
    write_out=False,
    use_hr_space=True,
    vmin_omega=-10,
    vmax_omega=10,
)