In [3]:
%pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.7.1-py3-none-any.whl.metadata (4.8 kB)
Collecting matplotlib (from stable-baselines3[extra])
  Downloading matplotlib-3.10.8-cp311-cp311-win_amd64.whl.metadata (52 kB)
Collecting opencv-python (from stable-baselines3[extra])
  Using cached opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl.metadata (19 kB)
Collecting pygame (from stable-baselines3[extra])
  Downloading pygame-2.6.1-cp311-cp311-win_amd64.whl.metadata (13 kB)
Collecting tensorboard>=2.9.1 (from stable-baselines3[extra])
  Downloading tensorboard-2.20.0-py3-none-any.whl.metadata (1.8 kB)
Collecting rich (from stable-baselines3[extra])
  Downloading rich-14.2.0-py3-none-any.whl.metadata (18 kB)
Collecting ale-py>=0.9.0 (from stable-baselines3[extra])
  Downloading ale_py-0.11.2-cp311-cp311-win_amd64.whl.metadata (9.2 kB)
Collecting absl-py>=0.4 (from tensorboard>=2.9.1->stable-baselines3[extra])
  Downloading absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)

  You can safely remove it manually.
  You can safely remove it manually.


In [1]:
from pyboy import PyBoy
pyboy = PyBoy('ROMs/Pokemon_Blue.gb', window='SDL2')
pyboy.set_emulation_speed(2)
for step in range(1000):
    pyboy.tick()
    pass
pyboy.button('START')
for step in range(1000):
    pyboy.tick()
    pyboy.button('a')
    pass
pyboy.stop()




In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from pyboy import PyBoy

class PyBoyPokemonEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"]}

    def __init__(self, rom_path, render_mode="rgb_array"):
        super().__init__()
        self.p_p_action = None
        self.p_action = None
        self.pyboy = PyBoy(
            rom_path,
            window="null" if render_mode != "human" else "SDL2",
            debug=False
        )
        with open("roms/state_file.state", "rb") as f:
            self.pyboy.load_state(f)
        self.pyboy.set_emulation_speed(0)
        # Action space: Up, Down, Left, Right, A, B, Start, Select
        self.buttons = [
            "up", "down", "left", "right",
            "a", "b", "start", "select"
        ]
        self.action_space = spaces.Discrete(len(self.buttons))

        # Observation: raw Game Boy screen (160Ã—144 RGB
        frame = self.pyboy.screen.ndarray 
        H, W, C = 144, 160, 3  # after dropping alpha, C=3
        self.observation_space = spaces.Box(low=0, high=255, shape=(H, W, C), dtype=np.uint8)
        self.render_mode = render_mode

    def step(self, action):
        # Press selected button for a few frames
        screen = self.pyboy.game_area()

        button = self.buttons[action]
        self.pyboy.button_press(button)
        self.pyboy.tick(4)
        self.pyboy.button_release(button)
        self.pyboy.tick(16)
        
        # TODO: Define a meaningful reward function
        # if some change on the screen:
        
        try:
            test = np.abs(np.sum(screen) - np.sum(self.pyboy.game_area())) > 1000
        except OverflowError:
            test = True

        if test:
            reward = 1.0  # reward for screen change
        elif action == self.p_p_action and action == self.p_action:
                reward = -1.0  # small reward for repeating the same action
        else:
            reward = -0.1  # default reward

        terminated = False
        truncated = False
        self.p_p_action = self.p_action
        self.p_action = action
        return self._get_frame(), reward, terminated, truncated, {}
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # Reset emulator
        obs = self._get_frame()
        return obs, {}

    def render(self):
        if self.render_mode == "human":
            pass  # PyBoy manages its own SDL window

    def _get_frame(self):
        frame = self.pyboy.screen.ndarray[..., :3]  # keep only RGB
        return frame.astype(np.uint8)



In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage

env = PyBoyPokemonEnv("roms/Pokemon_Blue.gb", render_mode="human")
# vec_env = DummyVecEnv([lambda: env])
def make_env():
    return PyBoyPokemonEnv("roms/Pokemon_Blue.gb", render_mode="rgb_array")
env = DummyVecEnv([make_env for _ in range(4)])
# env = VecTransposeImage(env)

model = PPO(
    "CnnPolicy",
    env,
    verbose=1,
    tensorboard_log="./ppo_pokemon/"
)

model.learn(total_timesteps=10000)
model.save("ppo_pokemon_blue")
pyboy.stop()


Using cuda device
Wrapping the env in a VecTransposeImage.
Logging to ./ppo_pokemon/PPO_13


  if np.abs(np.sum(screen) - np.sum(self.pyboy.game_area())) > 1000:


-----------------------------
| time/              |      |
|    fps             | 264  |
|    iterations      | 1    |
|    time_elapsed    | 30   |
|    total_timesteps | 8192 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 205         |
|    iterations           | 2           |
|    time_elapsed         | 79          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.017075565 |
|    clip_fraction        | 0.271       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.06       |
|    explained_variance   | -0.0527     |
|    learning_rate        | 0.0003      |
|    loss                 | 1.79        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0282     |
|    value_loss           | 3.44        |
-----------------------------------------


In [4]:
# Run a single environment for visualization
from pyboy import PyBoy
from stable_baselines3 import PPO

model = PPO.load("./ppo_pokemon_blue.zip") 
visual_env = PyBoyPokemonEnv("roms/Pokemon_Blue.gb", render_mode="human")
obs, info = visual_env.reset()
print(f"Starting....")
for _ in range(10000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = visual_env.step(action)
    visual_env.render()  # SDL2 window updates
    if terminated or truncated:
        obs, info = visual_env.reset()

visual_env.pyboy.stop()


Starting....
