Skip to content

Commit

Permalink
Merge pull request google#19 from wes-turner/fix_thread_pool
Browse files Browse the repository at this point in the history
Fix thread pool for to run iterations in async runner.
  • Loading branch information
arthurarg committed Apr 16, 2019
2 parents 33a2de5 + f4c4e1b commit 1ce0d6c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 56 deletions.
2 changes: 1 addition & 1 deletion dopamine/discrete_domains/configs/async_training.gin
Expand Up @@ -4,5 +4,5 @@ import dopamine.replay_memory.circular_replay_buffer

create_runner.schedule = 'async_train'
AsyncRunner.create_environment_fn = @gym_lib.create_gym_environment
AsyncRunner.max_simultaneous_iterations = 3
AsyncRunner.num_simultaneous_iterations = 3
WrappedReplayBuffer.use_contiguous_trajectories = True
68 changes: 27 additions & 41 deletions dopamine/discrete_domains/run_experiment.py
Expand Up @@ -561,7 +561,7 @@ class AsyncRunner(Runner):
def __init__(
self, base_dir, create_agent_fn,
create_environment_fn=atari_lib.create_atari_environment,
max_simultaneous_iterations=1, **kwargs):
num_simultaneous_iterations=1, **kwargs):
"""Creates an asynchronous runner.
Args:
Expand All @@ -570,64 +570,62 @@ def __init__(
environment, and returns an agent.
create_environment_fn: A function which receives a problem name and
creates a Gym environment for that problem (e.g. an Atari 2600 game).
max_simultaneous_iterations: int, maximum number of iterations running
num_simultaneous_iterations: int, number of iterations running
simultaneously in separate threads.
**kwargs: Additional positional arguments.
"""
threading_utils.initialize_local_attributes(
self, _environment=create_environment_fn)
self._eval_period = max_simultaneous_iterations
self._eval_period = num_simultaneous_iterations
self._num_simultaneous_iterations = num_simultaneous_iterations
self._output_lock = threading.Lock()
self._experience_queue = queue.Queue(max_simultaneous_iterations)
self._training_queue = queue.Queue(max_simultaneous_iterations)
self._training_queue = queue.Queue(num_simultaneous_iterations)

super(AsyncRunner, self).__init__(
base_dir=base_dir, create_agent_fn=create_agent_fn,
create_environment_fn=create_environment_fn, **kwargs)

# TODO(aarg): Decouple experience generation from training.
def _run_iterations(self):
"""Runs required number of training iterations sequentially.
Statistics from each iteration are logged and exported for tensorboard.
Iterations are run in multiple threads simultaneously (number of
simultaneous threads is specified by `max_simultaneous_iterations`). Each
simultaneous threads is specified by `num_simultaneous_iterations`). Each
time an iteration completes a new one starts until the right number of
iterations is run.
"""
# TODO(aarg): Change the thread management to an implementation with queues
# and with a fix number of thread workers. With the revamp, training might
# go inline to this method.
def _start_iteration(*args):
thread = threading.Thread(target=self._run_one_iteration, args=args)
thread.start()
return thread

training_worker = threading.Thread(target=self._run_training_steps)
training_worker.start()
experience_threads = []
experience_queue = queue.Queue()
worker_threads = []

# TODO(aarg): Consider refactoring to use step level tasks.
for _ in range(self._num_simultaneous_iterations):
worker_threads.append(
threading_utils.start_worker_thread(experience_queue))
worker_threads.append(
threading_utils.start_worker_thread(self._training_queue))

# TODO(westurner): See how to refactor the code to avoid setting an internal
# attribute.
self._completed_iteration = self._start_iteration
for iteration in range(self._start_iteration, self._num_iterations):
if (iteration + 1) % self._eval_period == 0:
# TODO(aarg): As part of the revamp indicated above, set the eval mode
# directly in the queue `push`.
self._experience_queue.put('train')
experience_threads.append(_start_iteration(iteration, True))
self._experience_queue.put('eval')
experience_threads.append(_start_iteration(iteration, False))
# TODO(aarg): Replace with ModeKeys.
experience_queue.put((self._run_one_iteration, (iteration, True)))
experience_queue.put((self._run_one_iteration, (iteration, False)))

