In [None]:
import dataclasses
import os
import pathlib
import pickle
import queue
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.animation as animation
import numpy as np
import wandb
from matplotlib import pyplot as plt

from cleanba import cleanba_impala
from cleanba.environments import BoxobanConfig, EnvConfig
from cleanba.network import Policy

wandb.init(mode="disabled")

In [None]:
@dataclasses.dataclass
class EvalConfig:
    env: EnvConfig
    n_episode_multiple: int = 1
    steps_to_think: list[int] = dataclasses.field(default_factory=lambda: [0])
    temperature: float = 0.0

    safeguard_max_episode_steps: int = 30000

    def run(self, policy: Policy, get_action_fn, params, *, key: jnp.ndarray) -> dict[str, float]:
        key, env_key, carry_key, obs_reset_key = jax.random.split(key, 4)
        env_seed = int(jax.random.randint(env_key, (), minval=0, maxval=2**31 - 2))
        envs = dataclasses.replace(self.env, seed=env_seed).make()
        max_steps = min(self.safeguard_max_episode_steps, self.env.max_episode_steps)

        episode_starts_no = jnp.zeros(envs.num_envs, dtype=jnp.bool_)

        metrics = {}
        try:
            for steps_to_think in self.steps_to_think:
                all_episode_returns = []
                all_episode_lengths = []
                all_episode_successes = []
                all_obs = []
                all_acts = []
                all_rewards = []
                all_level_infos = []
                envs = dataclasses.replace(self.env, seed=env_seed).make()
                reset_key = None
                for _ in range(self.n_episode_multiple):
                    reset_key, sub_reset_key = jax.random.split(obs_reset_key if reset_key is None else reset_key)
                    reset_seed = int(jax.random.randint(sub_reset_key, (), minval=0, maxval=2**31 - 2))
                    obs, level_infos = envs.reset(seed=reset_seed)
                    # reset the carry here so we can use `episode_starts_no` later
                    carry = policy.apply(params, carry_key, obs.shape, method=policy.initialize_carry)

                    # Update the carry with the initial observation many times
                    for think_step in range(steps_to_think):
                        carry, _, _, key = get_action_fn(
                            params, carry, obs, episode_starts_no, key, temperature=self.temperature
                        )

                    eps_done = np.zeros(envs.num_envs, dtype=np.bool_)
                    episode_success = np.zeros(envs.num_envs, dtype=np.bool_)
                    episode_returns = np.zeros(envs.num_envs, dtype=np.float64)
                    episode_lengths = np.zeros(envs.num_envs, dtype=np.int64)
                    episode_obs = np.zeros((max_steps+1, *obs.shape), dtype=np.int64)
                    episode_acts = np.zeros((max_steps, envs.num_envs), dtype=np.int64)
                    episode_rewards = np.zeros((max_steps, envs.num_envs), dtype=np.float64)
                    
                    episode_obs[0] = obs
                    i = 0
                    while not np.all(eps_done):
                        if i >= self.safeguard_max_episode_steps:
                            break
                        carry, action, _, key = get_action_fn(
                            params, carry, obs, episode_starts_no, key, temperature=self.temperature
                        )

                        cpu_action = np.asarray(action)
                        obs, rewards, terminated, truncated, infos = envs.step(cpu_action)
                        episode_returns[~eps_done] += rewards[~eps_done]
                        episode_lengths[~eps_done] += 1
                        episode_success[~eps_done] |= terminated[~eps_done]  # If episode terminates it's a success

                        episode_obs[i+1, ~eps_done] = obs[~eps_done]
                        episode_acts[i, ~eps_done] = cpu_action[~eps_done]
                        episode_rewards[i, ~eps_done] = rewards[~eps_done]

                        # Set as done the episodes which are done
                        eps_done |= truncated | terminated
                        i += 1

                    all_episode_returns.append(episode_returns)
                    all_episode_lengths.append(episode_lengths)
                    all_episode_successes.append(episode_success)

                    all_obs += [episode_obs[:episode_lengths[i], i] for i in range(envs.num_envs)]
                    all_acts += [episode_acts[:episode_lengths[i], i] for i in range(envs.num_envs)]
                    all_rewards += [episode_rewards[:episode_lengths[i], i] for i in range(envs.num_envs)]

                    all_obs.append(episode_obs)
                    all_acts.append(episode_acts)
                    all_level_infos.append(level_infos)

                all_episode_returns = np.concatenate(all_episode_returns)
                all_episode_lengths = np.concatenate(all_episode_lengths)
                all_episode_successes = np.concatenate(all_episode_successes)
                all_level_infos = {k: np.concatenate([d[k] for d in all_level_infos])
                                    for k in all_level_infos[0].keys() if not k.startswith("_")}

                metrics.update(
                    {
                        f"{steps_to_think:02d}_episode_returns": float(np.mean(all_episode_returns)),
                        f"{steps_to_think:02d}_episode_lengths": float(np.mean(all_episode_lengths)),
                        f"{steps_to_think:02d}_episode_successes": float(np.mean(all_episode_successes)),
                        f"{steps_to_think:02d}_num_episodes": len(all_episode_returns),
                        f"{steps_to_think:02d}_all_episode_info": dict(
                            episode_returns=all_episode_returns,
                            episode_lengths=all_episode_lengths,
                            episode_successes=all_episode_successes,
                            episode_obs=all_obs,
                            episode_acts=all_acts,
                            episode_rewards=all_rewards,
                            level_infos=level_infos,
                        ),
                    }
                )
        finally:
            envs.close()
        return metrics


