Skip to content

Commit

Permalink
Next task class for BatchedTask directly given by the TaskSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 12, 2024
1 parent f5afde3 commit 3363fee
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 2 additions & 4 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ class BatchedTask(Generic[EnvType]):

task_sampler: TaskSampler
tasks: List[Task]
task_classes: List[type(Task)]
callback_sensor_suite: Optional[SensorSuite]

def __init__(
Expand All @@ -404,7 +403,7 @@ def __init__(
task_info: Dict[str, Any],
max_steps: int,
task_sampler: TaskSampler,
task_classes: List[type(Task)],
task_class: type(Task),
callback_sensor_suite: Optional[SensorSuite],
**kwargs,
) -> None:
Expand All @@ -415,14 +414,13 @@ def __init__(
# Keep a reference to the task sampler
self.task_sampler = task_sampler

self.task_classes = task_classes
self.callback_sensor_suite = callback_sensor_suite

self.env = env

# Instantiate the first actual task from the currently sampled info
self.tasks = [
task_classes[0](
task_class(
env=env,
sensors=sensors,
task_info=task_info,
Expand Down
2 changes: 1 addition & 1 deletion tests/make_it_batch/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def next_task(
max_steps=self.max_steps,
action_space=self._action_space,
task_sampler=self,
task_classes=[BatchableObjectNaviThorGridTask],
task_class=BatchableObjectNaviThorGridTask,
callback_sensor_suite=self.callback_sensor_suite,
)
else:
Expand Down

0 comments on commit 3363fee

Please sign in to comment.