In [1]:
# git checkout e0bf2ef4532c2e487e7f2cc3f19fc9656c80398c

In [2]:
import dataclasses
import os
import pathlib
import pickle
import queue
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.animation as animation
import numpy as np
import wandb
from matplotlib import pyplot as plt

from cleanba import cleanba_impala
from cleanba.cleanba_impala import make_optimizer, unreplicate
from cleanba.environments import BoxobanConfig, EnvConfig
import json
import flax
import farconf
from flax.training.train_state import TrainState
from cleanba.config import Args


wandb.init(mode="disabled")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [3]:
@dataclasses.dataclass
class EvalConfig:
    env: EnvConfig
    n_episode_multiple: int = 1
    steps_to_think: list[int] = dataclasses.field(default_factory=lambda: [0])
    temperature: float = 0.0

    safeguard_max_episode_steps: int = 30000

    def run(self, get_action, agent_state, *, key: jnp.ndarray) -> dict[str, float]:
        key, env_key, carry_key, obs_reset_key = jax.random.split(key, 4)
        env_seed = int(jax.random.randint(env_key, (), minval=0, maxval=2**31 - 2))
        envs = dataclasses.replace(self.env, seed=env_seed).make()
        max_steps = min(self.safeguard_max_episode_steps, self.env.max_episode_steps)

        episode_starts_no = jnp.zeros(envs.num_envs, dtype=jnp.bool_)

        metrics = {}
        try:
            for steps_to_think in self.steps_to_think:
                all_episode_returns = []
                all_episode_lengths = []
                all_episode_successes = []
                all_obs = []
                all_acts = []
                all_rewards = []
                all_level_infos = []
                envs = dataclasses.replace(self.env, seed=env_seed).make()
                reset_key = None
                for _ in range(self.n_episode_multiple):
                    reset_key, sub_reset_key = jax.random.split(obs_reset_key if reset_key is None else reset_key)
                    reset_seed = int(jax.random.randint(sub_reset_key, (), minval=0, maxval=2**31 - 2))
                    obs, level_infos = envs.reset(seed=reset_seed)
                    # reset the carry here so we can use `episode_starts_no` later

                    eps_done = np.zeros(envs.num_envs, dtype=np.bool_)
                    episode_success = np.zeros(envs.num_envs, dtype=np.bool_)
                    episode_returns = np.zeros(envs.num_envs, dtype=np.float64)
                    episode_lengths = np.zeros(envs.num_envs, dtype=np.int64)
                    episode_obs = np.zeros((max_steps+1, *obs.shape), dtype=np.int64)
                    episode_acts = np.zeros((max_steps, envs.num_envs), dtype=np.int64)
                    episode_rewards = np.zeros((max_steps, envs.num_envs), dtype=np.float64)
                    
                    episode_obs[0] = obs
                    i = 0
                    while not np.all(eps_done):
                        if i >= self.safeguard_max_episode_steps:
                            break
                        action, _, key = get_action(
                            params=agent_state,
                            next_obs=obs,
                            key=key,
                            temperature=self.temperature,
                        )

                        cpu_action = np.asarray(action)
                        obs, rewards, terminated, truncated, infos = envs.step(cpu_action)
                        episode_returns[~eps_done] += rewards[~eps_done]
                        episode_lengths[~eps_done] += 1
                        episode_success[~eps_done] |= terminated[~eps_done]  # If episode terminates it's a success

                        episode_obs[i+1, ~eps_done] = obs[~eps_done]
                        episode_acts[i, ~eps_done] = cpu_action[~eps_done]
                        episode_rewards[i, ~eps_done] = rewards[~eps_done]

                        # Set as done the episodes which are done
                        eps_done |= truncated | terminated
                        i += 1

                    all_episode_returns.append(episode_returns)
                    all_episode_lengths.append(episode_lengths)
                    all_episode_successes.append(episode_success)

                    all_obs += [episode_obs[:episode_lengths[i], i] for i in range(envs.num_envs)]
                    all_acts += [episode_acts[:episode_lengths[i], i] for i in range(envs.num_envs)]
                    all_rewards += [episode_rewards[:episode_lengths[i], i] for i in range(envs.num_envs)]

                    all_obs.append(episode_obs)
                    all_acts.append(episode_acts)
                    all_level_infos.append(level_infos)

                all_episode_returns = np.concatenate(all_episode_returns)
                all_episode_lengths = np.concatenate(all_episode_lengths)
                all_episode_successes = np.concatenate(all_episode_successes)
                all_level_infos = {k: np.concatenate([d[k] for d in all_level_infos])
                                    for k in all_level_infos[0].keys() if not k.startswith("_")}

                metrics.update(
                    {
                        f"{steps_to_think:02d}_episode_returns": float(np.mean(all_episode_returns)),
                        f"{steps_to_think:02d}_episode_lengths": float(np.mean(all_episode_lengths)),
                        f"{steps_to_think:02d}_episode_successes": float(np.mean(all_episode_successes)),
                        f"{steps_to_think:02d}_num_episodes": len(all_episode_returns),
                        f"{steps_to_think:02d}_all_episode_info": dict(
                            episode_returns=all_episode_returns,
                            episode_lengths=all_episode_lengths,
                            episode_successes=all_episode_successes,
                            episode_obs=all_obs,
                            episode_acts=all_acts,
                            episode_rewards=all_rewards,
                            level_infos=all_level_infos,
                        ),
                    }
                )
        finally:
            envs.close()
        return metrics

