In [1]:
import contextlib
import tempfile
import threading
from functools import partial
from pathlib import Path
from typing import Callable, Dict, Iterator, Optional
import os

import gymnasium as gym
import numpy as np
from flax.training.train_state import TrainState
from gymnasium import spaces
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
from gymnasium.wrappers import NormalizeObservation

from cleanba.cleanba_impala import WandbWriter, load_train_state, train
from cleanba.config import Args
from cleanba.convlstm import ConvConfig, ConvLSTMConfig
from cleanba.network import GuezResNetConfig
from cleanba.environments import EnvConfig
from cleanba.evaluate import EvalConfig

In [2]:
# TODO: use generic Writer interface, this is not correct inheritance
class CheckingWriter(WandbWriter):
    def __init__(self, cfg: Args, save_dir: Path, eval_keys):
        self.last_global_step = -1
        self.metrics = {}
        self._save_dir = save_dir

        self.eval_keys = set(eval_keys)
        assert len(self.eval_keys) > 0
        self.eval_events = {k: threading.Event() for k in self.eval_keys}

        # assert cfg.save_model is True
        self._args = cfg
        self.step_digits = 4
        self.eval_metrics = {}
        self.eval_global_step = -1
        self.done_saving = threading.Event()
        self.done_saving.set()

    def add_scalar(self, name: str, value: int | float, global_step: int):
        if global_step == self.last_global_step:
            self.metrics.clear()

        self.last_global_step = global_step
        self.metrics[name] = value

        if name in self.eval_events:
            if self.eval_global_step != global_step:
                self.done_saving.wait(10)
                self.eval_metrics.clear()

            self.eval_global_step = global_step
            self.eval_events[name].set()
            self.eval_metrics[name] = value

    @contextlib.contextmanager
    def save_dir(self, global_step: int) -> Iterator[Path]:
        for event in self.eval_events.values():
            event.wait(timeout=5)

        with super().save_dir(global_step) as dir:
            yield dir

            assert self.last_global_step == global_step, "we want to save with the same step as last metrics"
            assert all(
                k in self.eval_metrics for k in self.eval_keys
            ), f"One of {self.eval_keys=} not present in {list(self.eval_metrics.keys())=}"

        # Clear for the next saving
        for event in self.eval_events.values():
            event.clear()
        self.done_saving.set()

        args, train_state = load_train_state(dir)
        assert args == self._args
        assert isinstance(train_state, TrainState)


In [3]:
if "CartPoleNoVel-v0" not in gym.registry or "CartPoleCHW-v0" not in gym.registry:
    class CartPoleCHWEnv(CartPoleEnv):
        """Variant of CartPoleEnv with velocity information removed, and CHW-shaped observations.
        This task requires memory to solve."""

        def __init__(self):
            super().__init__()
            high = np.array(
                [
                    self.x_threshold * 2,
                    3.4028235e+38,
                    self.theta_threshold_radians * 2,
                    3.4028235e+38,
                ],
                dtype=np.float32,
            )[:, None, None]
            self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        @staticmethod
        def _pos_obs(full_obs):
            return np.array(full_obs)[:, None, None]

        def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
            full_obs, info = super().reset(seed=seed, options=options)
            return CartPoleCHWEnv._pos_obs(full_obs), info

        def step(self, action):
            full_obs, rew, terminated, truncated, info = super().step(action)
            return CartPoleCHWEnv._pos_obs(full_obs), rew, terminated, truncated, info


    class CartPoleNoVelEnv(CartPoleEnv):
        """Variant of CartPoleEnv with velocity information removed, and CHW-shaped observations.
        This task requires memory to solve."""

        def __init__(self):
            super().__init__()
            high = np.array(
                [
                    self.x_threshold * 2,
                    self.theta_threshold_radians * 2,
                ],
                dtype=np.float32,
            )[:, None, None]
            self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        @staticmethod
        def _pos_obs(full_obs):
            xpos, _xvel, thetapos, _thetavel = full_obs
            return np.array([xpos, thetapos])[:, None, None]

        def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
            full_obs, info = super().reset(seed=seed, options=options)
            return CartPoleNoVelEnv._pos_obs(full_obs), info

        def step(self, action):
            full_obs, rew, terminated, truncated, info = super().step(action)
            return CartPoleNoVelEnv._pos_obs(full_obs), rew, terminated, truncated, info

    gym.register(
        id="CartPoleNoVel-v0",
        entry_point=CartPoleNoVelEnv,
        max_episode_steps=500,
    )

    gym.register(
        id="CartPoleCHW-v0",
        entry_point=CartPoleCHWEnv,
        max_episode_steps=500,
    )

