diff --git a/allenact/base_abstractions/task.py b/allenact/base_abstractions/task.py index 9c18bb78..eb1e5eed 100644 --- a/allenact/base_abstractions/task.py +++ b/allenact/base_abstractions/task.py @@ -406,6 +406,9 @@ def __init__( task_sampler: TaskSampler, task_class: type(Task), callback_sensor_suite: Optional[SensorSuite], + parallel_before_step: bool = False, + parallel_after_step: bool = True, + parallel_get_observations: bool = True, **kwargs, ) -> None: assert hasattr( @@ -419,6 +422,11 @@ def __init__( self.env = env + self.parallel_before_step = parallel_before_step + self.parallel_after_step = parallel_after_step + self.parallel_get_observations = parallel_get_observations + self.any_parallel = parallel_before_step or parallel_after_step or parallel_get_observations + # Instantiate the first actual task from the currently sampled info self.tasks = [ task_class( @@ -437,15 +445,16 @@ def __init__( for it in range(1, self.task_sampler.task_batch_size): self.tasks.append(self.make_new_task(it)) - # Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks - self.executor = ThreadPoolExecutor( - max_workers=min(10, self.task_sampler.task_batch_size) - ) + if self.any_parallel: + # Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks + self.executor = ThreadPoolExecutor( + max_workers=min(10, self.task_sampler.task_batch_size) + ) - # Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process - # resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly - # acquiring/releasing the mutex as needed). - self.task_sampler.batch_mutex = threading.Lock() + # Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process + # resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly + # acquiring/releasing the mutex as needed). + self.task_sampler.batch_mutex = threading.Lock() def make_new_task(self, batch_index): task_batch_size = self.task_sampler.task_batch_size @@ -472,14 +481,16 @@ def get_observations(self, **kwargs) -> List[Any]: # -> Dict[str, Any]: def obs_extract(it, task): res[it] = task.get_observations() - wait( - [ - self.executor.submit(obs_extract, it, task) - for it, task in enumerate(self.tasks) - ] - ) - # for it, task in enumerate(self.tasks): - # obs_extract(it, task) + if self.parallel_get_observations: + wait( + [ + self.executor.submit(obs_extract, it, task) + for it, task in enumerate(self.tasks) + ] + ) + else: + for it, task in enumerate(self.tasks): + obs_extract(it, task) return res @@ -508,14 +519,45 @@ def render(self, mode: str = "rgb", *args, **kwargs) -> np.ndarray: raise NotImplementedError() def step(self, action: Any) -> RLStepResult: - srs = self._step(action=action) + rewards, dones, infos = self._step(action=action) + return RLStepResult( + observation=self.get_observations(), + reward=rewards, # type:ignore + done=dones, # type:ignore + info=infos, # type:ignore + ) + + @final + def _step(self, action: Any) -> List[RLStepResult]: + # Prepare all actions + actions = [None] * len(self.tasks) + intermediates = [None] * len(self.tasks) + + def before_step(it, task): + actions[it], intermediates[it] = task._before_env_step(action[it]) + + if self.parallel_before_step: + wait( + [ + self.executor.submit(before_step, it, task) + for it, task in enumerate(self.tasks) + ] + ) + else: + for it, task in enumerate(self.tasks): + before_step(it, task) + + # Step over all tasks + self.env.step(actions) + + # Prepare all results (excluding observations) rewards = [None] * len(self.tasks) dones = [None] * len(self.tasks) infos = [None] * len(self.tasks) - def update_after_step(it, current_task): - sr = srs[it] + def after_step(it, current_task): + sr = current_task._after_env_step(action[it], actions[it], intermediates[it]) info = sr.info or {} @@ -552,60 +594,18 @@ def update_after_step(it, current_task): dones[it] = done infos[it] = info - # Ensure completion with wait(): - wait( - [ - self.executor.submit(update_after_step, it, current_task) - for it, current_task in enumerate(self.tasks) - ] - ) - # for it, current_task in enumerate(self.tasks): - # update_after_step(it, current_task) - - return RLStepResult( - observation=self.get_observations(), - reward=rewards, # type:ignore - done=dones, # type:ignore - info=infos, # type:ignore - ) - - @final - def _step(self, action: Any) -> List[RLStepResult]: - # Prepare all actions - actions = [None] * len(self.tasks) - intermediates = [None] * len(self.tasks) - - def before_step(it, task): - actions[it], intermediates[it] = task._before_env_step(action[it]) - - wait( - [ - self.executor.submit(before_step, it, task) - for it, task in enumerate(self.tasks) - ] - ) - # for it, task in enumerate(self.tasks): - # before_step(it, task) - - # Step over all tasks - self.env.step(actions) - - # Prepare all results (excluding observations) - srs: List[Optional[RLStepResult]] = [None] * len(self.tasks) - - def after_step(it, task): - srs[it] = task._after_env_step(action[it], actions[it], intermediates[it]) - - wait( - [ - self.executor.submit(after_step, it, task) - for it, task in enumerate(self.tasks) - ] - ) - # for it, task in enumerate(self.tasks): - # after_step(it, task) + if self.parallel_after_step: + wait( + [ + self.executor.submit(after_step, it, task) + for it, task in enumerate(self.tasks) + ] + ) + else: + for it, task in enumerate(self.tasks): + after_step(it, task) - return srs + return rewards, dones, infos def reached_max_steps(self) -> bool: get_logger().warning("Unexpected call to `reached_max_steps` in BatchedTask") @@ -623,7 +623,8 @@ def num_steps_taken(self) -> int: return -1 def close(self) -> None: - self.executor.shutdown(cancel_futures=True) + if self.any_parallel: + self.executor.shutdown(cancel_futures=True) self.tasks[0].close() def metrics(self) -> Dict[str, Any]: