Skip to content

Commit

Permalink
env_id consistency in tests (#1224)
Browse files Browse the repository at this point in the history
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed Dec 20, 2022
1 parent 7fb8336 commit 96b1a7c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/guide/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
"env_id": "CartPole-v1",
}
run = wandb.init(
project="sb3",
Expand All @@ -32,7 +32,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
# save_code=True, # optional
)
model = PPO(config["policy_type"], config["env_name"], verbose=1, tensorboard_log=f"runs/{run.id}")
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
Expand Down
28 changes: 14 additions & 14 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,26 @@
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize


def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v1"
else:
return "Pendulum-v1"


@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
def test_callbacks(tmp_path, model_class):
log_folder = tmp_path / "logs/callbacks/"

# DQN only support discrete actions
env_name = select_env(model_class)
env_id = select_env(model_class)
# Create RL model
# Small network for fast test
model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32]))
model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]))

checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)

eval_env = gym.make(env_name)
eval_env = gym.make(env_id)
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)

Expand Down Expand Up @@ -82,7 +89,7 @@ def test_callbacks(tmp_path, model_class):
n_envs = 2
# Pendulum-v1 has a timelimit of 200 timesteps
max_episode_length = 200
envs = make_vec_env(env_name, n_envs=n_envs, seed=0)
envs = make_vec_env(env_id, n_envs=n_envs, seed=0)

model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32]))

Expand All @@ -100,13 +107,6 @@ def test_callbacks(tmp_path, model_class):
shutil.rmtree(log_folder)


def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v1"
else:
return "Pendulum-v1"


def test_eval_callback_vec_env():
# tests that eval callback does not crash when given a vector
n_eval_envs = 3
Expand Down Expand Up @@ -153,17 +153,17 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
pytest.importorskip("tensorboard")
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

env_name = select_env(DQN)
env_id = select_env(DQN)
model = DQN(
"MlpPolicy",
env_name,
env_id,
policy_kwargs=dict(net_arch=[32]),
tensorboard_log=tmp_path,
verbose=1,
seed=1,
)

eval_env = gym.make(env_name)
eval_env = gym.make(env_id)
eval_freq = 101
eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False)
model.learn(500, callback=eval_callback)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def test_auto_wrap(model_class):
"""Test auto wrapping of env into a VecEnv."""
# Use different environment for DQN
if model_class is DQN:
env_name = "CartPole-v1"
env_id = "CartPole-v1"
else:
env_name = "Pendulum-v1"
env = gym.make(env_name)
env_id = "Pendulum-v1"
env = gym.make(env_id)
model = model_class("MlpPolicy", env)
model.learn(100)

Expand Down

0 comments on commit 96b1a7c

Please sign in to comment.