Skip to content

Commit

Permalink
Fixed task nm_steps_taken in BatchedTask, fixed metrics for batched t…
Browse files Browse the repository at this point in the history
…asks, default task_batch_size 0 for no batching (1 for batching with 1 task per sampler)
  • Loading branch information
jordis-ai2 committed Jul 8, 2024
1 parent e204377 commit 992629e
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 45 deletions.
41 changes: 20 additions & 21 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
max_sampler_processes_per_worker: Optional[int] = None,
initial_model_state_dict: Optional[Union[Dict[str, Any], int]] = None,
try_restart_after_task_error: bool = False,
task_batch_size: int = 1,
task_batch_size: int = 0,
**kwargs,
):
"""Initializer.
Expand Down Expand Up @@ -485,8 +485,7 @@ def initialize_storage_and_viz(
):
# No rollout storage, thus we are not
observations = self.vector_tasks.get_observations()
if self.task_batch_size > 1:
# observations = sum([obs["batch_observations"] for obs in observations], [])
if self.task_batch_size > 0:
observations = sum(observations, [])

npaused, keep, batch = self.remove_paused(observations)
Expand Down Expand Up @@ -526,7 +525,7 @@ def initialize_storage_and_viz(
def num_active_samplers(self):
if self.vector_tasks is None:
return 0
return self.vector_tasks.num_unpaused_tasks * self.task_batch_size
return self.vector_tasks.num_unpaused_tasks * max(self.task_batch_size, 1)

def act(
self,
Expand Down Expand Up @@ -673,7 +672,7 @@ def collect_step_across_all_task_samplers(

# Convert flattened actions into list of actions and send them
action_space = self.actor_critic.action_space
if self.task_batch_size > 1:
if self.task_batch_size > 0:
action_space = gym.spaces.Tuple((action_space,) * self.task_batch_size)
new_shape = tuple(flat_actions.shape)[:-2] + (flat_actions.shape[-2] // self.task_batch_size, flat_actions.shape[-1] * self.task_batch_size)
flat_actions = flat_actions.view(new_shape)
Expand All @@ -682,32 +681,32 @@ def collect_step_across_all_task_samplers(
su.action_list(action_space, flat_actions)
)

# Save after task completion metrics
for step_result in outputs:
if step_result.info is not None:
if COMPLETE_TASK_METRICS_KEY in step_result.info:
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in step_result.info:
self.single_process_task_callback_data.append(
step_result.info[COMPLETE_TASK_CALLBACK_KEY]
)
del step_result.info[COMPLETE_TASK_CALLBACK_KEY]

rewards: Union[List, torch.Tensor]
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]

if self.task_batch_size > 1:
if self.task_batch_size > 0:
# Each observation, reward, done, info is actually a list of task_batch_size units
observations = sum(observations, [])
rewards = sum(rewards, [])
dones = sum(dones, [])
# infos = sum(infos, []) # unused
infos = sum(infos, []) # unused
new_shape = tuple(flat_actions.shape)[:-2] + (flat_actions.shape[-2] * self.task_batch_size, flat_actions.shape[-1] // self.task_batch_size)
flat_actions = flat_actions.view(new_shape)

# Save after task completion metrics
for info in infos:
if info is not None:
if COMPLETE_TASK_METRICS_KEY in info:
self.single_process_metrics.append(
info[COMPLETE_TASK_METRICS_KEY]
)
del info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in info:
self.single_process_task_callback_data.append(
info[COMPLETE_TASK_CALLBACK_KEY]
)
del info[COMPLETE_TASK_CALLBACK_KEY]

rewards = torch.tensor(
rewards, dtype=torch.float, device=self.device, # type:ignore
)
Expand Down
2 changes: 1 addition & 1 deletion allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def start_train(
collect_valid_results: bool = False,
valid_on_initial_weights: bool = False,
try_restart_after_task_error: bool = False,
task_batch_size: int = 1,
task_batch_size: int = 0,
):
self._initialize_start_train_or_start_test()

Expand Down
12 changes: 6 additions & 6 deletions allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
read_timeout: Optional[
float
] = 60, # Seconds to wait for a task to return a response before timing out
task_batch_size: int = 1,
task_batch_size: int = 0,
) -> None:

self._is_waiting = False
Expand Down Expand Up @@ -337,7 +337,7 @@ def _task_sampling_loop_worker(
should_log: bool,
child_pipe: Optional[Connection] = None,
parent_pipe: Optional[Connection] = None,
task_batch_size: int = 1,
task_batch_size: int = 0,
) -> None:
"""process worker for creating and interacting with the
Tasks/TaskSampler."""
Expand Down Expand Up @@ -884,7 +884,7 @@ def __init__(
callback_sensor_suite: Optional[SensorSuite] = None,
auto_resample_when_done: bool = True,
should_log: bool = True,
task_batch_size: int = 1,
task_batch_size: int = 0,
) -> None:

self._is_closed = True
Expand Down Expand Up @@ -960,12 +960,12 @@ def _task_sampling_loop_generator_fn(
callback_sensor_suite: Optional[SensorSuite],
auto_resample_when_done: bool,
should_log: bool,
task_batch_size: int = 1,
task_batch_size: int = 0,
) -> Generator:
"""Generator for working with Tasks/TaskSampler."""

task_sampler_args = {**sampler_fn_args}
if task_batch_size > 1:
if task_batch_size > 0:
task_sampler_args["task_batch_size"] = task_batch_size
task_sampler_args["callback_sensor_suite"] = callback_sensor_suite
assert auto_resample_when_done, "auto resample should be the expected usage with batched tasks"
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def _create_generators(
make_sampler_fn: Callable[..., TaskSampler],
sampler_fn_args: Sequence[Dict[str, Any]],
callback_sensor_suite: Optional[SensorSuite],
task_batch_size: int = 1,
task_batch_size: int = 0,
) -> List[Generator]:

generators = []
Expand Down
18 changes: 18 additions & 0 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def get_observations(self, **kwargs) -> List[Any]: #-> Dict[str, Any]:
self.tasks[0].env.render() # assume this is stored locally in the env class

# return {"batch_observations": [task.get_observations() for task in self.tasks]}
# TODO this could be executed in a thread pool for very large batch sizes
return [task.get_observations() for task in self.tasks]

@property
Expand Down Expand Up @@ -461,11 +462,26 @@ def render(self, mode: str = "rgb", *args, **kwargs) -> np.ndarray:
def step(self, action: Any) -> RLStepResult:
srs = self._step(action=action)

# TODO this could be executed in a thread pool for very large batch sizes
for it, current_task in enumerate(self.tasks):
if srs[it].info is None:
srs[it] = srs[it].clone({"info": {}})

# If reward is Sequence, it's assumed to follow the same order imposed by spaces' flatten operation
if isinstance(srs[it].reward, Sequence):
if isinstance(current_task._total_reward, Sequence):
for it, rew in enumerate(srs[it].reward):
current_task._total_reward[it] += float(rew)
else:
current_task._total_reward = [float(r) for r in srs[it].reward]
else:
current_task._total_reward += float(srs[it].reward) # type:ignore

current_task._increment_num_steps_taken()

if current_task.is_done():
srs[it] = srs[it].clone({"done": True})

metrics = current_task.metrics()
if metrics is not None and len(metrics) != 0:
srs[it].info[COMPLETE_TASK_METRICS_KEY] = metrics
Expand All @@ -492,6 +508,7 @@ def _step(self, action: Any) -> List[RLStepResult]:
# Prepare all actions
actions = []
intermediates = []
# TODO this could be executed in a thread pool for very large batch sizes
for it, task in enumerate(self.tasks):
action_str, intermediate = task._before_env_step(action[it])
actions.append(action_str)
Expand All @@ -502,6 +519,7 @@ def _step(self, action: Any) -> List[RLStepResult]:

# Prepare all results (excluding observations)
srs = []
# TODO this could be executed in a thread pool for very large batch sizes
for it, task in enumerate(self.tasks):
sr = task._after_env_step(action[it], actions[it], intermediates[it])
srs.append(sr)
Expand Down
2 changes: 1 addition & 1 deletion allenact/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def get_argument_parser():
"--task_batch_size",
dest="task_batch_size",
type=int,
default=1,
default=0,
help="Makes task_batch_size training tasks be processed as a batch for each instantiated env",
)

Expand Down
66 changes: 50 additions & 16 deletions tests/make_it_batch/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def __init__(
**kwargs,
):
self.task_batch_size = task_batch_size
self.controllers = [IThorEnvironment(**kwargs) for _ in range(task_batch_size)]
self.controllers = [IThorEnvironment(**kwargs) for _ in range(max(1, task_batch_size))]
self._frames = []

def step(self, actions: List[str]):
assert len(actions) == self.task_batch_size
assert len(actions) == self.task_batch_size or len(actions) == self.task_batch_size + 1
for controller, action in zip(self.controllers, actions):
controller.step(action=action if action != "End" else "Pass")
self._frames = []
Expand Down Expand Up @@ -83,8 +83,30 @@ def render(self):


class BatchableObjectNaviThorGridTask(ObjectNaviThorGridTask):
def _step(self, action):
raise NotImplementedError()
# # TODO BEGIN For compatibility with batch_task_size = 0
#
# batch_index = 0
#
# def get_observations(self, **kwargs) -> List[Any]: #-> Dict[str, Any]:
# # Render all tasks in batch
# self._frames = []
# self.env.render()
# obs = super().get_observations()
# self.env._frames = []
# return obs
#
# def _step(self, action):
# # raise NotImplementedError()
# action_str, interm = self._before_env_step(action)
# self.env.step([action_str])
# self._after_env_step(action, action_str, interm)
# return RLStepResult(
# observation=self.get_observations(),
# reward=self.judge(),
# done=self.is_done(),
# info={"last_action_success": self.last_action_success},
# )
# # TODO END For compatibility with batch_task_size = 0

def is_goal_object_visible(self) -> bool:
"""Is the goal object currently visible?"""
Expand Down Expand Up @@ -126,8 +148,11 @@ def __init__(self, **kwargs):
if "task_batch_size" in kwargs:
self.task_batch_size = kwargs["task_batch_size"]
self.callback_sensor_suite = kwargs["callback_sensor_suite"]
kwargs.pop("task_batch_size")
kwargs.pop("callback_sensor_suite")
kwargs.pop("task_batch_size")
kwargs.pop("callback_sensor_suite")
else:
self.task_batch_size = 0
self.callback_sensor_suite = None
super().__init__(**kwargs)

def _create_environment(self):
Expand Down Expand Up @@ -175,16 +200,25 @@ def next_task(
"id"
] = f"{scene}__{'_'.join(list(map(str, self.env.controllers[idx].get_key(pose))))}__{task_info['object_type']}"

self._last_sampled_task = BatchedTask(
env=self.env,
sensors=self.sensors,
task_info=task_info,
max_steps=self.max_steps,
action_space=self._action_space,
task_sampler=self,
task_classes=[BatchableObjectNaviThorGridTask],
callback_sensor_suite=self.callback_sensor_suite,
)
if self.task_batch_size > 0:
self._last_sampled_task = BatchedTask(
env=self.env,
sensors=self.sensors,
task_info=task_info,
max_steps=self.max_steps,
action_space=self._action_space,
task_sampler=self,
task_classes=[BatchableObjectNaviThorGridTask],
callback_sensor_suite=self.callback_sensor_suite,
)
else:
self._last_sampled_task = BatchableObjectNaviThorGridTask(
env=self.env,
sensors=self.sensors,
task_info=task_info,
max_steps=self.max_steps,
action_space=self._action_space,
)
return self._last_sampled_task


Expand Down

0 comments on commit 992629e

Please sign in to comment.