In [None]:
steps_to_think=[0, 1, 2, 3, 4, 6, 8, 10]
n_episode_multiple = 20
num_envs = 100
unfil_env_cfg = EvalConfig(
    BoxobanConfig(
        split="test",
        difficulty="unfiltered",
        min_episode_steps=240,
        max_episode_steps=240,
        num_envs=num_envs,
        tinyworld_obs=True,
        seed=42,
    ),
    n_episode_multiple=n_episode_multiple,
    steps_to_think=steps_to_think,

)

val_med_env_cfg = EvalConfig(
    BoxobanConfig(
        split="valid",
        difficulty="medium",
        min_episode_steps=240,
        max_episode_steps=240,
        num_envs=num_envs,
        tinyworld_obs=True,
        seed=42,
    ),
    n_episode_multiple=n_episode_multiple,
    steps_to_think=steps_to_think,
)

In [None]:
val_med_env_cfg

In [None]:
base_path = pathlib.Path("/training/cleanba/logs/data")
if not (base_path / "unfil_log_dict.pkl").exists():
    path = pathlib.Path("/training/cleanba/044-more-planners/wandb/run-20240506_043059-6zhw6cw1/local-files/cp_208000000")
    args, train_state = cleanba_impala.load_train_state(path)
    prng_key = jax.random.PRNGKey(0)
    policy, carry_t, _ = args.net.init_params(args.train_env.make(), prng_key)
    get_action_fn = jax.jit(partial(policy.apply, method=policy.get_action), static_argnames="temperature")
    params = train_state.params

    unfil_log_dict = unfil_env_cfg.run(policy, get_action_fn, params, key=prng_key)
    unfil_all_episode_info = [unfil_log_dict.pop(f"{steps_to_think:02d}_all_episode_info") for steps_to_think in steps_to_think]

    print("finished unfiltered")

    val_log_dict = val_med_env_cfg.run(policy, get_action_fn, params, key=prng_key)
    val_all_episode_info = [val_log_dict.pop(f"{steps_to_think:02d}_all_episode_info") for steps_to_think in steps_to_think]

    import pickle
    with open(base_path / "unfil_log_dict.pkl", "wb") as f:
        pickle.dump(unfil_log_dict, f)
    with open(base_path / "unfil_all_episode_info.pkl", "wb") as f:
        pickle.dump(unfil_all_episode_info, f)
    with open(base_path / "val_log_dict.pkl", "wb") as f:
        pickle.dump(val_log_dict, f)
    with open(base_path / "val_all_episode_info.pkl", "wb") as f:
        pickle.dump(val_all_episode_info, f)

else:
    print("loading logs")
    import pickle
    with open(base_path / "unfil_log_dict.pkl", "rb") as f:
        unfil_log_dict = pickle.load(f)
    with open(base_path / "unfil_all_episode_info.pkl", "rb") as f:
        unfil_all_episode_info = pickle.load(f)
    with open(base_path / "val_log_dict.pkl", "rb") as f:
        val_log_dict = pickle.load(f)
    with open(base_path / "val_all_episode_info.pkl", "rb") as f:
        val_all_episode_info = pickle.load(f)

