In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from cfd_model.interpolator.torch_interpolator import interpolate
from src.utils import read_pickle, set_seeds

plt.rcParams["font.family"] = "serif"
pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
SRDA_DATA_DIR = f"{ROOT_DIR}/data/SRDA"
ENKF_DATA_DIR = f"{ROOT_DIR}/data/EnKF"

In [None]:
ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 16
OBS_GRID_INTERVAL = 8

LR_NX = 32
LR_NY = 17
HR_NX = 128
HR_NY = 65

DT = 0.25
T0 = START_TIME_INDEX * DT

# Define methods

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    ttl_header: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if "LR" in label:
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=HR_NX, ny=HR_NY, mode="nearest"
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR ground truth":
            gt = d
            ttl = "Ground Truth"
        else:
            maer = np.mean(np.abs(gt - d)) / np.mean(np.abs(gt))
            ttl = label
            ttl = f"{label}\n(MAE = {maer:.2f})"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY)
        else:
            assert d.shape == (LR_NX, LR_NY)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR ground truth" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=3, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    plt.show()


def get_all_srda_result_paths(config_name: str):
    sr_prior_file_path = f"{SRDA_DATA_DIR}/sr_prior_{config_name}.npy"
    sr_analysis_file_path = f"{SRDA_DATA_DIR}/sr_analysis_{config_name}.npy"
    lr_omega_file_path = f"{SRDA_DATA_DIR}/lr_omega_{config_name}.npy"
    hr_omega_file_path = f"{SRDA_DATA_DIR}/hr_omega_{config_name}.npy"
    hr_obsrv_file_path = f"{SRDA_DATA_DIR}/hr_obsrv_{config_name}.npy"

    return (
        sr_prior_file_path,
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    )


def read_all_srda_result_files(config_name: str):
    (
        sr_prior_file_path,
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_srda_result_paths(config_name)

    return (
        np.load(sr_prior_file_path),
        np.load(sr_analysis_file_path),
        np.load(lr_omega_file_path),
        np.load(hr_omega_file_path),
        np.load(hr_obsrv_file_path),
    )

# Plot snapshots

In [None]:
(
    all_sr_prior,
    all_sr_analysis,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_srda_result_files(config_name="default_neural_nets")

In [None]:
targets = [0]

for i_ensemble in targets:
    all_enkf_hr = read_pickle(
        f"{ENKF_DATA_DIR}/ens_mean_hr_og{OBS_GRID_INTERVAL:02}_{i_ensemble:04}.pickle"
    )

    for i_cycle in range(ASSIMILATION_PERIOD, 81, ASSIMILATION_PERIOD):

        t = (i_cycle + START_TIME_INDEX) * DT

        dict_data = {
            "HR ground truth": all_hr_omega[i_ensemble, i_cycle],
            "LR (no SR or DA)": all_lr_omega[i_ensemble, i_cycle],
            "EnKF (HR analysis)": all_enkf_hr[i_cycle],
            "SRDA (HR forecast)": all_sr_prior[i_ensemble, i_cycle],
            "SRDA (HR analysis)": all_sr_analysis[i_ensemble, i_cycle],
        }

        plot(
            dict_data,
            t,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            figsize=[25, 3.5],
            ttl_header=f"i_ens = {i_ensemble}, ",
            use_hr_space=True,
            vmin_omega=-10,
            vmax_omega=10,
        )