Skip to content

Commit

Permalink
max_workers in thread pool executor
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 9, 2024
1 parent 0700aa3 commit 3d3b9b1
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def __init__(
self.tasks.append(self.make_new_task(it))

# Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks
self.executor = ThreadPoolExecutor()
self.executor = ThreadPoolExecutor(max_workers=min(10, self.task_sampler.task_batch_size))

# Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process
# resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly
Expand Down Expand Up @@ -472,6 +472,8 @@ def obs_extract(it, task):
res[it] = task.get_observations()

wait([self.executor.submit(obs_extract, it, task) for it, task in enumerate(self.tasks)])
# for it, task in enumerate(self.tasks):
# obs_extract(it, task)

return res

Expand Down Expand Up @@ -546,6 +548,8 @@ def update_after_step(it, current_task):

# Ensure completion with wait():
wait([self.executor.submit(update_after_step, it, current_task) for it, current_task in enumerate(self.tasks)])
# for it, current_task in enumerate(self.tasks):
# update_after_step(it, current_task)

return RLStepResult(
observation=self.get_observations(),
Expand All @@ -564,6 +568,8 @@ def before_step(it, task):
actions[it], intermediates[it] = task._before_env_step(action[it])

wait([self.executor.submit(before_step, it, task) for it, task in enumerate(self.tasks)])
# for it, task in enumerate(self.tasks):
# before_step(it, task)

# Step over all tasks
self.env.step(actions)
Expand All @@ -575,6 +581,8 @@ def after_step(it, task):
srs[it] = task._after_env_step(action[it], actions[it], intermediates[it])

wait([self.executor.submit(after_step, it, task) for it, task in enumerate(self.tasks)])
# for it, task in enumerate(self.tasks):
# after_step(it, task)

return srs

Expand Down

0 comments on commit 3d3b9b1

Please sign in to comment.