Skip to content

Commit

Permalink
Fixed gen of traces with non-image vecenv (#913) (#914)
Browse files Browse the repository at this point in the history
* Only obs of the first env is added to the list when using vecenv without images (#913)

* Fixed gen of traces with non-image vecenv (#913)

* Fixed gen of traces with non-image vecenv (#913)

* Fixed gen of traces with non-image vecenv (#913)

* Added vecenv non img expert traj test (#913)
  • Loading branch information
jbarsce committed Jul 1, 2020
1 parent b3b217c commit 36df0ef
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Bug Fixes:
- Fixed ``seed()``` method for ``SubprocVecEnv``
- Fixed a bug ``callback.locals`` did not have the correct values (@PartiallyTyped)
- Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)
- Fixed a bug in the ``generate_expert_traj()`` method in ``record_expert.py`` when using a non-image vectorized environment (@jbarsce)


Deprecations:
Expand Down Expand Up @@ -714,4 +715,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @tirafesi @caburu @johannes-dornheim @kvenkman @aakash94
@enderdead @hardmaru
@enderdead @hardmaru @jbarsce
4 changes: 2 additions & 2 deletions stable_baselines/gail/dataset/record_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ def generate_expert_traj(model, save_path=None, env=None, n_timesteps=0,
mask = [True for _ in range(env.num_envs)]

while ep_idx < n_episodes:
obs_ = obs[0] if is_vec_env else obs
if record_images:
image_path = os.path.join(image_folder, "{}.{}".format(idx, image_ext))
obs_ = obs[0] if is_vec_env else obs
# Convert from RGB to BGR
# which is the format OpenCV expect
if obs_.shape[-1] == 3:
obs_ = cv2.cvtColor(obs_, cv2.COLOR_RGB2BGR)
cv2.imwrite(image_path, obs_)
observations.append(image_path)
else:
observations.append(obs)
observations.append(obs_)

if isinstance(model, BaseRLModel):
action, state = model.predict(obs, state=state, mask=mask)
Expand Down
11 changes: 10 additions & 1 deletion tests/test_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from stable_baselines import (A2C, ACER, ACKTR, GAIL, DDPG, DQN, PPO1, PPO2,
TD3, TRPO, SAC)
from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines.common.vec_env import VecFrameStack, DummyVecEnv
from stable_baselines.common.evaluation import evaluate_policy
from stable_baselines.common.callbacks import CheckpointCallback
from stable_baselines.gail import ExpertDataset, generate_expert_traj
Expand Down Expand Up @@ -148,3 +148,12 @@ def test_dataset_param_validation():
traj_data = np.load(EXPERT_PATH_PENDULUM)
with pytest.raises(ValueError):
ExpertDataset(traj_data=traj_data, expert_path=EXPERT_PATH_PENDULUM)


def test_generate_vec_env_non_image_observation():
env = DummyVecEnv([lambda: gym.make('CartPole-v1')] * 2)

model = PPO2('MlpPolicy', env)
model.learn(total_timesteps=5000)

generate_expert_traj(model, save_path='.', n_timesteps=0, n_episodes=5)

0 comments on commit 36df0ef

Please sign in to comment.