# Wait for all tasks to complete.
self._experience_queue.join()
experience_queue.join()
self._training_queue.join()
# Indicate training step thread to stop.

# Indicate workers to stop.
for _ in range(self._num_simultaneous_iterations):
experience_queue.put(None)
self._training_queue.put(None)

# Wait for all running threads to complete.
for thread in experience_threads:
for thread in worker_threads:
thread.join()
training_worker.join()

def _begin_episode(self, observation):
# Increments training steps and blocks if training is too slow.
Expand All @@ -648,17 +646,7 @@ def _enqueue_training_step(self):
"""
if self._agent.eval_mode:
return
self._training_queue.put(0) # Value doesn't matter.

def _run_training_steps(self):
"""Runs training steps until iterations and training queues are empty."""
while True:
item = self._training_queue.get()
if item is None:
self._training_queue.task_done()
return
self._agent.train_step()
self._training_queue.task_done()
self._training_queue.put((self._agent.train_step, tuple([])))

def _run_one_iteration(self, iteration, eval_mode):
"""Runs one iteration in separate thread, logs and checkpoints results.
Expand Down Expand Up @@ -688,5 +676,3 @@ def _run_one_iteration(self, iteration, eval_mode):
self._checkpoint_experiment(self._completed_iteration)
self._completed_iteration += 1
tf.logging.info('Completed %s.', iteration_name)
self._experience_queue.get()
self._experience_queue.task_done()
34 changes: 34 additions & 0 deletions dopamine/utils/threading_utils.py
Expand Up @@ -214,3 +214,37 @@ def initialize_local_attributes(obj, **kwargs):
raise AttributeError(
'Object `{}` already has a `{}` attribute.'.format(obj, default_attr))
setattr(obj, default_attr, val)


def _queue_worker(task_queue):
"""Reads and executes tasks in given queue until `None` is read."""
while True:
item = task_queue.get()
if item is None:
task_queue.task_done()
break
function, task = item
function(*task)
task_queue.task_done()


def start_worker_thread(task_queue):
"""Starts and returns a thread working on tasks in provided queue.
Tasks in `task_queue` needs to be stored as tuple of:
- function: function taking positional arguments and returning None.
- task: tuple of positional arguments to pass to the function.
Each task is executed by calling `function(*task)`.
The worker thread stops when a task `None` is added to the task queue and
processed by the worker.
Args:
task_queue: `queue.Queue` object containing tasks to perform.
Returns:
Thread object running and performing the tasks in `task_queue`.
"""
thread = threading.Thread(target=_queue_worker, args=(task_queue,))
thread.start()
return thread
37 changes: 23 additions & 14 deletions tests/dopamine/discrete_domains/run_async_training_test.py
Expand Up @@ -61,7 +61,7 @@ def testEnvironmentInitializationPerThread(self):
runner = self._get_runner(
create_agent_fn=test.mock.MagicMock(),
create_environment_fn=environment_fn, num_iterations=1,
training_steps=1, evaluation_steps=0, max_simultaneous_iterations=1)
training_steps=1, evaluation_steps=0, num_simultaneous_iterations=1)

# Environment called once in init.
environment_fn.assert_called_once()
Expand All @@ -76,7 +76,7 @@ def testNumIterations(self):
runner = self._get_runner(
create_agent_fn=agent_fn,
create_environment_fn=_get_mock_environment_fn(), num_iterations=18,
training_steps=1, evaluation_steps=0, max_simultaneous_iterations=1)
training_steps=1, evaluation_steps=0, num_simultaneous_iterations=1)
runner.run_experiment()
self.assertEqual(mock_agent.begin_episode.call_count, 18)

