In [1]:
print()

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
import ale_py

In [4]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
import torch
import numpy as np
import wandb
from wandb.integration.sb3 import WandbCallback

from gymnasium.wrappers import MaxAndSkipObservation, ResizeObservation, GrayscaleObservation, FrameStackObservation, ReshapeObservation

In [5]:
from stable_baselines3.common.monitor import Monitor

In [6]:
import matplotlib.pyplot as plt

In [7]:
import os

In [8]:
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback

In [9]:
gym.register_envs(ale_py)

In [10]:
from datetime import datetime
from stable_baselines3 import A2C
from stable_baselines3.ppo.policies import MlpPolicy
from wandb.integration.sb3 import WandbCallback

In [11]:
# ENV_NAME = 'ALE/MarioBros-v5'

# configuration file
config = {
    "policy_type": "CnnPolicy",
    "total_timesteps": 1000000,
    "env_name": "ALE/MarioBros-v5",
    "model_name": "A2C_MarioBros",
    "export_path": "./exports/",
    "videos_path": "./videos/",
}

In [12]:
# Wandb setup
run = wandb.init(
    project="Project_1",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    save_code=True,  # optional
)

In [13]:

class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0

class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super().__init__(env)
        # Check that 'FIRE' is a valid action in the environment
        assert 'FIRE' in env.unwrapped.get_action_meanings(), "Environment does not support 'FIRE' action"
        assert len(env.unwrapped.get_action_meanings()) >= 3, "Action space too small for expected actions"

    def step(self, action):
        return self.env.step(action)

    def reset(self, **kwargs):
        # Reset the environment
        obs, info = self.env.reset(**kwargs)

        # Perform the FIRE action
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:  # If game ends after FIRE, reset again
            obs, info = self.env.reset(**kwargs)

        return obs, info


def make_env(env_name, render_mode=None):
    def _init():
        env = gym.make(env_name, render_mode=render_mode)
        print("Standard Env.        : {}".format(env.observation_space.shape))
        env = FireResetEnv(env)
        print("FireResetEnv          : {}".format(env.observation_space.shape))
        env = ResizeObservation(env, (84, 84))
        print("ResizeObservation    : {}".format(env.observation_space.shape))
        env = GrayscaleObservation(env, keep_dim=True)
        print("GrayscaleObservation : {}".format(env.observation_space.shape))
        env = ScaledFloatFrame(env)
        print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
        env = Monitor(env, allow_early_resets=False) # from stable baselines
        print("Monitor               : {}".format(env.observation_space.shape))

        return env
    return _init



env = DummyVecEnv([make_env(config["env_name"])])
# stack 4 frames
env = VecFrameStack(env, n_stack=4)
print("Post VecFrameStack Shape: {}".format(env.observation_space.shape))

# convert back to PyTorch format (channel-first)
env = VecTransposeImage(env)
print("Final Observation Space: {}".format(env.observation_space.shape))

In [14]:
# define
model = A2C(config["policy_type"], env, verbose=0, tensorboard_log=f"runs/{run.id}", device="cuda")

# train
t0 = datetime.now()
model.learn(total_timesteps=config["total_timesteps"], callback=WandbCallback(verbose=2))
t1 = datetime.now()
print('>>> Training time (hh:mm:ss.ms): {}'.format(t1-t0))

# save and export model
model.save(config["export_path"] + config["model_name"])

In [15]:
wandb.finish()