In [None]:
val_log_dict

In [None]:
import matplotlib.pyplot as plt

env_names = ["unfiltered_test", "valid_medium"]
for i, log_dict in enumerate([unfil_log_dict, val_log_dict]):
    # plot XX_episode_successes across all XX which are steps_to_think
    episode_successes = [log_dict[f"{steps_to_think:02d}_episode_successes"] for steps_to_think in steps_to_think]
    episode_returns = [log_dict[f"{steps_to_think:02d}_episode_returns"] for steps_to_think in steps_to_think]
    
    fig, ax1 = plt.subplots()

    color = 'tab:red'
    ax1.set_xlabel('steps_to_think')
    ax1.set_ylabel('episode_successes', color=color)
    ax1.plot(steps_to_think, episode_successes, color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    color = 'tab:blue'
    ax2.set_ylabel('episode_returns', color=color)  # we already handled the x-label with ax1
    ax2.plot(steps_to_think, episode_returns, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    fig.tight_layout()  # otherwise the right y-label is slightly clipped
    plt.title(env_names[i])
    plt.show()

In [None]:
num_levels = len(val_all_episode_info[0]["episode_successes"])
improved_level_list = []
impaired_level_list = []
solved_better_returns = []
solved_worse_returns = []
unsolved_better_returns = []
unsolved_worse_returns = []
same_return_and_solve = []
baseline_steps = 0
best_steps = steps_to_think.index(6)
for i in range(len(val_all_episode_info[0]["episode_successes"])):
    solved_after_thinking = val_all_episode_info[baseline_steps]["episode_successes"][i] < \
        val_all_episode_info[best_steps]["episode_successes"][i]
    messed_up_after_thinking = val_all_episode_info[baseline_steps]["episode_successes"][i] > \
        val_all_episode_info[best_steps]["episode_successes"][i]

    solved_always = val_all_episode_info[baseline_steps]["episode_successes"][i] and \
        val_all_episode_info[best_steps]["episode_successes"][i]
    unsolved_always = not(val_all_episode_info[baseline_steps]["episode_successes"][i] or \
        val_all_episode_info[best_steps]["episode_successes"][i])
    better_return = val_all_episode_info[best_steps]["episode_returns"][i] > \
        val_all_episode_info[baseline_steps]["episode_returns"][i]
    worse_return = val_all_episode_info[best_steps]["episode_returns"][i] < \
        val_all_episode_info[baseline_steps]["episode_returns"][i]

    if solved_after_thinking:
        improved_level_list.append(i)
    elif messed_up_after_thinking:
        impaired_level_list.append(i)
    elif solved_always and better_return:
        solved_better_returns.append(i)
    elif solved_always and worse_return:
        solved_worse_returns.append(i)
    elif unsolved_always and better_return:
        unsolved_better_returns.append(i)
    elif unsolved_always and worse_return:
        unsolved_worse_returns.append(i)
    else:
        same_return_and_solve.append(i)

# print all fractions
improved_pc = len(improved_level_list)/num_levels*100
impaired_pc = len(impaired_level_list)/num_levels*100
solved_better_pc = len(solved_better_returns)/num_levels*100
solved_worse_pc = len(solved_worse_returns)/num_levels*100
unsolved_better_pc = len(unsolved_better_returns)/num_levels*100
unsolved_worse_pc = len(unsolved_worse_returns)/num_levels*100
same_return_and_solve_pc = len(same_return_and_solve)/num_levels*100

print(f"Solved previously unsolved:\t{improved_pc:.2f}%")
print(f"Unsolved previously solved:\t{impaired_pc:.2f}%")
print(f"Solved w/ better returns:\t{solved_better_pc:.2f}%")
print(f"Solved but worse returns:\t{solved_worse_pc:.2f}%")
print(f"Unsolved but better returns:\t{unsolved_better_pc:.2f}%")
print(f"Unsolved w/ worse returns:\t{unsolved_worse_pc:.2f}%")
print(f"Same return & solve:\t\t{same_return_and_solve_pc:.2f}%")

In [None]:
def save_level_video(level_idx, base_dir="./"):
    obs_baseline = np.moveaxis(val_all_episode_info[baseline_steps]["episode_obs"][level_idx], 1, 3)
    obs_best = np.moveaxis(val_all_episode_info[best_steps]["episode_obs"][level_idx], 1, 3)
    num_obs_baseline = len(obs_baseline)
    num_obs_best = len(obs_best)
    max_obs = max(num_obs_baseline, num_obs_best)
    fig, axs = plt.subplots(1, 2)
    fig.suptitle(f"Level {i}")
    ax1, ax2 = axs
    ax1.set_title(f"{steps_to_think[baseline_steps]} think steps")
    ax2.set_title(f"{steps_to_think[best_steps]} think steps")
    im1 = ax1.imshow(obs_baseline[0])
    im2 = ax2.imshow(obs_best[0])
    title = fig.suptitle(f"Level {i}: Step 0")

    def update_frame(j):
        baseline_img = obs_baseline[min(len(obs_baseline)-2, j)]
        # ax1.imshow(baseline_img)
        im1.set(data=baseline_img)
        best_img = obs_best[min(len(obs_best)-2, j)]
        # ax2.imshow(best_img)
        im2.set(data=best_img)
        title.set_text(f"Level {i}: Step {j}")
        return (im1, im2, title)
        

    anim = animation.FuncAnimation(
        fig,
        update_frame,  # type: ignore
        frames=max_obs,
        interval=1,
        repeat=False,
    )
    plt.tight_layout()
    base_dir = pathlib.Path(base_dir)
    base_dir.mkdir(parents=True, exist_ok=True)
    anim.save(base_dir / f'{i}.mp4', fps=3)
    print(f"Level {i} saved")

### Time to place box vs thinking steps

In [None]:
reward_for_placing_box = 0.9
reward_for_placing_last_box = -0.1 + 1.0 + 10.0

time_across_think_steps = []
condition_on_improved_levels = True
for j in range(len(steps_to_think)):
    all_rewards = val_all_episode_info[j]["episode_rewards"]
    if condition_on_improved_levels:
        time_for_placing_boxes = [np.where(all_rewards[level_idx] == reward_for_placing_box)[0] for level_idx in improved_level_list]
    else:
        time_for_placing_boxes = [np.where(reward_array == reward_for_placing_box)[0] for reward_array in all_rewards]
    avg_time_box_placed = [np.mean([t[box_idx] for t in time_for_placing_boxes if len(t) > box_idx]) for box_idx in range(3)]
    time_for_placing_last_box = [np.where(reward_array == reward_for_placing_last_box)[0] for reward_array in all_rewards]
    time_for_placing_last_box = [e for e in time_for_placing_last_box if len(e) > 0]
    avg_time_box_placed.append(np.mean(time_for_placing_last_box))
    time_across_think_steps.append(avg_time_box_placed)


In [None]:
plt.plot(steps_to_think, time_across_think_steps)
plt.legend(["Box 1", "Box 2", "Box 3", "Box 4"])
plt.xlabel("Steps to think")
plt.ylabel("Avg timesteps to place the box")
if condition_on_improved_levels:
    plt.title("On levels where thinking more solves an unsolved level")
else:
    plt.title("On all levels")
plt.show()


In [None]:
plt.plot(steps_to_think, time_across_think_steps)
plt.legend(["Box 1", "Box 2", "Box 3", "Box 4"])
plt.xlabel("Steps to think")
plt.ylabel("Avg timesteps to place the box")
plt.title("On levels where thinking more solves an unsolved level")
plt.show()


### Already solved but better return


In [None]:
for i in range(len(steps_to_think)):
    val_all_episode_info[i]["episode_obs"] = val_all_episode_info[i]["episode_obs"]

In [None]:
for level_idx in solved_better_returns:
    save_level_video(level_idx, base_dir="solved_better_returns/")
    break

### Collapse

In [None]:
# writer = cleanba_impala.WandbWriter(args)
# param_queue = queue.Queue(maxsize=1)
# rollout_queue = queue.Queue(maxsize=1)
# learner_policy_version = 0
# unreplicated_params = train_state.params
# with cleanba_impala.initialize_multi_device(args) as runtime_info:
#     device_params = jax.device_put(unreplicated_params, runtime_info.local_devices[0])
#     param_queue.put((device_params, learner_policy_version))
#     prng_key = jax.random.PRNGKey(0)
#     cleanba_impala.rollout(
#         prng_key,
#         args,
#         runtime_info,
#         rollout_queue,
#         param_queue,
#         writer,
#         runtime_info.learner_devices,
#         0,
#         runtime_info.local_devices[0],
#     )