Skip to content

Commit

Permalink
Merged after step and task restart, added options for enabling parall…
Browse files Browse the repository at this point in the history
…elization of sub_steps in AbstractTask
  • Loading branch information
jordis-ai2 committed Jul 12, 2024
1 parent a038b78 commit fd600c2
Showing 1 changed file with 74 additions and 73 deletions.
147 changes: 74 additions & 73 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def __init__(
task_sampler: TaskSampler,
task_class: type(Task),
callback_sensor_suite: Optional[SensorSuite],
parallel_before_step: bool = False,
parallel_after_step: bool = True,
parallel_get_observations: bool = True,
**kwargs,
) -> None:
assert hasattr(
Expand All @@ -419,6 +422,11 @@ def __init__(

self.env = env

self.parallel_before_step = parallel_before_step
self.parallel_after_step = parallel_after_step
self.parallel_get_observations = parallel_get_observations
self.any_parallel = parallel_before_step or parallel_after_step or parallel_get_observations

# Instantiate the first actual task from the currently sampled info
self.tasks = [
task_class(
Expand All @@ -437,15 +445,16 @@ def __init__(
for it in range(1, self.task_sampler.task_batch_size):
self.tasks.append(self.make_new_task(it))

# Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks
self.executor = ThreadPoolExecutor(
max_workers=min(10, self.task_sampler.task_batch_size)
)
if self.any_parallel:
# Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks
self.executor = ThreadPoolExecutor(
max_workers=min(10, self.task_sampler.task_batch_size)
)

# Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process
# resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly
# acquiring/releasing the mutex as needed).
self.task_sampler.batch_mutex = threading.Lock()
# Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process
# resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly
# acquiring/releasing the mutex as needed).
self.task_sampler.batch_mutex = threading.Lock()

def make_new_task(self, batch_index):
task_batch_size = self.task_sampler.task_batch_size
Expand All @@ -472,14 +481,16 @@ def get_observations(self, **kwargs) -> List[Any]: # -> Dict[str, Any]:
def obs_extract(it, task):
res[it] = task.get_observations()

wait(
[
self.executor.submit(obs_extract, it, task)
for it, task in enumerate(self.tasks)
]
)
# for it, task in enumerate(self.tasks):
# obs_extract(it, task)
if self.parallel_get_observations:
wait(
[
self.executor.submit(obs_extract, it, task)
for it, task in enumerate(self.tasks)
]
)
else:
for it, task in enumerate(self.tasks):
obs_extract(it, task)

return res

Expand Down Expand Up @@ -508,14 +519,45 @@ def render(self, mode: str = "rgb", *args, **kwargs) -> np.ndarray:
raise NotImplementedError()

def step(self, action: Any) -> RLStepResult:
srs = self._step(action=action)
rewards, dones, infos = self._step(action=action)

return RLStepResult(
observation=self.get_observations(),
reward=rewards, # type:ignore
done=dones, # type:ignore
info=infos, # type:ignore
)

@final
def _step(self, action: Any) -> List[RLStepResult]:
# Prepare all actions
actions = [None] * len(self.tasks)
intermediates = [None] * len(self.tasks)

def before_step(it, task):
actions[it], intermediates[it] = task._before_env_step(action[it])

if self.parallel_before_step:
wait(
[
self.executor.submit(before_step, it, task)
for it, task in enumerate(self.tasks)
]
)
else:
for it, task in enumerate(self.tasks):
before_step(it, task)

# Step over all tasks
self.env.step(actions)

# Prepare all results (excluding observations)
rewards = [None] * len(self.tasks)
dones = [None] * len(self.tasks)
infos = [None] * len(self.tasks)

def update_after_step(it, current_task):
sr = srs[it]
def after_step(it, current_task):
sr = current_task._after_env_step(action[it], actions[it], intermediates[it])

info = sr.info or {}

Expand Down Expand Up @@ -552,60 +594,18 @@ def update_after_step(it, current_task):
dones[it] = done
infos[it] = info

# Ensure completion with wait():
wait(
[
self.executor.submit(update_after_step, it, current_task)
for it, current_task in enumerate(self.tasks)
]
)
# for it, current_task in enumerate(self.tasks):
# update_after_step(it, current_task)

return RLStepResult(
observation=self.get_observations(),
reward=rewards, # type:ignore
done=dones, # type:ignore
info=infos, # type:ignore
)

@final
def _step(self, action: Any) -> List[RLStepResult]:
# Prepare all actions
actions = [None] * len(self.tasks)
intermediates = [None] * len(self.tasks)

def before_step(it, task):
actions[it], intermediates[it] = task._before_env_step(action[it])

wait(
[
self.executor.submit(before_step, it, task)
for it, task in enumerate(self.tasks)
]
)
# for it, task in enumerate(self.tasks):
# before_step(it, task)

# Step over all tasks
self.env.step(actions)

# Prepare all results (excluding observations)
srs: List[Optional[RLStepResult]] = [None] * len(self.tasks)

def after_step(it, task):
srs[it] = task._after_env_step(action[it], actions[it], intermediates[it])

wait(
[
self.executor.submit(after_step, it, task)
for it, task in enumerate(self.tasks)
]
)
# for it, task in enumerate(self.tasks):
# after_step(it, task)
if self.parallel_after_step:
wait(
[
self.executor.submit(after_step, it, task)
for it, task in enumerate(self.tasks)
]
)
else:
for it, task in enumerate(self.tasks):
after_step(it, task)

return srs
return rewards, dones, infos

def reached_max_steps(self) -> bool:
get_logger().warning("Unexpected call to `reached_max_steps` in BatchedTask")
Expand All @@ -623,7 +623,8 @@ def num_steps_taken(self) -> int:
return -1

def close(self) -> None:
self.executor.shutdown(cancel_futures=True)
if self.any_parallel:
self.executor.shutdown(cancel_futures=True)
self.tasks[0].close()

def metrics(self) -> Dict[str, Any]:
Expand Down

0 comments on commit fd600c2

Please sign in to comment.