Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refining task api, etc #63

Merged
merged 22 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 32 additions & 80 deletions nmmo/core/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import random
import copy
from typing import Any, Dict, List, Optional, Union, Tuple
from typing import Any, Dict, List, Callable
from collections import defaultdict
from ordered_set import OrderedSet

import gym
Expand All @@ -16,8 +16,7 @@
from nmmo.entity.entity import Entity
from nmmo.systems.item import Item
from nmmo.task.game_state import GameStateGenerator
from nmmo.task.task_api import Task
from nmmo.task.scenario import default_task
from nmmo.task import task_api
from scripted.baselines import Scripted

class Env(ParallelEnv):
Expand All @@ -41,15 +40,7 @@ def __init__(self,

self._gamestate_generator = GameStateGenerator(self.realm, self.config)
self.game_state = None
# Default task: rewards 1 each turn agent is alive
self.tasks: List[Tuple[Task,float]] = None
self._task_encoding = None
self._task_embedding_size = -1
t = default_task(self.possible_agents)
self.change_task(t,
embedding_size=self._task_embedding_size,
task_encoding=self._task_encoding,
reset=False)
self.tasks = task_api.nmmo_default_task(self.possible_agents)

# pylint: disable=method-cache-max-size-none
@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -88,12 +79,6 @@ def box(rows, cols):
if self.config.PROVIDE_ACTION_TARGETS:
obs_space['ActionTargets'] = self.action_space(None)

if self._task_encoding:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need this somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it, but not necessary in this format. Let me check back after seeing the syllabus integration.

obs_space['Task'] = gym.spaces.Box(
low=-2**20, high=2**20,
shape=(self._task_embedding_size,),
dtype=np.float32)

return gym.spaces.Dict(obs_space)

def _init_random(self, seed):
Expand Down Expand Up @@ -131,38 +116,18 @@ def action_space(self, agent):
############################################################################
# Core API

def change_task(self,
new_tasks: List[Union[Tuple[Task, float], Task]],
task_encoding: Optional[Dict[int, np.ndarray]] = None,
embedding_size: int=16,
reset: bool=True,
map_id=None,
seed=None,
options=None):
""" Changes the task given to each agent

Args:
new_task: The task to complete and calculate rewards
task_encoding: A mapping from eid to encoded task
embedding_size: The size of each embedding
reset: Resets the environment
"""
self._tasks = [t if isinstance(t, Tuple) else (t,1) for t in new_tasks]
self._task_encoding = task_encoding
self._task_embedding_size = embedding_size
if reset:
self.reset(map_id=map_id, seed=seed, options=options)

# TODO: This doesn't conform to the PettingZoo API
# pylint: disable=arguments-renamed
def reset(self, map_id=None, seed=None, options=None):
def reset(self, map_id=None, seed=None, options=None,
make_task_fn: Callable=None):
'''OpenAI Gym API reset function

Loads a new game map and returns initial observations

Args:
idx: Map index to load. Selects a random map by default

map_id: Map index to load. Selects a random map by default
seed: random seed to use
make_task_fn: A function to make tasks

Returns:
observations, as documented by _compute_observations()
Expand All @@ -186,16 +151,16 @@ def reset(self, map_id=None, seed=None, options=None):
if isinstance(ent.agent, Scripted):
self.scripted_agents.add(eid)

self.tasks = copy.deepcopy(self._tasks)
self.obs = self._compute_observations()
self._gamestate_generator = GameStateGenerator(self.realm, self.config)

gym_obs = {}
for a, o in self.obs.items():
gym_obs[a] = o.to_gym()
if self._task_encoding:
gym_obs[a]['Task'] = self._encode_goal().get(a,np.zeros(self._task_embedding_size))
return gym_obs
if make_task_fn is not None:
self.tasks = make_task_fn()
else:
for task in self.tasks:
task.reset()

return {a: o.to_gym() for a,o in self.obs.items()}

def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
'''Simulates one game tick or timestep
Expand Down Expand Up @@ -308,11 +273,7 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):

# Store the observations, since actions reference them
self.obs = self._compute_observations()
gym_obs = {}
for a, o in self.obs.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we still need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, this was gym_obs = {a: o.to_gym() for a,o in self.obs.items()}, but it was changed to that to add in

      if self._task_encoding:
        gym_obs[a]['Task'] = self._encode_goal().get(a,np.zeros(self._task_embedding_size))

Currently, I'm not sure about the exact form of the task encoding. Hoping to get some input/specs as we start task conditioned learning very soon

gym_obs[a] = o.to_gym()
if self._task_encoding:
gym_obs[a]['Task'] = self._encode_goal()[a]
gym_obs = {a: o.to_gym() for a,o in self.obs.items()}

rewards, infos = self._compute_rewards(self.obs.keys(), dones)

