Skip to content

Commit

Permalink
Merge pull request #71 from CarperAI/env-patch
Browse files Browse the repository at this point in the history
added end-of-episode stat, map size check, checks spawn error
  • Loading branch information
kywch committed Jun 8, 2023
2 parents 7386079 + fc737cb commit a670d02
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
20 changes: 17 additions & 3 deletions nmmo/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from nmmo.core.tile import Tile
from nmmo.entity.entity import Entity
from nmmo.systems.item import Item
from nmmo.task.game_state import GameStateGenerator
from nmmo.task import task_api
from nmmo.task.game_state import GameStateGenerator
from scripted.baselines import Scripted

class Env(ParallelEnv):
Expand All @@ -35,11 +35,13 @@ def __init__(self,
self.obs = None

self.possible_agents = list(range(1, config.PLAYER_N + 1))
self._dead_agents = OrderedSet()
self._dead_agents = set()
self._episode_stats = defaultdict(lambda: defaultdict(float))
self.scripted_agents = OrderedSet()

self._gamestate_generator = GameStateGenerator(self.realm, self.config)
self.game_state = None
# Default task: rewards 1 each turn agent is alive
self.tasks = task_api.nmmo_default_task(self.possible_agents)

# pylint: disable=method-cache-max-size-none
Expand Down Expand Up @@ -144,7 +146,8 @@ def reset(self, map_id=None, seed=None, options=None,

self._init_random(seed)
self.realm.reset(map_id)
self._dead_agents = OrderedSet()
self._dead_agents = set()
self._episode_stats.clear()

# check if there are scripted agents
for eid, ent in self.realm.players.items():
Expand Down Expand Up @@ -269,13 +272,24 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
if eid not in self.realm.players or self.realm.tick >= self.config.HORIZON:
if eid not in self._dead_agents:
self._dead_agents.add(eid)
self._episode_stats[eid]["death_tick"] = self.realm.tick
dones[eid] = True

# Store the observations, since actions reference them
self.obs = self._compute_observations()
gym_obs = {a: o.to_gym() for a,o in self.obs.items()}

rewards, infos = self._compute_rewards(self.obs.keys(), dones)
for k,r in rewards.items():
self._episode_stats[k]['reward'] += r

# When the episode ends, add the episode stats to the info of one of
# the last dagents
if len(self._dead_agents) == len(self.possible_agents):
for agent_id, stats in self._episode_stats.items():
if agent_id not in infos:
infos[agent_id] = {}
infos[agent_id]["episode_stats"] = stats

return gym_obs, rewards, dones, infos

Expand Down
5 changes: 5 additions & 0 deletions nmmo/core/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@ def reset(self, map_id):
raise

materials = {mat.index: mat for mat in material.All}
r, c = 0, 0
for r, row in enumerate(map_file):
for c, idx in enumerate(row):
mat = materials[idx]
tile = self.tiles[r, c]
tile.reset(mat, config)

assert c == config.MAP_SIZE - 1
assert r == config.MAP_SIZE - 1

self._repr = None

def step(self):
Expand Down
5 changes: 3 additions & 2 deletions nmmo/core/realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def reset(self, map_id: int = None):
self.log_helper.reset()
self.event_log.reset()

self.map.reset(map_id or np.random.randint(self.config.MAP_N) + 1)
map_id = map_id or np.random.randint(self.config.MAP_N) + 1
self.map.reset(map_id)
self.tick = 0

# EntityState and ItemState tables must be empty after players/npcs.reset()
self.players.reset()
Expand All @@ -92,7 +94,6 @@ def reset(self, map_id: int = None):

self.players.spawn()
self.npcs.spawn()
self.tick = 0

# Global item exchange
self.exchange = Exchange(self)
Expand Down
2 changes: 1 addition & 1 deletion nmmo/lib/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_team_spawn_positions(config, num_teams):
teams_per_sides = (num_teams + 3) // 4 # 1-4 -> 1, 5-8 -> 2, etc.

sides = get_edge_tiles(config)
assert len(sides[0]) > 4*teams_per_sides, 'Map too small for teams'
assert len(sides[0]) >= 4*teams_per_sides, 'Map too small for teams'

team_spawn_positions = []
for side in sides:
Expand Down

0 comments on commit a670d02

Please sign in to comment.