In [1]:
import warnings
warnings.filterwarnings('ignore')
import ale_py
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
import torch
import numpy as np
import wandb
from wandb.integration.sb3 import WandbCallback

from gymnasium.wrappers import MaxAndSkipObservation, ResizeObservation, GrayscaleObservation, FrameStackObservation, ReshapeObservation
from stable_baselines3.common.monitor import Monitor
import matplotlib.pyplot as plt
import os
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
gym.register_envs(ale_py)
from datetime import datetime
from stable_baselines3 import A2C
from stable_baselines3.ppo.policies import MlpPolicy
from wandb.integration.sb3 import WandbCallback
import collections
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecEnvWrapper

ModuleNotFoundError: No module named 'stable_baselines3'

In [None]:
# ENV_NAME = 'ALE/MarioBros-v5'

# configuration file
config = {
    "policy_type": "CnnPolicy",
    "total_timesteps": 1000000,
    "env_name": "ALE/DonkeyKong-v5", 
    "model_name": "ALE/DonkeyKong-v5",
    "export_path": "./exports/",
    "videos_path": "./videos/",
}

In [None]:
class MyVecTransposeImage(VecEnvWrapper):
    def __init__(self, venv, skip=False):
        super().__init__(venv)
        self.skip = skip

        # Get original shape: e.g., (84, 84, 4)
        old_shape = self.observation_space.shape
        # Transpose shape to (C, H, W)
        new_shape = (old_shape[2], old_shape[0], old_shape[1])  # (4, 84, 84)

        # Use the original low/high if they are uniform; if not, use min/max appropriately
        low_val = self.observation_space.low.min()
        high_val = self.observation_space.high.max()

        self.observation_space = gym.spaces.Box(
            low=low_val,
            high=high_val,
            shape=new_shape,
            dtype=self.observation_space.dtype
        )

    def reset(self):
        obs = self.venv.reset()
        return self.transpose_observations(obs)

    def step_async(self, actions):
        self.venv.step_async(actions)

    def step_wait(self):
        obs, rewards, dones, infos = self.venv.step_wait()
        return self.transpose_observations(obs), rewards, dones, infos

    def transpose_observations(self, obs):
        if self.skip:
            return obs
        if isinstance(obs, dict):
            for key, val in obs.items():
                obs[key] = self._transpose(val)
            return obs
        else:
            return self._transpose(obs)

    def _transpose(self, obs):
        # obs shape is (n_envs, H, W, C) -> transpose to (n_envs, C, H, W)
        return obs.transpose(0, 3, 1, 2)


In [None]:
class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        # The original shape remains (84,84,1), but the dtype and range change
        self.observation_space = gym.spaces.Box(
            low=0.0,
            high=1.0,
            shape=self.observation_space.shape,
            dtype=np.float32
        )
        
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


# class ScaledFloatFrame(gym.ObservationWrapper):
#     def observation(self, obs):
#         return np.array(obs).astype(np.float32) / 255.0



class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super().__init__(env)
        # Check that 'FIRE' is a valid action in the environment
        assert 'FIRE' in env.unwrapped.get_action_meanings(), "Environment does not support 'FIRE' action"
        assert len(env.unwrapped.get_action_meanings()) >= 3, "Action space too small for expected actions"

    def step(self, action):
        return self.env.step(action)

    def reset(self, **kwargs):
        # Reset the environment
        obs, info = self.env.reset(**kwargs)

        # Perform the FIRE action
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:  # If game ends after FIRE, reset again
            obs, info = self.env.reset(**kwargs)

        return obs, info
        
# Custom wrapper to add channel dimension
class AddChannelDimension(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape
        # Update the observation space to include a channel dimension
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(obs_shape[0], obs_shape[1], 1),
            dtype=np.uint8,
        )

    def observation(self, observation):
        # Add a channel dimension
        return np.expand_dims(observation, axis=-1)

# class MaxAndSkipEnv(gym.Wrapper):
#     def __init__(self, env=None, skip=4):
#         super(MaxAndSkipEnv, self).__init__(env)
#         self._obs_buffer = collections.deque(maxlen=2)
#         self._skip = skip

#     def step(self, action):
#         total_reward = 0.0
#         done = None
#         for _ in range(self._skip):
#             obs, reward, terminated,truncated, info = self.env.step(action)
#             done = terminated or truncated
#             self._obs_buffer.append(obs)
#             total_reward += reward
#             if done:
#                 break
#         max_frame = np.max(np.stack(self._obs_buffer), axis=0)
#         return max_frame, total_reward, terminated, truncated, info

    # def reset(self, *, seed=None, options=None):
    #     self._obs_buffer.clear()
    #     obs, info = self.env.reset(seed=seed, options=options)
    #     self._obs_buffer.append(obs)
    #     return obs, info


def make_env(env_name, obs_type="grayscale", render_mode=None,):
    def _init():
        env = gym.make(env_name, obs_type="grayscale", render_mode=render_mode)
        print("Standard Env.        : {}".format(env.observation_space.shape))
        env = FireResetEnv(env)
        print("FireResetEnv          : {}".format(env.observation_space.shape))
        # env = MaxAndSkipEnv(env)
        # print("MaxAndSkipEnv        : {}".format(env.observation_space.shape))
        env = ResizeObservation(env, (84, 84))
        print("ResizeObservation    : {}".format(env.observation_space.shape))
        env = AddChannelDimension(env)  # Add channel dimension here
        print("AddChannelDimension  : {}".format(env.observation_space.shape))
        # env = GrayscaleObservation(env, keep_dim=True)
        # print("GrayscaleObservation : {}".format(env.observation_space.shape))
        
        env = ScaledFloatFrame(env)
        print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
        
        # env = Monitor(env, allow_early_resets=False) # from stable baselines
        # print("Monitor               : {}".format(env.observation_space.shape))

        return env
    return _init



In [None]:
# env = DummyVecEnv([make_env(config["env_name"], render_mode="rgb_array")])
env = make_vec_env(env_id=make_env(config["env_name"], render_mode="rgb_array"), n_envs=1)
# stack 4 frames
env = VecFrameStack(env, n_stack=4)
print("Post VecFrameStack Shape: {}".format(env.observation_space.shape))

# convert back to PyTorch format (channel-first)
env = MyVecTransposeImage(env)
print("Final Observation Space: {}".format(env.observation_space.shape))

print("Render mode after wrapping:", env.render_mode)

Standard Env.        : (210, 160)
FireResetEnv          : (210, 160)
ResizeObservation    : (84, 84)
AddChannelDimension  : (84, 84, 1)
ScaledFloatFrame     : (84, 84, 1)
Post VecFrameStack Shape: (84, 84, 4)
Final Observation Space: (4, 84, 84)
Render mode after wrapping: rgb_array


In [None]:
# load model
model = A2C.load("../models/best_model")

In [None]:
import imageio
from PIL import Image
import PIL.ImageDraw as ImageDraw

rewards_glb = []
num_episodes = 2

for i in range(num_episodes):
    frames = []
    rewards_episode = []
    done = False
    obs = env.reset()

    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        # done = terminated or truncated
        rewards_episode.append(reward)

        frames.append(env.render())

    rewards_glb.append(sum(rewards_episode))
    # e.g. fps=50 == duration=20 (1000 * 1/50)
    imageio.mimwrite("model_name" +'_'+ str(i) +'.gif', frames, duration=20)

print("Rewards:", rewards_glb)

Rewards: [array([200.], dtype=float32), array([0.], dtype=float32)]


In [None]:
# make code for elimating ok.txt
# os.remove("ok.txt")