Expand All @@ -321,8 +282,6 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
def _validate_actions(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
'''Deserialize action arg values and validate actions
For now, it does a basic validation (e.g., value is not none).

TODO(kywch): add sophisticated validation like use/sell/give on the same item
'''
validated_actions = {}

Expand Down Expand Up @@ -423,9 +382,6 @@ def _compute_observations(self):
inventory, market)
return obs

def _encode_goal(self):
return self._task_encoding

def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]):
'''Computes the reward for the specified agent

Expand All @@ -442,27 +398,23 @@ def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]):
entity identified by ent_id.
'''
# Initialization
self.game_state = self._gamestate_generator.generate(self.realm, self.obs)
infos = {}
for eid in agents:
infos[eid] = {}
infos[eid]['task'] = {}
rewards = {eid: 0 for eid in agents}
infos = {agent_id: {'task': {}} for agent_id in agents}
rewards = defaultdict(int)
agents = set(agents)
reward_cache = {}

# Compute Rewards and infos
for task, weight in self.tasks:
task_rewards, task_infos = task.compute_rewards(self.game_state)
for eid, reward in task_rewards.items():
# Rewards, weighted
rewards[eid] = rewards.get(eid,0) + reward * weight
# Infos
for eid, info in task_infos.items():
if eid in infos:
infos[eid]['task'] = {**infos[eid]['task'], **info}

# Remove rewards for dead agents (?)
for eid in dones:
rewards[eid] = 0
self.game_state = self._gamestate_generator.generate(self.realm, self.obs)
for task in self.tasks:
if task in reward_cache:
task_rewards, task_infos = reward_cache[task]
else:
task_rewards, task_infos = task.compute_rewards(self.game_state)
reward_cache[task] = (task_rewards, task_infos)
for agent_id, reward in task_rewards.items():
if agent_id in agents and agent_id not in dones:
rewards[agent_id] = rewards.get(agent_id,0) + reward
infos[agent_id]['task'][task.name] = task_infos[agent_id] # progress

return rewards, infos

Expand Down
6 changes: 3 additions & 3 deletions nmmo/core/realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ def reset(self, map_id: int = None):
self.log_helper.reset()
self.event_log.reset()

if self._replay_helper is not None:
self._replay_helper.reset()

self.map.reset(map_id or np.random.randint(self.config.MAP_N) + 1)

# EntityState and ItemState tables must be empty after players/npcs.reset()
Expand All @@ -104,6 +101,9 @@ def reset(self, map_id: int = None):
Item.INSTANCE_ID = 0
self.items = {}

if self._replay_helper is not None:
self._replay_helper.reset()

def packet(self):
"""Client packet"""
return {
Expand Down
1 change: 0 additions & 1 deletion nmmo/lib/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,4 @@ def get_team_spawn_positions(config, num_teams):
idx = int(len(side)*(i+1)/(teams_per_sides + 1))
team_spawn_positions.append(side[idx])

np.random.shuffle(team_spawn_positions)
return team_spawn_positions
14 changes: 13 additions & 1 deletion nmmo/lib/team_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List


class TeamHelper():
def __init__(self, teams: Dict[int, List[int]]):
self.teams = teams
Expand All @@ -23,3 +22,16 @@ def agent_id(self, team_id: int, position: int) -> int:

def is_agent_in_team(self, agent_id:int , team_id: int) -> bool:
return agent_id in self.teams[team_id]

def get_target_agent(self, team_id: int, target: str):
if target == 'left_team':
return self.teams[(team_id+1) % self.num_teams]
if target == 'left_team_leader':
return self.teams[(team_id+1) % self.num_teams][0]
if target == 'right_team':
return self.teams[(team_id-1) % self.num_teams]
if target == 'right_team_leader':
return self.teams[(team_id-1) % self.num_teams][0]
if target == 'my_team_leader':
return self.teams[team_id][0]
return None
1 change: 1 addition & 0 deletions nmmo/render/replay_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def reset(self):
self.packets = []
self.map = None
self._i = 0
self.update() # to capture the initial packet

def __len__(self):
return len(self.packets)
Expand Down
9 changes: 6 additions & 3 deletions nmmo/systems/skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,13 @@ def update(self):
if not config.RESOURCE_SYSTEM_ENABLED:
return

if config.IMMORTAL:
return

depletion = config.RESOURCE_DEPLETION_RATE
water = self.entity.resources.water
water.decrement(depletion)

if self.config.IMMORTAL:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't love IMMORTAL, maybe we can just get rid of it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMMORTAL is pretty much for performance testing. Perhaps change to PERFORMANCE_TEST?

return

if not self.harvest_adjacent(material.Water, deplete=False):
return

Expand All @@ -288,6 +288,9 @@ def update(self):
if not config.RESOURCE_SYSTEM_ENABLED:
return

if config.IMMORTAL:
return

depletion = config.RESOURCE_DEPLETION_RATE
food = self.entity.resources.food
food.decrement(depletion)
Expand Down
3 changes: 1 addition & 2 deletions nmmo/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .game_state import *
from .predicate_api import *
from .task_api import *
from .scenario import *
from .team_helper import *
Loading