Skip to content

Commit

Permalink
Update learning_stable_baselines3.py example
Browse files Browse the repository at this point in the history
  • Loading branch information
mwydmuch committed Dec 30, 2023
1 parent b31b75a commit 499b5e5
Showing 1 changed file with 12 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

#####################################################################
# Example script of training agents with stable-baselines3
# on ViZDoom using the Gym API
# on ViZDoom using the Gymnasium API
#
# Note: ViZDoom must be installed with optional gym dependencies:
# pip install vizdoom[gym]
# You also need stable-baselines3:
# pip install stable-baselines3
# Note: For this example to work, you need to install stable-baselines3 and opencv:
# pip install stable-baselines3 opencv-python
#
# See more stable-baselines3 documentation here:
# https://stable-baselines3.readthedocs.io/en/master/index.html
Expand All @@ -16,20 +14,16 @@
from argparse import ArgumentParser

import cv2
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

import vizdoom.gym_wrapper # noqa
import gymnasium
import vizdoom.gymnasium_wrapper # noqa


DEFAULT_ENV = "VizdoomBasic-v0"
AVAILABLE_ENVS = [
env
for env in [env_spec.id for env_spec in gym.envs.registry.all()]
if "Vizdoom" in env
]
AVAILABLE_ENVS = [env for env in gymnasium.envs.registry.keys() if "Vizdoom" in env]
# Height and width of the resized image
IMAGE_SHAPE = (60, 80)

Expand All @@ -40,7 +34,7 @@
FRAME_SKIP = 4


class ObservationWrapper(gym.ObservationWrapper):
class ObservationWrapper(gymnasium.ObservationWrapper):
"""
ViZDoom environments return dictionaries as observations, containing
the main image as well other info.
Expand All @@ -62,12 +56,13 @@ def __init__(self, env, shape=IMAGE_SHAPE):
self.env.frame_skip = FRAME_SKIP

# Create new observation space with the new shape
num_channels = env.observation_space["rgb"].shape[-1]
print(env.observation_space)
num_channels = env.observation_space["screen"].shape[-1]
new_shape = (shape[0], shape[1], num_channels)
self.observation_space = gym.spaces.Box(0, 255, shape=new_shape, dtype=np.uint8)
self.observation_space = gymnasium.spaces.Box(0, 255, shape=new_shape, dtype=np.uint8)

def observation(self, observation):
observation = cv2.resize(observation["rgb"], self.image_shape_reverse)
observation = cv2.resize(observation["screen"], self.image_shape_reverse)
return observation


Expand All @@ -79,7 +74,7 @@ def main(args):
# This may lead to unstable learning, and we scale the rewards by 1/100
def wrap_env(env):
env = ObservationWrapper(env)
env = gym.wrappers.TransformReward(env, lambda r: r * 0.01)
env = gymnasium.wrappers.TransformReward(env, lambda r: r * 0.01)
return env

envs = make_vec_env(args.env, n_envs=N_ENVS, wrapper_class=wrap_env)
Expand Down

0 comments on commit 499b5e5

Please sign in to comment.