In [None]:
import dataclasses
import time
from datetime import datetime

import flax
import jax
import orbax
import xminigrid
from flax.training import orbax_utils
from jax.tree_util import Partial
from xminigrid.environment import EnvParams

from ulee_repo.DIAYN.config import TrainConfig as DIAYNTrainConfig
from ulee_repo.DIAYN.main_loop import full_training as diayn_full_training
from ulee_repo.DIAYN.setups import set_up_for_training as diayn_set_up_for_training
from ulee_repo.experiments.paths import build_trained_weights_path
from ulee_repo.PPO.config import TrainConfig as PPOTrainConfig
from ulee_repo.PPO.main_loop import full_training as ppo_full_training
from ulee_repo.PPO.main_loop import full_training_on_fixed_envs as ppo_full_training_on_fixed_envs
from ulee_repo.PPO.setups import set_up_for_training as ppo_set_up_for_training
from ulee_repo.RL2.config import TrainConfig as RL2TrainConfig
from ulee_repo.RL2.main_loop import full_training as rl2_full_training
from ulee_repo.RL2.setups import set_up_for_training as rl2_set_up_for_training
from ulee_repo.shared_code.logging import (
    generate_run_name,
    wandb_log_training_metrics,
    wandb_log_ulee_training_metrics,
)
from ulee_repo.ULEE.config import TrainConfig as ULEETrainConfig
from ulee_repo.ULEE.main_loop import full_training as ulee_full_training
from ulee_repo.ULEE.setups import set_up_for_training as ulee_set_up_for_training


In [None]:
import os

os.environ["WANDB_SILENT"] = "true"

In [None]:
print(xminigrid.registered_benchmarks())
print("-----------------------------------")
print(xminigrid.registered_environments())

## ULEE

In [None]:
def complete_config(config: ULEETrainConfig, env_params: EnvParams) -> ULEETrainConfig:
    config.goal_search.goal_searching_steps_per_env = config.goal_search.goal_searching_episodes_per_env * env_params.max_steps
    return config


def run_ulee_training(config: ULEETrainConfig):
    # setup
    rng, env_no_goals, env_unsup_goals, env_real_goals, env_params, benchmark, meta_learner_train_state, judge_train_state, goal_search_train_state, judge_replay_buffer = ulee_set_up_for_training(
        config
    )
    config = complete_config(config, env_params)

    # train
    print(f"Training with seed {config.train_seed}")
    t = time.time()
    full_training_partial = Partial(
        ulee_full_training, env_no_goals=env_no_goals, env_unsup_goals=env_unsup_goals, env_real_goals=env_real_goals, env_params=env_params, benchmark=benchmark, config=config
    )
    jitted_full_training = jax.jit(full_training_partial)
    train_info = jax.block_until_ready(
        jitted_full_training(
            rng=rng, meta_learner_train_state=meta_learner_train_state, judge_train_state=judge_train_state, goal_search_train_state=goal_search_train_state, judge_replay_buffer=judge_replay_buffer
        )
    )
    elapsed_time = time.time() - t
    print(f"Done in {elapsed_time / 60:.2f}min")

    try:
        # store results on disk
        save_path = build_trained_weights_path(
            algorithm_id="ulee",
            env_id=config.env_id,
            benchmark_id=config.benchmark_id,
            seed=config.train_seed,
            goal_search_algorithm=config.goal_search_algorithm,
            goal_sampling_method=config.goal_sampling_method,
        )
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # save_path = save_path.parent / f"{save_path.name}_{timestamp}"
        save_path.parent.mkdir(parents=True, exist_ok=True)

        train_config = dataclasses.asdict(config)
        train_config = flax.core.freeze(train_config)
        if config.goal_search_algorithm == "ppo":
            goal_search_params = train_info["goal_search_train_state"].params
            best_goal_search_params = train_info["best"][2]
        elif config.goal_search_algorithm == "diayn":
            policy_train_state, discriminator_train_state = train_info["goal_search_train_state"]
            goal_search_params = ({"params": policy_train_state.params, "constants": policy_train_state.constants}, discriminator_train_state.params)
            best_goal_search_params = ({"params": train_info["best"][2][0], "constants": policy_train_state.constants}, train_info["best"][2][1])
        elif config.goal_search_algorithm == "random":
            goal_search_params = None
            best_goal_search_params = None

        training_results = {
            "config": train_config,
            "meta_learner_params": train_info["meta_learner_state"].params,
            "judge_params": train_info["judge_train_state"].params,
            "goal_search_params": goal_search_params,
            "best_meta_learner_params": train_info["best"][1],
            "best_goal_search_params": best_goal_search_params,
            "metrics": train_info["metrics"],
        }

        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(training_results)
        orbax_checkpointer.save(save_path, training_results, save_args=save_args)
        print("saved training results to", save_path)
    except Exception as e:
        print(f"Error while saving training results to disk: {e}")

    # save logs to wandb
    run_name = generate_run_name(algorithm_name="ULEE", config=config, prefix="")
    tags = ["ulee", "train"]
    wandb_log_ulee_training_metrics(train_info["metrics"], config, run_name=run_name, tags=tags, num_final_episodes_for_evaluating_performance=10)


