In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import pathlib
import re
from copy import deepcopy
from glob import glob

import yaml
from tqdm.notebook import tqdm

In [None]:
ROOT_DIR = str(pathlib.Path.resolve(pathlib.Path(os.environ["PYTHONPATH"]) / ".."))
EXPERIMENT_NAME = "paper_experiment_01"

In [None]:
config_dir = f"{ROOT_DIR}/pytorch/config/{EXPERIMENT_NAME}"

reference = None

for config_path in sorted(glob(f"{config_dir}/*.yml")):
    print(f"config_file = {os.path.basename(config_path)}")

    config_name = os.path.basename(config_path).split(".")[0]

    m = re.match(
        "lt(\d+)og(\d+)_on([\-\w]+)_ep(\d+)_lr([\-\w]+)_sc([TF])_mu([TF])_a(\d+)_b(\d+)_sd(\d+)",
        config_name,
    )
    (
        lr_time_interval,
        ob_grid_interval,
        obs_noise,
        num_epochs,
        lr,
        use_sc,
        use_mixup,
        alpha,
        beta,
        seed,
    ) = m.groups()

    lr_time_interval = int(lr_time_interval)
    ob_grid_interval = int(ob_grid_interval)
    obs_noise = float(obs_noise)
    num_epochs = int(num_epochs)
    lr = float(lr)
    use_sc = use_sc == "T"
    use_mixup = use_mixup == "T"
    alpha = float(alpha)
    beta = float(beta)
    seed = int(seed)

    with open(config_path) as file:
        config = yaml.safe_load(file)

    if reference is None:
        reference = deepcopy(config)

    assert "lr_time_interval" in config["data"]
    assert "use_mixup" in config["data"]
    assert config["data"]["data_dir_name"] == "jet12"

    if config_name.endswith("_noLR"):
        assert config["data"]["use_lr_forecast"] is False
    else:
        assert "use_lr_forecast" not in config["data"]

    if ob_grid_interval == 0:
        assert config["data"]["use_observation"] is False
        # print("Use observation == False")
    else:
        assert config["data"]["use_observation"] is True

    for k1, v1 in config.items():

        if not isinstance(v1, dict):
            assert v1 == reference[k1], f"{k1}"
        else:
            for k2, v2 in v1.items():
                if not isinstance(v2, dict):
                    if k2 == "lr_time_interval":
                        assert v2 == lr_time_interval
                        # print(f"{k2} is checked")

                    elif k2 == "obs_grid_interval":
                        if ob_grid_interval != 0 or v2 != 0:
                            assert v2 == ob_grid_interval
                            # print(f"{k2} is checked.")

                    elif k2 == "num_epochs":
                        assert v2 == num_epochs
                        # print(f"{k2} is checked")

                    elif k2 == "obs_noise_std":
                        assert v2 == obs_noise
                        # print(f"{k2} is checked")

                    elif k2 == "lr_kind_names":
                        expected = ["lr_omega_no-noise"]
                        assert v2 == expected
                        assert len(v2) == len(expected)
                        for e, v in zip(expected, v2):
                            assert e == v
                        # print(f"{k2} is checked")

                    elif k2 == "use_mixup":
                        assert v2 == use_mixup
                        # print(f"{k2} is checked")

                    elif k2 == "beta_dist_alpha":
                        assert v2 == alpha
                        # print(f"{k2} is checked")

                    elif k2 == "beta_dist_beta":
                        assert v2 == beta
                        # print(f"{k2} is checked")

                    elif k2 == "lr":
                        assert v2 == lr
                        # print(f"{k2} is checked")

                    elif k2 == "model_name":
                        assert v2 == "ConvTransformerSrDaNetVer03"
                        # print(f"{k2} is checked")

                    elif k2 == "use_global_skip_connection":
                        assert v2 == use_sc
                        # print(f"{k2} is checked")

                    elif k2 == "use_observation":
                        pass

                    elif k2 == "use_lr_forecast":
                        assert v2 is False

                    elif k2 == "seed":
                        assert v2 == seed
                        # print("seed is checked")

                    else:
                        assert v2 == reference[k1][k2], f"{k1},{k2}"
                else:
                    for k3, v3 in v2.items():
                        assert not isinstance(v3, dict)
                        assert v3 == reference[k1][k2][k3], f"{k1},{k2},{k3}"

In [None]:
len(sorted(glob(f"{config_dir}/*.yml")))