In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import pathlib
from copy import deepcopy

import numpy as np
import yaml

In [None]:
ROOT_DIR = str(pathlib.Path.resolve(pathlib.Path(os.environ["PYTHONPATH"]) / ".."))
EXPERIMENT_NAME = "paper_experiment_01"

In [None]:
np.random.seed(42)
seeds = np.random.randint(100000, 999999, 8)

In [None]:
seeds

In [None]:
assert list(seeds) == [221958, 771155, 231932, 465838, 359178, 744167, 210268, 832180]

In [None]:
lr_time_intervals = [4]
obs_grid_intervals = [0, 4, 6, 8, 10, 12]
obs_noises = [0.1]
use_scs = [True]
use_mus = [True, False]
alphas = [2.0]
betas = [2.0]
use_lrs = [True, False]
seeds = [221958, 771155, 832180, 465838, 359178]


def generate_preference_params():
    for lt in lr_time_intervals:
        for og in obs_grid_intervals:
            for on in obs_noises:
                for sc in use_scs:
                    for mu in use_mus:
                        for a in alphas:
                            for b in betas:
                                for use_lr in use_lrs:
                                    for seed in seeds:
                                        yield lt, og, on, sc, mu, a, b, use_lr, seed

In [None]:
config_dir = f"{ROOT_DIR}/pytorch/config/{EXPERIMENT_NAME}"

for lt, og, on, sc, mu, a, b, use_lr, seed in generate_preference_params():
    ref_config_name = "lt4og12_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd221958"
    with open(f"{config_dir}/{ref_config_name}.yml") as file:
        ref_config = yaml.safe_load(file)

    if (not mu) and (a != 2 or b != 2):
        continue

    if (not use_lr) and (a != 2 or b != 2):
        continue

    if (not use_lr) and mu:
        continue

    if (not use_lr) and og <= 0:
        continue

    new_config = deepcopy(ref_config)

    new_config["data"]["lr_time_interval"] = lt
    new_config["data"]["obs_grid_interval"] = og
    new_config["data"]["obs_noise_std"] = on
    new_config["model"]["use_global_skip_connection"] = sc
    new_config["data"]["use_mixup"] = mu
    new_config["data"]["beta_dist_alpha"] = float(a)
    new_config["data"]["beta_dist_beta"] = float(b)
    new_config["train"]["seed"] = int(seed)

    new_config_name = ref_config_name
    new_config_name = new_config_name.replace("lt4", f"lt{lt:01}")
    new_config_name = new_config_name.replace("og12", f"og{og:02}")
    new_config_name = new_config_name.replace("on1e-01", f"on{on:.0e}")
    new_config_name = new_config_name.replace("a02", f"a{int(a):02}")
    new_config_name = new_config_name.replace("b02", f"b{int(b):02}")
    new_config_name = new_config_name.replace("sd221958", f"sd{int(seed):06}")

    if og <= 0:
        new_config["data"]["use_observation"] = False

    if not use_lr:
        new_config_name = new_config_name + "_noLR"
        new_config["data"]["use_lr_forecast"] = False

    if not sc:
        new_config_name = new_config_name.replace("scT", "scF")

    if not mu:
        new_config_name = new_config_name.replace("muT", "muF")

    if new_config_name == ref_config_name:
        print(f"Same name is detected. {ref_config_name}")
        continue

    with open(f"{config_dir}/{new_config_name}.yml", "w") as file:
        yaml.safe_dump(new_config, file)