In [27]:
import stable_baselines3
import gymnasium as gym
import numpy as np

from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
#from stable_baselines.common.atari_wrappers import make_atari
from stable_baselines3.dqn import CnnPolicy
from stable_baselines3.common.evaluation import evaluate_policy

#for ploting
from stable_baselines3.common import results_plotter

hyperparams = {
    "buffer_size": 1000, #old buffer_size=100000 had memory problems
    "learning_rate": 1e-4,
    "batch_size": 32,
    "learning_starts": 100000,
    "target_update_interval": 1000,
    "train_freq": 4,
    "gradient_steps": 1,
    "exploration_fraction": 0.1,
    "exploration_final_eps": 0.01,
    "optimize_memory_usage": False
}

# atari:
#   env_wrapper:
#     - stable_baselines3.common.atari_wrappers.AtariWrapper
#   frame_stack: 4
#   policy: 'CnnPolicy'
#   n_timesteps: !!float 1e7
#   buffer_size: 100000
#   learning_rate: !!float 1e-4
#   batch_size: 32
#   learning_starts: 100000
#   target_update_interval: 1000
#   train_freq: 4
#   gradient_steps: 1
#   exploration_fraction: 0.1
#   exploration_final_eps: 0.01
#   # If True, you need to deactivate handle_timeout_termination
#   # in the replay_buffer_kwargs
#   optimize_memory_usage: False


env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=4, seed=0, env_kwargs={"render_mode": "rgb_array"})

# Get observation space information
height, width, channels = env.observation_space.shape
actions = env.action_space.n
print(actions)#nbr of actions

# Stack 4 frames HERE'S THE PROBLEM!!!!!!!!!!
#env = VecFrameStack(env, n_stack=4)
obs = env.reset()
print(obs.shape)

# Create the DQN model with the provided hyperparameters
model = DQN("CnnPolicy", env, **hyperparams, verbose=1)  # CNN neural network for images, verbose for outputs

# Train the model
model.learn(total_timesteps=10_000)

# Create an evaluation environment
eval_env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=4, seed=0, env_kwargs={"render_mode": "rgb_array"})
print(obs.shape)
# Evaluate the trained agent
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=100)

# Print the evaluation results
print(f"Mean reward: {mean_reward:.2f}, Std reward: {std_reward:.2f}")


4
(4, 84, 84, 1)
Using cpu device
Wrapping the env in a VecTransposeImage.
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.747    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 620      |
|    time_elapsed     | 0        |
|    total_timesteps  | 256      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.541    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 657      |
|    time_elapsed     | 0        |
|    total_timesteps  | 464      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 640      |
|    ep_rew_mean      | 1        |
|    exploration_rate | 0.45     |
| time/               |          |
|    episodes         | 12       |
|    fps              | 638      |
|    time_elaps

In [30]:
# ploting with Helper from the library
results_plotter.plot_results(
    [env.log_dir()], 1e5, results_plotter.X_TIMESTEPS, "DQN Breakout"
)

AttributeError: 'DummyVecEnv' object has no attribute 'log_dir'

In [None]:
# Set up fake display; otherwise rendering will fail
import os
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'

import base64
from pathlib import Path

from IPython import display as ipythondisplay


def show_videos(video_path="", prefix=""):
    """
    Taken from https://github.com/eleurent/highway-env

    :param video_path: (str) Path to the folder containing videos
    :param prefix: (str) Filter the video, showing only the only starting with this prefix
    """
    html = []
    for mp4 in Path(video_path).glob("{}*.mp4".format(prefix)):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay 
                    loop controls style="height: 400px;">
                    <source src="data:video/mp4;base64,{}" type="video/mp4" />
                </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))


from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv


def record_video(env_id, model, video_length=500, prefix="", video_folder="videos/"):
    """
    :param env_id: (str)
    :param model: (RL model)
    :param video_length: (int)
    :param prefix: (str)
    :param video_folder: (str)
    """
    eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1", render_mode="rgb_array")])
    # Start the video at step=0 and record 500 steps
    eval_env = VecVideoRecorder(
        eval_env,
        video_folder=video_folder,
        record_video_trigger=lambda step: step == 0,
        video_length=video_length,
        name_prefix=prefix,
    )

    obs = eval_env.reset()
    for _ in range(video_length):
        action, _ = model.predict(obs)
        obs, _, _, _ = eval_env.step(action)

    # Close the video recorder
    eval_env.close()

record_video("BreakoutNoFrameskip-v4", model, video_length=500, prefix="dqn-Breakout")

show_videos("videos", prefix="dqn")


ValueError: axes don't match array