def save_level_video(level_idx, base_dir="./", force=False):
    base_dir = pathlib.Path(base_dir)
    base_dir.mkdir(parents=True, exist_ok=True)
    file_path = base_dir / f'{level_idx}.mp4'
    if file_path.exists() and not force:
        return
    obs_baseline = np.moveaxis(val_all_episode_info[baseline_steps]["episode_obs"][level_idx], 1, 3)
    obs_best = np.moveaxis(val_all_episode_info[best_steps]["episode_obs"][level_idx], 1, 3)
    num_obs_baseline = len(obs_baseline)
    num_obs_best = len(obs_best)
    max_obs = max(num_obs_baseline, num_obs_best)
    fig, axs = plt.subplots(1, 2)
    ax1, ax2 = axs
    ax1.set_title(f"{steps_to_think[baseline_steps]} think steps")
    ax2.set_title(f"{steps_to_think[best_steps]} think steps")
    im1 = ax1.imshow(obs_baseline[0])
    im2 = ax2.imshow(obs_best[0])
    title = fig.suptitle(f"Level {level_idx}: Step 0")

    def update_frame(j):
        baseline_img = obs_baseline[min(len(obs_baseline)-1, j)]
        # ax1.imshow(baseline_img)
        im1.set(data=baseline_img)
        best_img = obs_best[min(len(obs_best)-1, j)]
        # ax2.imshow(best_img)
        im2.set(data=best_img)
        title.set_text(f"Level {level_idx}: Step {j}")
        return (im1, im2, title)
        

    anim = animation.FuncAnimation(
        fig,
        update_frame,  # type: ignore
        frames=max_obs,
        interval=1,
        repeat=False,
    )
    plt.tight_layout()
    anim.save(file_path, fps=3)
    print(f"Level {level_idx} saved")




In [4]:
steps_to_think=[0]
n_episode_multiple = 50
num_envs = 100
episode_steps = 120
unfil = False
if unfil:
    unfil_env_cfg = EvalConfig(
        BoxobanConfig(
            split="test",
            difficulty="unfiltered",
            min_episode_steps=episode_steps,
            max_episode_steps=episode_steps,
            num_envs=num_envs,
            tinyworld_obs=True,
            seed=42,
        ),
        n_episode_multiple=n_episode_multiple,
        steps_to_think=steps_to_think,

    )

val_med_env_cfg = EvalConfig(
    BoxobanConfig(
        split="valid",
        difficulty="medium",
        min_episode_steps=episode_steps,
        max_episode_steps=episode_steps,
        num_envs=num_envs,
        tinyworld_obs=True,
        seed=42,
    ),
    n_episode_multiple=n_episode_multiple,
    steps_to_think=steps_to_think,
)

In [5]:
val_med_env_cfg

EvalConfig(env=BoxobanConfig(max_episode_steps=120, num_envs=100, seed=42, min_episode_steps=120, tinyworld_obs=True, tinyworld_render=False, terminate_on_first_box=False, reward_finished=10.0, reward_box=1.0, reward_step=-0.1, reset=False, asynchronous=True, cache_path=PosixPath('/opt/sokoban_cache'), split='valid', difficulty='medium'), n_episode_multiple=50, steps_to_think=[0], temperature=0.0, safeguard_max_episode_steps=30000)

In [6]:
base_checkpoint_path = pathlib.Path("/training/cleanba/029-long/wandb/run-20240427_080424-jojfc9yt/local-files/")
models_path = {n: base_checkpoint_path / d for n, d in [("250M", "cp_0250368000"), ("1T", "cp_1001472000")]}


