In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
from cfd_model.filter.low_pass_periodic_channel_domain import LowPassFilter
from cfd_model.interpolator.torch_interpolator import interpolate
from scipy import stats
from scipy.ndimage import sobel
from src.dataloader import (
    make_dataloaders_vorticity_making_observation_inside_time_series_splitted,
)
from src.model_maker import make_model
from src.sr_da_helper_2 import get_testdataset
from src.utils import read_pickle, set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

# Define constants

In [None]:
WRITE_EPS = True
DPI = 300

In [None]:
DEVICE = "cuda:0"

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

FIG_DIR = "./fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

In [None]:
ASSIMILATION_PERIOD = 4

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

UHR_NX = 1024
UHR_NY = 513

DT = LR_DT * LR_NT

In [None]:
# https://matplotlib.org/stable/users/prev_whats_new/dflt_style_changes.html
DICT_COLORS = {
    "LR": "#1f77b4",
    "EnKF": "#2ca02c",
    "EnKF(bicubic)": "#2ca02c",
    "EnKF (bicubic)": "#2ca02c",
    "SRDA": "#ff7f0e",
    "ST-SRDA": "#ff7f0e",
    "EnKF(HR)": "#1f77b4",
    "EnKF (HR)": "#1f77b4",
    "SRDA (mixup)": "#ff7f0e",
    "SRDA (no mixup)": "#d62728",
}
DICT_LINE_STYLES = {
    "LR": ":",
    "EnKF": "--",
    "EnKF(bicubic)": "--",
    "EnKF (bicubic)": "--",
    "SRDA": "-",
    "ST-SRDA": "-",
    "EnKF(HR)": ":",
    "EnKF (HR)": ":",
    "SRDA (mixup)": "-",
    "SRDA (no mixup)": "-.",
}
DICT_LEGEND = {
    "LR": "LR (no SR/DA)",
    "EnKF": "EnKF-SR",
    "EnKF(bicubic)": "EnKF-SR",
    "EnKF (bicubic)": "EnKF-SR",
    "EnKF(Bicubic)": "EnKF-SR",
    "EnKF (Bicubic)": "EnKF-SR",
    "SRDA": "ST-SRDA",
    "ST-SRDA": "ST-SRDA",
    "EnKF(HR)": "EnKF-HR",
    "EnKF (HR)": "EnKF-HR",
    "SRDA (mixup)": "ST-SRDA (mixup)",
    "SRDA (no mixup)": "ST-SRDA (no mixup)",
}

# Define methods

In [None]:
def get_uhr_and_hr_omegas(uhr_result_dir: str, num_times: int = 96):
    all_uhr_omegas = []
    for path in sorted(glob.glob(f"{uhr_result_dir}/*.npy")):
        uhr = torch.from_numpy(np.load(path)).squeeze()
        assert uhr.shape == (UHR_NX, UHR_NY)
        all_uhr_omegas.append(uhr)
    # Stack along time dim
    all_uhr_omegas = torch.stack(all_uhr_omegas)[:num_times]
    assert all_uhr_omegas.shape == (num_times, UHR_NX, UHR_NY)

    tmp = all_uhr_omegas[:, None, :, 1:]
    _omegas = F.avg_pool2d(tmp, kernel_size=8).squeeze()

    all_hr_omegas = torch.zeros((num_times, HR_NX, HR_NY), dtype=_omegas.dtype)
    all_hr_omegas[:, :, 1:] = _omegas

    return all_uhr_omegas, all_hr_omegas


