Skip to content

Commit

Permalink
Add VecVideoRecorder to record videos (#82)
Browse files Browse the repository at this point in the history
* Add VecVideoRecorder

* [ci skip] Update doc + potential bug fix in recorder

* [ci skip] Style fixes
  • Loading branch information
araffin committed Nov 18, 2018
1 parent 9f36c9a commit d36f1df
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 1 deletion.
34 changes: 34 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,40 @@ You can also move from learning on one environment to another for `continual lea
env.render()
Record a Video
--------------

Record a mp4 video (here using a random agent).

.. note::

It requires ffmpeg or avconv to be installed on the machine.

.. code-block:: python
import gym
from stable_baselines.common.vec_env import VecVideoRecorder, DummyVecEnv
env_id = 'CartPole-v1'
video_folder = 'logs/videos/'
video_length = 100
env = DummyVecEnv([lambda: gym.make(env_id)])
obs = env.reset()
# Record the video starting at the first step
env = VecVideoRecorder(env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix="random-agent-{}".format(env_id))
env.reset()
for _ in range(video_length + 1):
action = [env.action_space.sample()]
obs, _, _, _ = env.step(action)
env.close()
Bonus: Make a GIF of a Trained Agent
------------------------------------

Expand Down
7 changes: 7 additions & 0 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@ VecNormalize

.. autoclass:: VecNormalize
:members:


VecVideoRecorder
~~~~~~~~~~~~

.. autoclass:: VecVideoRecorder
:members:
6 changes: 6 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Changelog

For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.

Pre-Release 2.2.1a (WIP)
--------------------------

- added VecVideoRecorder to record mp4 videos from environment.


Release 2.2.0 (2018-11-07)
--------------------------

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from stable_baselines.ppo2 import PPO2
from stable_baselines.trpo_mpi import TRPO

__version__ = "2.2.0"
__version__ = "2.2.1a"


# patch Gym spaces to add equality functions, if not implemented
Expand Down
1 change: 1 addition & 0 deletions stable_baselines/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines.common.vec_env.vec_normalize import VecNormalize
from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
108 changes: 108 additions & 0 deletions stable_baselines/common/vec_env/vec_video_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os

from gym.wrappers.monitoring import video_recorder

from stable_baselines import logger
from stable_baselines.common.vec_env import VecEnvWrapper, DummyVecEnv, VecNormalize, VecFrameStack, SubprocVecEnv


class VecVideoRecorder(VecEnvWrapper):

def __init__(self, venv, video_folder, record_video_trigger,
video_length=200, name_prefix='rl-video'):
"""
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
It requires ffmpeg or avconv to be installed on the machine.
:param venv: (VecEnv or VecEnvWrapper)
:param video_folder: (str) Where to save videos
:param record_video_trigger: (func) Function that defines when to start recording.
The function takes the current number of step,
and returns whether we should start recording or not.
:param video_length: (int) Length of recorded videos
:param name_prefix: (str) Prefix to the video name
"""

VecEnvWrapper.__init__(self, venv)

self.env = venv
# Temp variable to retrieve metadata
temp_env = venv

# Unwrap to retrieve metadata dict
# that will be used by gym recorder
while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack):
temp_env = temp_env.venv

if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr('metadata')[0]
else:
metadata = temp_env.metadata

self.env.metadata = metadata

self.record_video_trigger = record_video_trigger
self.video_recorder = None

self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
os.makedirs(self.video_folder, exist_ok=True)

self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length

self.recording = False
self.recorded_frames = 0

def reset(self):
obs = self.venv.reset()
self.start_video_recorder()
return obs

def start_video_recorder(self):
self.close_video_recorder()

video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id,
self.step_id + self.video_length)
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env,
base_path=base_path,
metadata={'step_id': self.step_id}
)

self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True

def _video_enabled(self):
return self.record_video_trigger(self.step_id)

def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()

self.step_id += 1
if self.recording:
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.recorded_frames > self.video_length:
logger.info("Saving video to ", self.video_recorder.path)
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()

return obs, rews, dones, infos

def close_video_recorder(self):
if self.recording:
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1

def close(self):
VecEnvWrapper.close(self)
self.close_video_recorder()

def __del__(self):
self.close()

0 comments on commit d36f1df

Please sign in to comment.