In [2]:
import gymnasium as gym
import ale_py
from gymnasium.utils import play

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
import time

In [3]:
env = gym.make('ALE/Riverraid-v5', render_mode='rgb_array', max_episode_steps=1000) 

In [4]:
print('='*40)
print('Environment Spaces:\n')
print(f'Observation Space: {env.observation_space}')
print(f'Action Space:      {env.action_space} ')
print('='*40)

Environment Spaces:

Observation Space: Box(0, 255, (210, 160, 3), uint8)
Action Space:      Discrete(18) 


In [5]:
class Riverraid:
    def __init__(self, env):
        self.env = env
        self.time_alive = 0
        self.action_space = env.action_space
        self.observation_space = env.observation_space
        self.state, _ = self.env.reset()  # Handle info

    def step(self, action):
        self.state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        
        if not done:
            self.time_alive += 1
            if self.time_alive >= 10:
                reward += 1
        else:
            self.time_alive = 0
        
        return self.state, reward, done, info  

    def reset(self):
        self.time_alive = 0
        self.state, _ = self.env.reset()
        return self.state

In [6]:
model = PPO('CnnPolicy', env, verbose=1)
model.learn(total_timesteps=120000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 716      |
|    ep_rew_mean     | 1.5e+03  |
| time/              |          |
|    fps             | 120      |
|    iterations      | 1        |
|    time_elapsed    | 17       |
|    total_timesteps | 2048     |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 636        |
|    ep_rew_mean          | 1.11e+03   |
| time/                   |            |
|    fps                  | 35         |
|    iterations           | 2          |
|    time_elapsed         | 115        |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.32946146 |
|    clip_fraction        | 0.669      |
|    clip_range           | 0.2  

<stable_baselines3.ppo.ppo.PPO at 0x117fdb127d0>

In [7]:
# Save the model
model.save('ppo_riverraid')