In [None]:
import dataclasses
import time

import flax
import jax
import orbax
from flax.training import orbax_utils
from jax.tree_util import Partial

from ulee_repo.experiments.paths import build_trained_weights_path
from ulee_repo.RND.config import TrainConfig
from ulee_repo.RND.main_loop import full_training_on_fixed_envs
from ulee_repo.RND.setups import set_up_for_training
from ulee_repo.shared_code.logging import (
    generate_run_name,
    wandb_log_training_metrics,
)

In [None]:
def run_rnd_training(config: TrainConfig):
    # setup
    rng, env, env_params, benchmark, agent_train_state, predictor_train_state, target_train_state = set_up_for_training(config)

    # train
    print(f"Training with seed {config.train_seed}")
    t = time.time()

    full_training_partial = Partial(full_training_on_fixed_envs, env=env, env_params=env_params, benchmark=benchmark, config=config)
    jitted_full_training = jax.jit(full_training_partial)
    train_info = jax.block_until_ready(jitted_full_training(rng=rng, agent_train_state=agent_train_state, predictor_train_state=predictor_train_state, target_train_state=target_train_state))

    elapsed_time = time.time() - t
    print(f"Done in {elapsed_time / 60:.2f}min")

    try:
        # store results on disk
        save_path = build_trained_weights_path(
            algorithm_id="rnd",
            env_id=config.env_id,
            benchmark_id=config.benchmark_id,
            seed=config.train_seed,
        )
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # save_path = save_path.parent / f"{save_path.name}_{timestamp}"
        save_path.parent.mkdir(parents=True, exist_ok=True)

        train_config = dataclasses.asdict(config)
        train_config = flax.core.freeze(train_config)
        training_results = {
            "config": train_config,
            "agent_params": train_info["agent_state"].params,
            "predictor_params": train_info["predictor_state"].params,
            "metrics": train_info["metrics"],
        }

        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(training_results)
        orbax_checkpointer.save(save_path, training_results, save_args=save_args)
        print("saved training results to", save_path)
    except Exception as e:
        print(f"Error while saving training results to disk: {e}")

    # save logs to wandb
    extra_logs = {
        "training/lr": train_info["metrics"]["lr"],
        "predictor/loss": train_info["metrics"]["rnd_predictor_loss"],
    }
    run_name = generate_run_name(algorithm_name="RND", config=config, prefix="")
    tags = ["rnd", "train", "prueba"]
    wandb_log_training_metrics(train_info["metrics"], config, run_name, project_name="ULEE", tags=tags, extra_batch_metrics=extra_logs)
    time.sleep(15)


In [None]:
train_seeds = [10, 20, 30, 40]
total_timesteps = 1_000_000_000
env_id = "XLand-MiniGrid-R4-13x13"
benchmark_id = "small-1m"

for seed in train_seeds:
    config = TrainConfig(
        train_seed=seed,
        benchmark_split_seed=seed + 100,
        total_timesteps=total_timesteps,
        env_id=env_id,
        benchmark_id=benchmark_id,
    )

    run_rnd_training(config)