In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import copy
import datetime
import glob
import os
import pathlib
import time
from collections import OrderedDict

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import yaml
from src.data.dataloader import make_dataloaders_and_samplers
from src.models.evaluation_helper import (
    AveSsimLoss,
    TemperatureErrorNorm,
    VelocityComponentErrorNorm,
    VelocityErrorNorm,
    evaluate,
)
from src.models.model_maker import make_model
from src.utils.io_pickle import read_pickle
from src.utils.random_seed_helper import set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"

In [None]:
set_seeds(42)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

In [None]:
DEVICE = "cuda"

In [None]:
TMP_DATA_DIR = "./tmp"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
EXPERIMENT_NAME = "lr-inference"
CONFIG_DIR = f"{ROOT_DIR}/python/configs/{EXPERIMENT_NAME}"
CONFIG_PATHS = sorted([p for p in glob.glob(f"{CONFIG_DIR}/*.yml")])

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    if "test" in os.path.basename(config_path):
        continue

    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = f"{ROOT_DIR}/data/models/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/model_weight.pth",
        "learning_history_path": f"{_dir}/model_loss_history.csv",
        "log_path": f"{_dir}/log.txt",
        "test_score_path": f"{_dir}/test_errors.pickle",
    }

# Define methods

In [None]:
def get_dict_loss_fns(config: dict):
    return {
        "T-ErrorNorm[K]": TemperatureErrorNorm(scale=config["data"]["scales"][0]),
        "u-ErrorNorm[m/s]": VelocityComponentErrorNorm(
            scale=config["data"]["scales"][1], idx_channel=1
        ),
        "v-ErrorNorm[m/s]": VelocityComponentErrorNorm(
            scale=config["data"]["scales"][2], idx_channel=2
        ),
        "w-ErrorNorm[m/s]": VelocityComponentErrorNorm(
            scale=config["data"]["scales"][3], idx_channel=3
        ),
        "VelocityErrorNorm": VelocityErrorNorm(
            scales=config["data"]["scales"][1:], device=DEVICE
        ),
        "AveSsimLoss": AveSsimLoss(),
    }


def aggregate_err_by_time(timestamps: list, all_errors: dict):
    grouped_errors = {}

    for k, errors in all_errors.items():
        grouped_errors[k] = {m: [] for m in range(11, 61, 1)}
        assert len(errors.shape) == 2  # time and height dims
        assert errors.shape[0] == len(timestamps)  # time dim
        assert errors.shape[1] == 10  # height dim

        for i, dt in enumerate(timestamps):
            minutes = int(dt.minute)
            if minutes == 0:
                minutes = 60

            grouped_errors[k][minutes].append(errors[i])

        for minutes in grouped_errors[k].keys():
            vals = grouped_errors[k][minutes]
            if len(vals) > 0:
                grouped_errors[k][minutes] = np.stack(vals)
            else:
                grouped_errors[k][minutes] = np.nan

    return grouped_errors

# Plot learning curves

In [None]:
is_plotted = True

for config_name, config_info in CONFIGS.items():
    config_info["is_not_ended"] = True
    config = config_info["config"]

    if not os.path.exists(config_info["log_path"]):
        logger.info(f"Log does not exist: {config_name}")
        continue

    with open(config_info["log_path"], "r") as f:
        lines = f.readlines()
    if not lines[-3].startswith("End DDP:"):
        logger.warning(f"Training is not finished: {config_name}")

    df = pd.read_csv(config_info["learning_history_path"])
    assert len(df) < config["train"]["epochs"]

    config_info["is_not_ended"] = False

    if not is_plotted:
        continue

    plt.rcParams["font.size"] = 15
    fig = plt.figure(figsize=[5, 3])
    ax = plt.subplot(111)

    df.plot(
        ax=ax,
        xlabel="Epochs",
        ylabel="Loss",
    )
    ax.set_title(f'{config_name}\n{config["loss"]["name"]}')
    ax.set_yscale("log")

    plt.tight_layout()
    # fig.savefig(f"{FIG_DIR}/learning_curve_{config_name}.webp", bbox_inches="tight")
    plt.show()

# Evaluate models

In [None]:
dict_results = {}

for config_name, config_info in tqdm(CONFIGS.items(), total=len(CONFIGS)):
    #
    if "test" in config_name:
        continue

    if config_info["is_not_ended"]:
        logger.info(f"Training is not finished. {config_name}")
        continue
    else:
        logger.info(f"\n{config_name} is being evaluated.")

    config = copy.deepcopy(config_info["config"])
    config["data"]["batch_size"] = 1  # This must be 1 to easily obtain timestamps.
    assert "use_clipping_ground_truth" in config["data"]
    config["data"]["use_clipping_ground_truth"] = False

    logger.setLevel(WARNING)
    dataloaders, _ = make_dataloaders_and_samplers(
        root_dir=ROOT_DIR, config=config["data"], train_valid_test_kinds=["test"]
    )

    model = make_model(config["model"]).to(DEVICE)
    model.load_state_dict(torch.load(config_info["weight_path"], map_location=DEVICE))
    _ = model.eval()
    logger.setLevel(INFO)

    loss_fns = get_dict_loss_fns(config)

    test_errors = evaluate(
        dataloader=dataloaders["test"], model=model, loss_fns=loss_fns, device=DEVICE
    )

    idx = 1 if dataloaders["test"].dataset.n_input_snapshots == 3 else 0
    logger.info(
        f"n_input = {dataloaders['test'].dataset.n_input_snapshots}, so idx = {idx}"
    )

    # e.g., lr_tokyo_05m_20130709T040100.npy --> 20130709T040100
    timestamps = [
        datetime.datetime.strptime(
            os.path.basename(ps[idx]).split("_")[-1].replace(".npy", ""),
            "%Y%m%dT%H%M%S",
        )
        for ps in dataloaders["test"].dataset.truth_all_file_paths
    ]

    dict_results[config_name] = {
        "timestamps": timestamps,
        "errors": test_errors,
    }

