In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import copy
import gc
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
from numpy.testing import assert_array_equal
from scipy import signal
from src.dataloader import make_evaluation_dataloader_without_random_cropping
from src.loss_maker import (
    AbsDiffDivergence,
    AbsDiffTemperature,
    ChannelwiseMse,
    DiffOmegaVectorNorm,
    DiffVelocityVectorNorm,
    MaskedL1Loss,
    MaskedL1LossNearWall,
    MaskedL2Loss,
    MaskedL2LossNearWall,
    MixedDivergenceGradientL2LossDivMse,
    MixedDivergenceGradientL2LossGrdMse,
    MixedDivergenceGradientL2LossMse,
    MyL1Loss,
    MyL2Loss,
    ResidualContinuity,
    Ssim3dLoss,
)
from src.model_maker import make_model
from src.optim_helper import evaluate
from src.utils import calc_early_stopping_patience, set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
DL_DATA_DIR = pathlib.Path(f"{ROOT_DIR}/data/DL_data")

# FIG_DIR = f"{ROOT_DIR}/doc/report_20220912/fig"
# os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    logger.info("GPU is used.")
else:
    logger.error("No GPU. CPU is used.")
    # raise Exception("No GPU. CPU is used.")

In [None]:
CONFIG_PATHS = sorted(
    glob.glob(f"{ROOT_DIR}/pytorch/config/new_lr_unet_gconv_change_datashape_ddp/*.yml")
)

In [None]:
HR_IS_IN_BUILD = np.load(f"{ROOT_DIR}/data/DL_data/10/hr_is_in_build.npy")
LR_IS_IN_BUILD = np.load(f"{ROOT_DIR}/data/DL_data/10/lr_is_in_build.npy")

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    if "tutorial" in config_path:
        continue
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]

    _dir = f"{ROOT_DIR}/data/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

# Plot learning curves

In [None]:
for config_name, config_info in CONFIGS.items():
    if not os.path.exists(config_info["learning_history_path"]):
        print(f"{config_name} is skipped because of no result.")
        continue
    df = pd.read_csv(config_info["learning_history_path"])
    cnt = calc_early_stopping_patience(df, th_max_cnt=50)
    # print(f"{config_name}: max cnt = {cnt}")

    assert len(df) == config_info["config"]["train"]["num_epochs"]

    plt.rcParams["font.size"] = 15
    fig = plt.figure(figsize=[7, 5])
    ax = plt.subplot(111)

    df.plot(
        ax=ax,
        xlabel="Epochs",
        ylabel=config_info["config"]["train"]["loss"]["name"],
    )
    ax.set_title(config_name)
    plt.yscale("log")

    # fig.savefig(f"{FIG_DIR}/{config_name}_learning_curve.jpg")
    plt.show()

# Calc test scores

In [None]:
if os.path.exists("./test_scores.csv"):
    df_results = pd.read_csv("./test_scores.csv").set_index("Unnamed: 0")
    print("DF is read from csv")
else:
    df_results = pd.DataFrame()
    print("DF is created.")