def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    gt_label: str,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    vmin_omega: float = -10,
    vmax_omega: float = 10,
    vmin_diff: float = -1,
    vmax_diff: float = 1,
    font_size: int = 22,
    obs_grid_interval: int = 8,
    dot_size: float = 2,
    dpi: int = DPI,
    draw_pdf: bool = False,
    write_eps: bool = WRITE_EPS,
):

    xs = np.linspace(0, 2 * np.pi, num=UHR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=UHR_NY, endpoint=True)
    uhr_x, uhr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.size"] = font_size
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        d = np.squeeze(data)

        if d.shape == (UHR_NX, UHR_NY):
            x, y = uhr_x, uhr_y
        elif d.shape == (HR_NX, HR_NY):
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        if label == gt_label:
            gt = d
            ttl = label
        else:
            _d = interpolate(
                torch.from_numpy(d[None, ...]),
                nx=gt.shape[0],
                ny=gt.shape[1],
                mode="bicubic",
            )
            _d = _d.squeeze().numpy()
            assert _d.shape == gt.shape
            maer = np.mean(np.abs(gt - _d)) / np.mean(np.abs(gt))
            ttl = label
            ttl = f"{label}\n(MAE ratio = {maer:.2f})"

        if "Diff" in label or "diff" in label:
            cnts = ax.pcolormesh(
                x, y, d, cmap="twilight_shifted", vmin=vmin_diff, vmax=vmax_diff
            )

            fig.colorbar(
                cnts,
                ax=ax,
                ticks=[vmin_diff, vmin_diff / 2, 0, vmax_diff / 2, vmax_diff],
                extend="both",
            )
        else:

            cnts = ax.pcolormesh(
                x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
            )

            fig.colorbar(
                cnts,
                ax=ax,
                ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
                extend="both",
            )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == gt_label and obs is not None:
            assert obs.shape == (HR_NX, HR_NY)
            o = np.squeeze(obs).flatten()
            obs_x = hr_x.flatten()[~np.isnan(o)]
            obs_y = hr_y.flatten()[~np.isnan(o)]
            print(np.sum(~np.isnan(o)) / len(o) * 100)
            ax.scatter(obs_x, obs_y, marker=".", s=dot_size, c="k")
            if obs_grid_interval is not None:
                prob = OBS_GRID_RATIO[obs_grid_interval] * 100
                ttl = f"{ttl}\n(obs. points: {prob:.2f} %)"

        ax.set_title(ttl)
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    if t is None:
        plt.suptitle(ttl_header)
    else:
        plt.suptitle(f"{ttl_header}Time = {np.round(t, 2)}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg", dpi=dpi)
        if write_eps:
            fig.savefig(f"{FIG_DIR}/{fig_file_name}.eps", dpi=dpi)
        if draw_pdf:
            fig.savefig(f"{FIG_DIR}/{fig_file_name}.pdf", dpi=dpi)

    plt.show()

# Plot diff

In [None]:
obs_grid_interval = 8
# 221958, 771155, 832180, 465838, 359178

srda_config_name = (
    f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd771155"
)
srda_data_path = f"{ROOT_DIR}/pytorch/notebook/paper_experiment_06/data/{srda_config_name}_with_init_noise.npz"


with open(
    f"{ROOT_DIR}/pytorch/config/paper_experiment_06/{srda_config_name}.yml"
) as file:
    config = yaml.safe_load(file)

In [None]:
data = np.load(srda_data_path)
hr_omega = data["hr_omega"]
hr_obsrv = data["hr_obsrv"]
sr_forecast = data["sr_frcst"]

In [None]:
low_pass_filter = LowPassFilter(
    nx_lr=LR_NX, ny_lr=LR_NY, nx_hr=HR_NX, ny_hr=HR_NY, device="cpu"
)

