Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
bug-fix for dumping movies (+ small refactoring and rename 'VideoDump…
Browse files Browse the repository at this point in the history
…Method -> 'VideoDumpFilter')
  • Loading branch information
Gal Leibovich committed Oct 21, 2018
1 parent 3641684 commit 5a8da90
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 90 deletions.
11 changes: 6 additions & 5 deletions rl_coach/base_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from enum import Enum
from typing import Dict, List, Union

from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase
# from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
SelectedPhaseOnlyDumpFilter, MaxDumpFilter
from rl_coach.filters.filter import NoInputFilter


Expand Down Expand Up @@ -285,15 +285,14 @@ def __init__(self, dense_layer):
self.dense_layer = dense_layer



class VisualizationParameters(Parameters):
def __init__(self,
print_networks_summary=False,
dump_csv=True,
dump_signals_to_csv_every_x_episodes=5,
dump_gifs=False,
dump_mp4=False,
video_dump_methods=[],
video_dump_methods=None,
dump_in_episode_signals=False,
dump_parameters_documentation=True,
render=False,
Expand Down Expand Up @@ -352,6 +351,8 @@ def __init__(self,
which will be passed to the agent and allow using those images.
"""
super().__init__()
if video_dump_methods is None:
video_dump_methods = [SelectedPhaseOnlyDumpFilter(RunPhase.TEST), MaxDumpFilter()]
self.print_networks_summary = print_networks_summary
self.dump_csv = dump_csv
self.dump_gifs = dump_gifs
Expand All @@ -363,7 +364,7 @@ def __init__(self,
self.native_rendering = native_rendering
self.max_fps_for_human_control = max_fps_for_human_control
self.tensorboard = tensorboard
self.video_dump_methods = video_dump_methods
self.video_dump_filters = video_dump_methods
self.add_rendered_image_to_env_response = add_rendered_image_to_env_response


Expand Down
78 changes: 78 additions & 0 deletions rl_coach/core_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import numpy as np

from rl_coach.utils import force_list

ActionType = Union[int, float, np.ndarray, List]
GoalType = Union[None, np.ndarray]
ObservationType = np.ndarray
Expand Down Expand Up @@ -692,3 +694,79 @@ def to_batch(self):

def __getitem__(self, sliced):
return self.transitions[sliced]


"""
Video Dumping Methods
"""


class VideoDumpFilter(object):
"""
Method used to decide when to dump videos
"""
def should_dump(self, episode_terminated=False, **kwargs):
raise NotImplementedError("")


class AlwaysDumpFilter(VideoDumpFilter):
"""
Dump video for every episode
"""
def __init__(self):
super().__init__()

def should_dump(self, episode_terminated=False, **kwargs):
return True


class MaxDumpFilter(VideoDumpFilter):
"""
Dump video every time a new max total reward has been achieved
"""
def __init__(self):
super().__init__()
self.max_reward_achieved = -np.inf

def should_dump(self, episode_terminated=False, **kwargs):
# if the episode has not finished yet we want to be prepared for dumping a video
if not episode_terminated:
return True
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
return True
else:
return False


class EveryNEpisodesDumpFilter(object):
"""
Dump videos once in every N episodes
"""
def __init__(self, num_episodes_between_dumps: int):
super().__init__()
self.num_episodes_between_dumps = num_episodes_between_dumps
self.last_dumped_episode = 0
if num_episodes_between_dumps < 1:
raise ValueError("the number of episodes between dumps should be a positive number")

def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
self.last_dumped_episode = kwargs['episode_idx']
return True
else:
return False


class SelectedPhaseOnlyDumpFilter(object):
"""
Dump videos when the phase of the environment matches a predefined phase
"""
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
self.run_phases = force_list(run_phases)

def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['_phase'] in self.run_phases:
return True
else:
return False
96 changes: 13 additions & 83 deletions rl_coach/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,20 @@ def render(self) -> None:
else:
self.renderer.render_image(self.get_rendered_image())

def handle_episode_ended(self) -> None:
"""
End an episode
:return: None
"""
self.dump_video_of_last_episode_if_needed()

def reset_internal_state(self, force_environment_reset=False) -> EnvResponse:
"""
Reset the environment and all the variable of the wrapper
:param force_environment_reset: forces environment reset even when the game did not end
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""

self.dump_video_of_last_episode_if_needed()
self._restart_environment_episode(force_environment_reset)
self.last_episode_time = time.time()

Expand Down Expand Up @@ -392,17 +398,16 @@ def set_goal(self, goal: GoalType) -> None:
self.goal = goal

def should_dump_video_of_the_current_episode(self, episode_terminated=False):
if self.visualization_parameters.video_dump_methods:
for video_dump_method in force_list(self.visualization_parameters.video_dump_methods):
if not video_dump_method.should_dump(episode_terminated, **self.__dict__):
if self.visualization_parameters.video_dump_filters:
for video_dump_filter in force_list(self.visualization_parameters.video_dump_filters):
if not video_dump_filter.should_dump(episode_terminated, **self.__dict__):
return False
return True
return False
return True

def dump_video_of_last_episode_if_needed(self):
if self.visualization_parameters.video_dump_methods and self.last_episode_images != []:
if self.should_dump_video_of_the_current_episode(episode_terminated=True):
self.dump_video_of_last_episode()
if self.last_episode_images != [] and self.should_dump_video_of_the_current_episode(episode_terminated=True):
self.dump_video_of_last_episode()

def dump_video_of_last_episode(self):
frame_skipping = max(1, int(5 / self.frame_skip))
Expand Down Expand Up @@ -464,78 +469,3 @@ def get_rendered_image(self) -> np.ndarray:
"""
return np.transpose(self.state['observation'], [1, 2, 0])


"""
Video Dumping Methods
"""


class VideoDumpMethod(object):
"""
Method used to decide when to dump videos
"""
def should_dump(self, episode_terminated=False, **kwargs):
raise NotImplementedError("")


class AlwaysDumpMethod(VideoDumpMethod):
"""
Dump video for every episode
"""
def __init__(self):
super().__init__()

def should_dump(self, episode_terminated=False, **kwargs):
return True


class MaxDumpMethod(VideoDumpMethod):
"""
Dump video every time a new max total reward has been achieved
"""
def __init__(self):
super().__init__()
self.max_reward_achieved = -np.inf

def should_dump(self, episode_terminated=False, **kwargs):
# if the episode has not finished yet we want to be prepared for dumping a video
if not episode_terminated:
return True
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
return True
else:
return False


class EveryNEpisodesDumpMethod(object):
"""
Dump videos once in every N episodes
"""
def __init__(self, num_episodes_between_dumps: int):
super().__init__()
self.num_episodes_between_dumps = num_episodes_between_dumps
self.last_dumped_episode = 0
if num_episodes_between_dumps < 1:
raise ValueError("the number of episodes between dumps should be a positive number")

def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
self.last_dumped_episode = kwargs['episode_idx']
return True
else:
return False


class SelectedPhaseOnlyDumpMethod(object):
"""
Dump videos when the phase of the environment matches a predefined phase
"""
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
self.run_phases = force_list(run_phases)

def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['_phase'] in self.run_phases:
return True
else:
return False
3 changes: 1 addition & 2 deletions rl_coach/graph_managers/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,7 @@ def handle_episode_ended(self) -> None:
"""
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1

# TODO: we should disentangle ending the episode from resetting the internal state
# self.reset_internal_state()
[environment.handle_episode_ended() for environment in self.environments]

def train(self, steps: TrainingSteps) -> None:
"""
Expand Down

0 comments on commit 5a8da90

Please sign in to comment.