In [None]:
train_seeds = [10, 20, 30, 40]
goal_search_algorithms = ["ppo"] * len(train_seeds)
goal_sampling_methods = ["bounded_uniform"] * len(train_seeds)
total_timesteps = 5_000_000_000
env_id = "XLand-MiniGrid-R4-13x13"
benchmark_id = "small-1m"


for (
    seed,
    search_algo,
    sampling_method,
) in zip(train_seeds, goal_search_algorithms, goal_sampling_methods):
    config = ULEETrainConfig(
        train_seed=seed,
        benchmark_split_seed=seed + 100,
        total_timesteps=total_timesteps,
        env_id=env_id,
        benchmark_id=benchmark_id,
        goal_search_algorithm=search_algo,
        goal_sampling_method=sampling_method,
    )

    run_ulee_training(config)

## PPO

In [None]:
def run_ppo_training(config: PPOTrainConfig, fixed_envs: bool):
    # setup
    rng, env, env_params, benchmark, train_state = ppo_set_up_for_training(config)

    # train
    print(f"Training with seed {config.train_seed}")
    t = time.time()

    if fixed_envs:
        full_training_partial = Partial(ppo_full_training_on_fixed_envs, env=env, env_params=env_params, benchmark=benchmark, config=config)
        jitted_full_training = jax.jit(full_training_partial)
        train_info = jax.block_until_ready(jitted_full_training(rng=rng, train_state=train_state))
    else:
        full_training_partial = Partial(ppo_full_training, env=env, env_params=env_params, benchmark=benchmark, config=config)
        jitted_full_training = jax.jit(full_training_partial)
        train_info = jax.block_until_ready(jitted_full_training(rng=rng, train_state=train_state))
    elapsed_time = time.time() - t
    print(f"Done in {elapsed_time / 60:.2f}min")

    try:
        # store results on disk
        save_path = build_trained_weights_path(
            algorithm_id="ppo",
            env_id=config.env_id,
            benchmark_id=config.benchmark_id,
            seed=config.train_seed,
        )
        save_path.parent.mkdir(parents=True, exist_ok=True)

        train_config = dataclasses.asdict(config)
        train_config = flax.core.freeze(train_config)
        if fixed_envs:
            training_results = {
                "config": train_config,
                "agent_params": train_info["agent_state"].params,
                "metrics": train_info["metrics"],
            }
        else:
            training_results = {
                "config": train_config,
                "agent_params": train_info["agent_state"].params,
                "best_agent_params": train_info["best"][1],
                "metrics": train_info["metrics"],
            }

        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(training_results)
        orbax_checkpointer.save(save_path, training_results, save_args=save_args)
        print("saved training results to", save_path)
    except Exception as e:
        print(f"Error while saving training results to disk: {e}")

    # log metrics to wandb
    extra_logs = {
        "training/lr": train_info["metrics"]["lr"],
    }
    run_name = generate_run_name(algorithm_name="PPO", config=config, prefix="")
    tags = ["ppo", "train"]
    wandb_log_training_metrics(train_info["metrics"], config, run_name, project_name="ULEE", tags=tags, extra_batch_metrics=extra_logs)


In [None]:
train_seeds = [10, 20, 30, 40]
total_timesteps = 5_000_000_000
env_id = "XLand-MiniGrid-R4-13x13"
benchmark_id = "small-1m"

for seed in train_seeds:
    config = PPOTrainConfig(
        train_seed=seed,
        benchmark_split_seed=seed + 100,
        total_timesteps=total_timesteps,
        env_id=env_id,
        benchmark_id=benchmark_id,
    )

    run_ppo_training(config, fixed_envs=True)


## DIAYN

