In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
# !pip install xarray==2022.10.0

# Import libraries

In [None]:
import copy
import datetime
import gc
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

# import xarray as xr
import yaml
from numpy.testing import assert_array_equal
from scipy import interpolate, signal
from src.dataloader import (
    get_all_new_lr_data_dir_paths,
    make_dataloaders,
    make_evaluation_dataloader_without_random_cropping,
    split_into_train_valid_test_dirs,
)
from src.loss_maker import (
    AbsDiffDivergence,
    AbsDiffTemperature,
    ChannelwiseMse,
    DiffOmegaVectorNorm,
    DiffVelocityVectorNorm,
    MaskedL1Loss,
    MaskedL1LossNearWall,
    MaskedL2Loss,
    MaskedL2LossNearWall,
    MyL1Loss,
    MyL2Loss,
    ResidualContinuity,
    Ssim3dLoss,
    calc_mask_near_build_wall,
)
from src.model_maker import make_model
from src.optim_helper import evaluate
from src.ssim import SSIM3D
from src.utils import calc_early_stopping_patience, read_pickle, set_seeds, write_pickle
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)
plt.rcParams["font.family"] = "Times New Roman"

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
DL_DATA_DIR = pathlib.Path(f"{ROOT_DIR}/data/DL_data")
DL_INFERENCE_DIR = pathlib.Path(f"{ROOT_DIR}/data/DL_inferences")
MAKE_EPS_FILES = False

DATA_DIR = "./data"
FIG_DIR = f"{ROOT_DIR}/doc/paper_fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
EXPERIMENT_NAME = "new_lr_unet_gconv_change_datashape_ddp"
CONFIG_PATHS = sorted(glob.glob(f"{ROOT_DIR}/pytorch/config/{EXPERIMENT_NAME}/*.yml"))

In [None]:
HR_IS_IN_BUILD = torch.from_numpy(
    np.load(f"{ROOT_DIR}/data/DL_data/10/hr_is_in_build.npy")
)
LR_IS_IN_BUILD = torch.from_numpy(
    np.load(f"{ROOT_DIR}/data/DL_data/10/lr_is_in_build.npy")
)

HR_BUILDING_HEIGHT_PATH = f"{ROOT_DIR}/datascience/script/EleTopoZ_HR.txt"
LR_BUILDING_HEIGHT_PATH = f"{ROOT_DIR}/datascience/script/EleTopoZ_LR.txt"

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    if "tutorial" in config_path:
        continue
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]

    _dir = f"{ROOT_DIR}/data/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

# Define methods

In [None]:
def dimensionalize(data, means, scales):
    assert data.ndim == means.ndim == scales.ndim
    assert data.shape[1] == 4
    return data * scales + means


def undimensionalize(data, means, scales):
    assert data.ndim == means.ndim == scales.ndim
    assert data.shape[1] == 4
    return (data - means) / scales


def fill_na(xs: torch.Tensor):

    _xs = F.interpolate(xs[:, None], scale_factor=4, mode="trilinear").squeeze()
    _is_in_build = F.interpolate(
        LR_IS_IN_BUILD[:, None], scale_factor=4, mode="nearest"
    ).squeeze()
    assert _xs.shape == _is_in_build.shape == HR_IS_IN_BUILD.shape

    _xs = torch.where(_is_in_build == 0, _xs, torch.full_like(_xs, torch.nan))

    _xs = xr.DataArray(_xs)
    assert (np.isnan(_xs.values)).sum() > 0, "No Nan values"

    _xs = (
        _xs.interpolate_na(dim="dim_1", method="linear", fill_value="extrapolate")
        .interpolate_na(dim="dim_2", method="linear", fill_value="extrapolate")
        .interpolate_na(dim="dim_3", method="linear", fill_value="extrapolate")
    )

    assert (np.isnan(_xs.values)).sum() == 0

    return _xs


