In [1]:
%%bash

pip install gymnasium==0.29.1
pip install gymnasium[atari]
pip install gymnasium[accept-rom-license]



In [2]:
import gymnasium as gym; print(f'Gymnasium v{gym.__version__}')

Gymnasium v0.29.1


In [3]:
%%bash

pip install "stable-baselines3[extra]>=2.0.0a4"
pip install sb3-contrib

Collecting ale-py>=0.9.0 (from stable-baselines3[extra]>=2.0.0a4)
  Using cached ale_py-0.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB)
Using cached ale_py-0.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
Installing collected packages: ale-py
  Attempting uninstall: ale-py
    Found existing installation: ale-py 0.8.1
    Uninstalling ale-py-0.8.1:
      Successfully uninstalled ale-py-0.8.1
Successfully installed ale-py-0.10.1


In [4]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces

In [5]:
class NotebookAtariEnvWrapper(gym.Env):
    '''
    A wrapper that modifies an environment to contain an external memory for
    the agent to store and access information. Observation space and action
    space is expanded to allow these actions.
    '''
    def __init__(self, env, notebook_size=8):
        self.env = env

        if type(env.action_space) != spaces.Discrete:
            raise ValueError("Only Discrete action spaces are supported for now")

        self.notebook_size = notebook_size
        self.notebook = np.zeros(self.notebook_size, dtype=np.int32)

        self.observation_space = spaces.Dict( {
            "env": self.env.observation_space,  # Original Observation Space
            "notebook": spaces.MultiDiscrete([2] * self.notebook_size)
        })
        self.action_space = spaces.MultiDiscrete([env.action_space.n, self.notebook_size, 2])

    def reset(self, seed=None, options=None):
        env_obs = self.env.reset()
        self.notebook = np.zeros(self.notebook_size, dtype=np.int32)
        obs = {
            "env": env_obs,  # Observations from the original environment
            "notebook": self.notebook
        }
        return obs, {}

    def step(self, action):
        env_action, notebook_index, notebook_value = action
        # Pass the action as array to match VecEnv structure
        # Not an ideal solution, might need to change this if not using
        env_obs, reward, done, info = self.env.step([env_action])
        env_obs = env_obs[0]
        info = info[0]
        self.notebook[notebook_index] = notebook_value
        obs = {
            "env": env_obs,  # Observations from the original environment
            "notebook": self.notebook  # Current notebook state
        }
        return obs, reward, done, False, info

    def render(self):
        return self.env.render()

    def close(self):
        return self.env.close()

    def seed(self, seed):
        return self.env.seed(seed)

    def __getattr__(self, attr):
        return getattr(self.env, attr)

    def __str__(self):
        return str(self.env)

    def __repr__(self):
        return repr(self.env)

In [6]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3 import PPO
from sb3_contrib import RecurrentPPO

  from jax import xla_computation as _xla_computation


In [7]:
# Environment for PPO and RPPO models
orig_env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=0)

# Environment for PPO w/ Notebook models
pponb_env = NotebookAtariEnvWrapper(orig_env, notebook_size=16)

In [8]:
ppo_model = PPO("CnnPolicy", orig_env, verbose=1)
ppo_model.learn(total_timesteps=int(4096))

Using cpu device
Wrapping the env in a VecTransposeImage.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.42e+03 |
|    ep_rew_mean     | -21      |
| time/              |          |
|    fps             | 229      |
|    iterations      | 1        |
|    time_elapsed    | 8        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.41e+03    |
|    ep_rew_mean          | -21         |
| time/                   |             |
|    fps                  | 79          |
|    iterations           | 2           |
|    time_elapsed         | 51          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008180365 |
|    clip_fraction        | 0.0106      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.79       |
|    explained

<stable_baselines3.ppo.ppo.PPO at 0x7c77c54ba530>

In [12]:
rpponb_model = RecurrentPPO("CnnLstmPolicy", orig_env, verbose=1)
rpponb_model.learn(total_timesteps=int(4096))

Using cpu device
Wrapping the env in a VecTransposeImage.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.06e+03 |
|    ep_rew_mean     | -21      |
| time/              |          |
|    fps             | 142      |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 128      |
---------------------------------
--------------------------------------
| rollout/                |          |
|    ep_len_mean          | 3.06e+03 |
|    ep_rew_mean          | -21      |
| time/                   |          |
|    fps                  | 23       |
|    iterations           | 2        |
|    time_elapsed         | 10       |
|    total_timesteps      | 256      |
| train/                  |          |
|    approx_kl            | 0.00202  |
|    clip_fraction        | 0        |
|    clip_range           | 0.2      |
|    entropy_loss         | -1.79    |
|    explained_variance   | -0.00118 |
|    learning_rat

<sb3_contrib.ppo_recurrent.ppo_recurrent.RecurrentPPO at 0x7c7765d98670>

In [13]:
pponb_model = PPO("MultiInputPolicy", pponb_env, verbose=1)
policy = pponb_model.policy
# print(policy) # MultiInputPolicy automatically uses CNN as extractor for image input in a Dict observation space
pponb_model.learn(total_timesteps=int(4096))

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


  self.rewards.append(float(reward))
  obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 655      |
|    ep_rew_mean     | -15.7    |
| time/              |          |
|    fps             | 185      |
|    iterations      | 1        |
|    time_elapsed    | 11       |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 768         |
|    ep_rew_mean          | -17.4       |
| time/                   |             |
|    fps                  | 73          |
|    iterations           | 2           |
|    time_elapsed         | 55          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009858102 |
|    clip_fraction        | 0.0883      |
|    clip_range           | 0.2         |
|    entropy_loss         | -5.25       |
|    explained_variance   | -0.0231     |
|    learning_rate        | 0.

<stable_baselines3.ppo.ppo.PPO at 0x7c7765d98160>