In [30]:
def load_train_state(dir: pathlib.Path, env):
    with open(dir / "cfg.json", "r") as f:
        args_dict = json.load(f)
    # args_dict.pop("train_env")
    args_dict.pop("eval_envs")
    args = farconf.from_dict(args_dict, Args)

    _, _, params = args.net.init_params(env, jax.random.PRNGKey(1234))
    print(params["params"].keys())
    print(['.'.join([str(k)[2:-2] for k in p]) for p,l in jax.tree_util.tree_leaves_with_path(params["params"])])

    local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))

    target_state = TrainState.create(
        apply_fn=None,
        params=params["params"],
        tx=make_optimizer(args, params, total_updates=args.total_timesteps // local_batch_size),
    )

    with open(dir / "model", "rb") as f:
        train_state = flax.serialization.from_bytes(target_state, f.read())
    assert isinstance(train_state, TrainState)
    return args, train_state


import os
# os.environ["XLA_FLAGS"]="--xla_gpu_strict_conv_algorithm_picker=false"
os.environ["XLA_FLAGS"]="--xla_gpu_autotune_level=0"

base_path = pathlib.Path("/training/cleanba/logs/data/latest_resnet/")
for name, model_path in models_path.items():
    base_cache_path = base_path / name
    if not (base_cache_path / "val_log_dict.pkl").exists() or True:
        base_cache_path.mkdir(parents=True, exist_ok=True)
        args, train_state = load_train_state(model_path, val_med_env_cfg.env.make())
        prng_key = jax.random.PRNGKey(0)
        # policy, carry_t, _ = args.net.init_params(val_med_env_cfg.env.make(), prng_key)
        params = train_state.params
        params = jax.tree_util.tree_map((lambda x: x.squeeze(0) if x.shape[0] == 1 else x), params)

        if unfil:
            unfil_log_dict = unfil_env_cfg.run(args.net.get_action, params, key=prng_key)
            unfil_all_episode_info = unfil_log_dict.pop(f"{0:02d}_all_episode_info")

            print("finished unfiltered")
            with open(base_cache_path / "unfil_log_dict.pkl", "wb") as f:
                pickle.dump(unfil_log_dict, f)
            with open(base_cache_path / "unfil_all_episode_info.pkl", "wb") as f:
                pickle.dump(unfil_all_episode_info, f)

        val_log_dict = val_med_env_cfg.run(args.net.get_action, params, key=prng_key)
        val_all_episode_info = val_log_dict.pop(f"{0:02d}_all_episode_info")

        with open(base_cache_path / "val_log_dict.pkl", "wb") as f:
            pickle.dump(val_log_dict, f)
        with open(base_cache_path / "val_all_episode_info.pkl", "wb") as f:
            pickle.dump(val_all_episode_info, f)

    else:
        print("loading logs")
        if unfil:
            with open(base_cache_path / "unfil_log_dict.pkl", "rb") as f:
                unfil_log_dict = pickle.load(f)
            with open(base_cache_path / "unfil_all_episode_info.pkl", "rb") as f:
                unfil_all_episode_info = pickle.load(f)
        with open(base_cache_path / "val_log_dict.pkl", "rb") as f:
            val_log_dict = pickle.load(f)
        with open(base_cache_path / "val_all_episode_info.pkl", "rb") as f:
            val_all_episode_info = pickle.load(f)
    print(name)
    print("Success rate:", val_log_dict["00_episode_successes"])
    print("Mean episode length:", val_log_dict["00_episode_lengths"])
    print("Mean episode return:", val_log_dict["00_episode_returns"])

  gym.logger.warn(
  self.pid = os.fork()


dict_keys(['network_params', 'actor_params', 'critic_params'])
['actor_params.Output.bias', 'actor_params.Output.kernel', 'critic_params.Output.bias', 'critic_params.Output.kernel', 'network_params.Dense_0.bias', 'network_params.Dense_0.kernel', 'network_params.GuezConvSequence_0.GuezResidualBlock_0.Conv_0.bias', 'network_params.GuezConvSequence_0.GuezResidualBlock_0.Conv_0.kernel', 'network_params.GuezConvSequence_0.GuezResidualBlock_1.Conv_0.bias', 'network_params.GuezConvSequence_0.GuezResidualBlock_1.Conv_0.kernel', 'network_params.GuezConvSequence_0.xXx_Input_xXx.bias', 'network_params.GuezConvSequence_0.xXx_Input_xXx.kernel', 'network_params.GuezConvSequence_1.Conv_0.bias', 'network_params.GuezConvSequence_1.Conv_0.kernel', 'network_params.GuezConvSequence_1.GuezResidualBlock_0.Conv_0.bias', 'network_params.GuezConvSequence_1.GuezResidualBlock_0.Conv_0.kernel', 'network_params.GuezConvSequence_1.GuezResidualBlock_1.Conv_0.bias', 'network_params.GuezConvSequence_1.GuezResidualBloc

ValueError: The target dict keys and state dict keys do not match, target dict contains keys {'GuezConvSequence_2', 'GuezConvSequence_8', 'GuezConvSequence_1', 'GuezConvSequence_5', 'GuezConvSequence_0', 'GuezConvSequence_6', 'Dense_0', 'GuezConvSequence_7', 'GuezConvSequence_4', 'GuezConvSequence_3'} which are not present in state dict at path ./params/network_params

In [18]:
with open(model_path / "model", "rb") as f:
    sd = flax.serialization.msgpack_restore(f.read())

In [None]:
'network_params.GuezConvSequence_0.GuezResidualBlock_0.Conv_0.kernel',
 'network_params.params.GuezConvSequence_0.GuezResidualBlock_0.Conv_0.kernel',


In [29]:
['.'.join([str(k)[2:-2] for k in p]) for p,l in jax.tree_util.tree_leaves_with_path(sd["params"])]

['actor_params.params.Output.bias',
 'actor_params.params.Output.kernel',
 'critic_params.params.Output.bias',
 'critic_params.params.Output.kernel',
 'network_params.params.Dense_0.bias',
 'network_params.params.Dense_0.kernel',
 'network_params.params.GuezConvSequence_0.GuezResidualBlock_0.Conv_0.bias',
 'network_params.params.GuezConvSequence_0.GuezResidualBlock_0.Conv_0.kernel',
 'network_params.params.GuezConvSequence_0.GuezResidualBlock_1.Conv_0.bias',
 'network_params.params.GuezConvSequence_0.GuezResidualBlock_1.Conv_0.kernel',
 'network_params.params.GuezConvSequence_0.xXx_Input_xXx.bias',
 'network_params.params.GuezConvSequence_0.xXx_Input_xXx.kernel',
 'network_params.params.GuezConvSequence_1.Conv_0.bias',
 'network_params.params.GuezConvSequence_1.Conv_0.kernel',
 'network_params.params.GuezConvSequence_1.GuezResidualBlock_0.Conv_0.bias',
 'network_params.params.GuezConvSequence_1.GuezResidualBlock_0.Conv_0.kernel',
 'network_params.params.GuezConvSequence_1.GuezResidualB

In [13]:
params.keys()

NameError: name 'params' is not defined

### Plots

### Videos

In [8]:
# # solved but better returns
# do_save = False
# if do_save:
#     saved = 0
#     for level_idx in solved_better_returns:
#         if levels_with_same_obs[level_idx]:
#             continue
#         save_level_video(level_idx, base_dir="resnet/solved_but_better_returns/")
#         saved += 1
#         if saved >= 10:
#             break

In [9]:
# # solved with thinking more
# if do_save:
#     saved = 0
#     for level_idx in improved_level_list:
#         save_level_video(level_idx, base_dir="thinking_solves_unsolved/")
#         saved += 1
#         if saved >= 10:
#             break

### Collapse

In [10]:
# writer = cleanba_impala.WandbWriter(args)
# param_queue = queue.Queue(maxsize=1)
# rollout_queue = queue.Queue(maxsize=1)
# learner_policy_version = 0
# unreplicated_params = train_state.params
# with cleanba_impala.initialize_multi_device(args) as runtime_info:
#     device_params = jax.device_put(unreplicated_params, runtime_info.local_devices[0])
#     param_queue.put((device_params, learner_policy_version))
#     prng_key = jax.random.PRNGKey(0)
#     cleanba_impala.rollout(
#         prng_key,
#         args,
#         runtime_info,
#         rollout_queue,
#         param_queue,
#         writer,
#         runtime_info.learner_devices,
#         0,
#         runtime_info.local_devices[0],
#     )

# import glob

# all_log_levels = []
# for filename in glob.glob("/training/.sokoban_cache/boxoban-levels-master/medium/valid/logs/*"):
#     try:
#         file_idx, lev_idx = (int(c) for c in filename.split("/")[-1].split(".")[0].split("_")[1:])
#     except ValueError:
#         continue
#     all_log_levels.append((file_idx, lev_idx))

# # find file_idx, lev_idx not in all_log_levels
# not_present = []
# for file_idx in range(50):
#     for lev_idx in range(1000):
#         if (file_idx, lev_idx) not in all_log_levels:
#             not_present.append((file_idx, lev_idx))
# len(not_present)