Expand All @@ -88,16 +88,27 @@ def testNumberTrainingSteps(self,):
runner = self._get_runner(
create_agent_fn=test.mock.MagicMock(),
create_environment_fn=_get_mock_environment_fn(), num_iterations=3,
training_steps=2, evaluation_steps=6, max_simultaneous_iterations=1)
training_steps=2, evaluation_steps=6, num_simultaneous_iterations=1)
runner.run_experiment()

def _put_call_cnt(v):
return sum([list(call)[0] == v for call in mock_put.call_args_list])

self.assertEqual(_put_call_cnt(('train',)), 3)
self.assertEqual(_put_call_cnt(('eval',)), 3)
self.assertEqual(_put_call_cnt((0,)), 6)
self.assertEqual(_put_call_cnt((None,)), 1)
cnt = 0
for call in mock_put.call_args_list:
item = list(call)[0][0]
if isinstance(item, tuple):
cnt += item[1] == v
else:
cnt += item == v
return cnt

self.assertEqual(_put_call_cnt((0, False)), 1) # Train task.
self.assertEqual(_put_call_cnt((1, False,)), 1) # Train task.
self.assertEqual(_put_call_cnt((2, False,)), 1) # Train task.
self.assertEqual(_put_call_cnt((0, True,)), 1) # Eval task.
self.assertEqual(_put_call_cnt((1, True,)), 1) # Eval task.
self.assertEqual(_put_call_cnt((2, True,)), 1) # Eval task.
self.assertEqual(_put_call_cnt(tuple([])), 6) # Training step.
self.assertEqual(_put_call_cnt(None), 2) # Stop task.

def testNumberSteps(self):
"""Tests that the right number of agent steps are ran."""
Expand All @@ -106,7 +117,7 @@ def testNumberSteps(self):
runner = self._get_runner(
create_agent_fn=agent_fn,
create_environment_fn=_get_mock_environment_fn(), num_iterations=3,
training_steps=2, evaluation_steps=6, max_simultaneous_iterations=1)
training_steps=2, evaluation_steps=6, num_simultaneous_iterations=1)
runner.run_experiment()
self.assertEqual(agent.begin_episode.call_count, 24)

Expand All @@ -116,7 +127,7 @@ def testSummariesExportedWithProperTags(self, summary):
base_dir=self.get_temp_dir(), create_agent_fn=test.mock.MagicMock(),
create_environment_fn=_get_mock_environment_fn(),
num_iterations=2, training_steps=1, evaluation_steps=0,
max_simultaneous_iterations=2)
num_simultaneous_iterations=2)
runner._checkpoint_experiment = test.mock.Mock()
runner._log_experiment = test.mock.Mock()
runner._summary_writer = test.mock.Mock()
Expand All @@ -138,7 +149,7 @@ def setUp(self):
runner = run_experiment.AsyncRunner(
base_dir=self.get_temp_dir(), create_agent_fn=test.mock.MagicMock(),
create_environment_fn=_get_mock_environment_fn(), num_iterations=1,
training_steps=1, evaluation_steps=0, max_simultaneous_iterations=1)
training_steps=1, evaluation_steps=0, num_simultaneous_iterations=1)
runner._checkpoint_experiment = test.mock.Mock()
runner._log_experiment = test.mock.Mock()
runner._save_tensorboard_summaries = test.mock.Mock()
Expand All @@ -147,7 +158,6 @@ def setUp(self):

def testCompletedIterationCounterIsUsed(self,):
self.runner._completed_iteration = 20
self.runner._experience_queue.put(1)
self.runner._run_one_iteration(iteration=36, eval_mode=False)
self.runner._checkpoint_experiment.assert_called_once_with(20)

Expand All @@ -157,7 +167,6 @@ def testCompletedIterationCounterIsInitialized(self):

def testCompletedIterationCounterIsIncremented(self):
self.runner._completed_iteration = 20
self.runner._experience_queue.put(1)
self.runner._run_one_iteration(iteration=36, eval_mode=False)
self.assertEqual(self.runner._completed_iteration, 21)

Expand Down

0 comments on commit 1ce0d6c

Please sign in to comment.