In [20]:
import warnings
warnings.filterwarnings('ignore')
import ale_py
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
import torch
import numpy as np
import wandb
from wandb.integration.sb3 import WandbCallback

from gymnasium.wrappers import MaxAndSkipObservation, ResizeObservation, GrayscaleObservation, FrameStackObservation, ReshapeObservation
from stable_baselines3.common.monitor import Monitor
import matplotlib.pyplot as plt
import os
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
gym.register_envs(ale_py)
from datetime import datetime
from stable_baselines3 import A2C
from stable_baselines3.ppo.policies import MlpPolicy
from wandb.integration.sb3 import WandbCallback

In [21]:
# ENV_NAME = 'ALE/MarioBros-v5'

# configuration file
config = {
    "policy_type": "CnnPolicy",
    "total_timesteps": 1000000,
    "env_name": "ALE/Breakout-v5", # ALE/MarioBros-v5
    "model_name": "A2C_MarioBros",
    "export_path": "./exports/",
    "videos_path": "./videos/",
}

In [23]:

class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0

class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super().__init__(env)
        # Check that 'FIRE' is a valid action in the environment
        assert 'FIRE' in env.unwrapped.get_action_meanings(), "Environment does not support 'FIRE' action"
        assert len(env.unwrapped.get_action_meanings()) >= 3, "Action space too small for expected actions"

    def step(self, action):
        return self.env.step(action)

    def reset(self, **kwargs):
        # Reset the environment
        obs, info = self.env.reset(**kwargs)

        # Perform the FIRE action
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:  # If game ends after FIRE, reset again
            obs, info = self.env.reset(**kwargs)

        return obs, info


def make_env(env_name, render_mode=None):
    def _init():
        env = gym.make(env_name, render_mode=render_mode)
        print("Standard Env.        : {}".format(env.observation_space.shape))
        env = FireResetEnv(env)
        print("FireResetEnv          : {}".format(env.observation_space.shape))
        env = ResizeObservation(env, (84, 84))
        print("ResizeObservation    : {}".format(env.observation_space.shape))
        env = GrayscaleObservation(env, keep_dim=True)
        print("GrayscaleObservation : {}".format(env.observation_space.shape))
        env = ScaledFloatFrame(env)
        print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
        env = Monitor(env, allow_early_resets=True) # from stable baselines
        print("Monitor               : {}".format(env.observation_space.shape))

        return env
    return _init


# env = DummyVecEnv([make_env(config["env_name"])])
# # stack 4 frames
# env = VecFrameStack(env, n_stack=4)
# print("Post VecFrameStack Shape: {}".format(env.observation_space.shape))

# # convert back to PyTorch format (channel-first)
# env = VecTransposeImage(env)
# print("Final Observation Space: {}".format(env.observation_space.shape))

In [24]:
env = DummyVecEnv([make_env(config["env_name"], render_mode="rgb_array")])
# stack 4 frames
env = VecFrameStack(env, n_stack=4)
print("Post VecFrameStack Shape: {}".format(env.observation_space.shape))

# convert back to PyTorch format (channel-first)
env = VecTransposeImage(env)
print("Final Observation Space: {}".format(env.observation_space.shape))

print("Render mode after wrapping:", env.render_mode)

Standard Env.        : (210, 160, 3)
FireResetEnv          : (210, 160, 3)
ResizeObservation    : (84, 84, 3)
GrayscaleObservation : (84, 84, 1)
ScaledFloatFrame     : (84, 84, 1)
Monitor               : (84, 84, 1)
Post VecFrameStack Shape: (84, 84, 4)
Final Observation Space: (4, 84, 84)
Render mode after wrapping: rgb_array


In [25]:
# load model
model = A2C.load("Breakout-v5")

In [30]:
import imageio
from PIL import Image
import PIL.ImageDraw as ImageDraw

rewards_glb = []
num_episodes = 2

for i in range(num_episodes):
    frames = []
    rewards_episode = []
    done = False
    obs = env.reset()

    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        # done = terminated or truncated
        rewards_episode.append(reward)

        frames.append(env.render())

    rewards_glb.append(sum(rewards_episode))
    # e.g. fps=50 == duration=20 (1000 * 1/50)
    imageio.mimwrite("model_name" +'_'+ str(i) +'.gif', frames, duration=20)

print("Rewards:", rewards_glb)

Rewards: [array([3.], dtype=float32), array([0.], dtype=float32)]


In [None]:
# make code for elimating ok.txt
# os.remove("ok.txt")