Skip to content

Commit

Permalink
Merge pull request #116 from CarperAI/no-zero-mask
Browse files Browse the repository at this point in the history
Removed all all-zero masks
  • Loading branch information
jsuarez5341 committed Aug 29, 2023
2 parents b3446e7 + 7977437 commit eca6938
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
15 changes: 7 additions & 8 deletions nmmo/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def _make_action_targets(self):
# Test below. see tests/core/test_observation_tile.py, test_action_target_consts()
# assert len(action.Style.edges) == 3
masks["Attack"] = {
"Style": np.zeros(3, dtype=np.int8) if self.dummy_obs\
else np.ones(3, dtype=np.int8),
"Style": np.ones(3, dtype=np.int8),
"Target": self._make_attack_mask()
}

Expand All @@ -194,8 +193,7 @@ def _make_action_targets(self):
if self.config.EXCHANGE_SYSTEM_ENABLED:
masks["Sell"] = {
"InventoryItem": self._make_sell_mask(),
"Price": np.zeros(self.config.PRICE_N_OBS, dtype=np.int8) if self.dummy_obs\
else np.ones(self.config.PRICE_N_OBS, dtype=np.int8)
"Price": np.ones(self.config.PRICE_N_OBS, dtype=np.int8)
}
masks["Buy"] = {
"MarketItem": self._make_buy_mask()
Expand All @@ -207,16 +205,16 @@ def _make_action_targets(self):

if self.config.COMMUNICATION_SYSTEM_ENABLED:
masks["Comm"] = {
"Token":\
np.zeros(self.config.COMMUNICATION_NUM_TOKENS, dtype=np.int8) if self.dummy_obs\
else np.ones(self.config.COMMUNICATION_NUM_TOKENS, dtype=np.int8)
"Token":np.ones(self.config.COMMUNICATION_NUM_TOKENS, dtype=np.int8)
}

return masks

def _make_move_mask(self):
if self.dummy_obs:
return np.zeros(len(action.Direction.edges), dtype=np.int8)
mask = np.zeros(len(action.Direction.edges), dtype=np.int8)
mask[-1] = 1 # make sure the noop action is available
return mask
# pylint: disable=not-an-iterable
return np.array([self.tile(*d.delta).material_id in material.Habitable.indices
for d in action.Direction.edges], dtype=np.int8)
Expand Down Expand Up @@ -344,6 +342,7 @@ def _make_give_target_mask(self):

def _make_give_gold_mask(self):
mask = np.zeros(self.config.PRICE_N_OBS, dtype=np.int8)
mask[0] = 1 # To avoid all-0 masks. If the agent has no gold, this action will be ignored.
if self.dummy_obs:
return mask

Expand Down
2 changes: 1 addition & 1 deletion tests/action/test_ammo_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _assert_action_targets_zero(self, gym_obs):
for atn in [action.Use, action.Give, action.Destroy, action.Sell]:
mask += np.sum(gym_obs["ActionTargets"][atn.__name__]["InventoryItem"])
# If MarketItem and InventoryTarget have no-action flags, these sum up to 5
self.assertEqual(mask, 5*int(self.config.PROVIDE_NOOP_ACTION_TARGET))
self.assertEqual(mask, 1 + 5*int(self.config.PROVIDE_NOOP_ACTION_TARGET))

def test_spawn_immunity(self):
env = self._setup_env(random_seed=RANDOM_SEED)
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_observations(self):
self.assertEqual(np.sum(player_obs["Entity"]), 0)
self.assertEqual(np.sum(player_obs["Inventory"]), 0)
self.assertEqual(np.sum(player_obs["Market"]), 0)
self.assertEqual(np.sum(player_obs["ActionTargets"]["Move"]["Direction"]), 0)
self.assertEqual(np.sum(player_obs["ActionTargets"]["Attack"]["Style"]), 0)
self.assertEqual(np.sum(player_obs["ActionTargets"]["Move"]["Direction"]), 1)
self.assertEqual(np.sum(player_obs["ActionTargets"]["Attack"]["Style"]), 3)

obs, rewards, dones, infos = self.env.step({})

Expand Down

0 comments on commit eca6938

Please sign in to comment.