In [None]:
for config_name, config_info in tqdm(CONFIGS.items(), total=len(CONFIGS)):
    if not os.path.exists(config_info["weight_path"]):
        print(f"{config_name} is skipped because of no result.")
        continue

    if config_name in df_results.index:
        logger.info(f"{config_name} already exists. so skip calculaton.")
        continue

    logger.info(f"\n{config_name} is being evaluated")
    config = config_info["config"]

    loss_fns = {
        "L1": MyL1Loss(),
        "MaskedL1": MaskedL1Loss(),
        "MaskedL1NearWall": MaskedL1LossNearWall(),
        "L2": MyL2Loss(),
        "MaskedL2": MaskedL2Loss(),
        "MaskedL2NearWall": MaskedL2LossNearWall(),
        "ResidualContinuityEq": ResidualContinuity(config["data"]["stds"][1:]),
        "AbsDiffTemperature": AbsDiffTemperature(config["data"]["stds"][0]),
        "DiffVelocityNorm": DiffVelocityVectorNorm(config["data"]["stds"][1:]),
        "AbsDiffTemperatureLevZero": AbsDiffTemperature(
            config["data"]["stds"][0], lev=0
        ),
        "DiffVelocityNormLevZero": DiffVelocityVectorNorm(
            config["data"]["stds"][1:], lev=0
        ),
        "AbsDiffDivergence": AbsDiffDivergence(config["data"]["stds"][1:]),
        "DiffOmegaVectorNorm": DiffOmegaVectorNorm(config["data"]["stds"][1:]),
        "SSIM3D_1e-7": Ssim3dLoss(eps=1e-7),
        "ChannelwiseMseT": ChannelwiseMse(i_channel=0),
        "ChannelwiseMseU": ChannelwiseMse(i_channel=1),
        "ChannelwiseMseV": ChannelwiseMse(i_channel=2),
        "ChannelwiseMseW": ChannelwiseMse(i_channel=3),
        "MixedDivergenceGradientL2LossDivMse": MixedDivergenceGradientL2LossDivMse(
            config["data"]["stds"][1:]
        ),
        "MixedDivergenceGradientL2LossGrdMse": MixedDivergenceGradientL2LossGrdMse(
            config["data"]["stds"][1:]
        ),
        "MixedDivergenceGradientL2LossMse": MixedDivergenceGradientL2LossMse(
            config["data"]["stds"][1:]
        ),
    }

    test_loader = make_evaluation_dataloader_without_random_cropping(
        config, DL_DATA_DIR, batch_size=1
    )

    model = make_model(config).to(DEVICE)
    model.load_state_dict(torch.load(config_info["weight_path"], map_location=DEVICE))
    model.eval()

    results = evaluate(
        dataloader=test_loader,
        model=model,
        loss_fns=loss_fns,
        device=DEVICE,
        hide_progress_bar=False,
    )

    df_results.loc[config_name, "ExperimentName"] = config_info["experiment_name"]
    df_results.loc[config_name, "ModelName"] = config["model"]["model_name"]
    df_results.loc[config_name, "LossName"] = config["train"]["loss"]["name"]
    df_results.loc[config_name, "LearningRate"] = config["train"]["lr"]
    df_results.loc[config_name, "NumFeat0"] = config["model"]["num_feat0"]
    df_results.loc[config_name, "NumFeat1"] = config["model"]["num_feat1"]
    df_results.loc[config_name, "NumFeat2"] = config["model"]["num_feat2"]
    df_results.loc[config_name, "NumFeat3"] = config["model"]["num_feat3"]
    df_results.loc[config_name, "NumLatentLayers"] = config["model"][
        "num_latent_layers"
    ]
    df_results.loc[config_name, "CroppedSizeZ"] = config["data"]["hr_crop_size"][0]
    df_results.loc[config_name, "CroppedSizeY"] = config["data"]["hr_crop_size"][1]
    df_results.loc[config_name, "CroppedSizeX"] = config["data"]["hr_crop_size"][2]

    df_results.loc[config_name, "TrainDatasize"] = config["data"]["datasizes"]["train"]
    df_results.loc[config_name, "ValidDatasize"] = config["data"]["datasizes"]["valid"]
    df_results.loc[config_name, "TestDatasize"] = config["data"]["datasizes"]["test"]
    df_results.loc[config_name, "TotalTrainDatasize"] = (
        config["data"]["datasizes"]["train"] + config["data"]["datasizes"]["valid"]
    )
    df_results.loc[config_name, "WeightGradLoss"] = config["train"]["loss"].get(
        "weight_gradient_loss", 0.0
    )
    df_results.loc[config_name, "WeightDivLoss"] = config["train"]["loss"].get(
        "weight_divergence_loss", 0.0
    )
    df_results.loc[config_name, "max_discarded_lr_z_index"] = config["data"][
        "max_discarded_lr_z_index"
    ]

    for k, v in results.items():
        df_results.loc[config_name, k] = v.avg

In [None]:
df_results

In [None]:
df_results.to_csv("./test_scores.csv", index=True)

In [None]:
del model
del test_loader
gc.collect()
_ = torch.cuda.empty_cache()

# Analyze test scores

In [None]:
df_results = pd.read_csv("./test_scores.csv").set_index("Unnamed: 0")
df_results.sort_values("TotalTrainDatasize", inplace=True)

- Near walls or the ground, gated convs contribute to reducing errors.

In [None]:
ycols = [
    "MixedDivergenceGradientL2LossMse",
    "MixedDivergenceGradientL2LossGrdMse",
    "MixedDivergenceGradientL2LossDivMse",
    "SSIM3D_1e-7",
    "AbsDiffTemperature",
    "DiffVelocityNorm",
    "AbsDiffTemperatureLevZero",
    "DiffVelocityNormLevZero",
    "MaskedL2NearWall",
]
df_results.loc[["z0_2475_825_no_gconv", "z0_2475_825"], ycols]

In [None]:
df = df_results.copy(deep=True).sort_values(["max_discarded_lr_z_index"])
df = df[df.index != "z0_2475_825_no_gconv"]

In [None]:
ycols = [
    # "L1",
    "MixedDivergenceGradientL2LossMse",
    "MixedDivergenceGradientL2LossGrdMse",
    "MixedDivergenceGradientL2LossDivMse",
    "SSIM3D_1e-7",
    # "ResidualContinuityEq",
    "AbsDiffTemperature",
    "DiffVelocityNorm",
    # "AbsDiffDivergence",
    # "DiffOmegaVectorNorm",
    # "ChannelwiseMseT",
    # "ChannelwiseMseU",
    # "ChannelwiseMseV",
    # "ChannelwiseMseW",
    "AbsDiffTemperatureLevZero",
    "DiffVelocityNormLevZero",
    "MaskedL2NearWall",
]


plt.rcParams["font.size"] = 15
fig, axes = plt.subplots(3, 3, figsize=[15, 10], sharex=True)

for ax, ycol in zip(np.ravel(axes), ycols):
    for datasize, grp in df.groupby("TotalTrainDatasize"):
        xs = grp["max_discarded_lr_z_index"].values
        ys = grp[ycol].values
        ax.plot(xs, ys, "o-", label=f"{int(datasize):04}")

    label = ycol if "L2" not in ycol else "\nL2".join(ycol.split("L2"))
    ax.set_title(label)
    ax.set_ylabel(label)
    ax.set_xlabel("Discarded grids in z")
    ax.legend()

plt.tight_layout()
plt.show()