In [None]:
def run_diayn_training(config: DIAYNTrainConfig):
    # setup
    rng, env_no_goals, env_real_goals, env_params, benchmark, agent_train_state, discriminator_train_state = diayn_set_up_for_training(config)

    # train
    print(f"Training with seed {config.train_seed}")
    t = time.time()
    full_training_partial = Partial(diayn_full_training, env_no_goals=env_no_goals, env_real_goals=env_real_goals, env_params=env_params, benchmark=benchmark, config=config)
    jitted_full_training = jax.jit(full_training_partial)
    train_info = jax.block_until_ready(
        jitted_full_training(
            rng=rng,
            agent_train_state=agent_train_state,
            discriminator_train_state=discriminator_train_state,
        )
    )
    elapsed_time = time.time() - t
    print(f"Done in {elapsed_time / 60:.2f}min")

    try:
        # store results on disk
        save_path = build_trained_weights_path(
            algorithm_id="diayn",
            env_id=config.env_id,
            benchmark_id=config.benchmark_id,
            seed=config.train_seed,
        )
        save_path.parent.mkdir(parents=True, exist_ok=True)
        train_config = dataclasses.asdict(config)
        train_config = flax.core.freeze(train_config)
        agent_params = {"params": train_info["agent_state"].params, "constants": train_info["agent_state"].constants}
        best_agent_params = {"params": train_info["best"][1], "constants": train_info["agent_state"].constants}
        training_results = {
            "config": train_config,
            "agent_params": agent_params,
            "best_agent_params": best_agent_params,
            "metrics": train_info["metrics"],
        }

        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(training_results)
        orbax_checkpointer.save(save_path, training_results, save_args=save_args)
        print("saved training results to", save_path)
    except Exception as e:
        print(f"Error while saving training results to disk: {e}")

    # log metrics to wandb
    extra_logs = {
        "training/lr": train_info["metrics"]["lr"],
        "discriminator/discriminator_loss": train_info["metrics"]["discriminator_loss"],
        "discriminator/skills_logprob": train_info["metrics"]["skills_log_prob"],
    }
    run_name = generate_run_name(algorithm_name="DIAYN", config=config, prefix="")
    tags = ["diayn", "train"]
    wandb_log_training_metrics(train_info["metrics"], config, run_name, project_name="ULEE", tags=tags, extra_batch_metrics=extra_logs)
    time.sleep(15)


In [None]:
train_seeds = [10, 20, 30, 40]
total_timesteps = 5_000_000_000
env_id = "XLand-MiniGrid-R4-13x13"
benchmark_id = "small-1m"

for seed in train_seeds:
    config = DIAYNTrainConfig(
        train_seed=seed,
        benchmark_split_seed=seed + 100,
        total_timesteps=total_timesteps,
        env_id=env_id,
        benchmark_id=benchmark_id,
    )

    run_diayn_training(config)


## RL2

In [None]:
def run_rl2_training(config: RL2TrainConfig):
    # setup
    rng, env, env_params, benchmark, train_state = rl2_set_up_for_training(config)

    # train
    print(f"Training with seed {config.train_seed}")
    t = time.time()
    full_training_partial = Partial(rl2_full_training, env=env, env_params=env_params, benchmark=benchmark, config=config)
    jitted_full_training = jax.jit(full_training_partial)
    train_info = jax.block_until_ready(jitted_full_training(rng=rng, train_state=train_state))
    elapsed_time = time.time() - t
    print(f"Done in {elapsed_time / 60:.2f}min")

    try:
        # store results on disk
        save_path = build_trained_weights_path(
            algorithm_id="rl2",
            env_id=config.env_id,
            benchmark_id=config.benchmark_id,
            seed=config.train_seed,
        )
        save_path.parent.mkdir(parents=True, exist_ok=True)

        train_config = dataclasses.asdict(config)
        train_config = flax.core.freeze(train_config)
        training_results = {
            "config": train_config,
            "agent_params": train_info["agent_state"].params,
            "best_agent_params": train_info["best"][1],
            "metrics": train_info["metrics"],
        }

        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(training_results)
        orbax_checkpointer.save(save_path, training_results, save_args=save_args)
        print("saved training results to", save_path)
    except Exception as e:
        print(f"Error while saving training results to disk: {e}")

    # log metrics to wandb
    extra_logs = {
        "training/lr": train_info["metrics"]["lr"],
    }
    run_name = generate_run_name(algorithm_name="RL2", config=config, prefix="")
    tags = ["rl2", "train"]
    wandb_log_training_metrics(train_info["metrics"], config, run_name, project_name="ULEE", tags=tags, extra_batch_metrics=extra_logs)


In [None]:
train_seeds = [10, 20, 30, 40]
total_timesteps = 5_000_000_000
env_id = "XLand-MiniGrid-R4-13x13"
benchmark_id = "small-1m"

for seed in train_seeds:
    config = RL2TrainConfig(
        train_seed=seed,
        benchmark_split_seed=seed + 100,
        total_timesteps=total_timesteps,
        env_id=env_id,
        benchmark_id=benchmark_id,
    )

    run_rl2_training(config)