def read_building_height(
    building_path: str, target_col: str, margin: int = 0
) -> np.ndarray:

    with open(building_path, "r") as file:
        lines = file.readlines()

    cols = ["i", "j", "Ez", "Tz", "Tzl"]
    _dict = {}
    for i, line in enumerate(lines[1:]):  # skip header
        splits = list(
            map(lambda s: s.strip(), filter(lambda s: s != "", line.split(" ")))
        )
        _dict[i] = {k: v for k, v in zip(cols, splits)}

    df_topography = pd.DataFrame.from_dict(_dict).T

    for col in cols:
        if col == "i" or col == "j":
            df_topography[col] = df_topography[col].astype(int)
        else:
            df_topography[col] = df_topography[col].astype(float)

    ret = pd.pivot_table(
        data=df_topography[["i", "j", target_col]],
        values=target_col,
        index="i",
        columns="j",
        aggfunc="max",
    ).values

    if margin == 0:
        return ret
    else:
        return ret[margin:-margin, margin:-margin]


def calc_is_in_building(
    tz: np.ndarray, ez: np.ndarray, actual_levs: np.ndarray
) -> np.ndarray:

    # tz = build height, ez = ground height, both from sea surface

    assert tz.shape == ez.shape
    assert len(tz.shape) == 2  # y and x
    assert len(actual_levs.shape) == 1  # z

    _shape = actual_levs.shape + tz.shape  # dims = (z, y, x)

    is_in_building = np.zeros(_shape)
    for j in range(is_in_building.shape[1]):  # y dim
        for i in range(is_in_building.shape[2]):  # x dim
            t, e = tz[j, i], ez[j, i]
            if t <= e:  # BH is lower than or equal to the ground.
                continue  # This means there is no building.

            idx_top_of_build = (actual_levs < t).argmin()
            is_in_building[:idx_top_of_build, j, i] = 1

    return is_in_building


def calc_mean_and_errors(values):
    assert values.shape == (5, 32)
    means = torch.mean(values, dim=0)  # average over batch dim
    errs = torch.stack(
        [means - torch.min(values, dim=0)[0], torch.max(values, dim=0)[0] - means]
    )
    logger.info(f"mean shape = {means.shape}")
    logger.info(f"error shape = {errs.shape}")
    return means, errs

# Horizontal sections

## case: Z0

In [None]:
experiment_name = EXPERIMENT_NAME
config_name = "z0_2475_825"
index_sample = 300

#################

inference_dir = DL_INFERENCE_DIR / experiment_name / config_name
inference_paths = sorted(glob.glob(str(inference_dir / "*.npy")))

logger.setLevel(INFO)

config = CONFIGS[config_name]["config"]
logger.setLevel(WARNING)
test_loader = make_evaluation_dataloader_without_random_cropping(
    config, DL_DATA_DIR, batch_size=1
)
test_dataset = test_loader.dataset
logger.setLevel(INFO)

means = torch.Tensor(config["data"]["means"])[None, :, None, None, None]
scales = torch.Tensor(config["data"]["stds"])[None, :, None, None, None]

assert os.path.basename(inference_paths[index_sample]) == os.path.basename(
    test_dataset.lr_files[index_sample]
).replace("LR", "SR")


Xs, bs, ys = test_dataset.__getitem__(index_sample)

Xs_scaled = dimensionalize(Xs[None, ...], means, scales).squeeze()
bs = bs[None, ...]
ys_scaled = dimensionalize(ys[None, ...], means, scales).squeeze()

preds_scaled = torch.from_numpy(np.load(inference_paths[index_sample])).squeeze()

In [None]:
dict_cmap = {
    "tm": "turbo",
    "vl": "YlOrRd",
    "vp": "YlGnBu",
    "vr": "viridis",
}

sx, ex = 170, 230
sy, ey = 70, 130
assert ey - sy == ex - sx

use_data_vim_vmax = True
dict_vmin = {"tm": 32.00, "vl": -4.0, "vp": -3.0, "vr": -1.0}
dict_vmax = {"tm": 34.00, "vl": 2.0, "vp": 3.0, "vr": 1.0}

tex_labels = {
    "tm": r"$T$",
    "vl": r"$u$",
    "vp": r"$v$",
    "vr": r"$w$",
}

lr_scale_factor = 4
hr_scale_factor = 1
enhance_edges = False

var_names = {0: "tm", 1: "vl", 2: "vp", 3: "vr"}
var_units = {0: "[$^\circ$C]", 1: "[m/s]", 2: "[m/s]", 3: "[m/s]"}

var_dataset = OrderedDict(
    {
        "HR": ys_scaled,
        "LR": Xs_scaled,
        "SR": preds_scaled,
        # "HR-SR": ys_scaled - preds_scaled,
    }
)