# Plot snapshots

In [None]:
config_name = "default_lr"
max_idx = 1

In [None]:
config_info = CONFIGS[config_name]
config = copy.deepcopy(config_info["config"])
assert "use_clipping_ground_truth" in config["data"]
config["data"]["use_clipping_ground_truth"] = False

logger.setLevel(WARNING)
dataloaders, _ = make_dataloaders_and_samplers(
    root_dir=ROOT_DIR, config=config["data"], train_valid_test_kinds=["test"]
)

model = make_model(config["model"]).to(DEVICE)
model.load_state_dict(torch.load(config_info["weight_path"], map_location=DEVICE))
_ = model.eval()
logger.setLevel(INFO)

In [None]:
ts, Xs, bs, gt, timestamps = [], [], [], [], []
for i in range(max_idx):
    t, X, b, g, path = dataloaders["test"].dataset.__getitem__(
        idx=i, return_hr_path=True
    )
    ts.append(t)
    Xs.append(X)
    bs.append(b)
    gt.append(g)
    timestamps.append(os.path.basename(path).split("_")[-1].replace(".npy", ""))

ts = torch.stack(ts)
Xs = torch.stack(Xs)
bs = torch.stack(bs)
gt = torch.stack(gt)

In [None]:
preds = model(t=ts.to(DEVICE), x=Xs.to(DEVICE), b=bs.to(DEVICE)).detach().cpu()
assert Xs.shape[1:] == (4, 10, 80, 80)  # channel, z, y, x

Xs = dataloaders["test"].dataset._scale_inversely(Xs)
gt = dataloaders["test"].dataset._scale_inversely(gt)
preds = dataloaders["test"].dataset._scale_inversely(preds)

In [None]:
for i_batch in range(max_idx):
    lr_bldg = bs[i_batch]
    hr_dataset = gt[i_batch]
    lr_dataset = Xs[i_batch]
    sr_dataset = preds[i_batch]
    figsize = [24, 6]

    dict_cmap = {
        "tm": "turbo",
        "vl": "YlOrRd",
        "vp": "YlGnBu",
        "vr": "viridis",
    }

    sx, ex = 0, None
    sy, ey = 0, None

    for ilev in [0, 4]:
        plt.rcParams["font.size"] = 14
        fig, axes = plt.subplots(2, 6, figsize=figsize)
        axes = np.ravel(axes)

        is_out_bldg = lr_bldg[0, ilev].numpy().transpose()[sx:ex, sy:ey]
        height = 20.0 * ilev + 10.0

        for i, (v, cmap) in enumerate(dict_cmap.items()):
            vmin, vmax = None, None

            hr_gt = None
            for ax, resolution, org_data in zip(
                [axes[3 * i], axes[3 * i + 1], axes[3 * i + 2]],
                ["HR(resized)", "LR", "SR"],
                [hr_dataset[i, ilev], lr_dataset[i, ilev], sr_dataset[i, ilev]],
            ):
                assert org_data.ndim == 2  # y and x dims

                ax.set_aspect("equal")
                ax.set_xticks([])
                ax.set_yticks([])
                data = org_data.clone().numpy().transpose()[sx:ex, sy:ey]
                assert data.shape[0] == data.shape[1], "Not equal aspect ratio"

                data = np.where(is_out_bldg, data, np.nan)

                if v == "tm":
                    data -= 273.15

                dx = 20
                xs = np.arange(data.shape[0]) * dx
                ys = np.arange(data.shape[1]) * dx

                xs, ys = np.meshgrid(xs, ys, indexing="ij")

                if resolution == "HR(resized)":
                    vmin = np.nanquantile(data.flatten(), 0.02)
                    vmax = np.nanquantile(data.flatten(), 0.98)
                    hr_gt = data
                    print(f"{v}, {vmin:.1f}, {vmax:.1f}")
                    ax.set_title(f"{resolution}: {v}\nz = {height:.1f} m")
                else:
                    abs_diff = np.abs(data - hr_gt)
                    mae = np.nanmean(abs_diff)
                    ax.set_title(f"{resolution}: {v}\nMAE = {mae:.3f}")

                my_cmap = copy.deepcopy(matplotlib.colormaps[dict_cmap[v]])
                my_cmap.set_bad("dimgray")

                contours = ax.pcolormesh(
                    xs, ys, data, vmin=vmin, vmax=vmax, cmap=my_cmap
                )
                fig.colorbar(contours, ax=ax, extend="both")

        plt.suptitle(f"{timestamps[i_batch]}, z = {height:03} m")
        plt.tight_layout()
        plt.show()