Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update learning_stable_baselines3.py example #581

Merged
merged 2 commits into from
Dec 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

#####################################################################
# Example script of training agents with stable-baselines3
# on ViZDoom using the Gym API
# on ViZDoom using the Gymnasium API
#
# Note: ViZDoom must be installed with optional gym dependencies:
# pip install vizdoom[gym]
# You also need stable-baselines3:
# pip install stable-baselines3
# Note: For this example to work, you need to install stable-baselines3 and opencv:
# pip install stable-baselines3 opencv-python
#
# See more stable-baselines3 documentation here:
# https://stable-baselines3.readthedocs.io/en/master/index.html
Expand All @@ -16,20 +14,16 @@
from argparse import ArgumentParser

import cv2
import gym
import gymnasium
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

import vizdoom.gym_wrapper # noqa
import vizdoom.gymnasium_wrapper # noqa


DEFAULT_ENV = "VizdoomBasic-v0"
AVAILABLE_ENVS = [
env
for env in [env_spec.id for env_spec in gym.envs.registry.all()]
if "Vizdoom" in env
]
AVAILABLE_ENVS = [env for env in gymnasium.envs.registry.keys() if "Vizdoom" in env]
# Height and width of the resized image
IMAGE_SHAPE = (60, 80)

Expand All @@ -40,7 +34,7 @@
FRAME_SKIP = 4


class ObservationWrapper(gym.ObservationWrapper):
class ObservationWrapper(gymnasium.ObservationWrapper):
"""
ViZDoom environments return dictionaries as observations, containing
the main image as well other info.
Expand All @@ -62,12 +56,15 @@ def __init__(self, env, shape=IMAGE_SHAPE):
self.env.frame_skip = FRAME_SKIP

# Create new observation space with the new shape
num_channels = env.observation_space["rgb"].shape[-1]
print(env.observation_space)
num_channels = env.observation_space["screen"].shape[-1]
new_shape = (shape[0], shape[1], num_channels)
self.observation_space = gym.spaces.Box(0, 255, shape=new_shape, dtype=np.uint8)
self.observation_space = gymnasium.spaces.Box(
0, 255, shape=new_shape, dtype=np.uint8
)

def observation(self, observation):
observation = cv2.resize(observation["rgb"], self.image_shape_reverse)
observation = cv2.resize(observation["screen"], self.image_shape_reverse)
return observation


Expand All @@ -79,7 +76,7 @@ def main(args):
# This may lead to unstable learning, and we scale the rewards by 1/100
def wrap_env(env):
env = ObservationWrapper(env)
env = gym.wrappers.TransformReward(env, lambda r: r * 0.01)
env = gymnasium.wrappers.TransformReward(env, lambda r: r * 0.01)
return env

envs = make_vec_env(args.env, n_envs=N_ENVS, wrapper_class=wrap_env)
Expand Down
Loading