ttl_labels = {
    "HR": "Ground Truth",
    "LR": "LR Input",
    "SR": "SR Inference",
}


for idx_height in [0, 8]:

    enhance_edges = False if idx_height == 0 else True

    plt.rcParams["font.size"] = 18
    plt.rcParams["font.family"] = "Times New Roman"
    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=[8, 10])
    Gx, Gy = None, None

    for i, j, idx_var in zip([0, 1, 2, 3], [-1, -1, -1, -1], [0, 1, 2, 3]):
        var_name = var_names[idx_var]
        var_unit = var_units[idx_var]
        ground_truth, bldg = None, None

        for var_kind in var_dataset.keys():
            j += 1

            ax = axes[i, j]
            ax.set_aspect("equal")
            ax.xaxis.set_ticklabels([])
            ax.yaxis.set_ticklabels([])
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)

            if "HR" in var_kind or "SR" in var_kind:
                scale_factor = hr_scale_factor
                assert scale_factor == 1
                is_in_bldg = HR_IS_IN_BUILD[:, :32, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - HR_IS_IN_BUILD[:, :32, :, :], dtype=torch.float32
                ).unsqueeze(1)
            else:
                scale_factor = lr_scale_factor
                assert scale_factor == 4
                is_in_bldg = LR_IS_IN_BUILD[:, :8, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - LR_IS_IN_BUILD[:, :8, :, :], dtype=torch.float32
                ).unsqueeze(1)

            bldg = (
                torch.nn.functional.interpolate(
                    bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )
            is_in_bldg = (
                torch.nn.functional.interpolate(
                    is_in_bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )

            assert bldg.shape == is_in_bldg.shape == (4, 32, 320, 320)

            bldg = bldg[0, idx_height].transpose()[sx:ex, sy:ey]
            is_in_bldg = is_in_bldg[0, idx_height].transpose()[sx:ex, sy:ey]

            if enhance_edges:
                edge = signal.convolve2d(is_in_bldg, np.ones((3, 3)), "same")
                edge = np.where(edge > 0, np.ones_like(edge), np.zeros_like(edge))
                edge = np.where(edge * bldg > 0, np.ones_like(bldg), np.nan)

            data = torch.nn.functional.interpolate(
                var_dataset[var_kind][None, ...],
                scale_factor=scale_factor,
                mode="nearest",
            ).squeeze()

            assert data.shape == (4, 32, 320, 320)
            if var_kind == "LR":
                assert var_dataset[var_kind].shape == (4, 8, 80, 80)
            else:
                assert var_dataset[var_kind].shape == (4, 32, 320, 320)

            data = data[idx_var, idx_height].numpy().transpose()[sx:ex, sy:ey]

            data = np.where(bldg, data, np.nan)

            if Gx is None or Gy is None:
                Gx = 5 * np.arange(data.shape[0])
                Gy = 5 * np.arange(data.shape[1])
                Gx, Gy = np.meshgrid(Gx, Gy, indexing="ij")

            if var_kind != "HR-SR" and var_name == "tm":
                data -= 273.15

            if var_kind == "HR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                # print(var_name, vmin, vmax)

            if var_kind == "HR-SR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                abs_max = max([np.abs(vmin), np.abs(vmax)])
                vmax = abs_max
                vmin = -abs_max
                my_cmap = cm.get_cmap("seismic").copy()
                my_cmap.set_bad("white")
                ttl = f"{var_kind} {var_unit}"
            else:
                my_cmap = cm.get_cmap(dict_cmap[var_name]).copy()
                my_cmap.set_bad("whitesmoke")

                ttl_label = ttl_labels[var_kind]
                ttl = f"{ttl_label}\n{tex_labels[var_name]} {var_unit}"

            if not use_data_vim_vmax:
                vmin = dict_vmin[var_name]
                vmax = dict_vmax[var_name]

            contours = ax.pcolormesh(Gx, Gy, data, cmap=my_cmap, vmin=vmin, vmax=vmax)
            fig.colorbar(
                contours,
                ax=ax,
                format="%.1f",
                ticks=np.linspace(vmin, vmax, 4, endpoint=True),
                extend="both",
            )
            ax.set_title(ttl)

            if enhance_edges:
                ax.pcolormesh(
                    Gx, Gy, edge, cmap="binary", vmin=0.9, vmax=1.01, alpha=0.2
                )

    # The height index zero of the original HR array is located at 17.5 m
    # The bottom of this cell is located at 15.0 m
    height_meter = 5 * (idx_height) + 17.5

    header = "(a)" if idx_height == 0 else "(b)"

    plt.suptitle(f"{header} Height = {height_meter - 14.5} [m]")
    plt.tight_layout()

    # fig.savefig(
    #     f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}.jpg",
    #     dpi=300,
    # )
    # if MAKE_EPS_FILES:
    #     fig.savefig(
    #         f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}.eps",
    #         dpi=300,
    #     )
    plt.show()

In [None]:
dict_cmap = {
    "tm": "turbo",
    "vl": "Spectral",
    "vp": "inferno",
    "vr": "coolwarm",
}

sx, ex = 170, 230
sy, ey = 70, 130
assert ey - sy == ex - sx

use_data_vim_vmax = True
dict_vmin = {"tm": 32.00, "vl": -4.0, "vp": -3.0, "vr": -1.0}
dict_vmax = {"tm": 34.00, "vl": 2.0, "vp": 3.0, "vr": 1.0}

tex_labels = {
    "tm": r"$T$",
    "vl": r"$u$",
    "vp": r"$v$",
    "vr": r"$w$",
}

lr_scale_factor = 4
hr_scale_factor = 1
enhance_edges = False

var_names = {0: "tm", 1: "vl", 2: "vp", 3: "vr"}
var_units = {0: "[$^\circ$C]", 1: "[m/s]", 2: "[m/s]", 3: "[m/s]"}

var_dataset = OrderedDict(
    {
        "HR": ys_scaled,
        "LR": Xs_scaled,
        "SR": preds_scaled,
        # "HR-SR": ys_scaled - preds_scaled,
    }
)

ttl_labels = {
    "HR": "Ground Truth",
    "LR": "LR Input",
    "SR": "SR Inference",
}


for idx_height in [0, 8]:

    enhance_edges = True  # False if idx_height == 0 else True

    plt.rcParams["font.size"] = 18
    plt.rcParams["font.family"] = "Times New Roman"
    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=[8, 10])
    Gx, Gy = None, None

    for i, j, idx_var in zip([0, 1, 2, 3], [-1, -1, -1, -1], [0, 1, 2, 3]):
        var_name = var_names[idx_var]
        var_unit = var_units[idx_var]
        ground_truth, bldg = None, None

        for var_kind in var_dataset.keys():
            j += 1

            ax = axes[i, j]
            ax.set_aspect("equal")
            ax.xaxis.set_ticklabels([])
            ax.yaxis.set_ticklabels([])
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)

            if "HR" in var_kind or "SR" in var_kind:
                scale_factor = hr_scale_factor
                assert scale_factor == 1
                is_in_bldg = HR_IS_IN_BUILD[:, :32, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - HR_IS_IN_BUILD[:, :32, :, :], dtype=torch.float32
                ).unsqueeze(1)
            else:
                scale_factor = lr_scale_factor
                assert scale_factor == 4
                is_in_bldg = LR_IS_IN_BUILD[:, :8, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - LR_IS_IN_BUILD[:, :8, :, :], dtype=torch.float32
                ).unsqueeze(1)

            bldg = (
                torch.nn.functional.interpolate(
                    bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )
            is_in_bldg = (
                torch.nn.functional.interpolate(
                    is_in_bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )

            assert bldg.shape == is_in_bldg.shape == (4, 32, 320, 320)

            bldg = bldg[0, idx_height].transpose()[sx:ex, sy:ey]
            is_in_bldg = is_in_bldg[0, idx_height].transpose()[sx:ex, sy:ey]

            if enhance_edges:
                edge = signal.convolve2d(is_in_bldg, np.ones((3, 3)), "same")
                edge = np.where(edge > 0, np.ones_like(edge), np.zeros_like(edge))
                edge = np.where(edge * bldg > 0, np.ones_like(bldg), np.nan)

            data = torch.nn.functional.interpolate(
                var_dataset[var_kind][None, ...],
                scale_factor=scale_factor,
                mode="nearest",
            ).squeeze()

            assert data.shape == (4, 32, 320, 320)
            if var_kind == "LR":
                assert var_dataset[var_kind].shape == (4, 8, 80, 80)
            else:
                assert var_dataset[var_kind].shape == (4, 32, 320, 320)

            data = data[idx_var, idx_height].numpy().transpose()[sx:ex, sy:ey]

            data = np.where(bldg, data, np.nan)

            if Gx is None or Gy is None:
                Gx = 5 * np.arange(data.shape[0])
                Gy = 5 * np.arange(data.shape[1])
                Gx, Gy = np.meshgrid(Gx, Gy, indexing="ij")

            if var_kind != "HR-SR" and var_name == "tm":
                data -= 273.15

            if var_kind == "HR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                # print(var_name, vmin, vmax)

            if var_kind == "HR-SR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                abs_max = max([np.abs(vmin), np.abs(vmax)])
                vmax = abs_max
                vmin = -abs_max
                my_cmap = cm.get_cmap("seismic").copy()
                my_cmap.set_bad("white")
                ttl = f"{var_kind} {var_unit}"
            else:
                my_cmap = cm.get_cmap(dict_cmap[var_name]).copy()
                my_cmap.set_bad("whitesmoke")

                ttl_label = ttl_labels[var_kind]
                ttl = f"{ttl_label}\n{tex_labels[var_name]} {var_unit}"

            if not use_data_vim_vmax:
                vmin = dict_vmin[var_name]
                vmax = dict_vmax[var_name]

            contours = ax.pcolormesh(Gx, Gy, data, cmap=my_cmap, vmin=vmin, vmax=vmax)
            fig.colorbar(
                contours,
                ax=ax,
                format="%.1f",
                ticks=np.linspace(vmin, vmax, 4, endpoint=True),
                extend="both",
            )
            ax.set_title(ttl)

            if enhance_edges:
                ax.pcolormesh(
                    Gx, Gy, edge, cmap="binary", vmin=0.5, vmax=1.0, alpha=0.4
                )

    # The height index zero of the original HR array is located at 17.5 m
    # The bottom of this cell is located at 15.0 m
    height_meter = 5 * (idx_height) + 17.5

    header = "(a)" if idx_height == 0 else "(b)"

    plt.suptitle(f"{header} Height = {height_meter - 14.5} [m]")
    plt.tight_layout()

    # fig.savefig(
    #     f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}_news_paper.jpg",
    #     dpi=300,
    # )
    # if MAKE_EPS_FILES:
    #     fig.savefig(
    #         f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}_news_paper.eps",
    #         dpi=300,
    #     )
    plt.show()

## case: Z2

In [None]:
experiment_name = EXPERIMENT_NAME
config_name = "z2_2475_825"
index_sample = 300

#################

inference_dir = DL_INFERENCE_DIR / experiment_name / config_name
inference_paths = sorted(glob.glob(str(inference_dir / "*.npy")))

logger.setLevel(INFO)

config = CONFIGS[config_name]["config"]
logger.setLevel(WARNING)
test_loader = make_evaluation_dataloader_without_random_cropping(
    config, DL_DATA_DIR, batch_size=1
)
test_dataset = test_loader.dataset
logger.setLevel(INFO)

means = torch.Tensor(config["data"]["means"])[None, :, None, None, None]
scales = torch.Tensor(config["data"]["stds"])[None, :, None, None, None]

assert os.path.basename(inference_paths[index_sample]) == os.path.basename(
    test_dataset.lr_files[index_sample]
).replace("LR", "SR")


Xs, bs, ys = test_dataset.__getitem__(index_sample)

Xs_scaled = dimensionalize(Xs[None, ...], means, scales).squeeze()
bs = bs[None, ...]
ys_scaled = dimensionalize(ys[None, ...], means, scales).squeeze()

preds_scaled = torch.from_numpy(np.load(inference_paths[index_sample])).squeeze()

In [None]:
dict_cmap = {
    "tm": "turbo",
    "vl": "YlOrRd",
    "vp": "YlGnBu",
    "vr": "viridis",
}

sx, ex = 170, 230
sy, ey = 70, 130
assert ey - sy == ex - sx

use_data_vim_vmax = True
dict_vmin = {"tm": 32.00, "vl": -4.0, "vp": -3.0, "vr": -1.0}
dict_vmax = {"tm": 34.00, "vl": 2.0, "vp": 3.0, "vr": 1.0}

tex_labels = {
    "tm": r"$T$",
    "vl": r"$u$",
    "vp": r"$v$",
    "vr": r"$w$",
}

lr_scale_factor = 4
hr_scale_factor = 1
enhance_edges = False

var_names = {0: "tm", 1: "vl", 2: "vp", 3: "vr"}
var_units = {0: "[$^\circ$C]", 1: "[m/s]", 2: "[m/s]", 3: "[m/s]"}

var_dataset = OrderedDict(
    {
        "HR": ys_scaled,
        "LR": Xs_scaled,
        "SR": preds_scaled,
        # "HR-SR": ys_scaled - preds_scaled,
    }
)

ttl_labels = {
    "HR": "Ground Truth",
    "LR": "LR Input",
    "SR": "SR Inference",
}


for idx_height in [0, 8]:

    enhance_edges = False if idx_height == 0 else True

    plt.rcParams["font.size"] = 18
    plt.rcParams["font.family"] = "Times New Roman"
    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=[8, 10])
    Gx, Gy = None, None

    for i, j, idx_var in zip([0, 1, 2, 3], [-1, -1, -1, -1], [0, 1, 2, 3]):
        var_name = var_names[idx_var]
        var_unit = var_units[idx_var]
        ground_truth, bldg = None, None

        for var_kind in var_dataset.keys():
            j += 1

            ax = axes[i, j]
            ax.set_aspect("equal")
            ax.xaxis.set_ticklabels([])
            ax.yaxis.set_ticklabels([])
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)

            if "HR" in var_kind or "SR" in var_kind:
                scale_factor = hr_scale_factor
                assert scale_factor == 1
                is_in_bldg = HR_IS_IN_BUILD[:, :32, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - HR_IS_IN_BUILD[:, :32, :, :], dtype=torch.float32
                ).unsqueeze(1)
            else:
                scale_factor = lr_scale_factor
                assert scale_factor == 4
                is_in_bldg = LR_IS_IN_BUILD[:, :8, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - LR_IS_IN_BUILD[:, :8, :, :], dtype=torch.float32
                ).unsqueeze(1)

            bldg = (
                torch.nn.functional.interpolate(
                    bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )
            is_in_bldg = (
                torch.nn.functional.interpolate(
                    is_in_bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )

            assert bldg.shape == is_in_bldg.shape == (4, 32, 320, 320)

            bldg = bldg[0, idx_height].transpose()[sx:ex, sy:ey]
            is_in_bldg = is_in_bldg[0, idx_height].transpose()[sx:ex, sy:ey]

            if enhance_edges:
                edge = signal.convolve2d(is_in_bldg, np.ones((3, 3)), "same")
                edge = np.where(edge > 0, np.ones_like(edge), np.zeros_like(edge))
                edge = np.where(edge * bldg > 0, np.ones_like(bldg), np.nan)

            data = torch.nn.functional.interpolate(
                var_dataset[var_kind][None, ...],
                scale_factor=scale_factor,
                mode="nearest",
            ).squeeze()

            assert data.shape == (4, 32, 320, 320)
            if var_kind == "LR":
                assert var_dataset[var_kind].shape == (4, 8, 80, 80)
            else:
                assert var_dataset[var_kind].shape == (4, 32, 320, 320)

            data = data[idx_var, idx_height].numpy().transpose()[sx:ex, sy:ey]

            data = np.where(bldg, data, np.nan)

            if Gx is None or Gy is None:
                Gx = 5 * np.arange(data.shape[0])
                Gy = 5 * np.arange(data.shape[1])
                Gx, Gy = np.meshgrid(Gx, Gy, indexing="ij")

            if var_kind != "HR-SR" and var_name == "tm":
                data -= 273.15

            if var_kind == "HR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                # print(var_name, vmin, vmax)

            if var_kind == "HR-SR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                abs_max = max([np.abs(vmin), np.abs(vmax)])
                vmax = abs_max
                vmin = -abs_max
                my_cmap = cm.get_cmap("seismic").copy()
                my_cmap.set_bad("white")
                ttl = f"{var_kind} {var_unit}"
            else:
                my_cmap = cm.get_cmap(dict_cmap[var_name]).copy()
                my_cmap.set_bad("whitesmoke")

                ttl_label = ttl_labels[var_kind]
                ttl = f"{ttl_label}\n{tex_labels[var_name]} {var_unit}"

            if not use_data_vim_vmax:
                vmin = dict_vmin[var_name]
                vmax = dict_vmax[var_name]

            contours = ax.pcolormesh(Gx, Gy, data, cmap=my_cmap, vmin=vmin, vmax=vmax)
            fig.colorbar(
                contours,
                ax=ax,
                format="%.1f",
                ticks=np.linspace(vmin, vmax, 4, endpoint=True),
                extend="both",
            )
            ax.set_title(ttl)

            if enhance_edges:
                ax.pcolormesh(
                    Gx, Gy, edge, cmap="binary", vmin=0.9, vmax=1.01, alpha=0.2
                )

    # The height index zero of the original HR array is located at 17.5 m
    # The bottom of this cell is located at 15.0 m
    height_meter = 5 * (idx_height) + 17.5

    header = "(a)" if idx_height == 0 else "(b)"

    plt.suptitle(f"{header} Height = {height_meter - 14.5} [m]")
    plt.tight_layout()

    # fig.savefig(
    #     f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}.jpg",
    #     dpi=300,
    # )
    # if MAKE_EPS_FILES:
    #     fig.savefig(
    #         f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}.eps",
    #         dpi=300,
    #     )
    plt.show()

# Vertical sections

## Case: Z0

In [None]:
experiment_name = EXPERIMENT_NAME
config_name = "z0_2475_825"
index_sample = 300

#################

inference_dir = DL_INFERENCE_DIR / experiment_name / config_name
inference_paths = sorted(glob.glob(str(inference_dir / "*.npy")))

logger.setLevel(INFO)

config = CONFIGS[config_name]["config"]
logger.setLevel(WARNING)
test_loader = make_evaluation_dataloader_without_random_cropping(
    config, DL_DATA_DIR, batch_size=1
)
test_dataset = test_loader.dataset
logger.setLevel(INFO)

means = torch.Tensor(config["data"]["means"])[None, :, None, None, None]
scales = torch.Tensor(config["data"]["stds"])[None, :, None, None, None]

assert os.path.basename(inference_paths[index_sample]) == os.path.basename(
    test_dataset.lr_files[index_sample]
).replace("LR", "SR")


Xs, bs, ys = test_dataset.__getitem__(index_sample)

Xs_scaled = dimensionalize(Xs[None, ...], means, scales).squeeze()
bs = bs[None, ...]
ys_scaled = dimensionalize(ys[None, ...], means, scales).squeeze()

preds_scaled = torch.from_numpy(np.load(inference_paths[index_sample])).squeeze()

In [None]:
dict_cmap = {
    "tm": "turbo",
    "vl": "YlOrRd",
    "vp": "YlGnBu",
    "vr": "viridis",
}

sx, ex = 136, 180
sz, ez = 0, 12

use_data_vim_vmax = True
dict_vmin = {"tm": 32.00, "vl": -4.0, "vp": -3.0, "vr": -1.0}
dict_vmax = {"tm": 34.00, "vl": 2.0, "vp": 3.0, "vr": 1.0}

tex_labels = {
    "tm": r"$T$",
    "vl": r"$u$",
    "vp": r"$v$",
    "vr": r"$w$",
}

lr_scale_factor = 4
hr_scale_factor = 1

var_names = {0: "tm", 1: "vl", 2: "vp", 3: "vr"}
var_units = {0: "[$^\circ$C]", 1: "[m/s]", 2: "[m/s]", 3: "[m/s]"}

var_dataset = OrderedDict(
    {
        "HR": ys_scaled,
        "LR": Xs_scaled,
        "SR": preds_scaled,
        # "HR-SR": ys_scaled - preds_scaled,
    }
)

ttl_labels = {
    "HR": "Ground Truth",
    "LR": "LR Input",
    "SR": "SR Inference",
}


for idx_lat in [160]:

    plt.rcParams["font.size"] = 18
    plt.rcParams["font.family"] = "Times New Roman"
    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=[20, 9])
    Gx, Gz = None, None

    for i, j, idx_var in zip([0, 1, 2, 3], [-1, -1, -1, -1], [0, 1, 2, 3]):
        var_name = var_names[idx_var]
        var_unit = var_units[idx_var]
        ground_truth, bldg = None, None

        for var_kind in var_dataset.keys():
            j += 1

            ax = axes[i, j]
            ax.set_aspect("equal")
            ax.xaxis.set_ticklabels([])
            ax.yaxis.set_ticklabels([])
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)

            if "HR" in var_kind or "SR" in var_kind:
                scale_factor = hr_scale_factor
                assert scale_factor == 1
                is_in_bldg = HR_IS_IN_BUILD[:, :32, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - HR_IS_IN_BUILD[:, :32, :, :], dtype=torch.float32
                ).unsqueeze(1)
            else:
                scale_factor = lr_scale_factor
                assert scale_factor == 4
                is_in_bldg = LR_IS_IN_BUILD[:, :8, :, :].clone().unsqueeze(1)
                bldg = torch.tensor(
                    1 - LR_IS_IN_BUILD[:, :8, :, :], dtype=torch.float32
                ).unsqueeze(1)

            bldg = (
                torch.nn.functional.interpolate(
                    bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )
            is_in_bldg = (
                torch.nn.functional.interpolate(
                    is_in_bldg, scale_factor=scale_factor, mode="nearest"
                )
                .squeeze()
                .numpy()
            )

            assert bldg.shape == is_in_bldg.shape == (4, 32, 320, 320)

            bldg = bldg[0, :, idx_lat].transpose()[sx:ex, sz:ez]
            is_in_bldg = is_in_bldg[0, :, idx_lat].transpose()[sx:ex, sz:ez]

            data = torch.nn.functional.interpolate(
                var_dataset[var_kind][None, ...],
                scale_factor=scale_factor,
                mode="nearest",
            ).squeeze()

            assert data.shape == (4, 32, 320, 320)
            if var_kind == "LR":
                assert var_dataset[var_kind].shape == (4, 8, 80, 80)
            else:
                assert var_dataset[var_kind].shape == (4, 32, 320, 320)

            data = data[idx_var, :, idx_lat].numpy().transpose()[sx:ex, sz:ez]

            data = np.where(bldg, data, np.nan)

            if Gx is None or Gz is None:
                Gx = 5 * np.arange(data.shape[0])
                Gz = 5 * np.arange(data.shape[1])
                Gx, Gz = np.meshgrid(Gx, Gz, indexing="ij")

            if var_kind != "HR-SR" and var_name == "tm":
                data -= 273.15

            if var_kind == "HR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                # print(var_name, vmin, vmax)

            if var_kind == "HR-SR":
                vmin = np.nanquantile(data.flatten(), 0.05)
                vmax = np.nanquantile(data.flatten(), 0.95)
                abs_max = max([np.abs(vmin), np.abs(vmax)])
                vmax = abs_max
                vmin = -abs_max
                my_cmap = cm.get_cmap("seismic").copy()
                my_cmap.set_bad("white")
                ttl = f"{var_kind} {var_unit}"
            else:
                my_cmap = cm.get_cmap(dict_cmap[var_name]).copy()
                my_cmap.set_bad("lightgray")

                ttl_label = ttl_labels[var_kind]
                ttl = f"{ttl_label}\n{tex_labels[var_name]} {var_unit}"

            if not use_data_vim_vmax:
                vmin = dict_vmin[var_name]
                vmax = dict_vmax[var_name]

            contours = ax.pcolormesh(Gx, Gz, data, cmap=my_cmap, vmin=vmin, vmax=vmax)
            fig.colorbar(
                contours,
                ax=ax,
                format="%.1f",
                ticks=np.linspace(vmin, vmax, 4, endpoint=True),
                extend="both",
            )
            ax.set_title(ttl)

    # plt.suptitle(f"idx_lat = {idx_lat}")
    plt.tight_layout()

    # fig.savefig(
    #     f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}.jpg",
    #     dpi=300,
    # )
    # if MAKE_EPS_FILES:
    #     fig.savefig(
    #         f"{FIG_DIR}/SR_fields_{int(height_meter):03}m_{config_name}.eps",
    #         dpi=300,
    #     )
    plt.show()