In [None]:
i_ens = 185
for it in [24, 56, 84]:
    t = it * DT

    dict_data = OrderedDict({})

    gt = hr_omega[i_ens, it]
    sobel_x = sobel(gt, 0)
    sobel_y = sobel(gt, 1)
    sobel_norm = np.sqrt(sobel_x**2 + sobel_y**2)

    vmin = np.quantile(sobel_norm.flatten(), 0.01)
    vmax = np.quantile(sobel_norm.flatten(), 0.99)
    sobel_norm = (sobel_norm - vmin) / (vmax - vmin)

    diffs = hr_omega[i_ens, it] - sr_forecast[i_ens, it]

    dict_data["Ground truth"] = hr_omega[i_ens, it]
    # dict_data["ST-SRDA"] = sr_forecast[i_ens, it]
    # dict_data[r"Diff: GT $-$ SRDA"] = diffs
    # dict_data["Vortex edge"] = sobel_norm

    lr_gt = (
        low_pass_filter.apply(torch.from_numpy(hr_omega[i_ens, it][None, :]))
        .squeeze()
        .numpy()
    )
    lr_pred = (
        low_pass_filter.apply(torch.from_numpy(sr_forecast[i_ens, it][None, :]))
        .squeeze()
        .numpy()
    )

    dict_data["Ground truth\n(low-pass filtered)"] = lr_gt
    dict_data["ST-SRDA\n(low-pass filtered)"] = lr_pred
    dict_data[r"Diff (LR): GT $-$ SRDA"] = lr_gt - lr_pred

    plt.rcParams["font.size"] = 22
    fig, axes = plt.subplots(1, 4, figsize=[20, 4])

    for ax, (label, data) in zip(axes, dict_data.items()):
        print(data.shape)

        if ("Ground truth" in label) or ("ST-SRDA" in label):
            cmap = "twilight_shifted"
            vmin, vmax = -10, 10
        elif label == "Vortex edge":
            vmin, vmax = 0.3, 1
            cmap = "binary"
        elif "Diff" in label:
            vmin, vmax = -4, 4
            cmap = "bwr"

        xs = np.linspace(0, 2 * np.pi, num=data.shape[0], endpoint=False)
        ys = np.linspace(0, np.pi, num=data.shape[1], endpoint=False)
        xs, ys = np.meshgrid(xs, ys, indexing="ij")

        cnt = ax.pcolormesh(xs, ys, data, vmin=vmin, vmax=vmax, cmap=cmap)
        fig.colorbar(
            cnt,
            ax=ax,
            extend="both",
            ticks=[vmin, vmin / 2, 0, vmax / 2, vmax],
        )

        ax.set_title(label)

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"Time = {t}")
    plt.tight_layout()
    plt.show()

# Plot diff in distributions

In [None]:
hr_omega.shape, sr_forecast.shape

In [None]:
pvalues = []
is_plotted = False

for it in range(12, hr_omega.shape[1], 4):
    t = it * DT

    gt = torch.from_numpy(hr_omega[:, it])
    pred = torch.from_numpy(sr_forecast[:, it])

    gt = low_pass_filter.apply(gt).numpy().flatten()
    pred = low_pass_filter.apply(pred).numpy().flatten()
    assert len(gt) == len(pred)

    idx = np.array(range(len(gt)))
    np.random.shuffle(idx)
    idx = idx[:2000]
    gt = gt[idx]
    pred = pred[idx]

    diff = gt - pred

    results = stats.ks_2samp(gt, pred)
    print(f"t = {t}, results = {results}")
    pvalues.append(results.pvalue)

    if not is_plotted:
        continue

    fig, axes = plt.subplots(1, 3, figsize=[15, 4])

    for ax, data, label in zip(axes, [gt, pred, diff], ["GT", "Pred", "Diff"]):
        kappa = stats.kurtosis(data)
        rng = (-3, 3) if label == "Diff" else (-10, 10)
        ax.hist(data, bins=51, density=True, range=rng)
        ax.set_title(f"{label}, k={kappa:.5f}")

    plt.suptitle(f"Time = {t}")
    plt.tight_layout()
    plt.show()

In [None]:
min(pvalues), max(pvalues)

# Time series of kurtosis

In [None]:
ts, ps, ks_gt, ks_pred = [], [], [], []
for it in range(12, hr_omega.shape[1], 4):
    t = it * DT

    gt = torch.from_numpy(hr_omega[:, it])
    pred = torch.from_numpy(sr_forecast[:, it])

    gt = low_pass_filter.apply(gt).numpy().flatten()
    pred = low_pass_filter.apply(pred).numpy().flatten()

    results = stats.ks_2samp(gt, pred, alternative="two-sided")
    assert results.pvalue == 0

    ts.append(t)
    ks_gt.append(stats.kurtosis(gt, fisher=True))
    ks_pred.append(stats.kurtosis(pred, fisher=True))

In [None]:
plt.plot(ts, ks_gt, label="gt")
plt.plot(ts, ks_pred, label="pred")
plt.show()

In [None]:
plt.plot(ts, np.array(ks_pred) - np.array(ks_gt))