-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refining task api, etc #63
Changes from all commits
1a578a6
e16995a
977f6b0
a7975e5
62f3e0b
ad1bc2b
58dd1d7
153a1ed
fc3ae27
b40bf91
85b699f
1d2d46e
25bab85
20db562
18b04de
4f0c94c
44b47c1
d9aa89b
decb0a7
332f843
44eab79
79a66bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import functools | ||
import random | ||
import copy | ||
from typing import Any, Dict, List, Optional, Union, Tuple | ||
from typing import Any, Dict, List, Callable | ||
from collections import defaultdict | ||
from ordered_set import OrderedSet | ||
|
||
import gym | ||
|
@@ -16,8 +16,7 @@ | |
from nmmo.entity.entity import Entity | ||
from nmmo.systems.item import Item | ||
from nmmo.task.game_state import GameStateGenerator | ||
from nmmo.task.task_api import Task | ||
from nmmo.task.scenario import default_task | ||
from nmmo.task import task_api | ||
from scripted.baselines import Scripted | ||
|
||
class Env(ParallelEnv): | ||
|
@@ -41,15 +40,7 @@ def __init__(self, | |
|
||
self._gamestate_generator = GameStateGenerator(self.realm, self.config) | ||
self.game_state = None | ||
# Default task: rewards 1 each turn agent is alive | ||
self.tasks: List[Tuple[Task,float]] = None | ||
self._task_encoding = None | ||
self._task_embedding_size = -1 | ||
t = default_task(self.possible_agents) | ||
self.change_task(t, | ||
embedding_size=self._task_embedding_size, | ||
task_encoding=self._task_encoding, | ||
reset=False) | ||
self.tasks = task_api.nmmo_default_task(self.possible_agents) | ||
|
||
# pylint: disable=method-cache-max-size-none | ||
@functools.lru_cache(maxsize=None) | ||
|
@@ -88,12 +79,6 @@ def box(rows, cols): | |
if self.config.PROVIDE_ACTION_TARGETS: | ||
obs_space['ActionTargets'] = self.action_space(None) | ||
|
||
if self._task_encoding: | ||
obs_space['Task'] = gym.spaces.Box( | ||
low=-2**20, high=2**20, | ||
shape=(self._task_embedding_size,), | ||
dtype=np.float32) | ||
|
||
return gym.spaces.Dict(obs_space) | ||
|
||
def _init_random(self, seed): | ||
|
@@ -131,38 +116,18 @@ def action_space(self, agent): | |
############################################################################ | ||
# Core API | ||
|
||
def change_task(self, | ||
new_tasks: List[Union[Tuple[Task, float], Task]], | ||
task_encoding: Optional[Dict[int, np.ndarray]] = None, | ||
embedding_size: int=16, | ||
reset: bool=True, | ||
map_id=None, | ||
seed=None, | ||
options=None): | ||
""" Changes the task given to each agent | ||
|
||
Args: | ||
new_task: The task to complete and calculate rewards | ||
task_encoding: A mapping from eid to encoded task | ||
embedding_size: The size of each embedding | ||
reset: Resets the environment | ||
""" | ||
self._tasks = [t if isinstance(t, Tuple) else (t,1) for t in new_tasks] | ||
self._task_encoding = task_encoding | ||
self._task_embedding_size = embedding_size | ||
if reset: | ||
self.reset(map_id=map_id, seed=seed, options=options) | ||
|
||
# TODO: This doesn't conform to the PettingZoo API | ||
# pylint: disable=arguments-renamed | ||
def reset(self, map_id=None, seed=None, options=None): | ||
def reset(self, map_id=None, seed=None, options=None, | ||
make_task_fn: Callable=None): | ||
'''OpenAI Gym API reset function | ||
|
||
Loads a new game map and returns initial observations | ||
|
||
Args: | ||
idx: Map index to load. Selects a random map by default | ||
|
||
map_id: Map index to load. Selects a random map by default | ||
seed: random seed to use | ||
make_task_fn: A function to make tasks | ||
|
||
Returns: | ||
observations, as documented by _compute_observations() | ||
|
@@ -186,16 +151,16 @@ def reset(self, map_id=None, seed=None, options=None): | |
if isinstance(ent.agent, Scripted): | ||
self.scripted_agents.add(eid) | ||
|
||
self.tasks = copy.deepcopy(self._tasks) | ||
self.obs = self._compute_observations() | ||
self._gamestate_generator = GameStateGenerator(self.realm, self.config) | ||
|
||
gym_obs = {} | ||
for a, o in self.obs.items(): | ||
gym_obs[a] = o.to_gym() | ||
if self._task_encoding: | ||
gym_obs[a]['Task'] = self._encode_goal().get(a,np.zeros(self._task_embedding_size)) | ||
return gym_obs | ||
if make_task_fn is not None: | ||
self.tasks = make_task_fn() | ||
else: | ||
for task in self.tasks: | ||
task.reset() | ||
|
||
return {a: o.to_gym() for a,o in self.obs.items()} | ||
|
||
def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): | ||
'''Simulates one game tick or timestep | ||
|
@@ -308,11 +273,7 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): | |
|
||
# Store the observations, since actions reference them | ||
self.obs = self._compute_observations() | ||
gym_obs = {} | ||
for a, o in self.obs.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't we still need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously, this was
Currently, I'm not sure about the exact form of the task encoding. Hoping to get some input/specs as we start task conditioned learning very soon |
||
gym_obs[a] = o.to_gym() | ||
if self._task_encoding: | ||
gym_obs[a]['Task'] = self._encode_goal()[a] | ||
gym_obs = {a: o.to_gym() for a,o in self.obs.items()} | ||
|
||
rewards, infos = self._compute_rewards(self.obs.keys(), dones) | ||
|
||
|
@@ -321,8 +282,6 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): | |
def _validate_actions(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): | ||
'''Deserialize action arg values and validate actions | ||
For now, it does a basic validation (e.g., value is not none). | ||
|
||
TODO(kywch): add sophisticated validation like use/sell/give on the same item | ||
''' | ||
validated_actions = {} | ||
|
||
|
@@ -423,9 +382,6 @@ def _compute_observations(self): | |
inventory, market) | ||
return obs | ||
|
||
def _encode_goal(self): | ||
return self._task_encoding | ||
|
||
def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]): | ||
'''Computes the reward for the specified agent | ||
|
||
|
@@ -442,27 +398,23 @@ def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]): | |
entity identified by ent_id. | ||
''' | ||
# Initialization | ||
self.game_state = self._gamestate_generator.generate(self.realm, self.obs) | ||
infos = {} | ||
for eid in agents: | ||
infos[eid] = {} | ||
infos[eid]['task'] = {} | ||
rewards = {eid: 0 for eid in agents} | ||
infos = {agent_id: {'task': {}} for agent_id in agents} | ||
rewards = defaultdict(int) | ||
agents = set(agents) | ||
reward_cache = {} | ||
|
||
# Compute Rewards and infos | ||
for task, weight in self.tasks: | ||
task_rewards, task_infos = task.compute_rewards(self.game_state) | ||
for eid, reward in task_rewards.items(): | ||
# Rewards, weighted | ||
rewards[eid] = rewards.get(eid,0) + reward * weight | ||
# Infos | ||
for eid, info in task_infos.items(): | ||
if eid in infos: | ||
infos[eid]['task'] = {**infos[eid]['task'], **info} | ||
|
||
# Remove rewards for dead agents (?) | ||
for eid in dones: | ||
rewards[eid] = 0 | ||
self.game_state = self._gamestate_generator.generate(self.realm, self.obs) | ||
for task in self.tasks: | ||
if task in reward_cache: | ||
task_rewards, task_infos = reward_cache[task] | ||
else: | ||
task_rewards, task_infos = task.compute_rewards(self.game_state) | ||
reward_cache[task] = (task_rewards, task_infos) | ||
for agent_id, reward in task_rewards.items(): | ||
if agent_id in agents and agent_id not in dones: | ||
rewards[agent_id] = rewards.get(agent_id,0) + reward | ||
infos[agent_id]['task'][task.name] = task_infos[agent_id] # progress | ||
|
||
return rewards, infos | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -265,13 +265,13 @@ def update(self): | |
if not config.RESOURCE_SYSTEM_ENABLED: | ||
return | ||
|
||
if config.IMMORTAL: | ||
return | ||
|
||
depletion = config.RESOURCE_DEPLETION_RATE | ||
water = self.entity.resources.water | ||
water.decrement(depletion) | ||
|
||
if self.config.IMMORTAL: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't love IMMORTAL, maybe we can just get rid of it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMMORTAL is pretty much for performance testing. Perhaps change to PERFORMANCE_TEST? |
||
return | ||
|
||
if not self.harvest_adjacent(material.Water, deplete=False): | ||
return | ||
|
||
|
@@ -288,6 +288,9 @@ def update(self): | |
if not config.RESOURCE_SYSTEM_ENABLED: | ||
return | ||
|
||
if config.IMMORTAL: | ||
return | ||
|
||
depletion = config.RESOURCE_DEPLETION_RATE | ||
food = self.entity.resources.food | ||
food.decrement(depletion) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from .game_state import * | ||
from .predicate_api import * | ||
from .task_api import * | ||
from .scenario import * | ||
from .team_helper import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we need this somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need it, but not necessary in this format. Let me check back after seeing the syllabus integration.