Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 41 additions & 36 deletions gym-unity/gym_unity/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __init__(
self.visual_obs = None
self._n_agents = -1
self._done_agents: Set[int] = set()
self._stored_info: BatchedStepResult = None
# Save the step result from the last time all Agents requested decisions.
self._previous_step_result: BatchedStepResult = None
self._multiagent = multiagent
self._flattener = None
# Hidden flag used by Atari environments to determine if the game is over
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
self._env.reset()
step_result = self._env.get_step_result(self.brain_name)
self._check_agents(step_result.n_agents())
self._stored_info = step_result
self._previous_step_result = step_result

# Set observation and action spaces
if self.group_spec.is_action_discrete():
Expand Down Expand Up @@ -162,15 +163,15 @@ def reset(self) -> Union[List[np.ndarray], np.ndarray]:
Returns: observation (object/list): the initial observation of the
space.
"""
info = self._step(True)
n_agents = info.n_agents()
step_result = self._step(True)
n_agents = step_result.n_agents()
self._check_agents(n_agents)
self.game_over = False

if not self._multiagent:
res: GymStepResult = self._single_step(info)
res: GymStepResult = self._single_step(step_result)
else:
res = self._multi_step(info)
res = self._multi_step(step_result)
return res[0]

def step(self, action: List[Any]) -> GymStepResult:
Expand Down Expand Up @@ -215,17 +216,17 @@ def step(self, action: List[Any]) -> GymStepResult:
action = self._sanitize_action(action)
self._env.set_actions(self.brain_name, action)

info = self._step()
step_result = self._step()

n_agents = info.n_agents()
n_agents = step_result.n_agents()
self._check_agents(n_agents)

if not self._multiagent:
single_res = self._single_step(info)
single_res = self._single_step(step_result)
self.game_over = single_res[2]
return single_res
else:
multi_res = self._multi_step(info)
multi_res = self._multi_step(step_result)
self.game_over = all(multi_res[2])
return multi_res

Expand Down Expand Up @@ -358,8 +359,8 @@ def _check_agents(self, n_agents: int) -> None:
"initialization. This is not supported."
)

def _sanitize_info(self, info: BatchedStepResult) -> BatchedStepResult:
n_extra_agents = info.n_agents() - self._n_agents
def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
n_extra_agents = step_result.n_agents() - self._n_agents
if n_extra_agents < 0 or n_extra_agents > self._n_agents:
# In this case, some Agents did not request a decision when expected
# or too many requested a decision
Expand All @@ -369,46 +370,48 @@ def _sanitize_info(self, info: BatchedStepResult) -> BatchedStepResult:

# remove the done Agents
indices_to_keep: List[int] = []
for index in range(len(info.agent_id)):
if not info.done[index]:
for index, is_done in enumerate(step_result.done):
if not is_done:
indices_to_keep.append(index)

# set the new AgentDone flags to True
for index in range(len(info.agent_id)):
agent_id = info.agent_id[index]
if not self._stored_info.contains_agent(agent_id):
info.done[index] = True
# Set the new AgentDone flags to True
# Note that the corresponding agent_id that gets marked done will be different
# than the original agent that was done, but this is OK since the gym interface
# only cares about the ordering.
for index, agent_id in enumerate(step_result.agent_id):
if not self._previous_step_result.contains_agent(agent_id):
step_result.done[index] = True
if agent_id in self._done_agents:
info.done[index] = True
step_result.done[index] = True
self._done_agents = set()
self._stored_info = info # store the new original
self._previous_step_result = step_result # store the new original

_mask: Optional[List[np.array]] = None
if info.action_mask is not None:
if step_result.action_mask is not None:
_mask = []
for mask_index in range(len(info.action_mask)):
_mask.append(info.action_mask[mask_index][indices_to_keep])
for mask_index in range(len(step_result.action_mask)):
_mask.append(step_result.action_mask[mask_index][indices_to_keep])
new_obs: List[np.array] = []
for obs_index in range(len(info.obs)):
new_obs.append(info.obs[obs_index][indices_to_keep])
for obs_index in range(len(step_result.obs)):
new_obs.append(step_result.obs[obs_index][indices_to_keep])
return BatchedStepResult(
obs=new_obs,
reward=info.reward[indices_to_keep],
done=info.done[indices_to_keep],
max_step=info.max_step[indices_to_keep],
agent_id=info.agent_id[indices_to_keep],
reward=step_result.reward[indices_to_keep],
done=step_result.done[indices_to_keep],
max_step=step_result.max_step[indices_to_keep],
agent_id=step_result.agent_id[indices_to_keep],
action_mask=_mask,
)

def _sanitize_action(self, action: np.array) -> np.array:
if self._stored_info.n_agents() == self._n_agents:
if self._previous_step_result.n_agents() == self._n_agents:
return action
sanitized_action = np.zeros(
(self._stored_info.n_agents(), self.group_spec.action_size)
(self._previous_step_result.n_agents(), self.group_spec.action_size)
)
input_index = 0
for index in range(self._stored_info.n_agents()):
if not self._stored_info.done[index]:
for index in range(self._previous_step_result.n_agents()):
if not self._previous_step_result.done[index]:
sanitized_action[index, :] = action[input_index, :]
input_index = input_index + 1
return sanitized_action
Expand All @@ -419,8 +422,10 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult:
else:
self._env.step()
info = self._env.get_step_result(self.brain_name)
# In case some Agents raised a Done flag between steps, we re-request
# decisions until all agents request a real decision.
# Two possible cases here:
# 1) all agents requested decisions (some of which might be done)
# 2) some Agents were marked Done in between steps.
# In case 2, we re-request decisions until all agents request a real decision.
while info.n_agents() - sum(info.done) < self._n_agents:
if not info.done.all():
raise UnityGymException(
Expand Down