class CartPoleNoVelConfig(EnvConfig):

    @property
    def make(self) -> Callable[[], gym.vector.VectorEnv]:
        return partial(gym.vector.AsyncVectorEnv, env_fns=[CartPoleNoVelEnv] * self.num_envs)

class CartPoleConfig(EnvConfig):

    @property
    def make(self) -> Callable[[], gym.vector.VectorEnv]:
        return partial(gym.vector.AsyncVectorEnv, env_fns=[CartPoleCHWEnv] * self.num_envs)


In [4]:
import wandb

def train_cartpole_no_vel(policy="resnet", env="cartpole"):
    if policy == "resnet":
        net = GuezResNetConfig(
            channels=(64,),
            strides=(1,),
            kernel_sizes=(1,),
            mlp_hiddens=(64,),
        )
    else:
        net = ConvLSTMConfig(
            embed=[],
            recurrent=[ConvConfig(64, (1, 1), (1, 1), "SAME", True)],
            repeats_per_step=1,
            pool_and_inject=False,
            add_one_to_forget=True,
        )
    NUM_ENVS = 64
    if env == "cartpole":
        env_cfg = CartPoleConfig(num_envs=NUM_ENVS, max_episode_steps=500)
    else:
        env_cfg = CartPoleNoVelConfig(num_envs=NUM_ENVS, max_episode_steps=500)

    args = Args(
        train_env=env_cfg,
        eval_envs=dict(eval0=EvalConfig(env_cfg, steps_to_think=[0])),
        net=net,
        eval_frequency=100,
        save_model=False,
        log_frequency=100,
        local_num_envs=NUM_ENVS,
        num_actor_threads=1,
        num_minibatches=4,
        # If the whole thing deadlocks exit in some small multiple of 10 seconds
        queue_timeout=60,
        train_epochs=1,
        learning_rate=0.0001,
        total_timesteps=1000000,
        max_grad_norm=1e-3,
        base_fan_in=1,
        optimizer="adam",
    )

    tmpdir = tempfile.TemporaryDirectory()
    tmpdir_path = Path(tmpdir.name)

    # args.total_timesteps = args.num_steps * args.num_actor_threads * args.local_num_envs * args.eval_frequency
    # assert args.total_timesteps < 20

    # writer = CheckingWriter(
    #     args, tmpdir_path, ["eval0/00_episode_successes", "eval0/01_episode_successes", "eval1/02_episode_successes"]
    # )

    os.environ["WANDB_ENTITY"] = "farai"
    os.environ["WANDB_JOB_NAME"] = "cartpole_vel" if env == "cartpole" else "cartpole_no_vel"
    os.environ["WANDB_PROJECT"] = "lp-cleanba"
    os.environ["WANDB_RUN_GROUP"] = "cartpole_vel_grp" if env == "cartpole" else "cartpole_no_vel_grp"
    writer = WandbWriter(args)
    train(args, writer=writer)
    print("Done training")
    wandb.finish()
    return writer

writer = train_cartpole_no_vel()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtaufeeque9[0m ([33mfarai[0m). Use [1m`wandb login --relogin`[0m to force relogin


device_thread_id=0, SPS=5769.46, global_step=12800, avg_episode_returns=21.52, avg_episode_length=21.52, avg_rollout_time=0.09033
Evaluating  eval0
12800 actor_policy_version=8, actor_update=10, learner_policy_version=10, training time: 5.4152562618255615s
device_thread_id=0, SPS=3266.82, global_step=25600, avg_episode_returns=22.94, avg_episode_length=22.94, avg_rollout_time=0.01953
Evaluating  eval0
25600 actor_policy_version=18, actor_update=20, learner_policy_version=20, training time: 4.834189176559448s
device_thread_id=0, SPS=2984.60, global_step=38400, avg_episode_returns=22.72, avg_episode_length=22.72, avg_rollout_time=0.01871
Evaluating  eval0
38400 actor_policy_version=28, actor_update=30, learner_policy_version=30, training time: 4.80699896812439s
device_thread_id=0, SPS=2868.34, global_step=51200, avg_episode_returns=22.39, avg_episode_length=22.39, avg_rollout_time=0.01694
Evaluating  eval0
51200 actor_policy_version=38, actor_update=40, learner_policy_version=40, trainin

In [None]:
import wandb
wandb.finish()