Skip to content

Commit

Permalink
fix the env resource issue for async training (#221)
Browse files Browse the repository at this point in the history
* fix the env resource issue for async training

* make env management consistent among trainers

* remove old observation_transformer in drivers
  • Loading branch information
hnyu authored Oct 15, 2019
1 parent 07a451b commit 3c9b20f
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 43 deletions.
1 change: 1 addition & 0 deletions alf/bin/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def main(_):
num_episodes=FLAGS.num_episodes,
sleep_time_per_step=FLAGS.sleep_time_per_step,
record_file=FLAGS.record_file)
env.pyenv.close()


if __name__ == '__main__':
Expand Down
1 change: 0 additions & 1 deletion alf/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from absl import logging

import gin.tf.external_configurables
from alf.environments.utils import create_environment
from alf.utils import common
import alf.utils.external_configurables
from alf.trainers import policy_trainer
Expand Down
5 changes: 1 addition & 4 deletions alf/drivers/async_off_policy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from alf.drivers.off_policy_driver import OffPolicyDriver
from alf.drivers.threads import TFQueues, ActorThread, EnvThread, LogThread
from alf.experience_replayers.experience_replay import OnetimeExperienceReplayer
from alf.environments.utils import create_environment


@gin.configurable
Expand Down Expand Up @@ -121,9 +120,7 @@ def __init__(self,
coord=self._coord,
algorithm=self._algorithm,
tf_queues=self._tfq,
id=i,
observation_transformer=self._observation_transformer)
for i in range(num_actor_queues)
id=i) for i in range(num_actor_queues)
]
env_threads = [
EnvThread(
Expand Down
6 changes: 0 additions & 6 deletions alf/drivers/policy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class PolicyDriver(driver.Driver):
def __init__(self,
env,
algorithm,
observation_transformer: Callable = None,
observers=[],
use_rollout_state=False,
metrics=[],
Expand Down Expand Up @@ -101,7 +100,6 @@ def __init__(self,
self._debug_summaries = debug_summaries
self._summarize_grads_and_vars = summarize_grads_and_vars
self._summarize_action_distributions = summarize_action_distributions
self._observation_transformer = observation_transformer
self._train_step_counter = common.get_global_counter(
train_step_counter)
self._proc = psutil.Process(os.getpid())
Expand Down Expand Up @@ -157,10 +155,6 @@ def _training_summary(self, training_info, loss_info, grads_and_vars):
def algorithm(self):
return self._algorithm

@property
def observation_transformer(self):
return self._observation_transformer

@abc.abstractmethod
def _prepare_specs(self, algorithm):
pass
Expand Down
11 changes: 1 addition & 10 deletions alf/drivers/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,7 @@ class ActorThread(Thread):
stop (from another thread).
"""

def __init__(self,
name,
coord,
algorithm,
tf_queues,
id,
observation_transformer: Callable = None):
def __init__(self, name, coord, algorithm, tf_queues, id):
"""
Args:
name (str): the name of the actor thread
Expand All @@ -277,14 +271,11 @@ def __init__(self,
tf_queues (TFQueues): for storing all the tf.FIFOQueues for
communicating between threads
id (int): thread id
observation_transformer (Callable): transformation applied to
`time_step.observation`
"""
super().__init__(name=name, target=self._run, args=(coord, algorithm))
self._tfq = tf_queues
self._id = id
self._actor_q = self._tfq.actor_queues[id]
self._ob_transformer = observation_transformer

@tf.function
def _enqueue_actions(self, policy_step, action_dist_param, i, return_q):
Expand Down
12 changes: 5 additions & 7 deletions alf/trainers/off_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from alf.drivers.async_off_policy_driver import AsyncOffPolicyDriver
from alf.drivers.sync_off_policy_driver import SyncOffPolicyDriver
from alf.environments.utils import create_environment
from alf.trainers.policy_trainer import Trainer


Expand Down Expand Up @@ -65,14 +64,14 @@ class SyncOffPolicyTrainer(OffPolicyTrainer):

def init_driver(self):
return SyncOffPolicyDriver(
env=self._env,
env=self._envs[0],
use_rollout_state=self._config.use_rollout_state,
algorithm=self._algorithm,
debug_summaries=self._debug_summaries,
summarize_grads_and_vars=self._summarize_grads_and_vars)

def train_iter(self, iter_num, policy_state, time_step):
max_num_steps = self._unroll_length * self._env.batch_size
max_num_steps = self._unroll_length * self._envs[0].batch_size
if iter_num == 0 and self._initial_collect_steps != 0:
max_num_steps = self._initial_collect_steps
time_step, policy_state = self._driver.run(
Expand All @@ -97,11 +96,10 @@ def __init__(self, config):
self._driver_started = False

def init_driver(self):
envs = [self._env]
for i in range(1, self._config.num_envs):
envs.append(create_environment())
for _ in range(1, self._config.num_envs):
self._create_environment()
driver = AsyncOffPolicyDriver(
envs=envs,
envs=self._envs,
algorithm=self._algorithm,
use_rollout_state=self._config.use_rollout_state,
unroll_length=self._unroll_length,
Expand Down
2 changes: 1 addition & 1 deletion alf/trainers/on_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, config):

def init_driver(self):
return OnPolicyDriver(
env=self._env,
env=self._envs[0],
algorithm=self._algorithm,
train_interval=self._unroll_length,
debug_summaries=self._debug_summaries,
Expand Down
43 changes: 29 additions & 14 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self, config: TrainerConfig):
self._train_dir = os.path.join(root_dir, 'train')
self._eval_dir = os.path.join(root_dir, 'eval')

self._env = None
self._envs = []
self._algorithm_ctor = config.algorithm_ctor
self._algorithm = None
self._driver = None
Expand Down Expand Up @@ -222,20 +222,36 @@ def initialize(self):

tf.config.experimental_run_functions_eagerly(
not self._use_tf_functions)
self._env = create_environment()
common.set_global_env(self._env)
# Create an unwrapped env to expose subprocess gin confs which otherwise
# will be marked as "inoperative"
unwrapped_env = create_environment(nonparallel=True)
if self._evaluate:
self._eval_env = unwrapped_env
env = self._create_environment()
common.set_global_env(env)

self._algorithm = self._algorithm_ctor(
debug_summaries=self._debug_summaries)
self._algorithm.use_rollout_state = self._config.use_rollout_state

self._driver = self.init_driver()

# Create an unwrapped env to expose subprocess gin confs which otherwise
# will be marked as "inoperative". This env should be created last.
unwrapped_env = self._create_environment(nonparallel=True)
if self._evaluate:
self._eval_env = unwrapped_env

def _create_environment(self, nonparallel=False):
"""Create and register an env"""
env = create_environment(nonparallel=nonparallel)
self._register_env(env)
return env

def _register_env(self, env):
"""Register env so that later its resource will be recycled"""
self._envs.append(env)

def _close_envs(self):
"""Close all envs to release their resources"""
for env in self._envs:
env.pyenv.close()

@abc.abstractmethod
def init_driver(self):
"""Initialize driver
Expand All @@ -249,8 +265,8 @@ def init_driver(self):

def train(self):
"""Perform training."""
assert None not in (self._env, self._algorithm,
self._driver), "Trainer not initialized"
assert (None not in (self._algorithm, self._driver)) and self._envs, \
"Trainer not initialized"
self._restore_checkpoint()
run_under_record_context(
self._train,
Expand All @@ -259,9 +275,7 @@ def train(self):
flush_millis=self._summaries_flush_mills,
summary_max_queue=self._summary_max_queue)
self._save_checkpoint()
if self._evaluate:
self._eval_env.pyenv.close()
self._env.pyenv.close()
self._close_envs()

@abc.abstractmethod
def train_iter(self, iter_num, policy_state, time_step):
Expand All @@ -279,7 +293,8 @@ def train_iter(self, iter_num, policy_state, time_step):
pass

def _train(self):
self._env.reset()
for env in self._envs:
env.reset()
time_step = self._driver.get_initial_time_step()
policy_state = self._driver.get_initial_policy_state()
for iter_num in range(self._num_iterations):
Expand Down
2 changes: 2 additions & 0 deletions docs/async_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@ If possible, we want to minimize the time interval T between every two training
Assuming abundant CPU resources, we can imagine that async training is most suitable for problems with simple neural models but complex environment simulations (3D rendering, physics, etc) if the bottleneck is simulation speed. In this case the rollout time is much greater than the training time, and having multiple actors (ideally without comprising each actor's speed) in the data pipeline can decrease the waiting time between two training iterations.

Because async training generally is less sample efficient than sync training, it’s recommended to use it for cases where sample efficiency is not the main metric, e.g., to have faster turn-around times for tweaking model hyperparameters.

Another great benefit of async training is when sometimes we want to train on a large number of parallel environments with a large unroll length per training iteration. For both on-policy and sync off-policy training, we have to maintain a huge computational graph during each training update. This can cause the out-of-memory issue on a GPU. With async off-policy training, we could effectively have the same environment batch size and unroll length by splitting the environment batch size into several smaller ones (i.e., with K actors, each actor having B/K environments, but only need to main a computational graph of size 1/K during training updates).

0 comments on commit 3c9b20f

Please sign in to comment.