diff --git a/.github/workflows/pylint-test.yml b/.github/workflows/pylint-test.yml new file mode 100644 index 000000000..c80fe2877 --- /dev/null +++ b/.github/workflows/pylint-test.yml @@ -0,0 +1,30 @@ +name: pylint-test + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + pip install . + - name: Running unit tests + run: pytest + - name: Analysing the code with pylint + run: pylint --recursive=y nmmo tests + - name: Looking for xcxc, just in case + run: | + if grep -r --include='*.py' 'xcxc'; then + echo "Found xcxc in the code. Please check the file." + exit 1 + fi diff --git a/.gitignore b/.gitignore index 8d050b47e..0fa0b8db6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,14 @@ # Game maps maps/ *.swp +runs/* +wandb/* + +# local replay file from tests/test_deterministic_replay.py, test_render_save.py +tests/replay_local*.pickle +replay* + +.vscode # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 000000000..b9fbb3eda --- /dev/null +++ b/.pylintrc @@ -0,0 +1,31 @@ +[MESSAGES CONTROL] + +disable=W0511, # TODO/FIXME + W0105, # string is used as a statement + C0114, # missing module docstring + C0115, # missing class docstring + C0116, # missing function docstring + W0221, # arguments differ from overridden method + C0415, # import outside toplevel + E0611, # no name in module + R0901, # too many ancestors + R0902, # too many instance attributes + R0903, # too few public methods + R0911, # too many return statements + R0912, # too many branches + R0913, # too many arguments + R0914, # too many local variables + R0914, # too many local variables + R0915, # too many statements + R0401, # cyclic import + +[INDENTATION] +indent-string=' ' + +[MASTER] +good-names-rgxs=^[_a-zA-Z][_a-z0-9]?$ # whitelist short variables +known-third-party=ordered_set,numpy,gym,pettingzoo,vec_noise,imageio,scipy,tqdm +load-plugins=pylint.extensions.bad_builtin + +[BASIC] +bad-functions=print # checks if these functions are used \ No newline at end of file diff --git a/nmmo/__init__.py b/nmmo/__init__.py index 6e82efabb..d00f0456a 100644 --- a/nmmo/__init__.py +++ b/nmmo/__init__.py @@ -1,38 +1,34 @@ -from .version import __version__ +import logging -import os -motd = r''' ___ ___ ___ ___ - /__/\ /__/\ /__/\ / /\ Version {:<8} - \ \:\ | |::\ | |::\ / /::\ - \ \:\ | |:|:\ | |:|:\ / /:/\:\ An open source - _____\__\:\ __|__|:|\:\ __|__|:|\:\ / /:/ \:\ project originally - /__/::::::::\ /__/::::| \:\ /__/::::| \:\ /__/:/ \__\:\ founded by Joseph Suarez - \ \:\~~\~~\/ \ \:\~~\__\/ \ \:\~~\__\/ \ \:\ / /:/ and formalized at OpenAI - \ \:\ ~~~ \ \:\ \ \:\ \ \:\ /:/ - \ \:\ \ \:\ \ \:\ \ \:\/:/ Now developed and - \ \:\ \ \:\ \ \:\ \ \::/ maintained at MIT in - \__\/ \__\/ \__\/ \__\/ Phillip Isola's lab '''.format(__version__) +from .version import __version__ -from . import scripting from .lib import material, spawn -from .overlay import Overlay, OverlayRegistry +from .render.overlay import Overlay, OverlayRegistry from .io import action -from .io.stimulus import Serialized from .io.action import Action from .core import config, agent from .core.agent import Agent -from .core.env import Env, Replay -from . import scripting, emulation, integrations +from .core.env import Env from .systems.achievement import Task from .core.terrain import MapGenerator, Terrain -__all__ = ['Env', 'config', 'scripting', 'emulation', 'integrations', 'agent', 'Agent', 'MapGenerator', 'Terrain', - 'Serialized', 'action', 'Action', 'scripting', 'material', 'spawn', - 'Task', 'Overlay', 'OverlayRegistry', 'Replay'] +MOTD = rf''' ___ ___ ___ ___ + /__/\ /__/\ /__/\ / /\ Version {__version__:<8} + \ \:\ | |::\ | |::\ / /::\ + \ \:\ | |:|:\ | |:|:\ / /:/\:\ An open source + _____\__\:\ __|__|:|\:\ __|__|:|\:\ / /:/ \:\ project originally + /__/::::::::\ /__/::::| \:\ /__/::::| \:\ /__/:/ \__\:\ founded by Joseph Suarez + \ \:\~~\~~\/ \ \:\~~\__\/ \ \:\~~\__\/ \ \:\ / /:/ and formalized at OpenAI + \ \:\ ~~~ \ \:\ \ \:\ \ \:\ /:/ + \ \:\ \ \:\ \ \:\ \ \:\/:/ Now developed and + \ \:\ \ \:\ \ \:\ \ \::/ maintained at MIT in + \__\/ \__\/ \__\/ \__\/ Phillip Isola's lab ''' + +__all__ = ['Env', 'config', 'agent', 'Agent', 'MapGenerator', 'Terrain', + 'action', 'Action', 'material', 'spawn', + 'Task', 'Overlay', 'OverlayRegistry'] try: - import openskill - from .lib.rating import OpenSkillRating - __all__.append('OpenSkillRating') -except: - print('Warning: OpenSkill not installed. Ignore if you do not need this feature') + __all__.append('OpenSkillRating') +except RuntimeError: + logging.error('Warning: OpenSkill not installed. Ignore if you do not need this feature') diff --git a/nmmo/core/__init__.py b/nmmo/core/__init__.py index fda9f019f..e69de29bb 100644 --- a/nmmo/core/__init__.py +++ b/nmmo/core/__init__.py @@ -1,4 +0,0 @@ -from nmmo.core.map import Map -from nmmo.core.realm import Realm -from nmmo.core.tile import Tile -from nmmo.core.config import Config diff --git a/nmmo/core/agent.py b/nmmo/core/agent.py index 9332b2c6a..04fdd5500 100644 --- a/nmmo/core/agent.py +++ b/nmmo/core/agent.py @@ -1,34 +1,20 @@ -from pdb import set_trace as T - -from nmmo.lib import colors class Agent: - scripted = False - policy = 'Neural' - - color = colors.Neon.CYAN - pop = 0 - - def __init__(self, config, idx): - '''Base class for agents + policy = 'Neural' - Args: - config: A Config object - idx: Unique AgentID int - ''' - self.config = config - self.iden = idx - self.pop = Agent.pop + def __init__(self, config, idx): + '''Base class for agents - def __call__(self, obs): - '''Used by scripted agents to compute actions. Override in subclasses. + Args: + config: A Config object + idx: Unique AgentID int + ''' + self.config = config + self.iden = idx - Args: - obs: Agent observation provided by the environment - ''' - pass + def __call__(self, obs): + '''Used by scripted agents to compute actions. Override in subclasses. -class Random(Agent): - '''Moves randomly, including bumping into things and falling into lava''' - def __call__(self, obs): - return {Action.Move: {Action.Direction: rand.choice(Action.Direction.edges)}} + Args: + obs: Agent observation provided by the environment + ''' diff --git a/nmmo/core/config.py b/nmmo/core/config.py index 85b8d74d0..eca6181f6 100644 --- a/nmmo/core/config.py +++ b/nmmo/core/config.py @@ -1,698 +1,708 @@ -from pdb import set_trace as T -import numpy as np +# pylint: disable=invalid-name +from __future__ import annotations + import os +import sys +import logging import nmmo +from nmmo.core.agent import Agent +from nmmo.core.terrain import MapGenerator from nmmo.lib import utils, material, spawn - class Template(metaclass=utils.StaticIterable): - def __init__(self): - self.data = {} - cls = type(self) - - #Set defaults from static properties - for k, v in cls: - self.set(k, v) - - def override(self, **kwargs): - for k, v in kwargs.items(): - err = 'CLI argument: {} is not a Config property'.format(k) - assert hasattr(self, k), err - self.set(k, v) - - def set(self, k, v): - if type(v) is not property: - try: - setattr(self, k, v) - except: - print('Cannot set attribute: {} to {}'.format(k, v)) - quit() - self.data[k] = v - - def print(self): - keyLen = 0 - for k in self.data.keys(): - keyLen = max(keyLen, len(k)) - - print('Configuration') - for k, v in self.data.items(): - print(' {:{}s}: {}'.format(k, keyLen, v)) - - def items(self): - return self.data.items() - - def __iter__(self): - for k in self.data: - yield k - - def keys(self): - return self.data.keys() - - def values(self): - return self.data.values() + def __init__(self): + self.data = {} + cls = type(self) + + #Set defaults from static properties + for k, v in cls: + self.set(k, v) + + def override(self, **kwargs): + for k, v in kwargs.items(): + err = f'CLI argument: {k} is not a Config property' + assert hasattr(self, k), err + self.set(k, v) + + def set(self, k, v): + if not isinstance(v, property): + try: + setattr(self, k, v) + except AttributeError: + logging.error('Cannot set attribute: %s to %s', str(k), str(v)) + sys.exit() + self.data[k] = v + + # pylint: disable=bad-builtin + def print(self): + key_len = 0 + for k in self.data: + key_len = max(key_len, len(k)) + + print('Configuration') + for k, v in self.data.items(): + print(f' {k:{key_len}s}: {v}') + + def items(self): + return self.data.items() + + def __iter__(self): + for k in self.data: + yield k + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() def validate(config): - err = 'config.Config is a base class. Use config.{Small, Medium Large}''' - assert type(config) != Config, err - - if not config.TERRAIN_SYSTEM_ENABLED: - err = 'Invalid Config: {} requires Terrain' - assert not config.RESOURCE_SYSTEM_ENABLED, err.format('Resource') - assert not config.PROFESSION_SYSTEM_ENABLED, err.format('Profession') - - if not config.COMBAT_SYSTEM_ENABLED: - err = 'Invalid Config: {} requires Combat' - assert not config.NPC_SYSTEM_ENABLED, err.format('NPC') - - if not config.ITEM_SYSTEM_ENABLED: - err = 'Invalid Config: {} requires Inventory' - assert not config.EQUIPMENT_SYSTEM_ENABLED, err.format('Equipment') - assert not config.PROFESSION_SYSTEM_ENABLED, err.format('Profession') - assert not config.EXCHANGE_SYSTEM_ENABLED, err.format('Exchange') + err = 'config.Config is a base class. Use config.{Small, Medium Large}''' + assert isinstance(config, Config), err + + if not config.TERRAIN_SYSTEM_ENABLED: + err = 'Invalid Config: {} requires Terrain' + assert not config.RESOURCE_SYSTEM_ENABLED, err.format('Resource') + assert not config.PROFESSION_SYSTEM_ENABLED, err.format('Profession') + + if not config.COMBAT_SYSTEM_ENABLED: + err = 'Invalid Config: {} requires Combat' + assert not config.NPC_SYSTEM_ENABLED, err.format('NPC') + + if not config.ITEM_SYSTEM_ENABLED: + err = 'Invalid Config: {} requires Inventory' + assert not config.EQUIPMENT_SYSTEM_ENABLED, err.format('Equipment') + assert not config.PROFESSION_SYSTEM_ENABLED, err.format('Profession') + assert not config.EXCHANGE_SYSTEM_ENABLED, err.format('Exchange') class Config(Template): - '''An environment configuration object + '''An environment configuration object + + Global constants are defined as static class variables. You can override + any Config variable using standard CLI syntax (e.g. --NENT=128). - Global constants are defined as static class variables. You can override - any Config variable using standard CLI syntax (e.g. --NENT=128). + The default config as of v1.5 uses 1024x1024 maps with up to 2048 agents + and 1024 NPCs. It is suitable to time horizons of 8192+ steps. For smaller + experiments, consider the SmallMaps config. - The default config as of v1.5 uses 1024x1024 maps with up to 2048 agents - and 1024 NPCs. It is suitable to time horizons of 8192+ steps. For smaller - experiments, consider the SmallMaps config. - - Notes: - We use Google Fire internally to replace standard manual argparse - definitions for each Config property. This means you can subclass - Config to add new static attributes -- CLI definitions will be - generated automatically. - ''' + Notes: + We use Google Fire internally to replace standard manual argparse + definitions for each Config property. This means you can subclass + Config to add new static attributes -- CLI definitions will be + generated automatically. + ''' - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - # TODO: Come up with a better way - # to resolve mixin MRO conflicts - if not hasattr(self, 'TERRAIN_SYSTEM_ENABLED'): - self.TERRAIN_SYSTEM_ENABLED = False + # TODO: Come up with a better way + # to resolve mixin MRO conflicts + if not hasattr(self, 'TERRAIN_SYSTEM_ENABLED'): + self.TERRAIN_SYSTEM_ENABLED = False - if not hasattr(self, 'RESOURCE_SYSTEM_ENABLED'): - self.RESOURCE_SYSTEM_ENABLED = False + if not hasattr(self, 'RESOURCE_SYSTEM_ENABLED'): + self.RESOURCE_SYSTEM_ENABLED = False - if not hasattr(self, 'COMBAT_SYSTEM_ENABLED'): - self.COMBAT_SYSTEM_ENABLED = False + if not hasattr(self, 'COMBAT_SYSTEM_ENABLED'): + self.COMBAT_SYSTEM_ENABLED = False - if not hasattr(self, 'NPC_SYSTEM_ENABLED'): - self.NPC_SYSTEM_ENABLED = False + if not hasattr(self, 'NPC_SYSTEM_ENABLED'): + self.NPC_SYSTEM_ENABLED = False - if not hasattr(self, 'PROGRESSION_SYSTEM_ENABLED'): - self.PROGRESSION_SYSTEM_ENABLED = False + if not hasattr(self, 'PROGRESSION_SYSTEM_ENABLED'): + self.PROGRESSION_SYSTEM_ENABLED = False - if not hasattr(self, 'ITEM_SYSTEM_ENABLED'): - self.ITEM_SYSTEM_ENABLED = False + if not hasattr(self, 'ITEM_SYSTEM_ENABLED'): + self.ITEM_SYSTEM_ENABLED = False - if not hasattr(self, 'EQUIPMENT_SYSTEM_ENABLED'): - self.EQUIPMENT_SYSTEM_ENABLED = False + if not hasattr(self, 'EQUIPMENT_SYSTEM_ENABLED'): + self.EQUIPMENT_SYSTEM_ENABLED = False - if not hasattr(self, 'PROFESSION_SYSTEM_ENABLED'): - self.PROFESSION_SYSTEM_ENABLED = False + if not hasattr(self, 'PROFESSION_SYSTEM_ENABLED'): + self.PROFESSION_SYSTEM_ENABLED = False - if not hasattr(self, 'EXCHANGE_SYSTEM_ENABLED'): - self.EXCHANGE_SYSTEM_ENABLED = False + if not hasattr(self, 'EXCHANGE_SYSTEM_ENABLED'): + self.EXCHANGE_SYSTEM_ENABLED = False - if not hasattr(self, 'COMMUNICATION_SYSTEM_ENABLED'): - self.COMMUNICATION_SYSTEM_ENABLED = False - - if __debug__: - validate(self) + if not hasattr(self, 'COMMUNICATION_SYSTEM_ENABLED'): + self.COMMUNICATION_SYSTEM_ENABLED = False - deprecated_attrs = [ - 'NENT', 'NPOP', 'AGENTS', 'NMAPS', 'FORCE_MAP_GENERATION', 'SPAWN'] + if __debug__: + validate(self) - for attr in deprecated_attrs: - assert not hasattr(self, attr), f'{attr} has been deprecated or renamed' + deprecated_attrs = [ + 'NENT', 'NPOP', 'AGENTS', 'NMAPS', 'FORCE_MAP_GENERATION', 'SPAWN'] + for attr in deprecated_attrs: + assert not hasattr(self, attr), f'{attr} has been deprecated or renamed' - ############################################################################ - ### Meta-Parameters - def game_system_enabled(self, name) -> bool: - return hasattr(self, name) - def population_mapping_fn(self, idx) -> int: - return idx % self.NPOP + ############################################################################ + ### Meta-Parameters + def game_system_enabled(self, name) -> bool: + return hasattr(self, name) - RENDER = False - '''Flag used by render mode''' - SAVE_REPLAY = False - '''Flag used to save replays''' + SAVE_REPLAY = False + '''Flag used to save replays''' - PLAYERS = [] - '''Player classes from which to spawn''' + PROVIDE_ACTION_TARGETS = False + '''Flag used to provide action targets mask''' - TASKS = [] - '''Tasks for which to compute rewards''' + PLAYERS = [Agent] + '''Player classes from which to spawn''' - ############################################################################ - ### Emulation Parameters - - EMULATE_FLAT_OBS = False - '''Emulate a flat observation space''' + TASKS = [] + '''Tasks for which to compute rewards''' - EMULATE_FLAT_ATN = False - '''Emulate a flat action space''' + ############################################################################ + ### Emulation Parameters - EMULATE_CONST_PLAYER_N = False - '''Emulate a constant number of agents''' + EMULATE_FLAT_OBS = False + '''Emulate a flat observation space''' - EMULATE_CONST_HORIZON = False - '''Emulate a constant HORIZON simulations steps''' + EMULATE_FLAT_ATN = False + '''Emulate a flat action space''' + EMULATE_CONST_PLAYER_N = False + '''Emulate a constant number of agents''' - ############################################################################ - ### Population Parameters - LOG_VERBOSE = False - '''Whether to log server messages or just stats''' + EMULATE_CONST_HORIZON = False + '''Emulate a constant HORIZON simulations steps''' - LOG_ENV = False - '''Whether to log env steps (expensive)''' - LOG_MILESTONES = True - '''Whether to log server-firsts (semi-expensive)''' + ############################################################################ + ### Population Parameters + LOG_VERBOSE = False + '''Whether to log server messages or just stats''' - LOG_EVENTS = True - '''Whether to log events (semi-expensive)''' + LOG_ENV = False + '''Whether to log env steps (expensive)''' - LOG_FILE = None - '''Where to write logs (defaults to console)''' + LOG_MILESTONES = True + '''Whether to log server-firsts (semi-expensive)''' - PLAYERS = [] - '''Player classes from which to spawn''' + LOG_EVENTS = True + '''Whether to log events (semi-expensive)''' - TASKS = [] - '''Tasks for which to compute rewards''' + LOG_FILE = None + '''Where to write logs (defaults to console)''' - ############################################################################ - ### Player Parameters - PLAYER_N = None - '''Maximum number of players spawnable in the environment''' + ############################################################################ + ### Player Parameters + PLAYER_N = None + '''Maximum number of players spawnable in the environment''' - PLAYER_N_OBS = 100 - '''Number of distinct agent observations''' + # TODO(kywch): CHECK if there could be 100+ entities within one's vision + PLAYER_N_OBS = 100 + '''Number of distinct agent observations''' - @property - def PLAYER_POLICIES(self): - '''Number of player policies''' - return len(self.PLAYERS) + @property + def PLAYER_POLICIES(self): + '''Number of player policies''' + return len(self.PLAYERS) - PLAYER_BASE_HEALTH = 100 - '''Initial agent health''' + PLAYER_BASE_HEALTH = 100 + '''Initial agent health''' - PLAYER_VISION_RADIUS = 7 - '''Number of tiles an agent can see in any direction''' + PLAYER_VISION_RADIUS = 7 + '''Number of tiles an agent can see in any direction''' - @property - def PLAYER_VISION_DIAMETER(self): - '''Size of the square tile crop visible to an agent''' - return 2*self.PLAYER_VISION_RADIUS + 1 + @property + def PLAYER_VISION_DIAMETER(self): + '''Size of the square tile crop visible to an agent''' + return 2*self.PLAYER_VISION_RADIUS + 1 - PLAYER_DEATH_FOG = None - '''How long before spawning death fog. None for no death fog''' + PLAYER_DEATH_FOG = None + '''How long before spawning death fog. None for no death fog''' - PLAYER_DEATH_FOG_SPEED = 1 - '''Number of tiles per tick that the fog moves in''' - PLAYER_DEATH_FOG_FINAL_SIZE = 8 - '''Number of tiles from the center that the fog stops''' + ############################################################################ + ### Agent Parameters + IMMORTAL = False + '''Debug parameter: prevents agents from dying except by lava''' - RESPAWN = False + BASE_HEALTH = 10 + '''Initial Constitution level and agent health''' - PLAYER_LOADER = spawn.SequentialLoader - '''Agent loader class specifying spawn sampling''' + PLAYER_DEATH_FOG_SPEED = 1 + '''Number of tiles per tick that the fog moves in''' - PLAYER_SPAWN_ATTEMPTS = None - '''Number of player spawn attempts per tick + PLAYER_DEATH_FOG_FINAL_SIZE = 8 + '''Number of tiles from the center that the fog stops''' - Note that the env will attempt to spawn agents until success - if the current population size is zero.''' + PLAYER_LOADER = spawn.SequentialLoader + '''Agent loader class specifying spawn sampling''' - PLAYER_SPAWN_TEAMMATE_DISTANCE = 1 - '''Buffer tiles between teammates at spawn''' - - @property - def PLAYER_SPAWN_FUNCTION(self): - return spawn.spawn_concurrent + PLAYER_SPAWN_TEAMMATE_DISTANCE = 1 + '''Buffer tiles between teammates at spawn''' - @property - def PLAYER_TEAM_SIZE(self): - if __debug__: - assert not self.PLAYER_N % len(self.PLAYERS) - return self.PLAYER_N // len(self.PLAYERS) + @property + def PLAYER_TEAM_SIZE(self): + if __debug__: + assert not self.PLAYER_N % len(self.PLAYERS) + return self.PLAYER_N // len(self.PLAYERS) - ############################################################################ - ### Map Parameters - MAP_N = 1 - '''Number of maps to generate''' + ############################################################################ + ### Map Parameters + MAP_N = 1 + '''Number of maps to generate''' - MAP_N_TILE = len(material.All.materials) - '''Number of distinct terrain tile types''' + MAP_N_TILE = len(material.All.materials) + '''Number of distinct terrain tile types''' - @property - def MAP_N_OBS(self): - '''Number of distinct tile observations''' - return int(self.PLAYER_VISION_DIAMETER ** 2) + @property + def MAP_N_OBS(self): + '''Number of distinct tile observations''' + return int(self.PLAYER_VISION_DIAMETER ** 2) - MAP_CENTER = None - '''Size of each map (number of tiles along each side)''' + MAP_CENTER = None + '''Size of each map (number of tiles along each side)''' - MAP_BORDER = 16 - '''Number of lava border tiles surrounding each side of the map''' + MAP_BORDER = 16 + '''Number of lava border tiles surrounding each side of the map''' - @property - def MAP_SIZE(self): - return int(self.MAP_CENTER + 2*self.MAP_BORDER) + @property + def MAP_SIZE(self): + return int(self.MAP_CENTER + 2*self.MAP_BORDER) - MAP_GENERATOR = None - '''Specifies a user map generator. Uses default generator if unspecified.''' + MAP_GENERATOR = MapGenerator + '''Specifies a user map generator. Uses default generator if unspecified.''' - MAP_FORCE_GENERATION = True - '''Whether to regenerate and overwrite existing maps''' + MAP_FORCE_GENERATION = True + '''Whether to regenerate and overwrite existing maps''' - MAP_GENERATE_PREVIEWS = False - '''Whether map generation should also save .png previews (slow + large file size)''' + MAP_GENERATE_PREVIEWS = False + '''Whether map generation should also save .png previews (slow + large file size)''' - MAP_PREVIEW_DOWNSCALE = 1 - '''Downscaling factor for png previews''' + MAP_PREVIEW_DOWNSCALE = 1 + '''Downscaling factor for png previews''' - ############################################################################ - ### Path Parameters - PATH_ROOT = os.path.dirname(nmmo.__file__) - '''Global repository directory''' + ############################################################################ + ### Path Parameters + PATH_ROOT = os.path.dirname(nmmo.__file__) + '''Global repository directory''' - PATH_CWD = os.getcwd() - '''Working directory''' + PATH_CWD = os.getcwd() + '''Working directory''' - PATH_RESOURCE = os.path.join(PATH_ROOT, 'resource') - '''Resource directory''' + PATH_RESOURCE = os.path.join(PATH_ROOT, 'resource') + '''Resource directory''' - PATH_TILE = os.path.join(PATH_RESOURCE, '{}.png') - '''Tile path -- format me with tile name''' + PATH_TILE = os.path.join(PATH_RESOURCE, '{}.png') + '''Tile path -- format me with tile name''' - PATH_MAPS = None - '''Generated map directory''' + PATH_MAPS = None + '''Generated map directory''' - PATH_MAP_SUFFIX = 'map{}/map.npy' - '''Map file name''' + PATH_MAP_SUFFIX = 'map{}/map.npy' + '''Map file name''' - PATH_MAP_SUFFIX = 'map{}/map.npy' - '''Map file name''' + PATH_MAP_SUFFIX = 'map{}/map.npy' + '''Map file name''' ############################################################################ ### Game Systems (Static Mixins) class Terrain: - '''Terrain Game System''' + '''Terrain Game System''' - TERRAIN_SYSTEM_ENABLED = True - '''Game system flag''' + TERRAIN_SYSTEM_ENABLED = True + '''Game system flag''' - TERRAIN_FLIP_SEED = False - '''Whether to negate the seed used for generation (useful for unique heldout maps)''' + TERRAIN_FLIP_SEED = False + '''Whether to negate the seed used for generation (useful for unique heldout maps)''' - TERRAIN_FREQUENCY = -3 - '''Base noise frequency range (log2 space)''' + TERRAIN_FREQUENCY = -3 + '''Base noise frequency range (log2 space)''' - TERRAIN_FREQUENCY_OFFSET = 7 - '''Noise frequency octave offset (log2 space)''' + TERRAIN_FREQUENCY_OFFSET = 7 + '''Noise frequency octave offset (log2 space)''' - TERRAIN_LOG_INTERPOLATE_MIN = -2 - '''Minimum interpolation log-strength for noise frequencies''' + TERRAIN_LOG_INTERPOLATE_MIN = -2 + '''Minimum interpolation log-strength for noise frequencies''' - TERRAIN_LOG_INTERPOLATE_MAX = 0 - '''Maximum interpolation log-strength for noise frequencies''' + TERRAIN_LOG_INTERPOLATE_MAX = 0 + '''Maximum interpolation log-strength for noise frequencies''' - TERRAIN_TILES_PER_OCTAVE = 8 - '''Number of octaves sampled from log2 spaced TERRAIN_FREQUENCY range''' + TERRAIN_TILES_PER_OCTAVE = 8 + '''Number of octaves sampled from log2 spaced TERRAIN_FREQUENCY range''' - TERRAIN_LAVA = 0.0 - '''Noise threshold for lava generation''' + TERRAIN_LAVA = 0.0 + '''Noise threshold for lava generation''' - TERRAIN_WATER = 0.30 - '''Noise threshold for water generation''' + TERRAIN_WATER = 0.30 + '''Noise threshold for water generation''' - TERRAIN_GRASS = 0.70 - '''Noise threshold for grass''' + TERRAIN_GRASS = 0.70 + '''Noise threshold for grass''' - TERRAIN_FOREST = 0.85 - '''Noise threshold for forest''' + TERRAIN_FOREST = 0.85 + '''Noise threshold for forest''' class Resource: - '''Resource Game System''' + '''Resource Game System''' - RESOURCE_SYSTEM_ENABLED = True - '''Game system flag''' + RESOURCE_SYSTEM_ENABLED = True + '''Game system flag''' - RESOURCE_BASE = 100 - '''Initial level and capacity for Hunting + Fishing resource skills''' + RESOURCE_BASE = 100 + '''Initial level and capacity for Hunting + Fishing resource skills''' - RESOURCE_DEPLETION_RATE = 5 - '''Depletion rate for food and water''' + RESOURCE_DEPLETION_RATE = 5 + '''Depletion rate for food and water''' - RESOURCE_STARVATION_RATE = 10 - '''Damage per tick without food''' + RESOURCE_STARVATION_RATE = 10 + '''Damage per tick without food''' - RESOURCE_DEHYDRATION_RATE = 10 - '''Damage per tick without water''' + RESOURCE_DEHYDRATION_RATE = 10 + '''Damage per tick without water''' - RESOURCE_FOREST_CAPACITY = 1 - '''Maximum number of harvests before a forest tile decays''' + RESOURCE_FOREST_CAPACITY = 1 + '''Maximum number of harvests before a forest tile decays''' - RESOURCE_FOREST_RESPAWN = 0.025 - '''Probability that a harvested forest tile will regenerate each tick''' + RESOURCE_FOREST_RESPAWN = 0.025 + '''Probability that a harvested forest tile will regenerate each tick''' - RESOURCE_HARVEST_RESTORE_FRACTION = 1.0 - '''Fraction of maximum capacity restored upon collecting a resource''' + RESOURCE_HARVEST_RESTORE_FRACTION = 1.0 + '''Fraction of maximum capacity restored upon collecting a resource''' - RESOURCE_HEALTH_REGEN_THRESHOLD = 0.5 - '''Fraction of maximum resource capacity required to regen health''' + RESOURCE_HEALTH_REGEN_THRESHOLD = 0.5 + '''Fraction of maximum resource capacity required to regen health''' - RESOURCE_HEALTH_RESTORE_FRACTION = 0.1 - '''Fraction of health restored per tick when above half food+water''' + RESOURCE_HEALTH_RESTORE_FRACTION = 0.1 + '''Fraction of health restored per tick when above half food+water''' class Combat: - '''Combat Game System''' + '''Combat Game System''' - COMBAT_SYSTEM_ENABLED = True - '''Game system flag''' + COMBAT_SYSTEM_ENABLED = True + '''Game system flag''' - COMBAT_FRIENDLY_FIRE = True - '''Whether agents with the same population index can hit each other''' + COMBAT_SPAWN_IMMUNITY = 20 + '''Agents older than this many ticks cannot attack agents younger than this many ticks''' - COMBAT_SPAWN_IMMUNITY = 20 - '''Agents older than this many ticks cannot attack agents younger than this many ticks''' + COMBAT_STATUS_DURATION = 3 + '''Combat status lasts for this many ticks after the last combat event. + Combat events include both attacking and being attacked.''' - COMBAT_WEAKNESS_MULTIPLIER = 1.5 - '''Multiplier for super-effective attacks''' + COMBAT_WEAKNESS_MULTIPLIER = 1.5 + '''Multiplier for super-effective attacks''' - def COMBAT_DAMAGE_FORMULA(self, offense, defense, multiplier): - '''Damage formula''' - return int(multiplier * (offense * (15 / (15 + defense)))) + def COMBAT_DAMAGE_FORMULA(self, offense, defense, multiplier): + '''Damage formula''' + return int(multiplier * (offense * (15 / (15 + defense)))) - COMBAT_MELEE_DAMAGE = 30 - '''Melee attack damage''' + COMBAT_MELEE_DAMAGE = 30 + '''Melee attack damage''' - COMBAT_MELEE_REACH = 3 - '''Reach of attacks using the Melee skill''' + COMBAT_MELEE_REACH = 3 + '''Reach of attacks using the Melee skill''' - COMBAT_RANGE_DAMAGE = 30 - '''Range attack damage''' + COMBAT_RANGE_DAMAGE = 30 + '''Range attack damage''' - COMBAT_RANGE_REACH = 3 - '''Reach of attacks using the Range skill''' + COMBAT_RANGE_REACH = 3 + '''Reach of attacks using the Range skill''' - COMBAT_MAGE_DAMAGE = 30 - '''Mage attack damage''' + COMBAT_MAGE_DAMAGE = 30 + '''Mage attack damage''' - COMBAT_MAGE_REACH = 3 - '''Reach of attacks using the Mage skill''' + COMBAT_MAGE_REACH = 3 + '''Reach of attacks using the Mage skill''' class Progression: - '''Progression Game System''' + '''Progression Game System''' - PROGRESSION_SYSTEM_ENABLED = True - '''Game system flag''' + PROGRESSION_SYSTEM_ENABLED = True + '''Game system flag''' - PROGRESSION_BASE_XP_SCALE = 1 - '''Base XP awarded for each skill usage -- multiplied by skill level''' + PROGRESSION_BASE_XP_SCALE = 1 + '''Base XP awarded for each skill usage -- multiplied by skill level''' - PROGRESSION_COMBAT_XP_SCALE = 1 - '''Multiplier on top of XP_SCALE for Melee, Range, and Mage''' + PROGRESSION_COMBAT_XP_SCALE = 1 + '''Multiplier on top of XP_SCALE for Melee, Range, and Mage''' - PROGRESSION_AMMUNITION_XP_SCALE = 1 - '''Multiplier on top of XP_SCALE for Prospecting, Carving, and Alchemy''' + PROGRESSION_AMMUNITION_XP_SCALE = 1 + '''Multiplier on top of XP_SCALE for Prospecting, Carving, and Alchemy''' - PROGRESSION_CONSUMABLE_XP_SCALE = 5 - '''Multiplier on top of XP_SCALE for Fishing and Herbalism''' + PROGRESSION_CONSUMABLE_XP_SCALE = 5 + '''Multiplier on top of XP_SCALE for Fishing and Herbalism''' - PROGRESSION_LEVEL_MAX = 10 - '''Max skill level''' + PROGRESSION_LEVEL_MAX = 10 + '''Max skill level''' - PROGRESSION_MELEE_BASE_DAMAGE = 0 - '''Base Melee attack damage''' + PROGRESSION_MELEE_BASE_DAMAGE = 0 + '''Base Melee attack damage''' - PROGRESSION_MELEE_LEVEL_DAMAGE = 5 - '''Bonus Melee attack damage per level''' + PROGRESSION_MELEE_LEVEL_DAMAGE = 5 + '''Bonus Melee attack damage per level''' - PROGRESSION_RANGE_BASE_DAMAGE = 0 - '''Base Range attack damage''' + PROGRESSION_RANGE_BASE_DAMAGE = 0 + '''Base Range attack damage''' - PROGRESSION_RANGE_LEVEL_DAMAGE = 5 - '''Bonus Range attack damage per level''' + PROGRESSION_RANGE_LEVEL_DAMAGE = 5 + '''Bonus Range attack damage per level''' - PROGRESSION_MAGE_BASE_DAMAGE = 0 - '''Base Mage attack damage ''' + PROGRESSION_MAGE_BASE_DAMAGE = 0 + '''Base Mage attack damage ''' - PROGRESSION_MAGE_LEVEL_DAMAGE = 5 - '''Bonus Mage attack damage per level''' + PROGRESSION_MAGE_LEVEL_DAMAGE = 5 + '''Bonus Mage attack damage per level''' - PROGRESSION_BASE_DEFENSE = 0 - '''Base defense''' + PROGRESSION_BASE_DEFENSE = 0 + '''Base defense''' - PROGRESSION_LEVEL_DEFENSE = 5 - '''Bonus defense per level''' + PROGRESSION_LEVEL_DEFENSE = 5 + '''Bonus defense per level''' class NPC: - '''NPC Game System''' + '''NPC Game System''' - NPC_SYSTEM_ENABLED = True - '''Game system flag''' + NPC_SYSTEM_ENABLED = True + '''Game system flag''' - NPC_N = None - '''Maximum number of NPCs spawnable in the environment''' + NPC_N = None + '''Maximum number of NPCs spawnable in the environment''' - NPC_SPAWN_ATTEMPTS = 25 - '''Number of NPC spawn attempts per tick''' + NPC_SPAWN_ATTEMPTS = 25 + '''Number of NPC spawn attempts per tick''' - NPC_SPAWN_AGGRESSIVE = 0.80 - '''Percentage distance threshold from spawn for aggressive NPCs''' + NPC_SPAWN_AGGRESSIVE = 0.80 + '''Percentage distance threshold from spawn for aggressive NPCs''' - NPC_SPAWN_NEUTRAL = 0.50 - '''Percentage distance threshold from spawn for neutral NPCs''' + NPC_SPAWN_NEUTRAL = 0.50 + '''Percentage distance threshold from spawn for neutral NPCs''' - NPC_SPAWN_PASSIVE = 0.00 - '''Percentage distance threshold from spawn for passive NPCs''' - - NPC_LEVEL_MIN = 1 - '''Minimum NPC level''' + NPC_SPAWN_PASSIVE = 0.00 + '''Percentage distance threshold from spawn for passive NPCs''' - NPC_LEVEL_MAX = 10 - '''Maximum NPC level''' + NPC_LEVEL_MIN = 1 + '''Minimum NPC level''' - NPC_BASE_DEFENSE = 0 - '''Base NPC defense''' + NPC_LEVEL_MAX = 10 + '''Maximum NPC level''' - NPC_LEVEL_DEFENSE = 30 - '''Bonus NPC defense per level''' + NPC_BASE_DEFENSE = 0 + '''Base NPC defense''' - NPC_BASE_DAMAGE = 15 - '''Base NPC damage''' + NPC_LEVEL_DEFENSE = 30 + '''Bonus NPC defense per level''' - NPC_LEVEL_DAMAGE = 30 - '''Bonus NPC damage per level''' + NPC_BASE_DAMAGE = 15 + '''Base NPC damage''' + + NPC_LEVEL_DAMAGE = 30 + '''Bonus NPC damage per level''' class Item: - '''Inventory Game System''' + '''Inventory Game System''' + + ITEM_SYSTEM_ENABLED = True + '''Game system flag''' + + ITEM_N = 17 + '''Number of unique base item classes''' - ITEM_SYSTEM_ENABLED = True - '''Game system flag''' + ITEM_INVENTORY_CAPACITY = 12 + '''Number of inventory spaces''' - ITEM_N = 17 - '''Number of unique base item classes''' + ITEM_ALLOW_GIFT = True + '''Whether agents can give gold/item to each other''' - ITEM_INVENTORY_CAPACITY = 12 - '''Number of inventory spaces''' + @property + def INVENTORY_N_OBS(self): + '''Number of distinct item observations''' + return self.ITEM_INVENTORY_CAPACITY - @property - def ITEM_N_OBS(self): - '''Number of distinct item observations''' - return self.ITEM_N * self.NPC_LEVEL_MAX - #return self.INVENTORY_CAPACITY class Equipment: - '''Equipment Game System''' + '''Equipment Game System''' - EQUIPMENT_SYSTEM_ENABLED = True - '''Game system flag''' + EQUIPMENT_SYSTEM_ENABLED = True + '''Game system flag''' - WEAPON_DROP_PROB = 0.025 - '''Chance of getting a weapon while harvesting ammunition''' + WEAPON_DROP_PROB = 0.025 + '''Chance of getting a weapon while harvesting ammunition''' - EQUIPMENT_WEAPON_BASE_DAMAGE = 15 - '''Base weapon damage''' + EQUIPMENT_WEAPON_BASE_DAMAGE = 15 + '''Base weapon damage''' - EQUIPMENT_WEAPON_LEVEL_DAMAGE = 15 - '''Added weapon damage per level''' + EQUIPMENT_WEAPON_LEVEL_DAMAGE = 15 + '''Added weapon damage per level''' - EQUIPMENT_AMMUNITION_BASE_DAMAGE = 15 - '''Base ammunition damage''' + EQUIPMENT_AMMUNITION_BASE_DAMAGE = 15 + '''Base ammunition damage''' - EQUIPMENT_AMMUNITION_LEVEL_DAMAGE = 15 - '''Added ammunition damage per level''' + EQUIPMENT_AMMUNITION_LEVEL_DAMAGE = 15 + '''Added ammunition damage per level''' - EQUIPMENT_TOOL_BASE_DEFENSE = 30 - '''Base tool defense''' + EQUIPMENT_TOOL_BASE_DEFENSE = 30 + '''Base tool defense''' - EQUIPMENT_TOOL_LEVEL_DEFENSE = 0 - '''Added tool defense per level''' + EQUIPMENT_TOOL_LEVEL_DEFENSE = 0 + '''Added tool defense per level''' - EQUIPMENT_ARMOR_BASE_DEFENSE = 0 - '''Base armor defense''' + EQUIPMENT_ARMOR_BASE_DEFENSE = 0 + '''Base armor defense''' - EQUIPMENT_ARMOR_LEVEL_DEFENSE = 10 - '''Base equipment defense''' + EQUIPMENT_ARMOR_LEVEL_DEFENSE = 10 + '''Base equipment defense''' class Profession: - '''Profession Game System''' + '''Profession Game System''' - PROFESSION_SYSTEM_ENABLED = True - '''Game system flag''' + PROFESSION_SYSTEM_ENABLED = True + '''Game system flag''' - PROFESSION_TREE_CAPACITY = 1 - '''Maximum number of harvests before a tree tile decays''' + PROFESSION_TREE_CAPACITY = 1 + '''Maximum number of harvests before a tree tile decays''' - PROFESSION_TREE_RESPAWN = 0.105 - '''Probability that a harvested tree tile will regenerate each tick''' + PROFESSION_TREE_RESPAWN = 0.105 + '''Probability that a harvested tree tile will regenerate each tick''' - PROFESSION_ORE_CAPACITY = 1 - '''Maximum number of harvests before an ore tile decays''' + PROFESSION_ORE_CAPACITY = 1 + '''Maximum number of harvests before an ore tile decays''' - PROFESSION_ORE_RESPAWN = 0.10 - '''Probability that a harvested ore tile will regenerate each tick''' + PROFESSION_ORE_RESPAWN = 0.10 + '''Probability that a harvested ore tile will regenerate each tick''' - PROFESSION_CRYSTAL_CAPACITY = 1 - '''Maximum number of harvests before a crystal tile decays''' + PROFESSION_CRYSTAL_CAPACITY = 1 + '''Maximum number of harvests before a crystal tile decays''' - PROFESSION_CRYSTAL_RESPAWN = 0.10 - '''Probability that a harvested crystal tile will regenerate each tick''' + PROFESSION_CRYSTAL_RESPAWN = 0.10 + '''Probability that a harvested crystal tile will regenerate each tick''' - PROFESSION_HERB_CAPACITY = 1 - '''Maximum number of harvests before an herb tile decays''' + PROFESSION_HERB_CAPACITY = 1 + '''Maximum number of harvests before an herb tile decays''' - PROFESSION_HERB_RESPAWN = 0.01 - '''Probability that a harvested herb tile will regenerate each tick''' + PROFESSION_HERB_RESPAWN = 0.01 + '''Probability that a harvested herb tile will regenerate each tick''' - PROFESSION_FISH_CAPACITY = 1 - '''Maximum number of harvests before a fish tile decays''' + PROFESSION_FISH_CAPACITY = 1 + '''Maximum number of harvests before a fish tile decays''' - PROFESSION_FISH_RESPAWN = 0.01 - '''Probability that a harvested fish tile will regenerate each tick''' + PROFESSION_FISH_RESPAWN = 0.01 + '''Probability that a harvested fish tile will regenerate each tick''' - @staticmethod - def PROFESSION_CONSUMABLE_RESTORE(level): - return 50 + 5*level + @staticmethod + def PROFESSION_CONSUMABLE_RESTORE(level): + return 50 + 5*level class Exchange: - '''Exchange Game System''' + '''Exchange Game System''' + + EXCHANGE_SYSTEM_ENABLED = True + '''Game system flag''' + + EXCHANGE_LISTING_DURATION = 5 + '''The number of ticks, during which the item is listed for sale''' + + @property + def MARKET_N_OBS(self): + # TODO(kywch): This is a hack. Check if the limit is reached + # pylint: disable=no-member + '''Number of distinct item observations''' + return self.PLAYER_N * self.EXCHANGE_LISTING_DURATION - EXCHANGE_SYSTEM_ENABLED = True - '''Game system flag''' + PRICE_N_OBS = 99 # make it different from PLAYER_N_OBS + '''Number of distinct price observations + This also determines the maximum price one can set for an item + ''' - @property - def EXCHANGE_N_OBS(self): - '''Number of distinct item observations''' - return self.ITEM_N * self.NPC_LEVEL_MAX class Communication: - '''Exchange Game System''' + '''Exchange Game System''' - COMMUNICATION_SYSTEM_ENABLED = True - '''Game system flag''' + COMMUNICATION_SYSTEM_ENABLED = True + '''Game system flag''' - @property - def COMMUNICATION_NUM_TOKENS(self): - '''Number of distinct item observations''' - return self.ITEM_N * self.NPC_LEVEL_MAX + # CHECK ME: When do we actually use this? + COMMUNICATION_NUM_TOKENS = 50 + '''Number of distinct COMM tokens''' -class AllGameSystems(Terrain, Resource, Combat, NPC, Progression, Item, Equipment, Profession, Exchange, Communication): pass +class AllGameSystems( + Terrain, Resource, Combat, NPC, Progression, Item, + Equipment, Profession, Exchange, Communication): + pass ############################################################################ ### Config presets class Small(Config): - '''A small config for debugging and experiments with an expensive outer loop''' + '''A small config for debugging and experiments with an expensive outer loop''' - PATH_MAPS = 'maps/small' + PATH_MAPS = 'maps/small' - PLAYER_N = 64 - PLAYER_SPAWN_ATTEMPTS = 1 + PLAYER_N = 64 - MAP_PREVIEW_DOWNSCALE = 4 - MAP_CENTER = 32 + MAP_PREVIEW_DOWNSCALE = 4 + MAP_CENTER = 32 - TERRAIN_LOG_INTERPOLATE_MIN = 0 + TERRAIN_LOG_INTERPOLATE_MIN = 0 - NPC_N = 32 - NPC_LEVEL_MAX = 5 - NPC_LEVEL_SPREAD = 1 + NPC_N = 32 + NPC_LEVEL_MAX = 5 + NPC_LEVEL_SPREAD = 1 - PROGRESSION_SPAWN_CLUSTERS = 4 - PROGRESSION_SPAWN_UNIFORMS = 16 + PROGRESSION_SPAWN_CLUSTERS = 4 + PROGRESSION_SPAWN_UNIFORMS = 16 - HORIZON = 128 + HORIZON = 128 class Medium(Config): - '''A medium config suitable for most academic-scale research''' + '''A medium config suitable for most academic-scale research''' - PATH_MAPS = 'maps/medium' + PATH_MAPS = 'maps/medium' - PLAYER_N = 256 - PLAYER_SPAWN_ATTEMPTS = 2 + PLAYER_N = 128 - MAP_PREVIEW_DOWNSCALE = 16 - MAP_CENTER = 128 + MAP_PREVIEW_DOWNSCALE = 16 + MAP_CENTER = 128 - NPC_N = 128 - NPC_LEVEL_MAX = 10 - NPC_LEVEL_SPREAD = 1 + NPC_N = 128 + NPC_LEVEL_MAX = 10 + NPC_LEVEL_SPREAD = 1 - PROGRESSION_SPAWN_CLUSTERS = 64 - PROGRESSION_SPAWN_UNIFORMS = 256 + PROGRESSION_SPAWN_CLUSTERS = 64 + PROGRESSION_SPAWN_UNIFORMS = 256 - HORIZON = 1024 + HORIZON = 1024 class Large(Config): - '''A large config suitable for large-scale research or fast models''' + '''A large config suitable for large-scale research or fast models''' - PATH_MAPS = 'maps/large' + PATH_MAPS = 'maps/large' - PLAYER_N = 2048 - PLAYER_SPAWN_ATTEMPTS = 16 + PLAYER_N = 1024 - MAP_PREVIEW_DOWNSCALE = 64 - MAP_CENTER = 1024 + MAP_PREVIEW_DOWNSCALE = 64 + MAP_CENTER = 1024 - NPC_N = 1024 - NPC_LEVEL_MAX = 15 - NPC_LEVEL_SPREAD = 3 + NPC_N = 1024 + NPC_LEVEL_MAX = 15 + NPC_LEVEL_SPREAD = 3 - PROGRESSION_SPAWN_CLUSTERS = 1024 - PROGRESSION_SPAWN_UNIFORMS = 4096 + PROGRESSION_SPAWN_CLUSTERS = 1024 + PROGRESSION_SPAWN_UNIFORMS = 4096 - HORIZON = 8192 + HORIZON = 8192 -class Default(Medium, AllGameSystems): pass +class Default(Medium, AllGameSystems): + pass diff --git a/nmmo/core/env.py b/nmmo/core/env.py index 19cfff797..d7f0b4491 100644 --- a/nmmo/core/env.py +++ b/nmmo/core/env.py @@ -1,822 +1,424 @@ -from pdb import set_trace as T -import numpy as np -import random - import functools -from collections import defaultdict +import random +from typing import Any, Dict, List +from ordered_set import OrderedSet import gym -from pettingzoo import ParallelEnv - -import json -import lzma +import numpy as np +from pettingzoo.utils.env import AgentID, ParallelEnv import nmmo -from nmmo import entity, core, emulation -from nmmo.core import terrain -from nmmo.lib import log -from nmmo.infrastructure import DataType -from nmmo.systems import item as Item +from nmmo.core.config import Default +from nmmo.core.observation import Observation +from nmmo.core.tile import Tile +from nmmo.entity.entity import Entity +from nmmo.systems.item import Item +from nmmo.core import realm + +from scripted.baselines import Scripted -class Replay: - def __init__(self, config): - self.packets = [] - self.map = None +class Env(ParallelEnv): + # Environment wrapper for Neural MMO using the Parallel PettingZoo API - if config is not None: - self.path = config.SAVE_REPLAY + '.lzma' + def __init__(self, + config: Default = nmmo.config.Default(), seed=None): + self._init_random(seed) - self._i = 0 + super().__init__() - def update(self, packet): - data = {} - for key, val in packet.items(): - if key == 'environment': - self.map = val - continue - if key == 'config': - continue + self.config = config + self.realm = realm.Realm(config) + self.obs = None - data[key] = val + self.possible_agents = list(range(1, config.PLAYER_N + 1)) + self._dead_agents = OrderedSet() + self.scripted_agents = OrderedSet() - self.packets.append(data) + # pylint: disable=method-cache-max-size-none + @functools.lru_cache(maxsize=None) + def observation_space(self, agent: int): + '''Neural MMO Observation Space - def save(self): - print(f'Saving replay to {self.path} ...') + Args: + agent: Agent ID - data = { - 'map': self.map, - 'packets': self.packets} + Returns: + observation: gym.spaces object contained the structured observation + for the specified agent. Each visible object is represented by + continuous and discrete vectors of attributes. A 2-layer attentional + encoder can be used to convert this structured observation into + a flat vector embedding.''' - data = json.dumps(data).encode('utf8') - data = lzma.compress(data, format=lzma.FORMAT_ALONE) - with open(self.path, 'wb') as out: - out.write(data) + def box(rows, cols): + return gym.spaces.Box( + low=-2**20, high=2**20, + shape=(rows, cols), + dtype=np.float32) - @classmethod - def load(cls, path): - with open(path, 'rb') as fp: - data = fp.read() + obs_space = { + "Tick": gym.spaces.Discrete(1), + "AgentId": gym.spaces.Discrete(1), + "Tile": box(self.config.MAP_N_OBS, Tile.State.num_attributes), + "Entity": box(self.config.PLAYER_N_OBS, Entity.State.num_attributes) + } + + if self.config.ITEM_SYSTEM_ENABLED: + obs_space["Inventory"] = box(self.config.INVENTORY_N_OBS, Item.State.num_attributes) - data = lzma.decompress(data, format=lzma.FORMAT_ALONE) - data = json.loads(data.decode('utf-8')) + if self.config.EXCHANGE_SYSTEM_ENABLED: + obs_space["Market"] = box(self.config.MARKET_N_OBS, Item.State.num_attributes) - replay = Replay(None) - replay.map = data['map'] - replay.packets = data['packets'] - return replay + if self.config.PROVIDE_ACTION_TARGETS: + obs_space['ActionTargets'] = self.action_space(None) - def render(self): - from nmmo.websocket import Application - client = Application(realm=None) - for packet in self: - client.update(packet) + return gym.spaces.Dict(obs_space) - def __iter__(self): - self._i = 0 - return self + def _init_random(self, seed): + if seed is not None: + np.random.seed(seed) + random.seed(seed) - def __next__(self): - if self._i >= len(self.packets): - raise StopIteration - packet = self.packets[self._i] - packet['environment'] = self.map - self._i += 1 - return packet + @functools.lru_cache(maxsize=None) + def action_space(self, agent): + '''Neural MMO Action Space + Args: + agent: Agent ID -class Env(ParallelEnv): - '''Environment wrapper for Neural MMO using the Parallel PettingZoo API - - Neural MMO provides complex environments featuring structured observations/actions, - variably sized agent populations, and long time horizons. Usage in conjunction - with RLlib as demonstrated in the /projekt wrapper is highly recommended.''' - - metadata = {'render.modes': ['human'], 'name': 'neural-mmo'} - - def __init__(self, config=None, seed=None): - ''' - Args: - config : A forge.blade.core.Config object or subclass object - ''' - if seed is not None: - np.random.seed(seed) - random.seed(seed) - - super().__init__() - - if config is None: - config = nmmo.config.Default() - - assert isinstance(config, nmmo.config.Config), f'Config {config} is not a config instance (did you pass the class?)' - - if not config.PLAYERS: - from nmmo import agent - config.PLAYERS = [agent.Random] - - if not config.MAP_GENERATOR: - config.MAP_GENERATOR = terrain.MapGenerator - - self.realm = core.Realm(config) - self.registry = nmmo.OverlayRegistry(config, self) - - self.config = config - self.overlay = None - self.overlayPos = [256, 256] - self.client = None - self.obs = None - - self.has_reset = False - - # Populate dummy ob - self.dummy_ob = None - self.observation_space(0) - - if self.config.SAVE_REPLAY: - self.replay = Replay(config) - - if config.EMULATE_CONST_PLAYER_N: - self.possible_agents = [i for i in range(1, config.PLAYER_N + 1)] - - # Flat index actions - if config.EMULATE_FLAT_ATN: - self.flat_actions = emulation.pack_atn_space(config) - - @functools.lru_cache(maxsize=None) - def observation_space(self, agent: int): - '''Neural MMO Observation Space - - Args: - agent: Agent ID - - Returns: - observation: gym.spaces object contained the structured observation - for the specified agent. Each visible object is represented by - continuous and discrete vectors of attributes. A 2-layer attentional - encoder can be used to convert this structured observation into - a flat vector embedding.''' - - observation = {} - for entity in sorted(nmmo.Serialized.values()): - if not entity.enabled(self.config): - continue - - rows = entity.N(self.config) - continuous = 0 - discrete = 0 - - for _, attr in entity: - if attr.DISCRETE: - discrete += 1 - if attr.CONTINUOUS: - continuous += 1 - - name = entity.__name__ - observation[name] = { - 'Continuous': gym.spaces.Box( - low=-2**20, high=2**20, - shape=(rows, continuous), - dtype=DataType.CONTINUOUS), - 'Discrete' : gym.spaces.Box( - low=0, high=4096, - shape=(rows, discrete), - dtype=DataType.DISCRETE)} - - #TODO: Find a way to automate this - if name == 'Entity': - observation['Entity']['N'] = gym.spaces.Box( - low=0, high=self.config.PLAYER_N_OBS, - shape=(1,), dtype=DataType.DISCRETE) - elif name == 'Tile': - observation['Tile']['N'] = gym.spaces.Box( - low=0, high=self.config.PLAYER_VISION_DIAMETER, - shape=(1,), dtype=DataType.DISCRETE) - elif name == 'Item': - observation['Item']['N'] = gym.spaces.Box(low=0, high=self.config.ITEM_N_OBS, shape=(1,), dtype=DataType.DISCRETE) - elif name == 'Market': - observation['Market']['N'] = gym.spaces.Box(low=0, high=self.config.EXCHANGE_N_OBS, shape=(1,), dtype=DataType.DISCRETE) - - observation[name] = gym.spaces.Dict(observation[name]) - - observation = gym.spaces.Dict(observation) - - if not self.dummy_ob: - self.dummy_ob = observation.sample() - for ent_key, ent_val in self.dummy_ob.items(): - for attr_key, attr_val in ent_val.items(): - self.dummy_ob[ent_key][attr_key] *= 0 - - - if not self.config.EMULATE_FLAT_OBS: - return observation - - return emulation.pack_obs_space(observation) - - @functools.lru_cache(maxsize=None) - def action_space(self, agent): - '''Neural MMO Action Space - - Args: - agent: Agent ID - - Returns: - actions: gym.spaces object contained the structured actions - for the specified agent. Each action is parameterized by a list - of discrete-valued arguments. These consist of both fixed, k-way - choices (such as movement direction) and selections from the - observation space (such as targeting)''' - - if self.config.EMULATE_FLAT_ATN: - lens = [] - for atn in nmmo.Action.edges(self.config): - for arg in atn.edges: - lens.append(arg.N(self.config)) - return gym.spaces.MultiDiscrete(lens) - #return gym.spaces.Discrete(len(self.flat_actions)) - - actions = {} - for atn in sorted(nmmo.Action.edges(self.config)): - actions[atn] = {} - for arg in sorted(atn.edges): - n = arg.N(self.config) - actions[atn][arg] = gym.spaces.Discrete(n) - - actions[atn] = gym.spaces.Dict(actions[atn]) - - return gym.spaces.Dict(actions) - - ############################################################################ - ### Core API - def reset(self, idx=None, step=True): - '''OpenAI Gym API reset function - - Loads a new game map and returns initial observations - - Args: - idx: Map index to load. Selects a random map by default - - step: Whether to step the environment and return initial obs - - Returns: - obs: Initial obs if step=True, None otherwise - - Notes: - Neural MMO simulates a persistent world. Ideally, you should reset - the environment only once, upon creation. In practice, this approach - limits the number of parallel environment simulations to the number - of CPU cores available. At small and medium hardware scale, we - therefore recommend the standard approach of resetting after a long - but finite horizon: ~1000 timesteps for small maps and - 5000+ timesteps for large maps - - Returns: - observations, as documented by step() - ''' - self.has_reset = True - - self.actions = {} - self.dead = [] - - if idx is None: - idx = np.random.randint(self.config.MAP_N) + 1 - - self.worldIdx = idx - self.realm.reset(idx) - - # Set up logs - self.register_logs() - - if step: - self.obs, _, _, _ = self.step({}) - - return self.obs - - def close(self): - '''For conformity with the PettingZoo API only; rendering is external''' - pass - - def step(self, actions): - '''Simulates one game tick or timestep - - Args: - actions: A dictionary of agent decisions of format:: - - { - agent_1: { - action_1: [arg_1, arg_2], - action_2: [...], - ... - }, - agent_2: { - ... - }, - ... - } - - Where agent_i is the integer index of the i\'th agent - - The environment only evaluates provided actions for provided - agents. Unprovided action types are interpreted as no-ops and - illegal actions are ignored - - It is also possible to specify invalid combinations of valid - actions, such as two movements or two attacks. In this case, - one will be selected arbitrarily from each incompatible sets. - - A well-formed algorithm should do none of the above. We only - Perform this conditional processing to make batched action - computation easier. - - Returns: - (dict, dict, dict, None): - - observations: - A dictionary of agent observations of format:: - - { - agent_1: obs_1, - agent_2: obs_2, - ... - } - - Where agent_i is the integer index of the i\'th agent and - obs_i is specified by the observation_space function. - - rewards: - A dictionary of agent rewards of format:: - - { - agent_1: reward_1, - agent_2: reward_2, - ... - } - - Where agent_i is the integer index of the i\'th agent and - reward_i is the reward of the i\'th' agent. - - By default, agents receive -1 reward for dying and 0 reward for - all other circumstances. Override Env.reward to specify - custom reward functions - - dones: - A dictionary of agent done booleans of format:: - - { - agent_1: done_1, - agent_2: done_2, - ... - } - - Where agent_i is the integer index of the i\'th agent and - done_i is a boolean denoting whether the i\'th agent has died. - - Note that obs_i will be a garbage placeholder if done_i is true. - This is provided only for conformity with PettingZoo. Your - algorithm should not attempt to leverage observations outside of - trajectory bounds. You can omit garbage obs_i values by setting - omitDead=True. - - infos: - A dictionary of agent infos of format: - - { - agent_1: None, - agent_2: None, - ... - } - - Provided for conformity with PettingZoo - ''' - assert self.has_reset, 'step before reset' - - if self.config.RENDER or self.config.SAVE_REPLAY: - packet = { - 'config': self.config, - 'pos': self.overlayPos, - 'wilderness': 0 - } - - packet = {**self.realm.packet(), **packet} - - if self.overlay is not None: - packet['overlay'] = self.overlay - self.overlay = None - - self.packet = packet - - if self.config.SAVE_REPLAY: - self.replay.update(packet) - - #Preprocess actions for neural models - for entID in list(actions.keys()): - #TODO: Should this silently fail? Warning level options? - if entID not in self.realm.players: - continue - - ent = self.realm.players[entID] - - # Fix later -- don't allow action inputs for scripted agents - if ent.agent.scripted: - continue - - if not ent.alive: - continue - - if self.config.EMULATE_FLAT_ATN: - ent_action = {} - idx = 0 - for atn in nmmo.Action.edges(self.config): - ent_action[atn] = {} - for arg in atn.edges: - ent_action[atn][arg] = actions[entID][idx] - idx += 1 - actions[entID] = ent_action - - self.actions[entID] = {} - for atn, args in actions[entID].items(): - self.actions[entID][atn] = {} - drop = False - for arg, val in args.items(): - if arg.argType == nmmo.action.Fixed: - self.actions[entID][atn][arg] = arg.edges[val] - elif arg == nmmo.action.Target: - if val >= len(ent.targets): - drop = True - continue - targ = ent.targets[val] - self.actions[entID][atn][arg] = self.realm.entity(targ) - elif atn in (nmmo.action.Sell, nmmo.action.Use, nmmo.action.Give) and arg == nmmo.action.Item: - if val >= len(ent.inventory.dataframeKeys): - drop = True - continue - itm = [e for e in ent.inventory._item_references][val] - if type(itm) == Item.Gold: - drop = True - continue - self.actions[entID][atn][arg] = itm - elif atn == nmmo.action.Buy and arg == nmmo.action.Item: - if val >= len(self.realm.exchange.dataframeKeys): - drop = True - continue - itm = self.realm.exchange.dataframeVals[val] - self.actions[entID][atn][arg] = itm - elif __debug__: #Fix -inf in classifier and assert err on bad atns - assert False, f'Argument {arg} invalid for action {atn}' - - # Cull actions with bad args - if drop and atn in self.actions[entID]: - del self.actions[entID][atn] - - #Step: Realm, Observations, Logs - self.dead = self.realm.step(self.actions) - self.actions = {} - self.obs = {} - infos = {} - - obs, rewards, dones, self.raw = {}, {}, {}, {} - for entID, ent in self.realm.players.items(): - ob = self.realm.dataframe.get(ent) - self.obs[entID] = ob - if ent.agent.scripted: - atns = ent.agent(ob) - for atn, args in atns.items(): - for arg, val in args.items(): - atns[atn][arg] = arg.deserialize(self.realm, ent, val) - self.actions[entID] = atns - - else: - obs[entID] = ob - rewards[entID], infos[entID] = self.reward(ent) - dones[entID] = False - - self.log_env() - for entID, ent in self.dead.items(): - self.log_player(ent) - - self.realm.exchange.step() - - for entID, ent in self.dead.items(): - if ent.agent.scripted: - continue - rewards[ent.entID], infos[ent.entID] = self.reward(ent) - - dones[ent.entID] = False #TODO: Is this correct behavior? - if not self.config.EMULATE_CONST_HORIZON and not self.config.RESPAWN: - dones[ent.entID] = True - - obs[ent.entID] = self.dummy_ob - - if self.config.EMULATE_CONST_PLAYER_N: - emulation.pad_const_nent(self.config, self.dummy_ob, obs, rewards, dones, infos) - - if self.config.EMULATE_FLAT_OBS: - obs = nmmo.emulation.pack_obs(obs) - - if self.config.EMULATE_CONST_HORIZON: - assert self.realm.tick <= self.config.HORIZON - if self.realm.tick == self.config.HORIZON: - emulation.const_horizon(dones) - - if not len(self.realm.players.items()): - emulation.const_horizon(dones) - - #Pettingzoo API - self.agents = list(self.realm.players.keys()) - - self.obs = obs - return obs, rewards, dones, infos - - ############################################################################ - ### Logging - def max(self, fn): - return max(fn(player) for player in self.realm.players.values()) - - def max_held(self, policy): - lvls = [player.equipment.held.level.val for player in self.realm.players.values() - if player.equipment.held is not None and player.policy == policy] - - if len(lvls) == 0: - return 0 - - return max(lvls) - - def max_item(self, policy): - lvls = [player.equipment.item_level for player in self.realm.players.values() if player.policy == policy] - - if len(lvls) == 0: - return 0 - - return max(lvls) - - def log_env(self) -> None: - '''Logs player data upon death - - This function is called automatically once per environment step - to compute summary stats. You should not call it manually. - Instead, override this method to customize logging. - ''' - - # This fn more or less repeats log_player once per tick - # It was added to support eval-time logging - # It needs to be redone to not duplicate player logging and - # also not slow down training - if not self.config.LOG_ENV: - return - - quill = self.realm.quill - - if len(self.realm.players) == 0: - return - - #Aggregate logs across env - for key, fn in quill.shared.items(): - dat = defaultdict(list) - for _, player in self.realm.players.items(): - name = player.agent.policy - dat[name].append(fn(player)) - for policy, vals in dat.items(): - quill.log_env(f'{key}_{policy}', float(np.mean(vals))) - - if self.config.EXCHANGE_SYSTEM_ENABLED: - for item in nmmo.systems.item.ItemID.item_ids: - for level in range(1, 11): - name = item.__name__ - key = (item, level) - if key in self.realm.exchange.item_listings: - listing = self.realm.exchange.item_listings[key] - quill.log_env(f'Market/{name}-{level}_Price', listing.price if listing.price else 0) - quill.log_env(f'Market/{name}-{level}_Volume', listing.volume if listing.volume else 0) - quill.log_env(f'Market/{name}-{level}_Supply', listing.supply if listing.supply else 0) - else: - quill.log_env(f'Market/{name}-{level}_Price', 0) - quill.log_env(f'Market/{name}-{level}_Volume', 0) - quill.log_env(f'Market/{name}-{level}_Supply', 0) - - def register_logs(self): - config = self.config - quill = self.realm.quill - - quill.register('Basic/Lifetime', lambda player: player.history.timeAlive.val) - - if config.TASKS: - quill.register('Task/Completed', lambda player: player.diary.completed) - quill.register('Task/Reward' , lambda player: player.diary.cumulative_reward) - - else: - quill.register('Task/Completed', lambda player: player.history.timeAlive.val) - - # Skills - if config.PROGRESSION_SYSTEM_ENABLED: - if config.COMBAT_SYSTEM_ENABLED: - quill.register('Skill/Mage', lambda player: player.skills.mage.level.val) - quill.register('Skill/Range', lambda player: player.skills.range.level.val) - quill.register('Skill/Melee', lambda player: player.skills.melee.level.val) - if config.PROFESSION_SYSTEM_ENABLED: - quill.register('Skill/Fishing', lambda player: player.skills.fishing.level.val) - quill.register('Skill/Herbalism', lambda player: player.skills.herbalism.level.val) - quill.register('Skill/Prospecting', lambda player: player.skills.prospecting.level.val) - quill.register('Skill/Carving', lambda player: player.skills.carving.level.val) - quill.register('Skill/Alchemy', lambda player: player.skills.alchemy.level.val) - if config.EQUIPMENT_SYSTEM_ENABLED: - quill.register('Item/Held-Level', lambda player: player.inventory.equipment.held.level.val if player.inventory.equipment.held else 0) - quill.register('Item/Equipment-Total', lambda player: player.equipment.total(lambda e: e.level)) - - if config.EXCHANGE_SYSTEM_ENABLED: - quill.register('Item/Wealth', lambda player: player.inventory.gold.quantity.val) - - # Item usage - if config.PROFESSION_SYSTEM_ENABLED: - quill.register('Item/Ration-Consumed', lambda player: player.ration_consumed) - quill.register('Item/Poultice-Consumed', lambda player: player.poultice_consumed) - quill.register('Item/Ration-Level', lambda player: player.ration_level_consumed) - quill.register('Item/Poultice-Level', lambda player: player.poultice_level_consumed) - - # Market - if config.EXCHANGE_SYSTEM_ENABLED: - quill.register('Exchange/Player-Sells', lambda player: player.sells) - quill.register('Exchange/Player-Buys', lambda player: player.buys) - - - def log_player(self, player) -> None: - '''Logs player data upon death - - This function is called automatically when an agent dies - to compute summary stats. You should not call it manually. - Instead, override this method to customize logging. - - Args: - player: An agent - ''' - - name = player.agent.policy - config = self.config - quill = self.realm.quill - policy = player.policy - - for key, fn in quill.shared.items(): - quill.log_player(f'{key}_{policy}', fn(player)) - - # Duplicated task reward with/without name for SR calc - if player.diary: - if player.agent.scripted: - player.diary.update(self.realm, player) - - quill.log_player(f'Task_Reward', player.diary.cumulative_reward) - - for achievement in player.diary.achievements: - quill.log_player(achievement.name, float(achievement.completed)) - else: - quill.log_player(f'Task_Reward', player.history.timeAlive.val) - - # Used for SR - quill.log_player('PolicyID', player.agent.policyID) - if player.diary: - quill.log_player(f'Task_Reward', player.diary.cumulative_reward) - - def terminal(self): - '''Logs currently alive agents and returns all collected logs - - Automatic log calls occur only when agents die. To evaluate agent - performance over a fixed horizon, you will need to include logs for - agents that are still alive at the end of that horizon. This function - performs that logging and returns the associated a data structure - containing logs for the entire evaluation - - Args: - ent: An agent - - Returns: - Log datastructure - ''' - - for entID, ent in self.realm.players.entities.items(): - self.log_player(ent) - - if self.config.SAVE_REPLAY: - self.replay.save() - - return self.realm.quill.packet - - ############################################################################ - ### Override hooks - def reward(self, player): - '''Computes the reward for the specified agent - - Override this method to create custom reward functions. You have full - access to the environment state via self.realm. Our baselines do not - modify this method; specify any changes when comparing to baselines - - Args: - player: player object - - Returns: - reward: - The reward for the actions on the previous timestep of the - entity identified by entID. - ''' - info = {'population': player.pop} - - if player.entID not in self.realm.players: - return -1, info - - if not player.diary: - return 0, info - - achievement_rewards = player.diary.update(self.realm, player) - reward = sum(achievement_rewards.values()) - - info = {**info, **achievement_rewards} - return reward, info - - - ############################################################################ - ### Client data - def render(self, mode='human') -> None: - '''Data packet used by the renderer - - Returns: - packet: A packet of data for the client - ''' - - assert self.has_reset, 'render before reset' - packet = self.packet - - if not self.client: - from nmmo.websocket import Application - self.client = Application(self) - - pos, cmd = self.client.update(packet) - self.registry.step(self.obs, pos, cmd) - - def register(self, overlay) -> None: - '''Register an overlay to be sent to the client - - The intended use of this function is: User types overlay -> - client sends cmd to server -> server computes overlay update -> - register(overlay) -> overlay is sent to client -> overlay rendered - - Args: - values: A map-sized (self.size) array of floating point values - ''' - err = 'overlay must be a numpy array of dimension (*(env.size), 3)' - assert type(overlay) == np.ndarray, err - self.overlay = overlay.tolist() - - def dense(self): - '''Simulates an agent on every tile and returns observations - - This method is used to compute per-tile visualizations across the - entire map simultaneously. To do so, we spawn agents on each tile - one at a time. We compute the observation for each agent, delete that - agent, and go on to the next one. In this fashion, each agent receives - an observation where it is the only agent alive. This allows us to - isolate potential influences from observations of nearby agents - - This function is slow, and anything you do with it is probably slower. - As a concrete example, consider that we would like to visualize a - learned agent value function for the entire map. This would require - computing a forward pass for one agent per tile. To cut down on - computation costs, we omit lava tiles from this method - - Returns: - (dict, dict): - - observations: - A dictionary of agent observations as specified by step() - - ents: - A corresponding dictionary of agents keyed by their entID - ''' - config = self.config - R, C = self.realm.map.tiles.shape - - entID = 100000 - pop = 0 - name = "Value" - color = (255, 255, 255) - - - observations, ents = {}, {} - for r in range(R): - for c in range(C): - tile = self.realm.map.tiles[r, c] - if not tile.habitable: - continue - - current = tile.ents - n = len(current) - if n == 0: - ent = entity.Player(self.realm, (r, c), entID, pop, name, color) - else: - ent = list(current.values())[0] - - obs = self.realm.dataframe.get(ent) - if n == 0: - self.realm.dataframe.remove(nmmo.Serialized.Entity, entID, ent.pos) - - observations[entID] = obs - ents[entID] = ent - entID += 1 - - return observations, ents + Returns: + actions: gym.spaces object contained the structured actions + for the specified agent. Each action is parameterized by a list + of discrete-valued arguments. These consist of both fixed, k-way + choices (such as movement direction) and selections from the + observation space (such as targeting)''' + + actions = {} + for atn in sorted(nmmo.Action.edges(self.config)): + if atn.enabled(self.config): + + actions[atn] = {} + for arg in sorted(atn.edges): + n = arg.N(self.config) + actions[atn][arg] = gym.spaces.Discrete(n) + + actions[atn] = gym.spaces.Dict(actions[atn]) + + return gym.spaces.Dict(actions) + + ############################################################################ + # Core API + + # TODO: This doesn't conform to the PettingZoo API + # pylint: disable=arguments-renamed + def reset(self, map_id=None, seed=None, options=None): + '''OpenAI Gym API reset function + + Loads a new game map and returns initial observations + + Args: + idx: Map index to load. Selects a random map by default + + + Returns: + observations, as documented by _compute_observations() + + Notes: + Neural MMO simulates a persistent world. Ideally, you should reset + the environment only once, upon creation. In practice, this approach + limits the number of parallel environment simulations to the number + of CPU cores available. At small and medium hardware scale, we + therefore recommend the standard approach of resetting after a long + but finite horizon: ~1000 timesteps for small maps and + 5000+ timesteps for large maps + ''' + + self._init_random(seed) + self.realm.reset(map_id) + self._dead_agents = OrderedSet() + + # check if there are scripted agents + for eid, ent in self.realm.players.items(): + if isinstance(ent.agent, Scripted): + self.scripted_agents.add(eid) + + self.obs = self._compute_observations() + + return {a: o.to_gym() for a,o in self.obs.items()} + + def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): + '''Simulates one game tick or timestep + + Args: + actions: A dictionary of agent decisions of format:: + + { + agent_1: { + action_1: [arg_1, arg_2], + action_2: [...], + ... + }, + agent_2: { + ... + }, + ... + } + + Where agent_i is the integer index of the i\'th agent + + The environment only evaluates provided actions for provided + gents. Unprovided action types are interpreted as no-ops and + illegal actions are ignored + + It is also possible to specify invalid combinations of valid + actions, such as two movements or two attacks. In this case, + one will be selected arbitrarily from each incompatible sets. + + A well-formed algorithm should do none of the above. We only + Perform this conditional processing to make batched action + computation easier. + + Returns: + (dict, dict, dict, None): + + observations: + A dictionary of agent observations of format:: + + { + agent_1: obs_1, + agent_2: obs_2, + ... + } + + Where agent_i is the integer index of the i\'th agent and + obs_i is specified by the observation_space function. + + rewards: + A dictionary of agent rewards of format:: + + { + agent_1: reward_1, + agent_2: reward_2, + ... + } + + Where agent_i is the integer index of the i\'th agent and + reward_i is the reward of the i\'th' agent. + + By default, agents receive -1 reward for dying and 0 reward for + all other circumstances. Override Env.reward to specify + custom reward functions + + dones: + A dictionary of agent done booleans of format:: + + { + agent_1: done_1, + agent_2: done_2, + ... + } + + Where agent_i is the integer index of the i\'th agent and + done_i is a boolean denoting whether the i\'th agent has died. + + Note that obs_i will be a garbage placeholder if done_i is true. + This is provided only for conformity with PettingZoo. Your + algorithm should not attempt to leverage observations outside of + trajectory bounds. You can omit garbage obs_i values by setting + omitDead=True. + + infos: + A dictionary of agent infos of format: + + { + agent_1: None, + agent_2: None, + ... + } + + Provided for conformity with PettingZoo + ''' + assert self.obs is not None, 'step() called before reset' + + # Add in scripted agents' actions, if any + if self.scripted_agents: + actions = self._compute_scripted_agent_actions(actions) + + # Drop invalid actions of BOTH neural and scripted agents + # we don't need _deserialize_scripted_actions() anymore + actions = self._validate_actions(actions) + + # Execute actions + self.realm.step(actions) + + dones = {} + for eid in self.possible_agents: + if eid not in self._dead_agents and ( + eid not in self.realm.players or + self.realm.tick >= self.config.HORIZON): + + self._dead_agents.add(eid) + dones[eid] = True + + # Store the observations, since actions reference them + self.obs = self._compute_observations() + gym_obs = {a: o.to_gym() for a,o in self.obs.items()} + + rewards, infos = self._compute_rewards(self.obs.keys(), dones) + + return gym_obs, rewards, dones, infos + + def _validate_actions(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): + '''Deserialize action arg values and validate actions + For now, it does a basic validation (e.g., value is not none). + + TODO(kywch): add sophisticated validation like use/sell/give on the same item + ''' + validated_actions = {} + + for ent_id, atns in actions.items(): + if ent_id not in self.realm.players: + #assert ent_id in self.realm.players, f'Entity {ent_id} not in realm' + continue # Entity not in the realm -- invalid actions + + entity = self.realm.players[ent_id] + if not entity.alive: + #assert entity.alive, f'Entity {ent_id} is dead' + continue # Entity is dead -- invalid actions + + validated_actions[ent_id] = {} + + for atn, args in sorted(atns.items()): + action_valid = True + deserialized_action = {} + + if not atn.enabled(self.config): + action_valid = False + break + + for arg, val in sorted(args.items()): + obj = arg.deserialize(self.realm, entity, val) + if obj is None: + action_valid = False + break + deserialized_action[arg] = obj + + if action_valid: + validated_actions[ent_id][atn] = deserialized_action + + return validated_actions + + def _compute_scripted_agent_actions(self, actions: Dict[int, Dict[str, Dict[str, Any]]]): + '''Compute actions for scripted agents and add them into the action dict''' + for eid in self.scripted_agents: + # remove the dead scripted agent from the list + if eid not in self.realm.players: + self.scripted_agents.discard(eid) + continue + + # override the provided scripted agents' actions + actions[eid] = self.realm.players[eid].agent(self.obs[eid]) + + return actions + + def _compute_observations(self): + '''Neural MMO Observation API + + Args: + agents: List of agents to return observations for. If None, returns + observations for all agents + + Returns: + obs: Dictionary of observations for each agent + obs[agent_id] = { + "Entity": [e1, e2, ...], + "Tile": [t1, t2, ...], + "Inventory": [i1, i2, ...], + "Market": [m1, m2, ...], + "ActionTargets": { + "Attack": [a1, a2, ...], + "Sell": [s1, s2, ...], + "Buy": [b1, b2, ...], + "Move": [m1, m2, ...], + } + ''' + + obs = {} + + market = Item.Query.for_sale(self.realm.datastore) + + for agent in self.realm.players.values(): + agent_id = agent.id.val + agent_r = agent.row.val + agent_c = agent.col.val + + visible_entities = Entity.Query.window( + self.realm.datastore, + agent_r, agent_c, + self.config.PLAYER_VISION_RADIUS + ) + visible_tiles = Tile.Query.window( + self.realm.datastore, + agent_r, agent_c, + self.config.PLAYER_VISION_RADIUS) + + inventory = Item.Query.owned_by(self.realm.datastore, agent_id) + + obs[agent_id] = Observation( + self.config, self.realm.tick, + agent_id, visible_tiles, visible_entities, inventory, market) + + return obs + + def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]): + '''Computes the reward for the specified agent + + Override this method to create custom reward functions. You have full + access to the environment state via self.realm. Our baselines do not + modify this method; specify any changes when comparing to baselines + + Args: + player: player object + + Returns: + reward: + The reward for the actions on the previous timestep of the + entity identified by ent_id. + ''' + infos = {} + rewards = { eid: -1 for eid in dones } + + for agent_id in agents: + infos[agent_id] = {} + agent = self.realm.players.get(agent_id) + assert agent is not None, f'Agent {agent_id} not found' + + if agent.diary is not None: + rewards[agent_id] = sum(agent.diary.rewards.values()) + infos[agent_id].update(agent.diary.rewards) + + return rewards, infos + + + ############################################################################ + # PettingZoo API + ############################################################################ + + def render(self, mode='human'): + '''For conformity with the PettingZoo API only; rendering is external''' + + @property + def agents(self) -> List[AgentID]: + '''For conformity with the PettingZoo API only; rendering is external''' + return list(self.realm.players.keys()) + + def close(self): + '''For conformity with the PettingZoo API only; rendering is external''' + + def seed(self, seed=None): + return self._init_random(seed) + + def state(self) -> np.ndarray: + raise NotImplementedError + + metadata = {'render.modes': ['human'], 'name': 'neural-mmo'} diff --git a/nmmo/core/log_helper.py b/nmmo/core/log_helper.py new file mode 100644 index 000000000..bb118f049 --- /dev/null +++ b/nmmo/core/log_helper.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import Dict + +from nmmo.core.agent import Agent +from nmmo.entity.player import Player +from nmmo.lib.log import Logger, MilestoneLogger + + +class LogHelper: + @staticmethod + def create(realm) -> LogHelper: + if realm.config.LOG_ENV: + return SimpleLogHelper(realm) + return DummyLogHelper() + +class DummyLogHelper(LogHelper): + def reset(self) -> None: + pass + + def update(self, dead_players: Dict[int, Player]) -> None: + pass + + def log_milestone(self, milestone: str, value: float) -> None: + pass + + def log_event(self, event: str, value: float) -> None: + pass + +class SimpleLogHelper(LogHelper): + def __init__(self, realm) -> None: + self.realm = realm + self.config = realm.config + + self.reset() + + def reset(self): + self._env_logger = Logger() + self._player_logger = Logger() + self._event_logger = DummyLogHelper() + self._milestone_logger = DummyLogHelper() + + if self.config.LOG_EVENTS: + self._event_logger = Logger() + + if self.config.LOG_MILESTONES: + self._milestone_logger = MilestoneLogger(self.config.LOG_FILE) + + self._player_stats_funcs = {} + self._register_player_stats() + + def log_milestone(self, milestone: str, value: float) -> None: + if self.config.LOG_MILESTONES: + self._milestone_logger.log(milestone, value) + + def log_event(self, event: str, value: float) -> None: + if self.config.LOG_EVENTS: + self._event_logger.log(event, value) + + @property + def packet(self): + packet = {'Env': self._env_logger.stats, + 'Player': self._player_logger.stats} + + if self.config.LOG_EVENTS: + packet['Event'] = self._event_logger.stats + else: + packet['Event'] = 'Unavailable: config.LOG_EVENTS = False' + + if self.config.LOG_MILESTONES: + packet['Milestone'] = self._event_logger.stats + else: + packet['Milestone'] = 'Unavailable: config.LOG_MILESTONES = False' + + return packet + + def _register_player_stat(self, name: str, func: callable): + assert name not in self._player_stats_funcs + self._player_stats_funcs[name] = func + + def _register_player_stats(self): + self._register_player_stat('Basic/TimeAlive', lambda player: player.history.time_alive.val) + + if self.config.TASKS: + self._register_player_stat('Task/Completed', lambda player: player.diary.completed) + self._register_player_stat('Task/Reward' , lambda player: player.diary.cumulative_reward) + else: + self._register_player_stat('Task/Completed', lambda player: player.history.time_alive.val) + + # Skills + if self.config.PROGRESSION_SYSTEM_ENABLED: + if self.config.COMBAT_SYSTEM_ENABLED: + self._register_player_stat('Skill/Mage', lambda player: player.skills.mage.level.val) + self._register_player_stat('Skill/Range', lambda player: player.skills.range.level.val) + self._register_player_stat('Skill/Melee', lambda player: player.skills.melee.level.val) + if self.config.PROFESSION_SYSTEM_ENABLED: + self._register_player_stat('Skill/Fishing', lambda player: player.skills.fishing.level.val) + self._register_player_stat('Skill/Herbalism', + lambda player: player.skills.herbalism.level.val) + self._register_player_stat('Skill/Prospecting', + lambda player: player.skills.prospecting.level.val) + self._register_player_stat('Skill/Carving', + lambda player: player.skills.carving.level.val) + self._register_player_stat('Skill/Alchemy', + lambda player: player.skills.alchemy.level.val) + if self.config.EQUIPMENT_SYSTEM_ENABLED: + self._register_player_stat('Item/Held-Level', + lambda player: player.inventory.equipment.held.item.level.val \ + if player.inventory.equipment.held.item else 0) + self._register_player_stat('Item/Equipment-Total', + lambda player: player.equipment.total(lambda e: e.level)) + + if self.config.EXCHANGE_SYSTEM_ENABLED: + self._register_player_stat('Exchange/Player-Sells', lambda player: player.sells) + self._register_player_stat('Exchange/Player-Buys', lambda player: player.buys) + self._register_player_stat('Exchange/Player-Wealth', lambda player: player.gold.val) + + # Item usage + if self.config.PROFESSION_SYSTEM_ENABLED: + self._register_player_stat('Item/Ration-Consumed', lambda player: player.ration_consumed) + self._register_player_stat('Item/Poultice-Consumed', lambda player: player.poultice_consumed) + self._register_player_stat('Item/Ration-Level', lambda player: player.ration_level_consumed) + self._register_player_stat('Item/Poultice-Level', + lambda player: player.poultice_level_consumed) + + def update(self, dead_players: Dict[int, Player]) -> None: + for player in dead_players.values(): + for key, val in self._player_stats(player).items(): + self._player_logger.log(key, val) + + # TODO: handle env logging + + def _player_stats(self, player: Agent) -> Dict[str, float]: + stats = {} + policy = player.policy + + for key, stat_func in self._player_stats_funcs.items(): + stats[f'{key}_{policy}'] = stat_func(player) + + stats['Task_Reward'] = player.history.time_alive.val + + # If diary is enabled, log task and achievement stats + if player.diary: + stats['Task_Reward'] = player.diary.cumulative_reward + + for achievement in player.diary.achievements: + stats["Achievement_{achievement.name}"] = float(achievement.completed) + + return stats diff --git a/nmmo/core/map.py b/nmmo/core/map.py index de3682541..4f02859c9 100644 --- a/nmmo/core/map.py +++ b/nmmo/core/map.py @@ -1,83 +1,83 @@ -from pdb import set_trace as T -import numpy as np +import os import logging +import numpy as np from ordered_set import OrderedSet +from nmmo.core.tile import Tile -from nmmo import core from nmmo.lib import material -import os class Map: - '''Map object representing a list of tiles - - Also tracks a sparse list of tile updates - ''' - def __init__(self, config, realm): - self.config = config - self._repr = None - self.realm = realm - - sz = config.MAP_SIZE - self.tiles = np.zeros((sz, sz), dtype=object) - - for r in range(sz): - for c in range(sz): - self.tiles[r, c] = core.Tile(config, realm, r, c) - - @property - def packet(self): - '''Packet of degenerate resource states''' - missingResources = [] - for e in self.updateList: - missingResources.append(e.pos) - return missingResources - - @property - def repr(self): - '''Flat matrix of tile material indices''' - if not self._repr: - self._repr = [[t.mat.index for t in row] for row in self.tiles] - - return self._repr - - def reset(self, realm, idx): - '''Reuse the current tile objects to load a new map''' - config = self.config - self.updateList = OrderedSet() - - path_map_suffix = config.PATH_MAP_SUFFIX.format(idx) - fPath = os.path.join(config.PATH_CWD, config.PATH_MAPS, path_map_suffix) - - try: - map_file = np.load(fPath) - except FileNotFoundError: - print('Maps not found') - raise - - materials = {mat.index: mat for mat in material.All} - for r, row in enumerate(map_file): - for c, idx in enumerate(row): - mat = materials[idx] - tile = self.tiles[r, c] - tile.reset(mat, config) - - def step(self): - '''Evaluate updatable tiles''' - if self.config.LOG_MILESTONES and self.realm.quill.milestone.log_max(f'Resource_Depleted', len(self.updateList)) and self.config.LOG_VERBOSE: - logging.info(f'RESOURCE: Depleted {len(self.updateList)} resource tiles') - - - for e in self.updateList.copy(): - if not e.depleted: - self.updateList.remove(e) - e.step() - - def harvest(self, r, c, deplete=True): - '''Called by actions that harvest a resource tile''' - - if deplete: - self.updateList.add(self.tiles[r, c]) - - return self.tiles[r, c].harvest(deplete) + '''Map object representing a list of tiles + + Also tracks a sparse list of tile updates + ''' + def __init__(self, config, realm): + self.config = config + self._repr = None + self.realm = realm + self.update_list = None + + sz = config.MAP_SIZE + self.tiles = np.zeros((sz, sz), dtype=object) + + for r in range(sz): + for c in range(sz): + self.tiles[r, c] = Tile(realm, r, c) + + @property + def packet(self): + '''Packet of degenerate resource states''' + missing_resources = [] + for e in self.update_list: + missing_resources.append(e.pos) + return missing_resources + + @property + def repr(self): + '''Flat matrix of tile material indices''' + if not self._repr: + self._repr = [[t.material.index for t in row] for row in self.tiles] + + return self._repr + + def reset(self, map_id): + '''Reuse the current tile objects to load a new map''' + config = self.config + self.update_list = OrderedSet() + + path_map_suffix = config.PATH_MAP_SUFFIX.format(map_id) + f_path = os.path.join(config.PATH_CWD, config.PATH_MAPS, path_map_suffix) + + try: + map_file = np.load(f_path) + except FileNotFoundError: + logging.error('Maps not found') + raise + + materials = {mat.index: mat for mat in material.All} + for r, row in enumerate(map_file): + for c, idx in enumerate(row): + mat = materials[idx] + tile = self.tiles[r, c] + tile.reset(mat, config) + self._repr = None + + def step(self): + '''Evaluate updatable tiles''' + self.realm.log_milestone('Resource_Depleted', len(self.update_list), + f'RESOURCE: Depleted {len(self.update_list)} resource tiles') + + for e in self.update_list.copy(): + if not e.depleted: + self.update_list.remove(e) + e.step() + + def harvest(self, r, c, deplete=True): + '''Called by actions that harvest a resource tile''' + + if deplete: + self.update_list.add(self.tiles[r, c]) + + return self.tiles[r, c].harvest(deplete) diff --git a/nmmo/core/observation.py b/nmmo/core/observation.py new file mode 100644 index 000000000..2697bd140 --- /dev/null +++ b/nmmo/core/observation.py @@ -0,0 +1,376 @@ +from functools import lru_cache + +import numpy as np + +from nmmo.core.tile import TileState +from nmmo.entity.entity import EntityState +from nmmo.systems.item import ItemState +import nmmo.systems.item as item_system +from nmmo.io import action +from nmmo.lib import material, utils + + +class BasicObs: + def __init__(self, values, id_col): + self.values = values + self.ids = values[:, id_col] + + @property + def len(self): + return len(self.ids) + + def id(self, i): + return self.ids[i] if i < self.len else None + + def index(self, val): + return np.nonzero(self.ids == val)[0][0] if val in self.ids else None + + +class InventoryObs(BasicObs): + def __init__(self, values, id_col): + super().__init__(values, id_col) + self.inv_type = self.values[:,ItemState.State.attr_name_to_col["type_id"]] + self.inv_level = self.values[:,ItemState.State.attr_name_to_col["level"]] + + def sig(self, item: item_system.Item, level: int): + idx = np.nonzero((self.inv_type == item.ITEM_TYPE_ID) & (self.inv_level == level))[0] + return idx[0] if len(idx) else None + + +class Observation: + def __init__(self, + config, + current_tick: int, + agent_id: int, + tiles, + entities, + inventory, + market) -> None: + + self.config = config + self.current_tick = current_tick + self.agent_id = agent_id + + self.tiles = tiles[0:config.MAP_N_OBS] + self.entities = BasicObs(entities[0:config.PLAYER_N_OBS], + EntityState.State.attr_name_to_col["id"]) + + if config.COMBAT_SYSTEM_ENABLED: + latest_combat_tick = self.agent().latest_combat_tick + self.agent_in_combat = False if latest_combat_tick == 0 else \ + (current_tick - latest_combat_tick) < config.COMBAT_STATUS_DURATION + else: + self.agent_in_combat = False + + if config.ITEM_SYSTEM_ENABLED: + self.inventory = InventoryObs(inventory[0:config.INVENTORY_N_OBS], + ItemState.State.attr_name_to_col["id"]) + else: + assert inventory.size == 0 + + if config.EXCHANGE_SYSTEM_ENABLED: + self.market = BasicObs(market[0:config.MARKET_N_OBS], + ItemState.State.attr_name_to_col["id"]) + else: + assert market.size == 0 + + # pylint: disable=method-cache-max-size-none + @lru_cache(maxsize=None) + def tile(self, r_delta, c_delta): + '''Return the array object corresponding to a nearby tile + + Args: + r_delta: row offset from current agent + c_delta: col offset from current agent + + Returns: + Vector corresponding to the specified tile + ''' + agent = self.agent() + if (0 <= agent.row + r_delta < self.config.MAP_SIZE) & \ + (0 <= agent.col + c_delta < self.config.MAP_SIZE): + r_cond = (self.tiles[:,TileState.State.attr_name_to_col["row"]] == agent.row + r_delta) + c_cond = (self.tiles[:,TileState.State.attr_name_to_col["col"]] == agent.col + c_delta) + return TileState.parse_array(self.tiles[r_cond & c_cond][0]) + + # return a dummy lava tile at (inf, inf) + return TileState.parse_array([np.inf, np.inf, material.Lava.index]) + + # pylint: disable=method-cache-max-size-none + @lru_cache(maxsize=None) + def entity(self, entity_id): + rows = self.entities.values[self.entities.ids == entity_id] + if rows.size == 0: + return None + return EntityState.parse_array(rows[0]) + + # pylint: disable=method-cache-max-size-none + @lru_cache(maxsize=None) + def agent(self): + return self.entity(self.agent_id) + + def to_gym(self): + '''Convert the observation to a format that can be used by OpenAI Gym''' + + gym_obs = { + "CurrentTick": np.array([self.current_tick]), + "AgentId": np.array([self.agent_id]), + "Tile": np.vstack([ + self.tiles, + np.zeros((self.config.MAP_N_OBS - self.tiles.shape[0], self.tiles.shape[1])) + ]), + "Entity": np.vstack([ + self.entities.values, np.zeros(( + self.config.PLAYER_N_OBS - self.entities.values.shape[0], + self.entities.values.shape[1])) + ]), + } + + if self.config.ITEM_SYSTEM_ENABLED: + gym_obs["Inventory"] = np.vstack([ + self.inventory.values, np.zeros(( + self.config.INVENTORY_N_OBS - self.inventory.values.shape[0], + self.inventory.values.shape[1])) + ]) + + if self.config.EXCHANGE_SYSTEM_ENABLED: + gym_obs["Market"] = np.vstack([ + self.market.values, np.zeros(( + self.config.MARKET_N_OBS - self.market.values.shape[0], + self.market.values.shape[1])) + ]) + + if self.config.PROVIDE_ACTION_TARGETS: + gym_obs["ActionTargets"] = self._make_action_targets() + + return gym_obs + + def _make_action_targets(self): + # TODO(kywch): return all-0 masks for buy/sell/give during combat + + masks = {} + masks[action.Move] = { + action.Direction: self._make_move_mask() + } + + if self.config.COMBAT_SYSTEM_ENABLED: + masks[action.Attack] = { + action.Style: np.ones(len(action.Style.edges), dtype=np.int8), + action.Target: self._make_attack_mask() + } + + if self.config.ITEM_SYSTEM_ENABLED: + masks[action.Use] = { + action.InventoryItem: self._make_use_mask() + } + masks[action.Give] = { + action.InventoryItem: self._make_sell_mask(), + action.Target: self._make_give_target_mask() + } + masks[action.Destroy] = { + action.InventoryItem: self._make_destroy_item_mask() + } + + if self.config.EXCHANGE_SYSTEM_ENABLED: + masks[action.Sell] = { + action.InventoryItem: self._make_sell_mask(), + action.Price: np.ones(len(action.Price.edges), dtype=np.int8) + } + masks[action.Buy] = { + action.MarketItem: self._make_buy_mask() + } + masks[action.GiveGold] = { + action.Target: self._make_give_target_mask(), + action.Price: self._make_give_gold_mask() # reusing Price + } + + if self.config.COMMUNICATION_SYSTEM_ENABLED: + masks[action.Comm] = { + action.Token: np.ones(len(action.Token.edges), dtype=np.int8) + } + + return masks + + def _make_move_mask(self): + # pylint: disable=not-an-iterable + return np.array( + [self.tile(*d.delta).material_id in material.Habitable + for d in action.Direction.edges], dtype=np.int8) + + def _make_attack_mask(self): + # NOTE: Currently, all attacks have the same range + # if we choose to make ranges different, the masks + # should be differently generated by attack styles + assert self.config.COMBAT_MELEE_REACH == self.config.COMBAT_RANGE_REACH + assert self.config.COMBAT_MELEE_REACH == self.config.COMBAT_MAGE_REACH + assert self.config.COMBAT_RANGE_REACH == self.config.COMBAT_MAGE_REACH + + attack_range = self.config.COMBAT_MELEE_REACH + + agent = self.agent() + entities_pos = self.entities.values[:, [EntityState.State.attr_name_to_col["row"], + EntityState.State.attr_name_to_col["col"]]] + within_range = utils.linf(entities_pos, (agent.row, agent.col)) <= attack_range + + immunity = self.config.COMBAT_SPAWN_IMMUNITY + if 0 < immunity < agent.time_alive: + # ids > 0 equals entity.is_player + spawn_immunity = (self.entities.ids > 0) & \ + (self.entities.values[:,EntityState.State.attr_name_to_col["time_alive"]] < immunity) + else: + spawn_immunity = np.ones(self.entities.len, dtype=np.int8) + + # allow friendly fire but no self shooting + not_me = np.ones(self.entities.len, dtype=np.int8) + not_me[self.entities.index(agent.id)] = 0 # mask self + + return np.concatenate([within_range & not_me & spawn_immunity, + np.zeros(self.config.PLAYER_N_OBS - self.entities.len, dtype=np.int8)]) + + def _make_use_mask(self): + # empty inventory -- nothing to use + if not (self.config.ITEM_SYSTEM_ENABLED and self.inventory.len > 0) or self.agent_in_combat: + return np.zeros(self.config.INVENTORY_N_OBS, dtype=np.int8) + + item_skill = self._item_skill() + + not_listed = self.inventory.values[:,ItemState.State.attr_name_to_col["listed_price"]] == 0 + item_type = self.inventory.values[:,ItemState.State.attr_name_to_col["type_id"]] + item_level = self.inventory.values[:,ItemState.State.attr_name_to_col["level"]] + + # level limits are differently applied depending on item types + type_flt = np.tile ( np.array(list(item_skill.keys())), (self.inventory.len,1) ) + level_flt = np.tile ( np.array(list(item_skill.values())), (self.inventory.len,1) ) + item_type = np.tile( np.transpose(np.atleast_2d(item_type)), (1, len(item_skill))) + item_level = np.tile( np.transpose(np.atleast_2d(item_level)), (1, len(item_skill))) + level_satisfied = np.any((item_type == type_flt) & (item_level <= level_flt), axis=1) + + return np.concatenate([not_listed & level_satisfied, + np.zeros(self.config.INVENTORY_N_OBS - self.inventory.len, dtype=np.int8)]) + + def _item_skill(self): + agent = self.agent() + + # the minimum agent level is 1 + level = max(1, agent.melee_level, agent.range_level, agent.mage_level, + agent.fishing_level, agent.herbalism_level, agent.prospecting_level, + agent.carving_level, agent.alchemy_level) + return { + item_system.Hat.ITEM_TYPE_ID: level, + item_system.Top.ITEM_TYPE_ID: level, + item_system.Bottom.ITEM_TYPE_ID: level, + item_system.Sword.ITEM_TYPE_ID: agent.melee_level, + item_system.Bow.ITEM_TYPE_ID: agent.range_level, + item_system.Wand.ITEM_TYPE_ID: agent.mage_level, + item_system.Rod.ITEM_TYPE_ID: agent.fishing_level, + item_system.Gloves.ITEM_TYPE_ID: agent.herbalism_level, + item_system.Pickaxe.ITEM_TYPE_ID: agent.prospecting_level, + item_system.Chisel.ITEM_TYPE_ID: agent.carving_level, + item_system.Arcane.ITEM_TYPE_ID: agent.alchemy_level, + item_system.Scrap.ITEM_TYPE_ID: agent.melee_level, + item_system.Shaving.ITEM_TYPE_ID: agent.range_level, + item_system.Shard.ITEM_TYPE_ID: agent.mage_level, + item_system.Ration.ITEM_TYPE_ID: level, + item_system.Poultice.ITEM_TYPE_ID: level + } + + def _make_destroy_item_mask(self): + # empty inventory -- nothing to destroy + if not (self.config.ITEM_SYSTEM_ENABLED and self.inventory.len > 0) or self.agent_in_combat: + return np.zeros(self.config.INVENTORY_N_OBS, dtype=np.int8) + + not_equipped = self.inventory.values[:,ItemState.State.attr_name_to_col["equipped"]] == 0 + + # not equipped items in the inventory can be destroyed + return np.concatenate([not_equipped, + np.zeros(self.config.INVENTORY_N_OBS - self.inventory.len, dtype=np.int8)]) + + def _make_give_target_mask(self): + # empty inventory -- nothing to give + if not (self.config.ITEM_SYSTEM_ENABLED and self.inventory.len > 0) or self.agent_in_combat: + return np.zeros(self.config.PLAYER_N_OBS, dtype=np.int8) + + agent = self.agent() + entities_pos = self.entities.values[:, [EntityState.State.attr_name_to_col["row"], + EntityState.State.attr_name_to_col["col"]]] + same_tile = utils.linf(entities_pos, (agent.row, agent.col)) == 0 + not_me = self.entities.ids != self.agent_id + player = (self.entities.values[:,EntityState.State.attr_name_to_col["npc_type"]] == 0) + + return np.concatenate([ + (same_tile & player & not_me), + np.zeros(self.config.PLAYER_N_OBS - self.entities.len, dtype=np.int8)]) + + def _make_give_gold_mask(self): + gold = int(self.agent().gold) + mask = np.zeros(self.config.PRICE_N_OBS, dtype=np.int8) + + if gold and not self.agent_in_combat: + mask[:gold] = 1 # NOTE that action.Price starts from Discrete_1 + + return mask + + def _make_sell_mask(self): + # empty inventory -- nothing to sell + if not (self.config.EXCHANGE_SYSTEM_ENABLED and self.inventory.len > 0) \ + or self.agent_in_combat: + return np.zeros(self.config.INVENTORY_N_OBS, dtype=np.int8) + + not_equipped = self.inventory.values[:,ItemState.State.attr_name_to_col["equipped"]] == 0 + not_listed = self.inventory.values[:,ItemState.State.attr_name_to_col["listed_price"]] == 0 + + return np.concatenate([not_equipped & not_listed, + np.zeros(self.config.INVENTORY_N_OBS - self.inventory.len, dtype=np.int8)]) + + def _make_buy_mask(self): + if not self.config.EXCHANGE_SYSTEM_ENABLED or self.agent_in_combat: + return np.zeros(self.config.MARKET_N_OBS, dtype=np.int8) + + market_flt = np.ones(self.market.len, dtype=np.int8) + full_inventory = self.inventory.len >= self.config.ITEM_INVENTORY_CAPACITY + + # if the inventory is full, one can only buy existing ammo stack + # otherwise, one can buy anything owned by other, having enough money + if full_inventory: + exist_ammo_listings = self._existing_ammo_listings() + if not np.any(exist_ammo_listings): + return np.zeros(self.config.MARKET_N_OBS, dtype=np.int8) + market_flt = exist_ammo_listings + + agent = self.agent() + market_items = self.market.values + enough_gold = market_items[:,ItemState.State.attr_name_to_col["listed_price"]] <= agent.gold + not_mine = market_items[:,ItemState.State.attr_name_to_col["owner_id"]] != self.agent_id + + return np.concatenate([market_flt & enough_gold & not_mine, + np.zeros(self.config.MARKET_N_OBS - self.market.len, dtype=np.int8)]) + + def _existing_ammo_listings(self): + sig_col = (ItemState.State.attr_name_to_col["type_id"], + ItemState.State.attr_name_to_col["level"]) + ammo_id = [ammo.ITEM_TYPE_ID for ammo in + [item_system.Scrap, item_system.Shaving, item_system.Shard]] + + # search ammo stack from the inventory + type_flt = np.tile( np.array(ammo_id), (self.inventory.len,1)) + item_type = np.tile( + np.transpose(np.atleast_2d(self.inventory.values[:,sig_col[0]])), + (1, len(ammo_id))) + exist_ammo = self.inventory.values[np.any(item_type == type_flt, axis=1)] + + # self does not have ammo + if exist_ammo.shape[0] == 0: + return np.zeros(self.market.len, dtype=np.int8) + + # search the existing ammo stack from the market that's not mine + type_flt = np.tile( np.array(exist_ammo[:,sig_col[0]]), (self.market.len,1)) + level_flt = np.tile( np.array(exist_ammo[:,sig_col[1]]), (self.market.len,1)) + item_type = np.tile( np.transpose(np.atleast_2d(self.market.values[:,sig_col[0]])), + (1, exist_ammo.shape[0])) + item_level = np.tile( np.transpose(np.atleast_2d(self.market.values[:,sig_col[1]])), + (1, exist_ammo.shape[0])) + exist_ammo_listings = np.any((item_type == type_flt) & (item_level == level_flt), axis=1) + + not_mine = self.market.values[:,ItemState.State.attr_name_to_col["owner_id"]] != self.agent_id + + return exist_ammo_listings & not_mine diff --git a/nmmo/core/realm.py b/nmmo/core/realm.py index 6c206d207..e967a147c 100644 --- a/nmmo/core/realm.py +++ b/nmmo/core/realm.py @@ -1,309 +1,211 @@ -from pdb import set_trace as T -import numpy as np +from __future__ import annotations -from ordered_set import OrderedSet +import logging from collections import defaultdict -from collections.abc import Mapping -from typing import Dict, Callable +from typing import Dict + +import numpy as np import nmmo -from nmmo import core, infrastructure +from nmmo.core.log_helper import LogHelper +from nmmo.core.map import Map +from nmmo.core.tile import TileState +from nmmo.entity.entity import EntityState +from nmmo.entity.entity_manager import NPCManager, PlayerManager +from nmmo.io.action import Action, Buy +from nmmo.datastore.numpy_datastore import NumpyDatastore from nmmo.systems.exchange import Exchange -from nmmo.systems import combat -from nmmo.entity.npc import NPC -from nmmo.entity import Player -from nmmo.systems.item import Item - -from nmmo.io.action import Action -from nmmo.lib import colors, spawn, log - +from nmmo.systems.item import Item, ItemState +from nmmo.lib.event_log import EventLogger, EventState +from nmmo.render.replay_helper import ReplayHelper def prioritized(entities: Dict, merged: Dict): - '''Sort actions into merged according to priority''' - for idx, actions in entities.items(): - for atn, args in actions.items(): - merged[atn.priority].append((idx, (atn, args.values()))) - return merged - + """Sort actions into merged according to priority""" + for idx, actions in entities.items(): + for atn, args in actions.items(): + merged[atn.priority].append((idx, (atn, args.values()))) + return merged -class EntityGroup(Mapping): - def __init__(self, config, realm): - self.dataframe = realm.dataframe - self.config = config - - self.entities = {} - self.dead = {} - - def __len__(self): - return len(self.entities) - - def __contains__(self, e): - return e in self.entities - - def __getitem__(self, key): - return self.entities[key] - - def __iter__(self): - yield from self.entities - - def items(self): - return self.entities.items() - - @property - def corporeal(self): - return {**self.entities, **self.dead} - - @property - def packet(self): - return {k: v.packet() for k, v in self.corporeal.items()} - - def reset(self): - for entID, ent in self.entities.items(): - self.dataframe.remove(nmmo.Serialized.Entity, entID, ent.pos) - - self.entities = {} - self.dead = {} - - def add(iden, entity): - assert iden not in self.entities - self.entities[iden] = entity - - def remove(iden): - assert iden in self.entities - del self.entities[iden] - - def spawn(self, entity): - pos, entID = entity.pos, entity.entID - self.realm.map.tiles[pos].addEnt(entity) - self.entities[entID] = entity - - def cull(self): - self.dead = {} - for entID in list(self.entities): - player = self.entities[entID] - if not player.alive: - r, c = player.base.pos - entID = player.entID - self.dead[entID] = player - - self.realm.map.tiles[r, c].delEnt(entID) - del self.entities[entID] - self.realm.dataframe.remove(nmmo.Serialized.Entity, entID, player.pos) - - return self.dead - - def update(self, actions): - for entID, entity in self.entities.items(): - entity.update(self.realm, actions) - - -class NPCManager(EntityGroup): - def __init__(self, config, realm): - super().__init__(config, realm) - self.realm = realm - - self.spawn_dangers = [] - - def reset(self): - super().reset() - self.idx = -1 - - def spawn(self): - config = self.config - - if not config.NPC_SYSTEM_ENABLED: - return - - for _ in range(config.NPC_SPAWN_ATTEMPTS): - if len(self.entities) >= config.NPC_N: - break - - if self.spawn_dangers: - danger = self.spawn_dangers[-1] - r, c = combat.spawn(config, danger) - else: - center = config.MAP_CENTER - border = self.config.MAP_BORDER - r, c = np.random.randint(border, center+border, 2).tolist() - - if self.realm.map.tiles[r, c].occupied: - continue - - npc = NPC.spawn(self.realm, (r, c), self.idx) - if npc: - super().spawn(npc) - self.idx -= 1 - - if self.spawn_dangers: - self.spawn_dangers.pop() - - def cull(self): - for entity in super().cull().values(): - self.spawn_dangers.append(entity.spawn_danger) - - def actions(self, realm): - actions = {} - for idx, entity in self.entities.items(): - actions[idx] = entity.decide(realm) - return actions - -class PlayerManager(EntityGroup): - def __init__(self, config, realm): - super().__init__(config, realm) - self.palette = colors.Palette() - self.loader = config.PLAYER_LOADER - self.realm = realm - - def reset(self): - super().reset() - self.agents = self.loader(self.config) - self.spawned = OrderedSet() - - def spawnIndividual(self, r, c, idx): - pop, agent = next(self.agents) - agent = agent(self.config, idx) - player = Player(self.realm, (r, c), agent, self.palette.color(pop), pop) - super().spawn(player) - - def spawn(self): - #TODO: remove hard check against fixed function - if self.config.PLAYER_SPAWN_FUNCTION == spawn.spawn_concurrent: - idx = 0 - for r, c in self.config.PLAYER_SPAWN_FUNCTION(self.config): - idx += 1 - - if idx in self.entities: - continue - - if idx in self.spawned and not self.config.RESPAWN: - continue - - self.spawned.add(idx) - - if self.realm.map.tiles[r, c].occupied: - continue - - self.spawnIndividual(r, c, idx) - - return - - #MMO-style spawning - for _ in range(self.config.PLAYER_SPAWN_ATTEMPTS): - if len(self.entities) >= self.config.PLAYER_N: - break - - r, c = self.config.PLAYER_SPAWN_FUNCTION(self.config) - if self.realm.map.tiles[r, c].occupied: - continue - - self.spawnIndividual(r, c) - - while len(self.entities) == 0: - self.spawn() class Realm: - '''Top-level world object''' - def __init__(self, config): - self.config = config - Action.hook(config) - - # Generate maps if they do not exist - config.MAP_GENERATOR(config).generate_all_maps() - - # Load the world file - self.dataframe = infrastructure.Dataframe(self) - self.map = core.Map(config, self) - - # Entity handlers - self.players = PlayerManager(config, self) - self.npcs = NPCManager(config, self) - - # Global item exchange - self.exchange = Exchange() - - # Global item registry - self.items = {} - - # Initialize actions - nmmo.Action.init(config) - - def reset(self, idx): - '''Reset the environment and load the specified map - - Args: - idx: Map index to load - ''' - Item.INSTANCE_ID = 0 - self.quill = log.Quill(self.config) - self.map.reset(self, idx) - self.players.reset() - self.npcs.reset() - self.tick = 0 - - # Global item exchange - self.exchange = Exchange() - - # Global item registry - self.items = {} - - def packet(self): - '''Client packet''' - return {'environment': self.map.repr, - 'border': self.config.MAP_BORDER, - 'size': self.config.MAP_SIZE, - 'resource': self.map.packet, - 'player': self.players.packet, - 'npc': self.npcs.packet, - 'market': self.exchange.packet} - - @property - def population(self): - '''Number of player agents''' - return len(self.players.entities) - - def entity(self, entID): - '''Get entity by ID''' - if entID < 0: - return self.npcs[entID] + """Top-level world object""" + + def __init__(self, config): + self.config = config + assert isinstance( + config, nmmo.config.Config + ), f"Config {config} is not a config instance (did you pass the class?)" + + Action.hook(config) + + # Generate maps if they do not exist + config.MAP_GENERATOR(config).generate_all_maps() + + self.datastore = NumpyDatastore() + for s in [TileState, EntityState, ItemState, EventState]: + self.datastore.register_object_type(s._name, s.State.num_attributes) + + self.tick = None # to use as a "reset" checker + self.exchange = None + + # Load the world file + self.map = Map(config, self) + + self.log_helper = LogHelper.create(self) + self.event_log = EventLogger(self) + + # Entity handlers + self.players = PlayerManager(self) + self.npcs = NPCManager(self) + + # Global item registry + self.items = {} + + # Replay helper + self._replay_helper = ReplayHelper.create(self) + + # Initialize actions + nmmo.Action.init(config) + + def reset(self, map_id: int = None): + """Reset the environment and load the specified map + + Args: + idx: Map index to load + """ + self.log_helper.reset() + self.event_log.reset() + self._replay_helper.reset() + self.map.reset(map_id or np.random.randint(self.config.MAP_N) + 1) + + # EntityState and ItemState tables must be empty after players/npcs.reset() + self.players.reset() + self.npcs.reset() + + # TODO: track down entity/item leaks + EntityState.State.table(self.datastore).reset() + assert EntityState.State.table(self.datastore).is_empty(), \ + "EntityState table is not empty" + + # TODO(kywch): ItemState table is not empty after players/npcs.reset() + # but should be. Will fix this while debugging the item system. + # assert ItemState.State.table(self.datastore).is_empty(), \ + # "ItemState table is not empty" + ItemState.State.table(self.datastore).reset() + + self.players.spawn() + self.npcs.spawn() + self.tick = 0 + + # Global item exchange + self.exchange = Exchange(self) + + # Global item registry + Item.INSTANCE_ID = 0 + self.items = {} + + def packet(self): + """Client packet""" + return { + "environment": self.map.repr, + "border": self.config.MAP_BORDER, + "size": self.config.MAP_SIZE, + "resource": self.map.packet, + "player": self.players.packet, + "npc": self.npcs.packet, + "market": self.exchange.packet, + } + + @property + def num_players(self): + """Number of player agents""" + return len(self.players.entities) + + def entity(self, ent_id): + e = self.entity_or_none(ent_id) + assert e is not None, f"Entity {ent_id} does not exist" + return e + + def entity_or_none(self, ent_id): + if ent_id is None: + return None + + """Get entity by ID""" + if ent_id < 0: + return self.npcs.get(ent_id) + + return self.players.get(ent_id) + + def step(self, actions): + """Run game logic for one tick + + Args: + actions: Dict of agent actions + + Returns: + dead: List of dead agents + """ + # Prioritize actions + npc_actions = self.npcs.actions(self) + merged = defaultdict(list) + prioritized(actions, merged) + prioritized(npc_actions, merged) + + # Update entities and perform actions + self.players.update(actions) + self.npcs.update(npc_actions) + + # Execute actions -- CHECK ME the below priority + # - 10: Use - equip ammo, restore HP, etc. + # - 20: Buy - exchange while sellers, items, buyers are all intact + # - 30: Give, GiveGold - transfer while both are alive and at the same tile + # - 40: Destroy - use with SELL/GIVE, if not gone, destroy and recover space + # - 50: Attack + # - 60: Move + # - 70: Sell - to guarantee the listed items are available to buy + # - 99: Comm + for priority in sorted(merged): + # TODO: we should be randomizing these, otherwise the lower ID agents + # will always go first. --> ONLY SHUFFLE BUY + if priority == Buy.priority: + np.random.shuffle(merged[priority]) + + # CHECK ME: do we need this line? + # ent_id, (atn, args) = merged[priority][0] + for ent_id, (atn, args) in merged[priority]: + ent = self.entity(ent_id) + if ent.alive: + atn.call(self, ent, *args) + + dead = self.players.cull() + self.npcs.cull() + + # Update map + self.map.step() + self.exchange.step(self.tick) + self.log_helper.update(dead) + self._replay_helper.update() + + self.tick += 1 + + return dead + + def log_milestone(self, category: str, value: float, message: str = None, tags: Dict = None): + self.log_helper.log_milestone(category, value) + self.log_helper.log_event(category, value) + + if self.config.LOG_VERBOSE: + # TODO: more general handling of tags, if necessary + if tags and 'player_id' in tags: + logging.info("Milestone (Player %d): %s %s %s", tags['player_id'], category, value, message) else: - return self.players[entID] - - def step(self, actions): - '''Run game logic for one tick - - Args: - actions: Dict of agent actions - ''' - #Prioritize actions - npcActions = self.npcs.actions(self) - merged = defaultdict(list) - prioritized(actions, merged) - prioritized(npcActions, merged) - - #Update entities and perform actions - self.players.update(actions) - self.npcs.update(npcActions) - - #Execute actions - for priority in sorted(merged): - # Buy/sell priority - entID, (atn, args) = merged[priority][0] - if atn in (nmmo.action.Buy, nmmo.action.Sell): - merged[priority] = sorted(merged[priority], key=lambda x: x[0]) - for entID, (atn, args) in merged[priority]: - ent = self.entity(entID) - atn.call(self, ent, *args) - - #Spawn new agent and cull dead ones - #TODO: Place cull before spawn once PettingZoo API fixes respawn on same tick as death bug - self.players.spawn() - self.npcs.spawn() - - dead = self.players.cull() - self.npcs.cull() - - #Update map - self.map.step() - self.tick += 1 - - return dead + logging.info("Milestone: %s %s %s", category, value, message) + + def save_replay(self, save_path, compress=True): + self._replay_helper.save(save_path, compress) + + def get_replay(self): + return { + 'map': self._replay_helper.map, + 'packets': self._replay_helper.packets + } diff --git a/nmmo/core/terrain.py b/nmmo/core/terrain.py index d49627c92..e63fd5a9f 100644 --- a/nmmo/core/terrain.py +++ b/nmmo/core/terrain.py @@ -1,44 +1,47 @@ -from pdb import set_trace as T - -import scipy.stats as stats -import numpy as np -import random import os +import random +import logging +import numpy as np import vec_noise from imageio import imread, imsave -from tqdm import tqdm +from scipy import stats from nmmo import material -def sharp(self, noise): - '''Exponential noise sharpener for perlin ridges''' - return 2 * (0.5 - abs(0.5 - noise)); -class Save: - '''Save utility for map files''' - def render(mats, lookup, path): - '''Render tiles to png''' - images = [[lookup[e] for e in l] for l in mats] - image = np.vstack([np.hstack(e) for e in images]) - imsave(path, image) - - def fractal(terrain, path): - '''Render raw noise fractal to png''' - frac = (256*terrain).astype(np.uint8) - imsave(path, frac) - - def np(mats, path): - '''Save map to .npy''' - path = os.path.join(path, 'map.npy') - np.save(path, mats.astype(int)) +def sharp(noise): + '''Exponential noise sharpener for perlin ridges''' + return 2 * (0.5 - abs(0.5 - noise)) +class Save: + '''Save utility for map files''' + @staticmethod + def render(mats, lookup, path): + '''Render tiles to png''' + images = [[lookup[e] for e in l] for l in mats] + image = np.vstack([np.hstack(e) for e in images]) + imsave(path, image) + + @staticmethod + def fractal(terrain, path): + '''Render raw noise fractal to png''' + frac = (256*terrain).astype(np.uint8) + imsave(path, frac) + + @staticmethod + def as_numpy(mats, path): + '''Save map to .npy''' + path = os.path.join(path, 'map.npy') + np.save(path, mats.astype(int)) + +# pylint: disable=E1101:no-member +# Terrain uses setattr() class Terrain: - '''Terrain material class; populated at runtime''' - pass - -def generate_terrain(config, idx, interpolaters): + '''Terrain material class; populated at runtime''' + @staticmethod + def generate_terrain(config, map_id, interpolaters): center = config.MAP_CENTER border = config.MAP_BORDER size = config.MAP_SIZE @@ -48,16 +51,16 @@ def generate_terrain(config, idx, interpolaters): #Compute a unique seed based on map index #Flip seed used to ensure train/eval maps are different - seed = idx + 1 + seed = map_id + 1 if config.TERRAIN_FLIP_SEED: - seed = -seed + seed = -seed #Log interpolation factor if not interpolaters: - interpolaters = np.logspace(config.TERRAIN_LOG_INTERPOLATE_MIN, - config.TERRAIN_LOG_INTERPOLATE_MAX, config.MAP_N) + interpolaters = np.logspace(config.TERRAIN_LOG_INTERPOLATE_MIN, + config.TERRAIN_LOG_INTERPOLATE_MAX, config.MAP_N) - interpolate = interpolaters[idx] + interpolate = interpolaters[map_id] #Data buffers val = np.zeros((size, size, octaves)) @@ -69,7 +72,7 @@ def generate_terrain(config, idx, interpolaters): start = frequency end = min(start, start - np.log2(center) + offset) for idx, freq in enumerate(np.logspace(start, end, octaves, base=2)): - val[:, :, idx] = vec_noise.snoise2(seed*size + freq*X, idx*size + freq*Y) + val[:, :, idx] = vec_noise.snoise2(seed*size + freq*X, idx*size + freq*Y) #Compute L1 distance x = np.abs(np.arange(size) - size//2) @@ -89,8 +92,8 @@ def generate_terrain(config, idx, interpolaters): X, Y = np.meshgrid(s, s) expand = int(np.log2(center)) - 2 for idx, octave in enumerate(range(expand, 1, -1)): - freq, mag = 1 / 2**octave, 1 / 2**idx - noise += mag * vec_noise.snoise2(seed*size + freq*X, idx*size + freq*Y) + freq, mag = 1 / 2**octave, 1 / 2**idx + noise += mag * vec_noise.snoise2(seed*size + freq*X, idx*size + freq*Y) noise -= np.min(noise) noise = octaves * noise / np.max(noise) - 1e-12 @@ -98,16 +101,16 @@ def generate_terrain(config, idx, interpolaters): #Compute L1 and Perlin scale factor for i in range(octaves): - start = octaves - i - 1 - scale[l1 <= high] = np.arange(start, start + octaves) - high -= delta + start = octaves - i - 1 + scale[l1 <= high] = np.arange(start, start + octaves) + high -= delta start = noise - 1 - l1Scale = np.clip(l1, 0, size//2 - border - 2) - l1Scale = l1Scale / np.max(l1Scale) + l1_scale = np.clip(l1, 0, size//2 - border - 2) + l1_scale = l1_scale / np.max(l1_scale) for i in range(octaves): - idxs = l1Scale*scale[:, :, i] + (1-l1Scale)*(start + i) - scale[:, :, i] = pdf[idxs.astype(int)] + idxs = l1_scale*scale[:, :, i] + (1-l1_scale)*(start + i) + scale[:, :, i] = pdf[idxs.astype(int)] #Blend octaves std = np.std(val) @@ -120,17 +123,17 @@ def generate_terrain(config, idx, interpolaters): #Threshold to materials matl = np.zeros((size, size), dtype=object) for y in range(size): - for x in range(size): - v = val[y, x] - if v <= config.TERRAIN_WATER: - mat = Terrain.WATER - elif v <= config.TERRAIN_GRASS: - mat = Terrain.GRASS - elif v <= config.TERRAIN_FOREST: - mat = Terrain.FOREST - else: - mat = Terrain.STONE - matl[y, x] = mat + for x in range(size): + v = val[y, x] + if v <= config.TERRAIN_WATER: + mat = Terrain.WATER + elif v <= config.TERRAIN_GRASS: + mat = Terrain.GRASS + elif v <= config.TERRAIN_FOREST: + mat = Terrain.FOREST + else: + mat = Terrain.STONE + matl[y, x] = mat #Lava and grass border matl[l1 > size/2 - border] = Terrain.LAVA @@ -142,143 +145,161 @@ def generate_terrain(config, idx, interpolaters): return val, matl, interpolaters +def place_fish(tiles): + placed = False + allow = {Terrain.GRASS} -def fish(config, tiles, mat, mmin, mmax): - r = random.randint(mmin, mmax) - c = random.randint(mmin, mmax) - - allow = {Terrain.GRASS} - if (tiles[r, c] not in {Terrain.WATER} or - (tiles[r-1, c] not in allow and tiles[r+1, c] not in allow and - tiles[r, c-1] not in allow and tiles[r, c+1] not in allow)): - fish(config, tiles, mat, mmin, mmax) - else: - tiles[r, c] = mat + water_loc = np.where(tiles == Terrain.WATER) + water_loc = list(zip(water_loc[0], water_loc[1])) + random.shuffle(water_loc) -def uniform(config, tiles, mat, mmin, mmax): - r = random.randint(mmin, mmax) - c = random.randint(mmin, mmax) - - if tiles[r, c] not in {Terrain.GRASS}: - uniform(config, tiles, mat, mmin, mmax) - else: - tiles[r, c] = mat - -def cluster(config, tiles, mat, mmin, mmax): - mmin = mmin + 1 - mmax = mmax - 1 + for r, c in water_loc: + if tiles[r-1, c] in allow or tiles[r+1, c] in allow or \ + tiles[r, c-1] in allow or tiles[r, c+1] in allow: + tiles[r, c] = Terrain.FISH + placed = True + break - r = random.randint(mmin, mmax) - c = random.randint(mmin, mmax) + if not placed: + raise RuntimeError('Could not find the water tile to place fish.') - matls = {Terrain.GRASS} - if tiles[r, c] not in matls: - return cluster(config, tiles, mat, mmin-1, mmax+1) +def uniform(config, tiles, mat, mmin, mmax): + r = random.randint(mmin, mmax) + c = random.randint(mmin, mmax) + if tiles[r, c] not in {Terrain.GRASS}: + uniform(config, tiles, mat, mmin, mmax) + else: tiles[r, c] = mat - if tiles[r-1, c] in matls: - tiles[r-1, c] = mat - if tiles[r+1, c] in matls: - tiles[r+1, c] = mat - if tiles[r, c-1] in matls: - tiles[r, c-1] = mat - if tiles[r, c+1] in matls: - tiles[r, c+1] = mat + +def cluster(config, tiles, mat, mmin, mmax): + mmin = mmin + 1 + mmax = mmax - 1 + + r = random.randint(mmin, mmax) + c = random.randint(mmin, mmax) + + matls = {Terrain.GRASS} + if tiles[r, c] not in matls: + cluster(config, tiles, mat, mmin-1, mmax+1) + return + + tiles[r, c] = mat + if tiles[r-1, c] in matls: + tiles[r-1, c] = mat + if tiles[r+1, c] in matls: + tiles[r+1, c] = mat + if tiles[r, c-1] in matls: + tiles[r, c-1] = mat + if tiles[r, c+1] in matls: + tiles[r, c+1] = mat def spawn_profession_resources(config, tiles): - mmin = config.MAP_BORDER + 1 - mmax = config.MAP_SIZE - config.MAP_BORDER - 1 + mmin = config.MAP_BORDER + 1 + mmax = config.MAP_SIZE - config.MAP_BORDER - 1 - for _ in range(config.PROGRESSION_SPAWN_CLUSTERS): - cluster(config, tiles, Terrain.ORE, mmin, mmax) - cluster(config, tiles, Terrain.TREE, mmin, mmax) - cluster(config, tiles, Terrain.CRYSTAL, mmin, mmax) + for _ in range(config.PROGRESSION_SPAWN_CLUSTERS): + cluster(config, tiles, Terrain.ORE, mmin, mmax) + cluster(config, tiles, Terrain.TREE, mmin, mmax) + cluster(config, tiles, Terrain.CRYSTAL, mmin, mmax) - for _ in range(config.PROGRESSION_SPAWN_UNIFORMS): - uniform(config, tiles, Terrain.HERB, mmin, mmax) - fish(config, tiles, Terrain.FISH, mmin, mmax) + for _ in range(config.PROGRESSION_SPAWN_UNIFORMS): + uniform(config, tiles, Terrain.HERB, mmin, mmax) + place_fish(tiles) class MapGenerator: - '''Procedural map generation''' - def __init__(self, config): - self.config = config - self.loadTextures() - - def loadTextures(self): - '''Called during setup; loads and resizes tile pngs''' - lookup = {} - path = self.config.PATH_TILE - scale = self.config.MAP_PREVIEW_DOWNSCALE - for mat in material.All: - key = mat.tex - tex = imread(path.format(key)) - lookup[mat.index] = tex[:, :, :3][::scale, ::scale] - setattr(Terrain, key.upper(), mat.index) - self.textures = lookup - - def generate_all_maps(self): - '''Generates NMAPS maps according to generate_map - - Provides additional utilities for saving to .npy and rendering png previews''' - - config = self.config - - #Only generate if maps are not cached - path_maps = os.path.join(config.PATH_CWD, config.PATH_MAPS) - os.makedirs(path_maps, exist_ok=True) - if not config.MAP_FORCE_GENERATION and os.listdir(path_maps): - return - - if __debug__: - print('Generating {} maps'.format(config.MAP_N)) - - for idx in tqdm(range(config.MAP_N)): - path = path_maps + '/map' + str(idx+1) - os.makedirs(path, exist_ok=True) - - terrain, tiles = self.generate_map(idx) - - - #Save/render - Save.np(tiles, path) - if config.MAP_GENERATE_PREVIEWS: - b = config.MAP_BORDER - tiles = [e[b:-b+1] for e in tiles][b:-b+1] - Save.fractal(terrain, path+'/fractal.png') - Save.render(tiles, self.textures, path+'/map.png') - - def generate_map(self, idx): - '''Generate a single map - - The default method is a relatively complex multiscale perlin noise method. - This is not just standard multioctave noise -- we are seeding multioctave noise - itself with perlin noise to create localized deviations in scale, plus additional - biasing to decrease terrain frequency towards the center of the map - - We found that this creates more visually interesting terrain and more deviation in - required planning horizon across different parts of the map. This is by no means a - gold-standard: you are free to override this method and create customized terrain - generation more suitable for your application. Simply pass MAP_GENERATOR=YourMapGenClass - as a config argument.''' - config = self.config - if config.TERRAIN_SYSTEM_ENABLED: - if not hasattr(self, 'interpolaters'): - self.interpolaters = None - terrain, tiles, interpolaters = generate_terrain(config, idx, self.interpolaters) - else: - size = config.MAP_SIZE - terrain = np.zeros((size, size)) - tiles = np.zeros((size, size), dtype=object) - - for r in range(size): - for c in range(size): - linf = max(abs(r - size//2), abs(c - size//2)) - if linf <= size//2 - config.MAP_BORDER: - tiles[r, c] = Terrain.GRASS - else: - tiles[r, c] = Terrain.LAVA - - if config.PROFESSION_SYSTEM_ENABLED: - spawn_profession_resources(config, tiles) - - return terrain, tiles + '''Procedural map generation''' + def __init__(self, config): + self.config = config + self.load_textures() + self.interpolaters = None + + def load_textures(self): + '''Called during setup; loads and resizes tile pngs''' + lookup = {} + path = self.config.PATH_TILE + scale = self.config.MAP_PREVIEW_DOWNSCALE + for mat in material.All: + key = mat.tex + tex = imread(path.format(key)) + lookup[mat.index] = tex[:, :, :3][::scale, ::scale] + setattr(Terrain, key.upper(), mat.index) + self.textures = lookup + + def generate_all_maps(self): + '''Generates NMAPS maps according to generate_map + + Provides additional utilities for saving to .npy and rendering png previews''' + + config = self.config + + #Only generate if maps are not cached + path_maps = os.path.join(config.PATH_CWD, config.PATH_MAPS) + os.makedirs(path_maps, exist_ok=True) + + if not config.MAP_FORCE_GENERATION and os.listdir(path_maps): + # check if the folder has all the required maps + all_maps_exist = True + for idx in range(config.MAP_N, -1, -1): + map_file = path_maps + '/map' + str(idx+1) + '/map.npy' + if not os.path.exists(map_file): + # override MAP_FORCE_GENERATION = FALSE and generate maps + all_maps_exist = False + break + + # do not generate maps if all maps exist + if all_maps_exist: + return + + if __debug__: + logging.info('Generating %s maps', str(config.MAP_N)) + + for idx in range(config.MAP_N): + path = path_maps + '/map' + str(idx+1) + os.makedirs(path, exist_ok=True) + + terrain, tiles = self.generate_map(idx) + + #Save/render + Save.as_numpy(tiles, path) + if config.MAP_GENERATE_PREVIEWS: + b = config.MAP_BORDER + tiles = [e[b:-b+1] for e in tiles][b:-b+1] + Save.fractal(terrain, path+'/fractal.png') + Save.render(tiles, self.textures, path+'/map.png') + + def generate_map(self, idx): + '''Generate a single map + + The default method is a relatively complex multiscale perlin noise method. + This is not just standard multioctave noise -- we are seeding multioctave noise + itself with perlin noise to create localized deviations in scale, plus additional + biasing to decrease terrain frequency towards the center of the map + + We found that this creates more visually interesting terrain and more deviation in + required planning horizon across different parts of the map. This is by no means a + gold-standard: you are free to override this method and create customized terrain + generation more suitable for your application. Simply pass MAP_GENERATOR=YourMapGenClass + as a config argument.''' + config = self.config + if config.TERRAIN_SYSTEM_ENABLED: + if not hasattr(self, 'interpolaters'): + self.interpolaters = None + terrain, tiles, _ = Terrain.generate_terrain(config, idx, self.interpolaters) + else: + size = config.MAP_SIZE + terrain = np.zeros((size, size)) + tiles = np.zeros((size, size), dtype=object) + + for r in range(size): + for c in range(size): + linf = max(abs(r - size//2), abs(c - size//2)) + if linf <= size//2 - config.MAP_BORDER: + tiles[r, c] = Terrain.GRASS + else: + tiles[r, c] = Terrain.LAVA + + if config.PROFESSION_SYSTEM_ENABLED: + spawn_profession_resources(config, tiles) + + return terrain, tiles diff --git a/nmmo/core/tile.py b/nmmo/core/tile.py index 6f1828f7e..a4dc7b19d 100644 --- a/nmmo/core/tile.py +++ b/nmmo/core/tile.py @@ -1,93 +1,99 @@ -from pdb import set_trace as T +from types import SimpleNamespace import numpy as np -import nmmo +from nmmo.datastore.serialized import SerializedState from nmmo.lib import material -class Tile: - def __init__(self, config, realm, r, c): - self.config = config - self.realm = realm - - self.serialized = 'R{}-C{}'.format(r, c) - - self.r = nmmo.Serialized.Tile.R(realm.dataframe, self.serial, r) - self.c = nmmo.Serialized.Tile.C(realm.dataframe, self.serial, c) - self.nEnts = nmmo.Serialized.Tile.NEnts(realm.dataframe, self.serial) - self.index = nmmo.Serialized.Tile.Index(realm.dataframe, self.serial, 0) - - realm.dataframe.init(nmmo.Serialized.Tile, self.serial, (r, c)) - - @property - def serial(self): - return self.serialized - - @property - def repr(self): - return ((self.r, self.c)) - - @property - def pos(self): - return self.r.val, self.c.val - - @property - def habitable(self): - return self.mat in material.Habitable - - @property - def vacant(self): - return len(self.ents) == 0 and self.habitable - - @property - def occupied(self): - return not self.vacant - - @property - def impassible(self): - return self.mat in material.Impassible - - @property - def lava(self): - return self.mat == material.Lava - - def reset(self, mat, config): - self.state = mat(config) - self.mat = mat(config) - - self.depleted = False - self.tex = mat.tex - self.ents = {} - - self.nEnts.update(0) - self.index.update(self.state.index) - - def addEnt(self, ent): - assert ent.entID not in self.ents - self.nEnts.update(1) - self.ents[ent.entID] = ent - - def delEnt(self, entID): - assert entID in self.ents - self.nEnts.update(0) - del self.ents[entID] - - def step(self): - if not self.depleted or np.random.rand() > self.mat.respawn: - return - - self.depleted = False - self.state = self.mat - - self.index.update(self.state.index) - - def harvest(self, deplete): - if __debug__: - assert not self.depleted, f'{self.state} is depleted' - assert self.state in material.Harvestable, f'{self.state} not harvestable' - - if deplete: - self.depleted = True - self.state = self.mat.deplete(self.config) - self.index.update(self.state.index) - - return self.mat.harvest() +# pylint: disable=no-member +TileState = SerializedState.subclass( + "Tile", [ + "row", + "col", + "material_id", + ]) + +TileState.Limits = lambda config: { + "row": (0, config.MAP_SIZE-1), + "col": (0, config.MAP_SIZE-1), + "material_id": (0, config.MAP_N_TILE), +} + +TileState.Query = SimpleNamespace( + window=lambda ds, r, c, radius: ds.table("Tile").window( + TileState.State.attr_name_to_col["row"], + TileState.State.attr_name_to_col["col"], + r, c, radius), +) + +class Tile(TileState): + def __init__(self, realm, r, c): + super().__init__(realm.datastore, TileState.Limits(realm.config)) + self.realm = realm + self.config = realm.config + + self.row.update(r) + self.col.update(c) + + self.state = None + self.material = None + self.depleted = False + self.tex = None + + self.entities = {} + + @property + def repr(self): + return ((self.row.val, self.col.val)) + + @property + def pos(self): + return self.row.val, self.col.val + + @property + def habitable(self): + return self.material in material.Habitable + + @property + def impassible(self): + return self.material in material.Impassible + + @property + def lava(self): + return self.material == material.Lava + + def reset(self, mat, config): + self.state = mat(config) + self.material = mat(config) + self.material_id.update(self.state.index) + + self.depleted = False + self.tex = self.material.tex + + self.entities = {} + + def add_entity(self, ent): + assert ent.ent_id not in self.entities + self.entities[ent.ent_id] = ent + + def remove_entity(self, ent_id): + assert ent_id in self.entities + del self.entities[ent_id] + + def step(self): + if not self.depleted or np.random.rand() > self.material.respawn: + return + + self.depleted = False + self.state = self.material + self.material_id.update(self.state.index) + + def harvest(self, deplete): + assert not self.depleted, f'{self.state} is depleted' + assert self.state in material.Harvestable, f'{self.state} not harvestable' + + if deplete: + self.depleted = True + self.state = self.material.deplete(self.config) + self.material_id.update(self.state.index) + + return self.material.harvest() diff --git a/nmmo/datastore/__init__.py b/nmmo/datastore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nmmo/datastore/datastore.py b/nmmo/datastore/datastore.py new file mode 100644 index 000000000..44e652b72 --- /dev/null +++ b/nmmo/datastore/datastore.py @@ -0,0 +1,92 @@ +from __future__ import annotations +from typing import Dict, List +from nmmo.datastore.id_allocator import IdAllocator + +""" +This code defines a data storage system that allows for the +creation, manipulation, and querying of records. + +The DataTable class serves as the foundation for the data +storage, providing methods for updating and retrieving data, +as well as filtering and querying records. + +The DatastoreRecord class represents a single record within +a table and provides a simple interface for interacting with +the data. The Datastore class serves as the main entry point +for the data storage system, allowing for the creation and +management of tables and records. + +The implementation of the DataTable class is left to the +developer, but the DatastoreRecord and Datastore classes +should be sufficient for most use cases. + +See numpy_datastore.py for an implementation. +""" +class DataTable: + def __init__(self, num_columns: int): + self._num_columns = num_columns + self._id_allocator = IdAllocator(100) + + def reset(self): + self._id_allocator = IdAllocator(100) + + def update(self, row_id: int, col: int, value): + raise NotImplementedError + + def get(self, ids: List[id]): + raise NotImplementedError + + def where_in(self, col: int, values: List): + raise NotImplementedError + + def where_eq(self, col: str, value): + raise NotImplementedError + + def where_neq(self, col: str, value): + raise NotImplementedError + + def window(self, row_idx: int, col_idx: int, row: int, col: int, radius: int): + raise NotImplementedError + + def remove_row(self, row_id: int): + raise NotImplementedError + + def add_row(self) -> int: + raise NotImplementedError + + def is_empty(self) -> bool: + raise NotImplementedError + +class DatastoreRecord: + def __init__(self, datastore, table: DataTable, row_id: int) -> None: + self.datastore = datastore + self.table = table + self.id = row_id + + def update(self, col: int, value): + self.table.update(self.id, col, value) + + def get(self, col: int): + return self.table.get(self.id)[col] + + def delete(self): + self.table.remove_row(self.id) + +class Datastore: + def __init__(self) -> None: + self._tables: Dict[str, DataTable] = {} + + def register_object_type(self, object_type: str, num_colums: int): + if object_type not in self._tables: + self._tables[object_type] = self._create_table(num_colums) + + def create_record(self, object_type: str) -> DatastoreRecord: + table = self._tables[object_type] + row_id = table.add_row() + return DatastoreRecord(self, table, row_id) + + def table(self, object_type: str) -> DataTable: + return self._tables[object_type] + + def _create_table(self, num_columns: int) -> DataTable: + raise NotImplementedError diff --git a/nmmo/datastore/id_allocator.py b/nmmo/datastore/id_allocator.py new file mode 100644 index 000000000..a93e8c1f1 --- /dev/null +++ b/nmmo/datastore/id_allocator.py @@ -0,0 +1,21 @@ +from ordered_set import OrderedSet + +class IdAllocator: + def __init__(self, max_id): + # Key 0 is reserved as padding + self.max_id = 1 + self.free = OrderedSet() + self.expand(max_id) + + def full(self): + return len(self.free) == 0 + + def remove(self, row_id): + self.free.add(row_id) + + def allocate(self): + return self.free.pop(0) + + def expand(self, max_id): + self.free.update(OrderedSet(range(self.max_id, max_id))) + self.max_id = max_id diff --git a/nmmo/datastore/numpy_datastore.py b/nmmo/datastore/numpy_datastore.py new file mode 100644 index 000000000..e737ad9cd --- /dev/null +++ b/nmmo/datastore/numpy_datastore.py @@ -0,0 +1,70 @@ +from typing import List + +import numpy as np + +from nmmo.datastore.datastore import Datastore, DataTable + + +class NumpyTable(DataTable): + def __init__(self, num_columns: int, initial_size: int, dtype=np.float32): + super().__init__(num_columns) + self._dtype = dtype + self._initial_size = initial_size + self._max_rows = 0 + self._data = np.zeros((0, self._num_columns), dtype=self._dtype) + self._expand(self._initial_size) + + def reset(self): + super().reset() # resetting _id_allocator + self._max_rows = 0 + self._data = np.zeros((0, self._num_columns), dtype=self._dtype) + self._expand(self._initial_size) + + def update(self, row_id: int, col: int, value): + self._data[row_id, col] = value + + def get(self, ids: List[int]): + return self._data[ids] + + def where_eq(self, col: int, value): + return self._data[self._data[:,col] == value] + + def where_neq(self, col: int, value): + return self._data[self._data[:,col] != value] + + def where_in(self, col: int, values: List): + return self._data[np.isin(self._data[:,col], values)] + + def window(self, row_idx: int, col_idx: int, row: int, col: int, radius: int): + return self._data[( + (np.abs(self._data[:,row_idx] - row) <= radius) & + (np.abs(self._data[:,col_idx] - col) <= radius) + ).ravel()] + + def add_row(self) -> int: + if self._id_allocator.full(): + self._expand(self._max_rows * 2) + row_id = self._id_allocator.allocate() + return row_id + + def remove_row(self, row_id: int) -> int: + self._id_allocator.remove(row_id) + self._data[row_id] = 0 + + def _expand(self, max_rows: int): + assert max_rows > self._max_rows + data = np.zeros((max_rows, self._num_columns), dtype=self._dtype) + data[:self._max_rows] = self._data + self._max_rows = max_rows + self._id_allocator.expand(max_rows) + self._data = data + + def is_empty(self) -> bool: + all_data_zero = np.sum(self._data)==0 + # 0th row is reserved as padding, so # of free ids is _max_rows-1 + all_id_free = len(self._id_allocator.free) == self._max_rows-1 + return all_data_zero and all_id_free + +class NumpyDatastore(Datastore): + def _create_table(self, num_columns: int) -> DataTable: + return NumpyTable(num_columns, 100) diff --git a/nmmo/datastore/serialized.py b/nmmo/datastore/serialized.py new file mode 100644 index 000000000..652280292 --- /dev/null +++ b/nmmo/datastore/serialized.py @@ -0,0 +1,120 @@ +from __future__ import annotations +from ast import Tuple + +import math +from types import SimpleNamespace +from typing import Dict, List +from nmmo.datastore.datastore import Datastore, DatastoreRecord + +""" +This code defines classes for serializing and deserializing data +in a structured way. + +The SerializedAttribute class represents a single attribute of a +record and provides methods for updating and querying its value, +as well as enforcing minimum and maximum bounds on the value. + +The SerializedState class serves as a base class for creating +serialized representations of specific types of data, using a +list of attribute names to define the structure of the data. +The subclass method is a factory method for creating subclasses +of SerializedState that are tailored to specific types of data. +""" + +class SerializedAttribute(): + def __init__(self, + name: str, + datastore_record: DatastoreRecord, + column: int, min_val=-math.inf, max_val=math.inf) -> None: + self._name = name + self.datastore_record = datastore_record + self._column = column + self._min = min_val + self._max = max_val + self._val = 0 + + @property + def val(self): + return self._val + + def update(self, value): + value = min(self._max, max(self._min, value)) + + self.datastore_record.update(self._column, value) + self._val = value + + @property + def min(self): + return self._min + + @property + def max(self): + return self._max + + def increment(self, val=1, max_v=math.inf): + self.update(min(max_v, self.val + val)) + return self + + def decrement(self, val=1, min_v=-math.inf): + self.update(max(min_v, self.val - val)) + return self + + @property + def empty(self): + return self.val == 0 + + def __eq__(self, other): + return self.val == other + + def __ne__(self, other): + return self.val != other + + def __lt__(self, other): + return self.val < other + + def __le__(self, other): + return self.val <= other + + def __gt__(self, other): + return self.val > other + + def __ge__(self, other): + return self.val >= other + +class SerializedState(): + @staticmethod + def subclass(name: str, attributes: List[str]): + class Subclass(SerializedState): + _name = name + State = SimpleNamespace( + attr_name_to_col = {a: i for i, a in enumerate(attributes)}, + num_attributes = len(attributes), + table = lambda ds: ds.table(name) + ) + + def __init__(self, datastore: Datastore, + limits: Dict[str, Tuple[float, float]] = None): + + limits = limits or {} + self.datastore_record = datastore.create_record(name) + + for attr, col in self.State.attr_name_to_col.items(): + try: + setattr(self, attr, + SerializedAttribute(attr, self.datastore_record, col, + *limits.get(attr, (-math.inf, math.inf)))) + except Exception as exc: + raise RuntimeError('Failed to set attribute' + attr) from exc + + @classmethod + def parse_array(cls, data) -> SimpleNamespace: + # Takes in a data array and returns a SimpleNamespace object with + # attribute names as keys and corresponding values from the input + # data array. + assert len(data) == cls.State.num_attributes, \ + f"Expected {cls.State.num_attributes} attributes, got {len(data)}" + return SimpleNamespace(**{ + attr: data[col] for attr, col in cls.State.attr_name_to_col.items() + }) + + return Subclass diff --git a/nmmo/emulation.py b/nmmo/emulation.py deleted file mode 100644 index 9ebd1e6e6..000000000 --- a/nmmo/emulation.py +++ /dev/null @@ -1,139 +0,0 @@ -from pdb import set_trace as T -import numpy as np - -from collections import defaultdict -import itertools - -import gym - -import nmmo -from nmmo.infrastructure import DataType - -class SingleAgentEnv: - def __init__(self, env, idx, max_idx): - self.config = env.config - self.env = env - self.idx = idx - self.last = idx == max_idx - - def reset(self): - if not self.env.has_reset: - self.obs = self.env.reset() - - return self.obs[self.idx] - - def step(self, actions): - if self.last: - self.obs, self.rewards, self.dones, self.infos = self.env.step(actions) - - i = self.idx - return self.obs[i], self.rewards[i], self.dones[i], self.infos[i] - -def multiagent_to_singleagent(config): - assert config.EMULATE_CONST_PLAYER_N, "Wrapper requires constant num agents" - - base_env = nmmo.Env(config) - n = config.PLAYER_N - - return [SingleAgentEnv(base_env, i, n) for i in range(1, n+1)] - -def pad_const_nent(config, dummy_ob, obs, rewards, dones, infos): - for i in range(1, config.PLAYER_N+1): - if i not in obs: - obs[i] = dummy_ob - rewards[i] = 0 - infos[i] = {} - dones[i] = False - -def const_horizon(dones): - for agent in dones: - dones[agent] = True - - return dones - -def pack_atn_space(config): - actions = defaultdict(dict) - for atn in sorted(nmmo.Action.edges(config)): - for arg in sorted(atn.edges): - actions[atn][arg] = arg.N(config) - - n = 0 - flat_actions = {} - for atn, args in actions.items(): - ranges = [range(e) for e in args.values()] - for vals in itertools.product(*ranges): - flat_actions[n] = {atn: {arg: val for arg, val in zip(args, vals)}} - n += 1 - - return flat_actions - -def pack_obs_space(observation): - n = 0 - #for entity, obs in observation.items(): - for entity in observation: - obs = observation[entity] - #for attr_name, attr_box in obs.items(): - for attr_name in obs: - attr_box = obs[attr_name] - n += np.prod(observation[entity][attr_name].shape) - - return gym.spaces.Box( - low=-2**20, high=2**20, - shape=(int(n),), dtype=DataType.CONTINUOUS) - - -def batch_obs(config, obs): - batched = {} - for (entity_name,), entity in nmmo.io.stimulus.Serialized: - if not entity.enabled(config): - continue - - batched[entity_name] = {} - for dtype in 'Continuous Discrete N'.split(): - attr_obs = [obs[k][entity_name][dtype] for k in obs] - batched[entity_name][dtype] = np.stack(attr_obs, 0) - - return batched - -def pack_obs(obs): - packed = {} - for key in obs: - ary = [] - obs[key].items() - for ent_name, ent_attrs in obs[key].items(): - for attr_name, attr in ent_attrs.items(): - ary.append(attr.ravel()) - packed[key] = np.concatenate(ary) - - return packed - -def unpack_obs(config, packed_obs): - obs, idx = {}, 0 - batch = len(packed_obs) - for (entity_name,), entity in nmmo.io.stimulus.Serialized: - if not entity.enabled(config): - continue - - n_entity = entity.N(config) - n_continuous, n_discrete = 0, 0 - obs[entity_name] = {} - - for attribute_name, attribute in entity: - if attribute.CONTINUOUS: - n_continuous += 1 - if attribute.DISCRETE: - n_discrete += 1 - - inc = int(n_entity * n_continuous) - obs[entity_name]['Continuous'] = packed_obs[:, idx: idx + inc].reshape(batch, n_entity, n_continuous) - idx += inc - - inc = int(n_entity * n_discrete) - obs[entity_name]['Discrete'] = packed_obs[:, idx: idx + inc].reshape(batch, n_entity, n_discrete) - idx += inc - - inc = 1 - obs[entity_name]['N'] = packed_obs[:, idx: idx + inc].reshape(batch, 1) - idx += inc - - return obs diff --git a/nmmo/entity/entity.py b/nmmo/entity/entity.py index d4ae1e399..13a523b93 100644 --- a/nmmo/entity/entity.py +++ b/nmmo/entity/entity.py @@ -1,253 +1,344 @@ -from pdb import set_trace as T + +import math +from types import SimpleNamespace + import numpy as np -import nmmo -from nmmo.systems import skill, droptable, combat, equipment, inventory -from nmmo.lib import material, utils +from nmmo.core.config import Config +from nmmo.datastore.serialized import SerializedState +from nmmo.systems import inventory +from nmmo.lib.log import EventCode + +# pylint: disable=no-member +EntityState = SerializedState.subclass( + "Entity", [ + "id", + "npc_type", # 1 - passive, 2 - neutral, 3 - aggressive + "row", + "col", + + # Status + "damage", + "time_alive", + "freeze", + "item_level", + "attacker_id", + "latest_combat_tick", + "message", + + # Resources + "gold", + "health", + "food", + "water", + + # Combat + "melee_level", + "range_level", + "mage_level", + + # Skills + "fishing_level", + "herbalism_level", + "prospecting_level", + "carving_level", + "alchemy_level", + ]) + +EntityState.Limits = lambda config: { + **{ + "id": (-math.inf, math.inf), + "npc_type": (0, 4), + "row": (0, config.MAP_SIZE-1), + "col": (0, config.MAP_SIZE-1), + "damage": (0, math.inf), + "time_alive": (0, math.inf), + "freeze": (0, 3), + "item_level": (0, 5*config.NPC_LEVEL_MAX), + "attacker_id": (-np.inf, math.inf), + "latest_combat_tick": (0, math.inf), + "health": (0, config.PLAYER_BASE_HEALTH), + }, + **({ + "message": (0, config.COMMUNICATION_NUM_TOKENS), + } if config.COMMUNICATION_SYSTEM_ENABLED else {}), + **({ + "gold": (0, math.inf), + "food": (0, config.RESOURCE_BASE), + "water": (0, config.RESOURCE_BASE), + } if config.RESOURCE_SYSTEM_ENABLED else {}), + **({ + "melee_level": (0, config.PROGRESSION_LEVEL_MAX), + "range_level": (0, config.PROGRESSION_LEVEL_MAX), + "mage_level": (0, config.PROGRESSION_LEVEL_MAX), + "fishing_level": (0, config.PROGRESSION_LEVEL_MAX), + "herbalism_level": (0, config.PROGRESSION_LEVEL_MAX), + "prospecting_level": (0, config.PROGRESSION_LEVEL_MAX), + "carving_level": (0, config.PROGRESSION_LEVEL_MAX), + "alchemy_level": (0, config.PROGRESSION_LEVEL_MAX), + } if config.PROGRESSION_SYSTEM_ENABLED else {}), +} + +EntityState.Query = SimpleNamespace( + # Whole table + table=lambda ds: ds.table("Entity").where_neq( + EntityState.State.attr_name_to_col["id"], 0), + + # Single entity + by_id=lambda ds, id: ds.table("Entity").where_eq( + EntityState.State.attr_name_to_col["id"], id)[0], + + # Multiple entities + by_ids=lambda ds, ids: ds.table("Entity").where_in( + EntityState.State.attr_name_to_col["id"], ids), + + # Entities in a radius + window=lambda ds, r, c, radius: ds.table("Entity").window( + EntityState.State.attr_name_to_col["row"], + EntityState.State.attr_name_to_col["col"], + r, c, radius), +) class Resources: - def __init__(self, ent): - self.health = nmmo.Serialized.Entity.Health(ent.dataframe, ent.entID) - self.water = nmmo.Serialized.Entity.Water( ent.dataframe, ent.entID) - self.food = nmmo.Serialized.Entity.Food( ent.dataframe, ent.entID) + def __init__(self, ent, config): + self.config = config + self.health = ent.health + self.water = ent.water + self.food = ent.food - def update(self, realm, entity, actions): - config = realm.config + self.health.update(config.PLAYER_BASE_HEALTH) + if config.RESOURCE_SYSTEM_ENABLED: + self.water.update(config.RESOURCE_BASE) + self.food.update(config.RESOURCE_BASE) - if not config.RESOURCE_SYSTEM_ENABLED: - return + def update(self): + if not self.config.RESOURCE_SYSTEM_ENABLED: + return - self.water.max = config.RESOURCE_BASE - self.food.max = config.RESOURCE_BASE + regen = self.config.RESOURCE_HEALTH_RESTORE_FRACTION + thresh = self.config.RESOURCE_HEALTH_REGEN_THRESHOLD - regen = config.RESOURCE_HEALTH_RESTORE_FRACTION - thresh = config.RESOURCE_HEALTH_REGEN_THRESHOLD + food_thresh = self.food > thresh * self.config.RESOURCE_BASE + water_thresh = self.water > thresh * self.config.RESOURCE_BASE - food_thresh = self.food > thresh * config.RESOURCE_BASE - water_thresh = self.water > thresh * config.RESOURCE_BASE + if food_thresh and water_thresh: + restore = np.floor(self.health.max * regen) + self.health.increment(restore) - if food_thresh and water_thresh: - restore = np.floor(self.health.max * regen) - self.health.increment(restore) + if self.food.empty: + self.health.decrement(self.config.RESOURCE_STARVATION_RATE) - if self.food.empty: - self.health.decrement(config.RESOURCE_STARVATION_RATE) + if self.water.empty: + self.health.decrement(self.config.RESOURCE_DEHYDRATION_RATE) - if self.water.empty: - self.health.decrement(config.RESOURCE_DEHYDRATION_RATE) - - def packet(self): - data = {} - data['health'] = self.health.packet() - data['food'] = self.food.packet() - data['water'] = self.water.packet() - return data + def packet(self): + data = {} + data['health'] = { 'val': self.health.val, 'max': self.config.PLAYER_BASE_HEALTH } + data['food'] = { 'val': self.food.val, 'max': self.config.RESOURCE_BASE } + data['water'] = { 'val': self.water.val, 'max': self.config.RESOURCE_BASE } + return data class Status: - def __init__(self, ent): - self.config = ent.config - self.freeze = nmmo.Serialized.Entity.Freeze(ent.dataframe, ent.entID) - - def update(self, realm, entity, actions): - self.freeze.decrement() - - def packet(self): - data = {} - data['freeze'] = self.freeze.val - return data - -class History: - def __init__(self, ent): - self.actions = {} - self.attack = None - - self.origPos = ent.pos - self.exploration = 0 - self.playerKills = 0 - - self.damage_received = 0 - self.damage_inflicted = 0 - - self.damage = nmmo.Serialized.Entity.Damage( ent.dataframe, ent.entID) - self.timeAlive = nmmo.Serialized.Entity.TimeAlive(ent.dataframe, ent.entID) - - self.lastPos = None - - def update(self, realm, entity, actions): - self.attack = None - self.damage.update(0) - - self.actions = {} - if entity.entID in actions: - self.actions = actions[entity.entID] - - exploration = utils.linf(entity.pos, self.origPos) - self.exploration = max(exploration, self.exploration) - - self.timeAlive.increment() - - def packet(self): - data = {} - data['damage'] = self.damage.val - data['timeAlive'] = self.timeAlive.val - data['damage_inflicted'] = self.damage_inflicted - data['damage_received'] = self.damage_received - - if self.attack is not None: - data['attack'] = self.attack - - actions = {} - for atn, args in self.actions.items(): - atn_packet = {} - - #Avoid recursive player packet - if atn.__name__ == 'Attack': - continue - - for key, val in args.items(): - if hasattr(val, 'packet'): - atn_packet[key.__name__] = val.packet - else: - atn_packet[key.__name__] = val.__name__ - actions[atn.__name__] = atn_packet - data['actions'] = actions - - return data - -class Base: - def __init__(self, ent, pos, iden, name, color, pop): - self.name = name + str(iden) - self.color = color - r, c = pos - - self.r = nmmo.Serialized.Entity.R(ent.dataframe, ent.entID, r) - self.c = nmmo.Serialized.Entity.C(ent.dataframe, ent.entID, c) - - self.population = nmmo.Serialized.Entity.Population(ent.dataframe, ent.entID, pop) - self.self = nmmo.Serialized.Entity.Self( ent.dataframe, ent.entID, 1) - self.identity = nmmo.Serialized.Entity.ID( ent.dataframe, ent.entID, ent.entID) - self.level = nmmo.Serialized.Entity.Level( ent.dataframe, ent.entID, 3) - self.item_level = nmmo.Serialized.Entity.ItemLevel( ent.dataframe, ent.entID, 0) - self.gold = nmmo.Serialized.Entity.Gold( ent.dataframe, ent.entID, 0) - self.comm = nmmo.Serialized.Entity.Comm( ent.dataframe, ent.entID, 0) - - ent.dataframe.init(nmmo.Serialized.Entity, ent.entID, (r, c)) - - def update(self, realm, entity, actions): - self.level.update(combat.level(entity.skills)) - - if realm.config.EQUIPMENT_SYSTEM_ENABLED: - self.item_level.update(entity.equipment.total(lambda e: e.level)) - - if realm.config.EXCHANGE_SYSTEM_ENABLED: - self.gold.update(entity.inventory.gold.quantity.val) - - @property - def pos(self): - return self.r.val, self.c.val - - def packet(self): - data = {} - - data['r'] = self.r.val - data['c'] = self.c.val - data['name'] = self.name - data['level'] = self.level.val - data['item_level'] = self.item_level.val - data['color'] = self.color.packet() - data['population'] = self.population.val - data['self'] = self.self.val - - return data - -class Entity: - def __init__(self, realm, pos, iden, name, color, pop): - self.realm = realm - self.dataframe = realm.dataframe - self.config = realm.config - - self.policy = name - self.entID = iden - self.repr = None - self.vision = 5 + def __init__(self, ent): + self.freeze = ent.freeze - self.attacker = None - self.target = None - self.closest = None - self.spawnPos = pos - - self.attackerID = nmmo.Serialized.Entity.AttackerID(self.dataframe, self.entID, 0) - - #Submodules - self.base = Base(self, pos, iden, name, color, pop) - self.status = Status(self) - self.history = History(self) - self.resources = Resources(self) + def update(self): + if self.freeze.val > 0: + self.freeze.decrement(1) - self.inventory = inventory.Inventory(realm, self) + def packet(self): + data = {} + data['freeze'] = self.freeze.val + return data - def packet(self): - data = {} - - data['status'] = self.status.packet() - data['history'] = self.history.packet() - data['inventory'] = self.inventory.packet() - data['alive'] = self.alive - return data - - def update(self, realm, actions): - '''Update occurs after actions, e.g. does not include history''' - if self.history.damage == 0: - self.attacker = None - self.attackerID.update(0) - - self.base.update(realm, self, actions) - self.status.update(realm, self, actions) - self.history.update(realm, self, actions) +# NOTE: History.packet() is actively used in visulazing attacks +class History: + def __init__(self, ent): + self.actions = {} + self.attack = None + + self.starting_position = ent.pos + self.exploration = 0 + self.player_kills = 0 + + self.damage_received = 0 + self.damage_inflicted = 0 + + self.damage = ent.damage + self.time_alive = ent.time_alive + + self.last_pos = None + + def update(self, entity, actions): + self.attack = None + self.damage.update(0) + + self.actions = {} + if entity.ent_id in actions: + self.actions = actions[entity.ent_id] + + self.time_alive.increment() + + def packet(self): + data = {} + data['damage'] = self.damage.val + data['timeAlive'] = self.time_alive.val + data['damage_inflicted'] = self.damage_inflicted + data['damage_received'] = self.damage_received + + if self.attack is not None: + data['attack'] = self.attack + + # NOTE: the client seems to use actions for visualization + # but produces errors with the new actions. So we comment out these for now + # actions = {} + # for atn, args in self.actions.items(): + # atn_packet = {} + + # # Avoid recursive player packet + # if atn.__name__ == 'Attack': + # continue + + # for key, val in args.items(): + # if hasattr(val, 'packet'): + # atn_packet[key.__name__] = val.packet + # else: + # atn_packet[key.__name__] = val.__name__ + # actions[atn.__name__] = atn_packet + # data['actions'] = actions + data['actions'] = {} + + return data + +# pylint: disable=no-member +class Entity(EntityState): + def __init__(self, realm, pos, entity_id, name): + super().__init__(realm.datastore, EntityState.Limits(realm.config)) + + self.realm = realm + self.config: Config = realm.config + + self.policy = name + self.entity_id = entity_id + self.repr = None + + self.name = name + str(entity_id) + + self.row.update(pos[0]) + self.col.update(pos[1]) + self.id.update(entity_id) + + self.vision = self.config.PLAYER_VISION_RADIUS + + self.attacker = None + self.target = None + self.closest = None + self.spawn_pos = pos + + # Submodules + self.status = Status(self) + self.history = History(self) + self.resources = Resources(self, self.config) + self.inventory = inventory.Inventory(realm, self) + + @property + def ent_id(self): + return self.id.val + + def packet(self): + data = {} + + data['status'] = self.status.packet() + data['history'] = self.history.packet() + data['inventory'] = self.inventory.packet() + data['alive'] = self.alive + data['base'] = { + 'r': self.row.val, + 'c': self.col.val, + 'name': self.name, + 'level': self.attack_level, + 'item_level': self.item_level.val, + } + + return data + + def update(self, realm, actions): + '''Update occurs after actions, e.g. does not include history''' + if self.history.damage == 0: + self.attacker = None + self.attacker_id.update(0) + + if realm.config.EQUIPMENT_SYSTEM_ENABLED: + self.item_level.update(self.equipment.total(lambda e: e.level)) + + self.status.update() + self.history.update(self, actions) + + # Returns True if the entity is alive + def receive_damage(self, source, dmg): + self.history.damage_received += dmg + self.history.damage.update(dmg) + self.resources.health.decrement(dmg) + + if self.alive: + return True - def receiveDamage(self, source, dmg): - self.history.damage_received += dmg - self.history.damage.update(dmg) - self.resources.health.decrement(dmg) + # at this point, self is dead + if source: + source.history.player_kills += 1 + self.realm.event_log.record(EventCode.PLAYER_KILL, source, target=self) + + # if self is dead, unlist its items from the market regardless of looting + if self.config.EXCHANGE_SYSTEM_ENABLED: + for item in list(self.inventory.items): + self.realm.exchange.unlist_item(item) + + # if self is dead but no one can loot, destroy its items + if source is None or not source.is_player: # nobody or npcs cannot loot + if self.config.ITEM_SYSTEM_ENABLED: + for item in list(self.inventory.items): + item.destroy() + return False - if self.alive: - return True + # now, source can loot the dead self + return False - if source is None: - return True + # pylint: disable=unused-argument + def apply_damage(self, dmg, style): + self.history.damage_inflicted += dmg - if not source.isPlayer: - return True + @property + def pos(self): + return int(self.row.val), int(self.col.val) + @property + def alive(self): + if self.resources.health.empty: return False - def applyDamage(self, dmg, style): - self.history.damage_inflicted += dmg + return True - @property - def pos(self): - return self.base.pos + @property + def is_player(self) -> bool: + return False - @property - def alive(self): - if self.resources.health.empty: - return False + @property + def is_npc(self) -> bool: + return False - return True + @property + def attack_level(self) -> int: + melee = self.skills.melee.level.val + ranged = self.skills.range.level.val + mage = self.skills.mage.level.val - @property - def isPlayer(self) -> bool: - return False + return int(max(melee, ranged, mage)) - @property - def isNPC(self) -> bool: + @property + def in_combat(self) -> bool: + # NOTE: the initial latest_combat_tick is 0, and valid values are greater than 0 + if not self.config.COMBAT_SYSTEM_ENABLED or self.latest_combat_tick.val == 0: return False - @property - def level(self) -> int: - melee = self.skills.melee.level.val - ranged = self.skills.range.level.val - mage = self.skills.mage.level.val - - return int(max(melee, ranged, mage)) + return (self.realm.tick - self.latest_combat_tick.val) < self.config.COMBAT_STATUS_DURATION diff --git a/nmmo/entity/entity_manager.py b/nmmo/entity/entity_manager.py new file mode 100644 index 000000000..7315097bc --- /dev/null +++ b/nmmo/entity/entity_manager.py @@ -0,0 +1,156 @@ +from collections.abc import Mapping +from typing import Dict + +import numpy as np +from ordered_set import OrderedSet + +from nmmo.entity.entity import Entity +from nmmo.entity.npc import NPC +from nmmo.entity.player import Player +from nmmo.lib import spawn +from nmmo.systems import combat + + +class EntityGroup(Mapping): + def __init__(self, realm): + self.datastore = realm.datastore + self.realm = realm + self.config = realm.config + + self.entities: Dict[int, Entity] = {} + self.dead: Dict[int, Entity] = {} + + def __len__(self): + return len(self.entities) + + def __contains__(self, e): + return e in self.entities + + def __getitem__(self, key) -> Entity: + return self.entities[key] + + def __iter__(self) -> Entity: + yield from self.entities + + def items(self): + return self.entities.items() + + @property + def corporeal(self): + return {**self.entities, **self.dead} + + @property + def packet(self): + return {k: v.packet() for k, v in self.corporeal.items()} + + def reset(self): + for ent in self.entities.values(): + ent.datastore_record.delete() + + self.entities = {} + self.dead = {} + + def spawn(self, entity): + pos, ent_id = entity.pos, entity.id.val + self.realm.map.tiles[pos].add_entity(entity) + self.entities[ent_id] = entity + + def cull(self): + self.dead = {} + for ent_id in list(self.entities): + player = self.entities[ent_id] + if not player.alive: + r, c = player.pos + ent_id = player.ent_id + self.dead[ent_id] = player + + self.realm.map.tiles[r, c].remove_entity(ent_id) + self.entities[ent_id].datastore_record.delete() + del self.entities[ent_id] + + return self.dead + + def update(self, actions): + for entity in self.entities.values(): + entity.update(self.realm, actions) + + +class NPCManager(EntityGroup): + def __init__(self, realm): + super().__init__(realm) + self.next_id = -1 + self.spawn_dangers = [] + + def reset(self): + super().reset() + self.next_id = -1 + self.spawn_dangers = [] + + def spawn(self): + config = self.config + + if not config.NPC_SYSTEM_ENABLED: + return + + for _ in range(config.NPC_SPAWN_ATTEMPTS): + if len(self.entities) >= config.NPC_N: + break + + if self.spawn_dangers: + danger = self.spawn_dangers[-1] + r, c = combat.spawn(config, danger) + else: + center = config.MAP_CENTER + border = self.config.MAP_BORDER + # pylint: disable=unbalanced-tuple-unpacking + r, c = np.random.randint(border, center+border, 2).tolist() + + npc = NPC.spawn(self.realm, (r, c), self.next_id) + if npc: + super().spawn(npc) + self.next_id -= 1 + + if self.spawn_dangers: + self.spawn_dangers.pop() + + def cull(self): + for entity in super().cull().values(): + self.spawn_dangers.append(entity.spawn_danger) + + def actions(self, realm): + actions = {} + for idx, entity in self.entities.items(): + actions[idx] = entity.decide(realm) + return actions + +class PlayerManager(EntityGroup): + def __init__(self, realm): + super().__init__(realm) + self.loader = self.realm.config.PLAYER_LOADER + self.agents = None + self.spawned = None + + def reset(self): + super().reset() + self.agents = self.loader(self.config) + self.spawned = OrderedSet() + + def spawn_individual(self, r, c, idx): + agent = next(self.agents) + agent = agent(self.config, idx) + player = Player(self.realm, (r, c), agent) + super().spawn(player) + + def spawn(self): + idx = 0 + for r, c in spawn.spawn_concurrent(self.config): + idx += 1 + + if idx in self.entities: + continue + + if idx in self.spawned: + continue + + self.spawned.add(idx) + self.spawn_individual(r, c, idx) diff --git a/nmmo/entity/npc.py b/nmmo/entity/npc.py index a9fec1a97..2e431644f 100644 --- a/nmmo/entity/npc.py +++ b/nmmo/entity/npc.py @@ -1,169 +1,183 @@ -from pdb import set_trace as T -import numpy as np import random -import nmmo from nmmo.entity import entity -from nmmo.systems import combat, equipment, ai, combat, skill -from nmmo.lib.colors import Neon -from nmmo.systems import item as Item -from nmmo.systems import droptable from nmmo.io import action as Action +from nmmo.systems import combat, droptable +from nmmo.systems.ai import policy +from nmmo.systems import item as Item +from nmmo.systems import skill +from nmmo.systems.inventory import EquipmentSlot class Equipment: - def __init__(self, total, - melee_attack, range_attack, mage_attack, - melee_defense, range_defense, mage_defense): + def __init__(self, total, + melee_attack, range_attack, mage_attack, + melee_defense, range_defense, mage_defense): - self.level = total - self.ammunition = None + self.level = total + self.ammunition = EquipmentSlot() - self.melee_attack = melee_attack - self.range_attack = range_attack - self.mage_attack = mage_attack - self.melee_defense = melee_defense - self.range_defense = range_defense - self.mage_defense = mage_defense + self.melee_attack = melee_attack + self.range_attack = range_attack + self.mage_attack = mage_attack + self.melee_defense = melee_defense + self.range_defense = range_defense + self.mage_defense = mage_defense - def total(self, getter): - return getter(self) + def total(self, getter): + return getter(self) - @property - def packet(self): - packet = {} + # pylint: disable=R0801 + # Similar lines here and in inventory.py + @property + def packet(self): + packet = {} - packet['item_level'] = self.total + packet['item_level'] = self.total + packet['melee_attack'] = self.melee_attack + packet['range_attack'] = self.range_attack + packet['mage_attack'] = self.mage_attack + packet['melee_defense'] = self.melee_defense + packet['range_defense'] = self.range_defense + packet['mage_defense'] = self.mage_defense - packet['melee_attack'] = self.melee_attack - packet['range_attack'] = self.range_attack - packet['mage_attack'] = self.mage_attack - packet['melee_defense'] = self.melee_defense - packet['range_defense'] = self.range_defense - packet['mage_defense'] = self.mage_defense - - return packet + return packet +# pylint: disable=no-member class NPC(entity.Entity): - def __init__(self, realm, pos, iden, name, color, pop): - super().__init__(realm, pos, iden, name, color, pop) - self.skills = skill.Combat(realm, self) - self.realm = realm - - - def update(self, realm, actions): - super().update(realm, actions) - - if not self.alive: - return + def __init__(self, realm, pos, iden, name, npc_type): + super().__init__(realm, pos, iden, name) + self.skills = skill.Combat(realm, self) + self.realm = realm + self.last_action = None + self.droptable = None + self.spawn_danger = None + self.equipment = None + self.npc_type.update(npc_type) + + def update(self, realm, actions): + super().update(realm, actions) + + if not self.alive: + return + + self.resources.health.increment(1) + self.last_action = actions + + # Returns True if the entity is alive + def receive_damage(self, source, dmg): + if super().receive_damage(source, dmg): + return True - self.resources.health.increment(1) - self.lastAction = actions + # run the next lines if the npc is killed + # source receive gold & items in the droptable + # pylint: disable=no-member + source.gold.increment(self.gold.val) + self.gold.update(0) - def receiveDamage(self, source, dmg): - if super().receiveDamage(source, dmg): - return True + # TODO(kywch): make source receive the highest-level items first + # because source cannot take it if the inventory is full + # Also, destroy the remaining items if the source cannot take those + for item in self.droptable.roll(self.realm, self.attack_level): + if source.inventory.space: + source.inventory.receive(item) - for item in self.droptable.roll(self.realm, self.level): - if source.inventory.space: - source.inventory.receive(item) + return False - @staticmethod - def spawn(realm, pos, iden): - config = realm.config + @staticmethod + def spawn(realm, pos, iden): + config = realm.config - # Select AI Policy - danger = combat.danger(config, pos) - if danger >= config.NPC_SPAWN_AGGRESSIVE: - ent = Aggressive(realm, pos, iden) - elif danger >= config.NPC_SPAWN_NEUTRAL: - ent = PassiveAggressive(realm, pos, iden) - elif danger >= config.NPC_SPAWN_PASSIVE: - ent = Passive(realm, pos, iden) - else: - return + # Select AI Policy + danger = combat.danger(config, pos) + if danger >= config.NPC_SPAWN_AGGRESSIVE: + ent = Aggressive(realm, pos, iden) + elif danger >= config.NPC_SPAWN_NEUTRAL: + ent = PassiveAggressive(realm, pos, iden) + elif danger >= config.NPC_SPAWN_PASSIVE: + ent = Passive(realm, pos, iden) + else: + return None - ent.spawn_danger = danger + ent.spawn_danger = danger - # Select combat focus - style = random.choice((Action.Melee, Action.Range, Action.Mage)) - ent.skills.style = style + # Select combat focus + style = random.choice((Action.Melee, Action.Range, Action.Mage)) + ent.skills.style = style - # Compute level - level = 0 - if config.PROGRESSION_SYSTEM_ENABLED: - level_min = config.NPC_LEVEL_MIN - level_max = config.NPC_LEVEL_MAX - level = int(danger * (level_max - level_min) + level_min) + # Compute level + level = 0 + if config.PROGRESSION_SYSTEM_ENABLED: + level_min = config.NPC_LEVEL_MIN + level_max = config.NPC_LEVEL_MAX + level = int(danger * (level_max - level_min) + level_min) - # Set skill levels - if style == Action.Melee: - ent.skills.melee.setExpByLevel(level) - elif style == Action.Range: - ent.skills.range.setExpByLevel(level) - elif style == Action.Mage: - ent.skills.mage.setExpByLevel(level) + # Set skill levels + if style == Action.Melee: + ent.skills.melee.set_experience_by_level(level) + elif style == Action.Range: + ent.skills.range.set_experience_by_level(level) + elif style == Action.Mage: + ent.skills.mage.set_experience_by_level(level) - # Gold - if config.EXCHANGE_SYSTEM_ENABLED: - ent.inventory.gold.quantity.update(level) + # Gold + if config.EXCHANGE_SYSTEM_ENABLED: + # pylint: disable=no-member + ent.gold.update(level) - ent.droptable = droptable.Standard() + ent.droptable = droptable.Standard() - # Equipment to instantiate - if config.EQUIPMENT_SYSTEM_ENABLED: - lvl = level - random.random() - ilvl = int(5 * lvl) + # Equipment to instantiate + if config.EQUIPMENT_SYSTEM_ENABLED: + lvl = level - random.random() + ilvl = int(5 * lvl) - offense = int(config.NPC_BASE_DAMAGE + lvl*config.NPC_LEVEL_DAMAGE) - defense = int(config.NPC_BASE_DEFENSE + lvl*config.NPC_LEVEL_DEFENSE) + offense = int(config.NPC_BASE_DAMAGE + lvl*config.NPC_LEVEL_DAMAGE) + defense = int(config.NPC_BASE_DEFENSE + lvl*config.NPC_LEVEL_DEFENSE) - ent.equipment = Equipment(ilvl, offense, offense, offense, defense, defense, defense) + ent.equipment = Equipment(ilvl, offense, offense, offense, defense, defense, defense) - armor = [Item.Hat, Item.Top, Item.Bottom] - ent.droptable.add(random.choice(armor)) + armor = [Item.Hat, Item.Top, Item.Bottom] + ent.droptable.add(random.choice(armor)) - if config.PROFESSION_SYSTEM_ENABLED: - tools = [Item.Rod, Item.Gloves, Item.Pickaxe, Item.Chisel, Item.Arcane] - ent.droptable.add(random.choice(tools)) + if config.PROFESSION_SYSTEM_ENABLED: + tools = [Item.Rod, Item.Gloves, Item.Pickaxe, Item.Chisel, Item.Arcane] + ent.droptable.add(random.choice(tools)) - return ent + return ent - def packet(self): - data = super().packet() + def packet(self): + data = super().packet() - data['base'] = self.base.packet() - data['skills'] = self.skills.packet() - data['resource'] = {'health': self.resources.health.packet()} + data['skills'] = self.skills.packet() + data['resource'] = { 'health': { + 'val': self.resources.health.val, 'max': self.config.PLAYER_BASE_HEALTH } } - return data + return data - @property - def isNPC(self) -> bool: - return True + @property + def is_npc(self) -> bool: + return True class Passive(NPC): - def __init__(self, realm, pos, iden): - super().__init__(realm, pos, iden, 'Passive', Neon.GREEN, -1) - self.dataframe.init(nmmo.Serialized.Entity, iden, pos) + def __init__(self, realm, pos, iden): + super().__init__(realm, pos, iden, 'Passive', 1) - def decide(self, realm): - return ai.policy.passive(realm, self) + def decide(self, realm): + return policy.passive(realm, self) class PassiveAggressive(NPC): - def __init__(self, realm, pos, iden): - super().__init__(realm, pos, iden, 'Neutral', Neon.ORANGE, -2) - self.dataframe.init(nmmo.Serialized.Entity, iden, pos) + def __init__(self, realm, pos, iden): + super().__init__(realm, pos, iden, 'Neutral', 2) - def decide(self, realm): - return ai.policy.neutral(realm, self) + def decide(self, realm): + return policy.neutral(realm, self) class Aggressive(NPC): - def __init__(self, realm, pos, iden): - super().__init__(realm, pos, iden, 'Hostile', Neon.RED, -3) - self.dataframe.init(nmmo.Serialized.Entity, iden, pos) + def __init__(self, realm, pos, iden): + super().__init__(realm, pos, iden, 'Hostile', 3) - def decide(self, realm): - return ai.policy.hostile(realm, self) + def decide(self, realm): + return policy.hostile(realm, self) diff --git a/nmmo/entity/player.py b/nmmo/entity/player.py index fba68049a..cbb900295 100644 --- a/nmmo/entity/player.py +++ b/nmmo/entity/player.py @@ -1,137 +1,136 @@ -import numpy as np -from pdb import set_trace as T - -import nmmo -from nmmo.systems import ai, equipment, inventory -from nmmo.lib import material - from nmmo.systems.skill import Skills from nmmo.systems.achievement import Diary -from nmmo.systems import combat from nmmo.entity import entity +# pylint: disable=no-member class Player(entity.Entity): - def __init__(self, realm, pos, agent, color, pop): - super().__init__(realm, pos, agent.iden, agent.policy, color, pop) - self.agent = agent - self.pop = pop - - # Scripted hooks - self.target = None - self.food = None - self.water = None - self.vision = 7 - - # Logs - self.buys = 0 - self.sells = 0 - self.ration_consumed = 0 - self.poultice_consumed = 0 - self.ration_level_consumed = 0 - self.poultice_level_consumed = 0 - - # Submodules - self.skills = Skills(realm, self) - - self.diary = None - tasks = realm.config.TASKS - if tasks: - self.diary = Diary(tasks) - - self.dataframe.init(nmmo.Serialized.Entity, self.entID, self.pos) - - @property - def serial(self): - return self.population, self.entID - - @property - def isPlayer(self) -> bool: + def __init__(self, realm, pos, agent): + super().__init__(realm, pos, agent.iden, agent.policy) + + self.agent = agent + self.immortal = realm.config.IMMORTAL + + # Scripted hooks + self.target = None + self.vision = 7 + + # Logs + self.buys = 0 + self.sells = 0 + self.ration_consumed = 0 + self.poultice_consumed = 0 + self.ration_level_consumed = 0 + self.poultice_level_consumed = 0 + + # Submodules + self.skills = Skills(realm, self) + + # Gold: initialize with 1 gold, like the old nmmo + # CHECK ME: should the initial amount be in the config? + if realm.config.EXCHANGE_SYSTEM_ENABLED: + self.gold.update(1) + + self.diary = None + tasks = realm.config.TASKS + if tasks: + self.diary = Diary(self, tasks) + + @property + def serial(self): + return self.ent_id + + @property + def is_player(self) -> bool: + return True + + @property + def level(self) -> int: + # a player's level is the max of all skills + # CHECK ME: the initial level is 1 because of Basic skills, + # which are harvesting food/water and don't progress + return max(e.level.val for e in self.skills.skills) + + def apply_damage(self, dmg, style): + super().apply_damage(dmg, style) + self.skills.apply_damage(style) + + # TODO(daveey): The returns for this function are a mess + def receive_damage(self, source, dmg): + if self.immortal: + return False + + # super().receive_damage returns True if self is alive after taking dmg + if super().receive_damage(source, dmg): return True - @property - def population(self): - if __debug__: - assert self.base.population.val == self.pop - return self.pop - - @property - def level(self) -> int: - return combat.level(self.skills) - - def applyDamage(self, dmg, style): - super().applyDamage(dmg, style) - self.skills.applyDamage(dmg, style) - - def receiveDamage(self, source, dmg): - if super().receiveDamage(source, dmg): - return True - - if not self.config.ITEM_SYSTEM_ENABLED: - return False - - for item in list(self.inventory._item_references): - if not item.quantity.val: - continue - - self.inventory.remove(item) - source.inventory.receive(item) - - if not super().receiveDamage(source, dmg): - if source: - source.history.playerKills += 1 - return - - self.skills.receiveDamage(dmg) - - @property - def equipment(self): - return self.inventory.equipment - - def packet(self): - data = super().packet() - - data['entID'] = self.entID - data['annID'] = self.population - - data['base'] = self.base.packet() - data['resource'] = self.resources.packet() - data['skills'] = self.skills.packet() - data['inventory'] = self.inventory.packet() - - return data - - def update(self, realm, actions): - '''Post-action update. Do not include history''' - super().update(realm, actions) - - # Spawsn battle royale style death fog - # Starts at 0 damage on the specified config tick - # Moves in from the edges by 1 damage per tile per tick - # So after 10 ticks, you take 10 damage at the edge and 1 damage - # 10 tiles in, 0 damage in farther - # This means all agents will be force killed around - # MAP_CENTER / 2 + 100 ticks after spawning - fog = self.config.PLAYER_DEATH_FOG - if fog is not None and self.realm.tick >= fog: - r, c = self.pos - cent = self.config.MAP_BORDER + self.config.MAP_CENTER // 2 - - # Distance from center of the map - dist = max(abs(r - cent), abs(c - cent)) - - # Safe final area - if dist > self.config.PLAYER_DEATH_FOG_FINAL_SIZE: - # Damage based on time and distance from center - time_dmg = self.config.PLAYER_DEATH_FOG_SPEED * (self.realm.tick - fog + 1) - dist_dmg = dist - self.config.MAP_CENTER // 2 - dmg = max(0, dist_dmg + time_dmg) - self.receiveDamage(None, dmg) - - if not self.alive: - return - - self.resources.update(realm, self, actions) - self.skills.update(realm, self) - - if self.diary: - self.diary.update(realm, self) + if not self.config.ITEM_SYSTEM_ENABLED: + return False + + # starting from here, source receive gold & inventory items + if self.config.EXCHANGE_SYSTEM_ENABLED: + source.gold.increment(self.gold.val) + self.gold.update(0) + + # TODO(kywch): make source receive the highest-level items first + # because source cannot take it if the inventory is full + # Also, destroy the remaining items if the source cannot take those + for item in list(self.inventory.items): + self.inventory.remove(item) + + # if source doesn't have space, inventory.receive() destroys the item + source.inventory.receive(item) + + # CHECK ME: this is an empty function. do we still need this? + self.skills.receive_damage(dmg) + return False + + @property + def equipment(self): + return self.inventory.equipment + + def packet(self): + data = super().packet() + + data['entID'] = self.ent_id + + data['resource'] = self.resources.packet() + data['skills'] = self.skills.packet() + data['inventory'] = self.inventory.packet() + + return data + + def update(self, realm, actions): + '''Post-action update. Do not include history''' + super().update(realm, actions) + + # Spawsn battle royale style death fog + # Starts at 0 damage on the specified config tick + # Moves in from the edges by 1 damage per tile per tick + # So after 10 ticks, you take 10 damage at the edge and 1 damage + # 10 tiles in, 0 damage in farther + # This means all agents will be force killed around + # MAP_CENTER / 2 + 100 ticks after spawning + fog = self.config.PLAYER_DEATH_FOG + if fog is not None and self.realm.tick >= fog: + row, col = self.pos + cent = self.config.MAP_BORDER + self.config.MAP_CENTER // 2 + + # Distance from center of the map + dist = max(abs(row - cent), abs(col - cent)) + + # Safe final area + if dist > self.config.PLAYER_DEATH_FOG_FINAL_SIZE: + # Damage based on time and distance from center + time_dmg = self.config.PLAYER_DEATH_FOG_SPEED * (self.realm.tick - fog + 1) + dist_dmg = dist - self.config.MAP_CENTER // 2 + dmg = max(0, dist_dmg + time_dmg) + self.receive_damage(None, dmg) + + if not self.alive: + return + + self.resources.update() + self.skills.update() + + if self.diary: + self.diary.update(realm) diff --git a/nmmo/infrastructure.py b/nmmo/infrastructure.py deleted file mode 100644 index 4de687e63..000000000 --- a/nmmo/infrastructure.py +++ /dev/null @@ -1,279 +0,0 @@ -'''Infrastructure layer for representing agent observations - -Maintains a synchronized + serialized representation of agent observations in -flat tensors. This allows for fast observation processing as a set of tensor -slices instead of a lengthy traversal over hundreds of game properties. - -Synchronization bugs are notoriously difficult to track down: make sure -to follow the correct instantiation protocol, e.g. as used for defining -agent/tile observations, when adding new types observations to the code''' - -from pdb import set_trace as T -import numpy as np - -from collections import defaultdict - -import nmmo - -class DataType: - CONTINUOUS = np.float32 - DISCRETE = np.int32 - -class Index: - '''Lookup index of attribute names''' - def __init__(self, prealloc): - self.free = {idx for idx in range(1, prealloc)} - self.index = {} - self.back = {} - - def full(self): - return len(self.free) == 0 - - def remove(self, key): - row = self.index[key] - del self.index[key] - del self.back[row] - - self.free.add(row) - return row - - def update(self, key): - if key in self.index: - row = self.index[key] - else: - row = self.free.pop() - self.index[key] = row - self.back[row] = key - - return row - - def get(self, key): - return self.index[key] - - def teg(self, row): - return self.back[row] - - def expand(self, cur, nxt): - self.free.update({idx for idx in range(cur, nxt)}) - -class ContinuousTable: - '''Flat tensor representation for a set of continuous attributes''' - def __init__(self, config, obj, prealloc, dtype=DataType.CONTINUOUS): - self.config = config - self.dtype = dtype - self.cols = {} - self.nCols = 0 - - for (attribute,), attr in obj: - self.initAttr(attribute, attr) - - self.data = self.initData(prealloc, self.nCols) - - def initAttr(self, key, attr): - if attr.CONTINUOUS: - self.cols[key] = self.nCols - self.nCols += 1 - - def initData(self, nRows, nCols): - return np.zeros((nRows, nCols), dtype=self.dtype) - - def update(self, row, attr, val): - col = self.cols[attr] - self.data[row, col] = val - - def expand(self, cur, nxt): - data = self.initData(nxt, self.nCols) - data[:cur] = self.data - - self.data = data - self.nRows = nxt - - def get(self, rows, pad=None): - data = self.data[rows] - data[rows==0] = 0 - - if pad is not None: - data = np.pad(data, ((0, pad-len(data)), (0, 0))) - - return data - -class DiscreteTable(ContinuousTable): - '''Flat tensor representation for a set of discrete attributes''' - def __init__(self, config, obj, prealloc, dtype=DataType.DISCRETE): - self.discrete, self.cumsum = {}, 0 - super().__init__(config, obj, prealloc, dtype) - - def initAttr(self, key, attr): - if not attr.DISCRETE: - return - - self.cols[key] = self.nCols - - #Flat index - attr = attr(None, None, 0, config=self.config) - self.discrete[key] = self.cumsum - - self.cumsum += attr.max - attr.min + 1 - self.nCols += 1 - - def update(self, row, attr, val): - col = self.cols[attr] - self.data[row, col] = val + self.discrete[attr] - -class Grid: - '''Flat representation of tile/agent positions''' - def __init__(self, R, C): - self.data = np.zeros((R, C), dtype=np.int32) - - def zero(self, pos): - r, c = pos - self.data[r, c] = 0 - - def set(self, pos, val): - r, c = pos - self.data[r, c] = val - - def move(self, pos, nxt, row): - self.zero(pos) - self.set(nxt, row) - - def window(self, rStart, rEnd, cStart, cEnd): - crop = self.data[rStart:rEnd, cStart:cEnd].ravel() - return list(filter(lambda x: x != 0, crop)) - -class GridTables: - '''Combines a Grid + Index + Continuous and Discrete tables - - Together, these data structures provide a robust and efficient - flat tensor representation of an entire class of observations, - such as agents or tiles''' - def __init__(self, config, obj, pad, prealloc=1000, expansion=2): - self.grid = Grid(config.MAP_SIZE, config.MAP_SIZE) - self.continuous = ContinuousTable(config, obj, prealloc) - self.discrete = DiscreteTable(config, obj, prealloc) - self.index = Index(prealloc) - - self.nRows = prealloc - self.expansion = expansion - self.radius = config.PLAYER_VISION_RADIUS - self.pad = pad - - def get(self, ent, radius=None, entity=False): - if radius is None: - radius = self.radius - - r, c = ent.pos - cent = self.grid.data[r, c] - - if __debug__: - assert cent != 0 - - rows = self.grid.window( - r-radius, r+radius+1, - c-radius, c+radius+1) - - #Self entity first - if entity: - rows.remove(cent) - rows.insert(0, cent) - - values = {'Continuous': self.continuous.get(rows, self.pad), - 'Discrete': self.discrete.get(rows, self.pad)} - - if entity: - ents = [self.index.teg(e) for e in rows] - if __debug__: - assert ents[0] == ent.entID - return values, ents - - return values - - def getFlat(self, keys): - if __debug__: - err = f'Dataframe got {len(keys)} keys with pad {self.pad}' - assert len(keys) <= self.pad, err - - rows = [self.index.get(key) for key in keys[:self.pad]] - values = {'Continuous': self.continuous.get(rows, self.pad), - 'Discrete': self.discrete.get(rows, self.pad)} - return values - - def update(self, obj, val): - key, attr = obj.key, obj.attr - if self.index.full(): - cur = self.nRows - self.nRows = cur * self.expansion - - self.index.expand(cur, self.nRows) - self.continuous.expand(cur, self.nRows) - self.discrete.expand(cur, self.nRows) - - row = self.index.update(key) - if obj.DISCRETE: - self.discrete.update(row, attr, val - obj.min) - if obj.CONTINUOUS: - self.continuous.update(row, attr, val) - - def move(self, key, pos, nxt): - row = self.index.get(key) - self.grid.move(pos, nxt, row) - - def init(self, key, pos): - if pos is None: - return - - row = self.index.get(key) - self.grid.set(pos, row) - - def remove(self, key, pos): - self.index.remove(key) - self.grid.zero(pos) - -class Dataframe: - '''Infrastructure wrapper class''' - def __init__(self, realm): - config = realm.config - self.config = config - self.data = defaultdict(dict) - - for (objKey,), obj in nmmo.Serialized: - if not obj.enabled(config): - continue - self.data[objKey] = GridTables(config, obj, pad=obj.N(config)) - - self.realm = realm - - def update(self, node, val): - self.data[node.obj].update(node, val) - - def remove(self, obj, key, pos): - self.data[obj.__name__].remove(key, pos) - - def init(self, obj, key, pos): - self.data[obj.__name__].init(key, pos) - - def move(self, obj, key, pos, nxt): - self.data[obj.__name__].move(key, pos, nxt) - - def get(self, ent): - stim = {} - - stim['Entity'], ents = self.data['Entity'].get(ent, entity=True) - stim['Entity']['N'] = np.array([len(ents)], dtype=np.int32) - - ent.targets = ents - stim['Tile'] = self.data['Tile'].get(ent) - stim['Tile']['N'] = np.array([self.config.PLAYER_VISION_DIAMETER], dtype=np.int32) - - #Current must have the same pad - if self.config.ITEM_SYSTEM_ENABLED: - items = ent.inventory.dataframeKeys - stim['Item'] = self.data['Item'].getFlat(items) - stim['Item']['N'] = np.array([len(items)], dtype=np.int32) - - if self.config.EXCHANGE_SYSTEM_ENABLED: - market = self.realm.exchange.dataframeKeys - stim['Market'] = self.data['Item'].getFlat(market) - stim['Market']['N'] = np.array([len(market)], dtype=np.int32) - - return stim diff --git a/nmmo/integrations.py b/nmmo/integrations.py deleted file mode 100644 index c4fa228a3..000000000 --- a/nmmo/integrations.py +++ /dev/null @@ -1,160 +0,0 @@ -from pdb import set_trace as T - -import nmmo -from nmmo import Env - -def rllib_env_cls(): - try: - from ray import rllib - except ImportError: - raise ImportError('Integrations depend on rllib. Install ray[rllib] and then retry') - class RLlibEnv(Env, rllib.MultiAgentEnv): - def __init__(self, config): - self.config = config['config'] - self.config.EMULATE_CONST_HORIZON = True - super().__init__(self.config) - - def render(self): - #Patrch of RLlib dupe rendering bug - if not self.config.RENDER: - return - - super().render() - - def step(self, actions): - obs, rewards, dones, infos = super().step(actions) - - population = len(self.realm.players) == 0 - hit_horizon = self.realm.tick >= self.config.EMULATE_CONST_HORIZON - - dones['__all__'] = False - if not self.config.RENDER and (hit_horizon or population): - dones['__all__'] = True - - return obs, rewards, dones, infos - - return RLlibEnv - -class SB3Env(Env): - def __init__(self, config, seed=None): - config.EMULATE_FLAT_OBS = True - config.EMULATE_FLAT_ATN = True - config.EMULATE_CONST_PLAYER_N = True - config.EMULATE_CONST_HORIZON = True - - super().__init__(config, seed=seed) - - def step(self, actions): - assert type(actions) == dict - - obs, rewards, dones, infos = super().step(actions) - - if self.realm.tick >= self.config.HORIZON or len(self.realm.players) == 0: - # Cheat logs into infos - stats = self.terminal() - stats = {**stats['Env'], **stats['Player'], **stats['Milestone'], **stats['Event']} - - infos[1]['logs'] = stats - - return obs, rewards, dones, infos - -class CleanRLEnv(SB3Env): - def __init__(self, config, seed=None): - super().__init__(config, seed=seed) - -def sb3_vec_envs(config_cls, num_envs, num_cpus): - try: - import supersuit as ss - except ImportError: - raise ImportError('SB3 integration depend on supersuit. Install and then retry') - - config = config_cls() - env = SB3Env(config) - - env = ss.pettingzoo_env_to_vec_env_v1(env) - env.black_death = True #We provide our own black_death emulation - env = ss.concat_vec_envs_v1(env, num_envs, num_cpus, - base_class='stable_baselines3') - - return env - -def cleanrl_vec_envs(config_classes, verbose=True): - '''Creates a vector environment object from a list of configs. - - Each subenv points to a single agent, but many agents can share the same env. - All envs must have the same observation and action space, but they can have - different numbers of agents''' - - try: - import supersuit as ss - except ImportError: - raise ImportError('CleanRL integration depend on supersuit. Install and then retry') - - def make_env_fn(config_cls): - '''Wraps the make_env fn to add a a config argument''' - def make_env(): - config = config_cls() - env = CleanRLEnv(config) - - env = ss.pettingzoo_env_to_vec_env_v1(env) - env.black_death = True #We provide our own black_death emulation - - env = ss.concat_vec_envs_v1(env, - config.NUM_ENVS // config.PLAYER_N, - config.NUM_CPUS, - base_class='gym') - - env.single_observation_space = env.observation_space - env.single_action_space = env.action_space - env.is_vector_env = True - - return env - return make_env - - dummy_env = None - all_envs = [] - - num_cpus = 0 - num_envs = 0 - num_agents = 0 - - if type(config_classes) != list: - config_classes = [config_classes] - - for idx, cls in enumerate(config_classes): - assert isinstance(cls, type), 'config_cls must be a type (did ytou pass an instance?)' - assert hasattr(cls, 'NUM_ENVS'), f'config class {cls} must define NUM_ENVS' - assert hasattr(cls, 'NUM_CPUS'), f'config class {cls} must define NUM_CPUS' - assert isinstance(cls, type), f'config class {cls} must be a type (did you pass an instance?)' - - if dummy_env is None: - config = cls() - dummy_env = CleanRLEnv(config) - - #neural = [e == nmmo.Agent for e in cls.PLAYERS] - #n_neural = sum(neural) / len(neural) * config.NUM_ENVS - #assert int(n_neural) == n_neural, f'{sum(neural)} neural agents and {cls.PLAYER_N} classes' - #n_neural = int(n_neural) - - envs = make_env_fn(cls)#, n_neural) - all_envs.append(envs) - - # TODO: Find a cleaner way to specify env scale that enables multiple envs per CPU - # without having to pass multiple configs - num_cpus += cls.NUM_CPUS - num_envs += cls.NUM_CPUS - num_agents += cls.NUM_CPUS * cls.PLAYER_N - - - - envs = ss.vector.ProcConcatVec(all_envs, - dummy_env.observation_space(1), - dummy_env.action_space(1), - num_agents, - dummy_env.metadata) - envs.is_vector_env = True - - if verbose: - print(f'nmmo.integrations.cleanrl_vec_envs created {num_envs} envs across {num_cpus} cores') - - return envs diff --git a/nmmo/io/action.py b/nmmo/io/action.py index 1ddfbde6b..434c0e5c4 100644 --- a/nmmo/io/action.py +++ b/nmmo/io/action.py @@ -1,440 +1,711 @@ -from pdb import set_trace as T -from ordered_set import OrderedSet -import numpy as np +# CHECK ME: Should these be fixed as well? +# pylint: disable=no-method-argument,unused-argument,no-self-argument,no-member from enum import Enum, auto +from ordered_set import OrderedSet + +import numpy as np -import nmmo from nmmo.lib import utils from nmmo.lib.utils import staticproperty +from nmmo.systems.item import Item, Stack +from nmmo.lib.log import EventCode + class NodeType(Enum): - #Tree edges - STATIC = auto() #Traverses all edges without decisions - SELECTION = auto() #Picks an edge to follow + #Tree edges + STATIC = auto() #Traverses all edges without decisions + SELECTION = auto() #Picks an edge to follow - #Executable actions - ACTION = auto() #No arguments - CONSTANT = auto() #Constant argument - VARIABLE = auto() #Variable argument + #Executable actions + ACTION = auto() #No arguments + CONSTANT = auto() #Constant argument + VARIABLE = auto() #Variable argument class Node(metaclass=utils.IterableNameComparable): - @classmethod - def init(cls, config): - pass + @classmethod + def init(cls, config): + pass - @staticproperty - def edges(): - return [] + @staticproperty + def edges(): + return [] - #Fill these in - @staticproperty - def priority(): - return None + #Fill these in + @staticproperty + def priority(): + return None - @staticproperty - def type(): - return None + @staticproperty + def type(): + return None - @staticproperty - def leaf(): - return False + @staticproperty + def leaf(): + return False - @classmethod - def N(cls, config): - return len(cls.edges) + @classmethod + def N(cls, config): + return len(cls.edges) - def deserialize(realm, entity, index): - return index + def deserialize(realm, entity, index): + return index - def args(stim, entity, config): - return [] + def args(stim, entity, config): + return [] class Fixed: - pass + pass #ActionRoot class Action(Node): - nodeType = NodeType.SELECTION - hooked = False - - @classmethod - def init(cls, config): - # Sets up serialization domain - if Action.hooked: - return + nodeType = NodeType.SELECTION + hooked = False + + @classmethod + def init(cls, config): + # Sets up serialization domain + if Action.hooked: + return + + Action.hooked = True + + #Called upon module import (see bottom of file) + #Sets up serialization domain + def hook(config): + idx = 0 + arguments = [] + for action in Action.edges(config): + action.init(config) + for args in action.edges: + args.init(config) + if not 'edges' in args.__dict__: + continue + for arg in args.edges: + arguments.append(arg) + arg.serial = tuple([idx]) + arg.idx = idx + idx += 1 + Action.arguments = arguments + + @staticproperty + def n(): + return len(Action.arguments) + + # pylint: disable=invalid-overridden-method + @classmethod + def edges(cls, config): + '''List of valid actions''' + edges = [Move] + if config.COMBAT_SYSTEM_ENABLED: + edges.append(Attack) + if config.ITEM_SYSTEM_ENABLED: + edges += [Use, Give, Destroy] + if config.EXCHANGE_SYSTEM_ENABLED: + edges += [Buy, Sell, GiveGold] + if config.COMMUNICATION_SYSTEM_ENABLED: + edges.append(Comm) + return edges + + def args(stim, entity, config): + raise NotImplementedError - Action.hooked = True - - #Called upon module import (see bottom of file) - #Sets up serialization domain - def hook(config): - idx = 0 - arguments = [] - for action in Action.edges(config): - action.init(config) - for args in action.edges: - args.init(config) - if not 'edges' in args.__dict__: - continue - for arg in args.edges: - arguments.append(arg) - arg.serial = tuple([idx]) - arg.idx = idx - idx += 1 - Action.arguments = arguments - - @staticproperty - def n(): - return len(Action.arguments) - - @classmethod - def edges(cls, config): - '''List of valid actions''' - edges = [Move] - if config.COMBAT_SYSTEM_ENABLED: - edges.append(Attack) - if config.ITEM_SYSTEM_ENABLED: - edges += [Use] - if config.EXCHANGE_SYSTEM_ENABLED: - edges += [Buy, Sell] - if config.COMMUNICATION_SYSTEM_ENABLED: - edges.append(Comm) - return edges - - def args(stim, entity, config): - return nmmo.Serialized.edges class Move(Node): - priority = 1 - nodeType = NodeType.SELECTION - def call(env, entity, direction): - r, c = entity.pos - entID = entity.entID - entity.history.lastPos = (r, c) - rDelta, cDelta = direction.delta - rNew, cNew = r+rDelta, c+cDelta - - #One agent per cell - tile = env.map.tiles[rNew, cNew] - if tile.occupied and not tile.lava: - return - - if entity.status.freeze > 0: - return - - env.dataframe.move(nmmo.Serialized.Entity, entID, (r, c), (rNew, cNew)) - entity.base.r.update(rNew) - entity.base.c.update(cNew) - - env.map.tiles[r, c].delEnt(entID) - env.map.tiles[rNew, cNew].addEnt(entity) - - if env.map.tiles[rNew, cNew].lava: - entity.receiveDamage(None, entity.resources.health.val) - - @staticproperty - def edges(): - return [Direction] - - @staticproperty - def leaf(): - return True + priority = 60 + nodeType = NodeType.SELECTION + def call(realm, entity, direction): + if direction is None: + return + + assert entity.alive, "Dead entity cannot act" + + r, c = entity.pos + ent_id = entity.ent_id + entity.history.last_pos = (r, c) + r_delta, c_delta = direction.delta + r_new, c_new = r+r_delta, c+c_delta + + # CHECK ME: lava-jumping agents in the tutorial no longer works + if realm.map.tiles[r_new, c_new].impassible: + return + + if entity.status.freeze > 0: + return + + entity.row.update(r_new) + entity.col.update(c_new) + + realm.map.tiles[r, c].remove_entity(ent_id) + realm.map.tiles[r_new, c_new].add_entity(entity) + + # exploration record keeping. moved from entity.py, History.update() + dist_from_spawn = utils.linf(entity.spawn_pos, (r_new, c_new)) + if dist_from_spawn > entity.history.exploration: + entity.history.exploration = dist_from_spawn + if entity.is_player: + realm.event_log.record(EventCode.GO_FARTHEST, entity, + distance=dist_from_spawn) + + # CHECK ME: material.Impassible includes lava, so this line is not reachable + if realm.map.tiles[r_new, c_new].lava: + entity.receive_damage(None, entity.resources.health.val) + + @staticproperty + def edges(): + return [Direction] + + @staticproperty + def leaf(): + return True + + def enabled(config): + return True class Direction(Node): - argType = Fixed + argType = Fixed - @staticproperty - def edges(): - return [North, South, East, West] + @staticproperty + def edges(): + return [North, South, East, West, Stay] - def args(stim, entity, config): - return Direction.edges + def args(stim, entity, config): + return Direction.edges + + def deserialize(realm, entity, index): + return deserialize_fixed_arg(Direction, index) + +# a quick helper function +def deserialize_fixed_arg(arg, index): + if isinstance(index, (int, np.int64)): + if index < 0: + return None # so that the action will be discarded + val = min(index-1, len(arg.edges)-1) + return arg.edges[val] + + # if index is not int, it's probably already deserialized + if index not in arg.edges: + return None # so that the action will be discarded + return index class North(Node): - delta = (-1, 0) + delta = (-1, 0) class South(Node): - delta = (1, 0) + delta = (1, 0) class East(Node): - delta = (0, 1) + delta = (0, 1) class West(Node): - delta = (0, -1) + delta = (0, -1) +class Stay(Node): + delta = (0, 0) class Attack(Node): - priority = 0 - nodeType = NodeType.SELECTION - @staticproperty - def n(): - return 3 - - @staticproperty - def edges(): - return [Style, Target] - - @staticproperty - def leaf(): - return True - - def inRange(entity, stim, config, N): - R, C = stim.shape - R, C = R//2, C//2 - - rets = OrderedSet([entity]) - for r in range(R-N, R+N+1): - for c in range(C-N, C+N+1): - for e in stim[r, c].ents.values(): - rets.add(e) - - rets = list(rets) - return rets - - def l1(pos, cent): - r, c = pos - rCent, cCent = cent - return abs(r - rCent) + abs(c - cCent) - - def call(env, entity, style, targ): - config = env.config - - if entity.isPlayer and not config.COMBAT_SYSTEM_ENABLED: - return - - # Testing a spawn immunity against old agents to avoid spawn camping - immunity = config.COMBAT_SPAWN_IMMUNITY - if entity.isPlayer and targ.isPlayer and entity.history.timeAlive.val > immunity and targ.history.timeAlive < immunity: - return - - #Check if self targeted - if entity.entID == targ.entID: - return - - #ADDED: POPULATION IMMUNITY - if not config.COMBAT_FRIENDLY_FIRE and entity.isPlayer and entity.base.population.val == targ.base.population.val: - return - - #Check attack range - rng = style.attackRange(config) - start = np.array(entity.base.pos) - end = np.array(targ.base.pos) - dif = np.max(np.abs(start - end)) - - #Can't attack same cell or out of range - if dif == 0 or dif > rng: - return - - #Execute attack - entity.history.attack = {} - entity.history.attack['target'] = targ.entID - entity.history.attack['style'] = style.__name__ - targ.attacker = entity - targ.attackerID.update(entity.entID) - - from nmmo.systems import combat - dmg = combat.attack(env, entity, targ, style.skill) - - if style.freeze and dmg > 0: - targ.status.freeze.update(config.COMBAT_FREEZE_TIME) - - return dmg + priority = 50 + nodeType = NodeType.SELECTION + @staticproperty + def n(): + return 3 + + @staticproperty + def edges(): + return [Style, Target] + + @staticproperty + def leaf(): + return True + + def enabled(config): + return config.COMBAT_SYSTEM_ENABLED + + def in_range(entity, stim, config, N): + R, C = stim.shape + R, C = R//2, C//2 + + rets = OrderedSet([entity]) + for r in range(R-N, R+N+1): + for c in range(C-N, C+N+1): + for e in stim[r, c].entities.values(): + rets.add(e) + + rets = list(rets) + return rets + + # CHECK ME: do we need l1 distance function? + # systems/ai/utils.py also has various distance functions + # which we may want to clean up + # def l1(pos, cent): + # r, c = pos + # r_cent, c_cent = cent + # return abs(r - r_cent) + abs(c - c_cent) + + def call(realm, entity, style, target): + if style is None or target is None: + return None + + assert entity.alive, "Dead entity cannot act" + + config = realm.config + if entity.is_player and not config.COMBAT_SYSTEM_ENABLED: + return None + + # Testing a spawn immunity against old agents to avoid spawn camping + immunity = config.COMBAT_SPAWN_IMMUNITY + if entity.is_player and target.is_player and \ + target.history.time_alive < immunity < entity.history.time_alive.val: + return None + + #Check if self targeted + if entity.ent_id == target.ent_id: + return None + + #Can't attack out of range + if utils.linf(entity.pos, target.pos) > style.attack_range(config): + return None + + #Execute attack + entity.history.attack = {} + entity.history.attack['target'] = target.ent_id + entity.history.attack['style'] = style.__name__ + target.attacker = entity + target.attacker_id.update(entity.ent_id) + + from nmmo.systems import combat + dmg = combat.attack(realm, entity, target, style.skill) + + if style.freeze and dmg > 0: + target.status.freeze.update(config.COMBAT_FREEZE_TIME) + + # record the combat tick for both entities + # players and npcs both have latest_combat_tick in EntityState + for ent in [entity, target]: + ent.latest_combat_tick.update(realm.tick + 1) # because the tick is about to increment + + return dmg class Style(Node): - argType = Fixed - @staticproperty - def edges(): - return [Melee, Range, Mage] + argType = Fixed + @staticproperty + def edges(): + return [Melee, Range, Mage] - def args(stim, entity, config): - return Style.edges + def args(stim, entity, config): + return Style.edges + + def deserialize(realm, entity, index): + return deserialize_fixed_arg(Style, index) class Target(Node): - argType = None + argType = None - @classmethod - def N(cls, config): - #return config.WINDOW ** 2 - return config.PLAYER_N_OBS + @classmethod + def N(cls, config): + return config.PLAYER_N_OBS - def deserialize(realm, entity, index): - return realm.entity(index) + def deserialize(realm, entity, index: int): + # NOTE: index is the entity id + # CHECK ME: should index be renamed to ent_id? + return realm.entity_or_none(index) - def args(stim, entity, config): - #Should pass max range? - return Attack.inRange(entity, stim, config, None) + def args(stim, entity, config): + #Should pass max range? + return Attack.in_range(entity, stim, config, None) class Melee(Node): - nodeType = NodeType.ACTION - freeze=False + nodeType = NodeType.ACTION + freeze=False - def attackRange(config): - return config.COMBAT_MELEE_REACH + def attack_range(config): + return config.COMBAT_MELEE_REACH - def skill(entity): - return entity.skills.melee + def skill(entity): + return entity.skills.melee class Range(Node): - nodeType = NodeType.ACTION - freeze=False + nodeType = NodeType.ACTION + freeze=False - def attackRange(config): - return config.COMBAT_RANGE_REACH + def attack_range(config): + return config.COMBAT_RANGE_REACH - def skill(entity): - return entity.skills.range + def skill(entity): + return entity.skills.range class Mage(Node): - nodeType = NodeType.ACTION - freeze=False + nodeType = NodeType.ACTION + freeze=False + + def attack_range(config): + return config.COMBAT_MAGE_REACH + + def skill(entity): + return entity.skills.mage - def attackRange(config): - return config.COMBAT_MAGE_REACH - def skill(entity): - return entity.skills.mage +class InventoryItem(Node): + argType = None + + @classmethod + def N(cls, config): + return config.INVENTORY_N_OBS + + # TODO(kywch): What does args do? + def args(stim, entity, config): + return stim.exchange.items() + + def deserialize(realm, entity, index: int): + # NOTE: index is from the inventory, NOT item id + inventory = Item.Query.owned_by(realm.datastore, entity.id.val) + + if index >= inventory.shape[0]: + return None + + item_id = inventory[index, Item.State.attr_name_to_col["id"]] + return realm.items[item_id] class Use(Node): - priority = 3 + priority = 10 + + @staticproperty + def edges(): + return [InventoryItem] + + def enabled(config): + return config.ITEM_SYSTEM_ENABLED + + def call(realm, entity, item): + if item is None or item.owner_id.val != entity.ent_id: + return + + assert entity.alive, "Dead entity cannot act" + assert entity.is_player, "Npcs cannot use an item" + assert item.quantity.val > 0, "Item quantity cannot be 0" # indicates item leak + + if not realm.config.ITEM_SYSTEM_ENABLED: + return + + if item not in entity.inventory: + return + + if entity.in_combat: # player cannot use item during combat + return + + # cannot use listed items or items that have higher level + if item.listed_price.val > 0 or item.level_gt(entity): + return - @staticproperty - def edges(): - return [Item] + item.use(entity) - def call(env, entity, item): - if item not in entity.inventory: - return +class Destroy(Node): + priority = 40 - return item.use(entity) + @staticproperty + def edges(): + return [InventoryItem] + + def enabled(config): + return config.ITEM_SYSTEM_ENABLED + + def call(realm, entity, item): + if item is None or item.owner_id.val != entity.ent_id: + return + + assert entity.alive, "Dead entity cannot act" + assert entity.is_player, "Npcs cannot destroy an item" + assert item.quantity.val > 0, "Item quantity cannot be 0" # indicates item leak + + if not realm.config.ITEM_SYSTEM_ENABLED: + return + + if item not in entity.inventory: + return + + if item.equipped.val: # cannot destroy equipped item + return + + if entity.in_combat: # player cannot destroy item during combat + return + + item.destroy() + + realm.event_log.record(EventCode.DESTROY_ITEM, entity) class Give(Node): - priority = 2 + priority = 30 + + @staticproperty + def edges(): + return [InventoryItem, Target] + + def enabled(config): + return config.ITEM_SYSTEM_ENABLED + + def call(realm, entity, item, target): + if item is None or item.owner_id.val != entity.ent_id or target is None: + return + + assert entity.alive, "Dead entity cannot act" + assert entity.is_player, "Npcs cannot give an item" + assert item.quantity.val > 0, "Item quantity cannot be 0" # indicates item leak + + config = realm.config + if not config.ITEM_SYSTEM_ENABLED: + return + + if not (target.is_player and target.alive): + return + + if item not in entity.inventory: + return + + # cannot give the equipped or listed item + if item.equipped.val or item.listed_price.val: + return + + if entity.in_combat: # player cannot give item during combat + return + + if not (config.ITEM_ALLOW_GIFT and + entity.ent_id != target.ent_id and # but not self + target.is_player and + entity.pos == target.pos): # the same tile + return + + if not target.inventory.space: + # receiver inventory is full - see if it has an ammo stack with the same sig + if isinstance(item, Stack): + if not target.inventory.has_stack(item.signature): + # no ammo stack with the same signature, so cannot give + return + else: # no space, and item is not ammo stack, so cannot give + return + + entity.inventory.remove(item) + target.inventory.receive(item) + + realm.event_log.record(EventCode.GIVE_ITEM, entity) + + +class GiveGold(Node): + priority = 30 + + @staticproperty + def edges(): + # CHECK ME: for now using Price to indicate the gold amount to give + return [Price, Target] + + def enabled(config): + return config.EXCHANGE_SYSTEM_ENABLED + + def call(realm, entity, amount, target): + if amount is None or target is None: + return - @staticproperty - def edges(): - return [Item, Target] + assert entity.alive, "Dead entity cannot act" + assert entity.is_player, "Npcs cannot give gold" - def call(env, entity, item, target): - if item not in entity.inventory: - return + config = realm.config + if not config.EXCHANGE_SYSTEM_ENABLED: + return - if not target.isPlayer: - return + if not (target.is_player and target.alive): + return - if not target.inventory.space: - return + if entity.in_combat: # player cannot give gold during combat + return - entity.inventory.remove(item, quantity=1) - item = type(item)(env, item.level.val) - target.inventory.receive(item) + if not (config.ITEM_ALLOW_GIFT and + entity.ent_id != target.ent_id and # but not self + target.is_player and + entity.pos == target.pos): # the same tile + return - return True + if not isinstance(amount, int): + amount = amount.val + if not (amount > 0 and entity.gold.val > 0): # no gold to give + return -class Item(Node): - argType = 'Entity' + amount = min(amount, entity.gold.val) - @classmethod - def N(cls, config): - return config.ITEM_N_OBS + entity.gold.decrement(amount) + target.gold.increment(amount) - def args(stim, entity, config): - return stim.exchange.items() + realm.event_log.record(EventCode.GIVE_GOLD, entity) - def deserialize(realm, entity, index): - return realm.items[index] + +class MarketItem(Node): + argType = None + + @classmethod + def N(cls, config): + return config.MARKET_N_OBS + + # TODO(kywch): What does args do? + def args(stim, entity, config): + return stim.exchange.items() + + def deserialize(realm, entity, index: int): + # NOTE: index is from the market, NOT item id + market = Item.Query.for_sale(realm.datastore) + + if index >= market.shape[0]: + return None + + item_id = market[index, Item.State.attr_name_to_col["id"]] + return realm.items[item_id] class Buy(Node): - priority = 4 - argType = Fixed + priority = 20 + argType = Fixed - @staticproperty - def edges(): - return [Item] + @staticproperty + def edges(): + return [MarketItem] - def call(env, entity, item): - #Do not process exchange actions on death tick - if not entity.alive: - return + def enabled(config): + return config.EXCHANGE_SYSTEM_ENABLED - if not entity.inventory.space: - return + def call(realm, entity, item): + if item is None or item.owner_id.val == 0: + return + + assert entity.alive, "Dead entity cannot act" + assert entity.is_player, "Npcs cannot buy an item" + assert item.quantity.val > 0, "Item quantity cannot be 0" # indicates item leak + assert item.equipped.val == 0, 'Listed item must not be equipped' + + if not realm.config.EXCHANGE_SYSTEM_ENABLED: + return + + if entity.gold.val < item.listed_price.val: # not enough money + return + + if entity.ent_id == item.owner_id.val: # cannot buy own item + return + + if entity.in_combat: # player cannot buy item during combat + return + + if not entity.inventory.space: + # buyer inventory is full - see if it has an ammo stack with the same sig + if isinstance(item, Stack): + if not entity.inventory.has_stack(item.signature): + # no ammo stack with the same signature, so cannot give + return + else: # no space, and item is not ammo stack, so cannot give + return - return env.exchange.buy(env, entity, item) + # one can try to buy, but the listing might have gone (perhaps bought by other) + realm.exchange.buy(entity, item) class Sell(Node): - priority = 4 - argType = Fixed + priority = 70 + argType = Fixed - @staticproperty - def edges(): - return [Item, Price] + @staticproperty + def edges(): + return [InventoryItem, Price] - def call(env, entity, item, price): - #Do not process exchange actions on death tick - if not entity.alive: - return + def enabled(config): + return config.EXCHANGE_SYSTEM_ENABLED - # TODO: Find a better way to check this - # Should only occur when item is used on same tick - # Otherwise should not be possible - if item not in entity.inventory: - return + def call(realm, entity, item, price): + if item is None or item.owner_id.val != entity.ent_id or price is None: + return - if type(price) != int: - price = price.val + assert entity.alive, "Dead entity cannot act" + assert entity.is_player, "Npcs cannot sell an item" + assert item.quantity.val > 0, "Item quantity cannot be 0" # indicates item leak - return env.exchange.sell(env, entity, item, price) + if not realm.config.EXCHANGE_SYSTEM_ENABLED: + return + + if item not in entity.inventory: + return + + if entity.in_combat: # player cannot sell item during combat + return + + # cannot sell the equipped or listed item + if item.equipped.val or item.listed_price.val: + return + + if not isinstance(price, int): + price = price.val + + if not price > 0: + return + + realm.exchange.sell(entity, item, price, realm.tick) def init_discrete(values): - classes = [] - for i in values: - name = f'Discrete_{i}' - cls = type(name, (object,), {'val': i}) - classes.append(cls) - return classes + classes = [] + for i in values: + name = f'Discrete_{i}' + cls = type(name, (object,), {'val': i}) + classes.append(cls) + + return classes class Price(Node): - argType = Fixed + argType = Fixed - @classmethod - def init(cls, config): - Price.classes = init_discrete(list(range(100))) + @classmethod + def init(cls, config): + # gold should be > 0 + Price.classes = init_discrete(range(1, config.PRICE_N_OBS+1)) - @staticproperty - def edges(): - return Price.classes + @staticproperty + def edges(): + return Price.classes + + def args(stim, entity, config): + return Price.edges + + def deserialize(realm, entity, index): + return deserialize_fixed_arg(Price, index) - def args(stim, entity, config): - return Price.edges class Token(Node): - argType = Fixed + argType = Fixed + + @classmethod + def init(cls, config): + Token.classes = init_discrete(range(config.COMMUNICATION_NUM_TOKENS)) - @classmethod - def init(cls, config): - Comm.classes = init_discrete(range(config.COMMUNICATION_NUM_TOKENS)) + @staticproperty + def edges(): + return Token.classes - @staticproperty - def edges(): - return Comm.classes + def args(stim, entity, config): + return Token.edges + + def deserialize(realm, entity, index): + return deserialize_fixed_arg(Token, index) - def args(stim, entity, config): - return Comm.edges class Comm(Node): - argType = Fixed - priority = 0 + argType = Fixed + priority = 99 + + @staticproperty + def edges(): + return [Token] + + def enabled(config): + return config.COMMUNICATION_SYSTEM_ENABLED - @staticproperty - def edges(): - return [Token] + def call(realm, entity, token): + if token is None: + return - def call(env, entity, token): - entity.base.comm.update(token.val) + entity.message.update(token.val) #TODO: Solve AGI class BecomeSkynet: - pass + pass diff --git a/nmmo/io/stimulus.py b/nmmo/io/stimulus.py deleted file mode 100644 index ca0e03e3b..000000000 --- a/nmmo/io/stimulus.py +++ /dev/null @@ -1,467 +0,0 @@ -from pdb import set_trace as T -import numpy as np - -from nmmo.lib import utils - -class SerializedVariable: - CONTINUOUS = False - DISCRETE = False - def __init__(self, dataframe, key, val=None, config=None): - if config is None: - config = dataframe.config - - self.obj = str(self.__class__).split('.')[-2] - self.attr = self.__class__.__name__ - self.key = key - - self.min = 0 - self.max = np.inf - self.val = val - - self.dataframe = dataframe - self.init(config) - err = 'Must set a default val upon instantiation or init()' - assert self.val is not None, err - - #Update dataframe - if dataframe is not None: - self.update(self.val) - - #Defined for cleaner stim files - def init(self): - pass - - def packet(self): - return { - 'val': self.val, - 'max': self.max} - - def update(self, val): - self.val = min(max(val, self.min), self.max) - self.dataframe.update(self, self.val) - return self - - def increment(self, val=1): - self.update(self.val + val) - return self - - def decrement(self, val=1): - self.update(self.val - val) - return self - - @property - def empty(self): - return self.val == 0 - - def __add__(self, other): - self.increment(other) - return self - - def __sub__(self, other): - self.decrement(other) - return self - - def __eq__(self, other): - return self.val == other - - def __ne__(self, other): - return self.val != other - - def __lt__(self, other): - return self.val < other - - def __le__(self, other): - return self.val <= other - - def __gt__(self, other): - return self.val > other - - def __ge__(self, other): - return self.val >= other - -class Continuous(SerializedVariable): - CONTINUOUS = True - -class Discrete(Continuous): - DISCRETE = True - - -class Serialized(metaclass=utils.IterableNameComparable): - def dict(): - return {k[0] : v for k, v in dict(Stimulus).items()} - - class Entity(metaclass=utils.IterableNameComparable): - @staticmethod - def enabled(config): - return True - - @staticmethod - def N(config): - return config.PLAYER_N_OBS - - class Self(Discrete): - def init(self, config): - self.max = 1 - self.scale = 1.0 - - class ID(Continuous): - def init(self, config): - self.min = -np.inf - self.scale = 0.001 - - class AttackerID(Continuous): - def init(self, config): - self.min = -np.inf - self.scale = 0.001 - - class Level(Continuous): - def init(self, config): - self.scale = 0.05 - - class ItemLevel(Continuous): - def init(self, config): - self.scale = 0.025 - self.max = 5 * config.NPC_LEVEL_MAX - - class Comm(Discrete): - def init(self, config): - self.scale = 0.025 - self.max = 1 - if config.COMMUNICATION_SYSTEM_ENABLED: - self.max = config.COMMUNICATION_NUM_TOKENS - - class Population(Discrete): - def init(self, config): - self.min = -3 #NPC index - self.max = config.PLAYER_POLICIES - 1 - self.scale = 1.0 - - class R(Discrete): - def init(self, config): - self.min = 0 - self.max = config.MAP_SIZE - 1 - self.scale = 0.15 - - class C(Discrete): - def init(self, config): - self.min = 0 - self.max = config.MAP_SIZE - 1 - self.scale = 0.15 - - # Historical stats - class Damage(Continuous): - def init(self, config): - #This scale may eventually be too high - self.val = 0 - self.scale = 0.1 - - class TimeAlive(Continuous): - def init(self, config): - self.val = 0 - self.scale = 0.01 - - # Status effects - class Freeze(Continuous): - def init(self, config): - self.val = 0 - self.max = 3 - self.scale = 0.3 - - class Gold(Continuous): - def init(self, config): - self.val = 0 - self.scale = 0.01 - - # Resources -- Redo the max/min scaling. You can't change these - # after init without messing up the embeddings - class Health(Continuous): - def init(self, config): - self.val = config.PLAYER_BASE_HEALTH - self.max = config.PLAYER_BASE_HEALTH - self.scale = 0.1 - - class Food(Continuous): - def init(self, config): - if config.RESOURCE_SYSTEM_ENABLED: - self.val = config.RESOURCE_BASE - self.max = config.RESOURCE_BASE - else: - self.val = 1 - self.max = 1 - - self.scale = 0.01 - - class Water(Continuous): - def init(self, config): - if config.RESOURCE_SYSTEM_ENABLED: - self.val = config.RESOURCE_BASE - self.max = config.RESOURCE_BASE - else: - self.val = 1 - self.max = 1 - - self.scale = 0.01 - - class Melee(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Range(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Mage(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Fishing(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Herbalism(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Prospecting(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Carving(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Alchemy(Continuous): - def init(self, config): - self.val = 1 - self.max = 1 - if config.PROGRESSION_SYSTEM_ENABLED: - self.max = config.PROGRESSION_LEVEL_MAX - - class Tile(metaclass=utils.IterableNameComparable): - @staticmethod - def enabled(config): - return True - - @staticmethod - def N(config): - return config.MAP_N_OBS - - class NEnts(Continuous): - def init(self, config): - self.max = config.PLAYER_N - self.val = 0 - self.scale = 1.0 - - class Index(Discrete): - def init(self, config): - self.max = config.MAP_N_TILE - self.scale = 0.15 - - class R(Discrete): - def init(self, config): - self.max = config.MAP_SIZE - 1 - self.scale = 0.15 - - class C(Discrete): - def init(self, config): - self.max = config.MAP_SIZE - 1 - self.scale = 0.15 - - class Item(metaclass=utils.IterableNameComparable): - @staticmethod - def enabled(config): - return config.ITEM_SYSTEM_ENABLED - - @staticmethod - def N(config): - return config.ITEM_N_OBS - - class ID(Continuous): - def init(self, config): - self.scale = 0.001 - - class Index(Discrete): - def init(self, config): - self.max = config.ITEM_N + 1 - self.scale = 1.0 / self.max - - class Level(Continuous): - def init(self, config): - self.max = 99 - self.scale = 1.0 / self.max - - class Capacity(Continuous): - def init(self, config): - self.max = 99 - self.scale = 1.0 / self.max - - class Quantity(Continuous): - def init(self, config): - self.max = 99 - self.scale = 1.0 / self.max - - class Tradable(Discrete): - def init(self, config): - self.max = 1 - self.scale = 1.0 - - class MeleeAttack(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class RangeAttack(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class MageAttack(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class MeleeDefense(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class RangeDefense(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class MageDefense(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class HealthRestore(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class ResourceRestore(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class Price(Continuous): - def init(self, config): - self.scale = 0.01 - - class Equipped(Discrete): - def init(self, config): - self.scale = 1.0 - - # TODO: Figure out how to autogen this from Items - class Market(metaclass=utils.IterableNameComparable): - @staticmethod - def enabled(config): - return config.EXCHANGE_SYSTEM_ENABLED - - @staticmethod - def N(config): - return config.EXCHANGE_N_OBS - - class ID(Continuous): - def init(self, config): - self.scale = 0.001 - - class Index(Discrete): - def init(self, config): - self.max = config.ITEM_N + 1 - self.scale = 1.0 / self.max - - class Level(Continuous): - def init(self, config): - self.max = 99 - self.scale = 1.0 / self.max - - class Capacity(Continuous): - def init(self, config): - self.max = 99 - self.scale = 1.0 / self.max - - class Quantity(Continuous): - def init(self, config): - self.max = 99 - self.scale = 1.0 / self.max - - class Tradable(Discrete): - def init(self, config): - self.max = 1 - self.scale = 1.0 - - class MeleeAttack(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class RangeAttack(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class MageAttack(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class MeleeDefense(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class RangeDefense(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class MageDefense(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class HealthRestore(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class ResourceRestore(Continuous): - def init(self, config): - self.max = 100 - self.scale = 1.0 / self.max - - class Price(Continuous): - def init(self, config): - self.scale = 0.01 - - class Equipped(Discrete): - def init(self, config): - self.scale = 1.0 - - -for objName, obj in Serialized: - for idx, (attrName, attr) in enumerate(obj): - attr.index = idx diff --git a/nmmo/lib/colors.py b/nmmo/lib/colors.py index 2db948899..37d4188ba 100644 --- a/nmmo/lib/colors.py +++ b/nmmo/lib/colors.py @@ -1,8 +1,9 @@ +# pylint: disable=all + #Various Enums used for handling materials, entity types, etc. #Data texture pairs are used for enums that require textures. #These textures are filled in by the Render class at run time. -from pdb import set_trace as T import numpy as np import colorsys @@ -58,7 +59,7 @@ class Tier: GOLD = Color('GOLD', '#ffae00') PLATINUM = Color('PLATINUM', '#cd75ff') DIAMOND = Color('DIAMOND', '#00bbbb') - + class Swatch: def colors(): '''Return list of swatch colors''' @@ -66,7 +67,7 @@ def colors(): def rand(): '''Return random swatch color''' - all_colors = colors() + all_colors = Swatch.colors() randInd = np.random.randint(0, len(all_colors)) return all_colors[randInd] @@ -87,7 +88,7 @@ class Neon(Swatch): FUCHSIA = Color('FUCHSIA', '#ff0080') SPRING = Color('SPRING', '#80ff80') SKY = Color('SKY', '#0080ff') - + WHITE = Color('WHITE', '#ffffff') GRAY = Color('GRAY', '#666666') BLACK = Color('BLACK', '#000000') diff --git a/nmmo/lib/event_log.py b/nmmo/lib/event_log.py new file mode 100644 index 000000000..b054ca458 --- /dev/null +++ b/nmmo/lib/event_log.py @@ -0,0 +1,169 @@ +from types import SimpleNamespace +from typing import List +from copy import deepcopy + +import numpy as np + +from nmmo.datastore.serialized import SerializedState +from nmmo.entity import Entity +from nmmo.systems.item import Item +from nmmo.lib.log import EventCode + +# pylint: disable=no-member +EventState = SerializedState.subclass("Event", [ + "id", # unique event id + "ent_id", + "tick", + + "event", + + "type", + "level", + "number", + "gold", + "target_ent", +]) + +EventAttr = EventState.State.attr_name_to_col + +EventState.Query = SimpleNamespace( + table=lambda ds: ds.table("Event").where_neq(EventAttr["id"], 0), + + by_event=lambda ds, event_code: ds.table("Event").where_eq( + EventAttr["event"], event_code), +) + +# defining col synoyms for different event types +ATTACK_COL_MAP = { + 'combat_style': EventAttr['type'], + 'damage': EventAttr['number'] } + +ITEM_COL_MAP = { + 'item_type': EventAttr['type'], + 'quantity': EventAttr['number'], + 'price': EventAttr['gold'] } + +LEVEL_COL_MAP = { 'skill': EventAttr['type'] } + +EXPLORE_COL_MAP = { 'distance': EventAttr['number'] } + + +class EventLogger(EventCode): + def __init__(self, realm): + self.realm = realm + self.config = realm.config + self.datastore = realm.datastore + + self.valid_events = { val: evt for evt, val in EventCode.__dict__.items() + if isinstance(val, int) } + + # add synonyms to the attributes + self.attr_to_col = deepcopy(EventAttr) + self.attr_to_col.update(ATTACK_COL_MAP) + self.attr_to_col.update(ITEM_COL_MAP) + self.attr_to_col.update(LEVEL_COL_MAP) + self.attr_to_col.update(EXPLORE_COL_MAP) + + def reset(self): + EventState.State.table(self.datastore).reset() + + # define event logging + def _create_event(self, entity: Entity, event_code: int): + log = EventState(self.datastore) + log.id.update(log.datastore_record.id) + log.ent_id.update(entity.ent_id) + # the tick increase by 1 after executing all actions + log.tick.update(self.realm.tick+1) + log.event.update(event_code) + + return log + + def record(self, event_code: int, entity: Entity, **kwargs): + if event_code in [EventCode.EAT_FOOD, EventCode.DRINK_WATER, + EventCode.GIVE_ITEM, EventCode.DESTROY_ITEM, + EventCode.GIVE_GOLD]: + # Logs for these events are for counting only + self._create_event(entity, event_code) + return + + if event_code == EventCode.GO_FARTHEST: # use EXPLORE_COL_MAP + if ('distance' in kwargs and kwargs['distance'] > 0): + log = self._create_event(entity, event_code) + log.number.update(kwargs['distance']) + return + + if event_code == EventCode.SCORE_HIT: + # kwargs['combat_style'] should be Skill.CombatSkill + if ('combat_style' in kwargs and kwargs['combat_style'].SKILL_ID in [1, 2, 3]) & \ + ('damage' in kwargs and kwargs['damage'] >= 0): + log = self._create_event(entity, event_code) + log.type.update(kwargs['combat_style'].SKILL_ID) + log.number.update(kwargs['damage']) + return + + if event_code == EventCode.PLAYER_KILL: + if ('target' in kwargs and isinstance(kwargs['target'], Entity)): + target = kwargs['target'] + log = self._create_event(entity, event_code) + log.target_ent.update(target.ent_id) + + # CHECK ME: attack_level or "general" level?? need to clarify + log.level.update(target.attack_level) + return + + if event_code in [EventCode.CONSUME_ITEM, EventCode.HARVEST_ITEM, EventCode.EQUIP_ITEM]: + # CHECK ME: item types should be checked. For example, + # Only Ration and Poultice can be consumed + # Only Ration, Poultice, Scrap, Shaving, Shard can be produced + # The quantity should be 1 for all of these events + if ('item' in kwargs and isinstance(kwargs['item'], Item)): + item = kwargs['item'] + log = self._create_event(entity, event_code) + log.type.update(item.ITEM_TYPE_ID) + log.level.update(item.level.val) + log.number.update(item.quantity.val) + return + + if event_code in [EventCode.LIST_ITEM, EventCode.BUY_ITEM]: + if ('item' in kwargs and isinstance(kwargs['item'], Item)) & \ + ('price' in kwargs and kwargs['price'] > 0): + item = kwargs['item'] + log = self._create_event(entity, event_code) + log.type.update(item.ITEM_TYPE_ID) + log.level.update(item.level.val) + log.number.update(item.quantity.val) + log.gold.update(kwargs['price']) + return + + if event_code == EventCode.EARN_GOLD: + if ('amount' in kwargs and kwargs['amount'] > 0): + log = self._create_event(entity, event_code) + log.gold.update(kwargs['amount']) + return + + if event_code == EventCode.LEVEL_UP: + # kwargs['skill'] should be Skill.Skill + if ('skill' in kwargs and kwargs['skill'].SKILL_ID in range(1,9)) & \ + ('level' in kwargs and kwargs['level'] >= 0): + log = self._create_event(entity, event_code) + log.type.update(kwargs['skill'].SKILL_ID) + log.level.update(kwargs['level']) + return + + # If reached here, then something is wrong + # CHECK ME: The below should be commented out after debugging + raise ValueError(f"Event code: {event_code}", kwargs) + + def get_data(self, event_code=None, agents: List[int]=None): + if event_code is None: + event_data = EventState.Query.table(self.datastore).astype(np.int32) + elif event_code in self.valid_events: + event_data = EventState.Query.by_event(self.datastore, event_code).astype(np.int32) + else: + return None + + if agents: + flt_idx = np.in1d(event_data[:, EventAttr['ent_id']], agents) + return event_data[flt_idx] + + return event_data diff --git a/nmmo/lib/log.py b/nmmo/lib/log.py index 463b6cbb5..6ee72296e 100644 --- a/nmmo/lib/log.py +++ b/nmmo/lib/log.py @@ -1,107 +1,66 @@ -from pdb import set_trace as T from collections import defaultdict -from nmmo.lib import material -from copy import deepcopy -import os import logging -import numpy as np -import json, pickle -import time -from nmmo.lib import utils class Logger: - def __init__(self): - self.stats = defaultdict(list) - - def log(self, key, val): - try: - int_val = int(val) - except TypeError as e: - print(f'{val} must be int or float') - raise e - self.stats[key].append(val) - return True + def __init__(self): + self.stats = defaultdict(list) -class MilestoneLogger(Logger): - def __init__(self, log_file): - super().__init__() - logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO, filename=log_file, filemode='w') - - def log_min(self, key, val): - if key in self.stats and val >= self.stats[key][-1]: - return False - - self.log(key, val) - return True - - def log_max(self, key, val): - if key in self.stats and val <= self.stats[key][-1]: - return False - - self.log(key, val) - return True - -class Quill: - def __init__(self, config): - self.config = config - - self.env = Logger() - self.player = Logger() - self.event = Logger() - - self.shared = {} - - if config.LOG_MILESTONES: - self.milestone = MilestoneLogger(config.LOG_FILE) - - def register(self, key, fn): - assert key not in self.shared, f'Log key {key} already exists' - self.shared[key] = fn - - def log_env(self, key, val): - self.env.log(key, val) - - def log_player(self, key, val): - self.player.log(key, val) - - @property - def packet(self): - packet = {'Env': self.env.stats, - 'Player': self.player.stats} - - if self.config.LOG_EVENTS: - packet['Event'] = self.event.stats - else: - packet['Event'] = 'Unavailable: config.LOG_EVENTS = False' - - if self.config.LOG_MILESTONES: - packet['Milestone'] = self.event.stats - else: - packet['Milestone'] = 'Unavailable: config.LOG_MILESTONES = False' - - return packet - -#Log wrapper and benchmarker -class Benchmarker: - def __init__(self, logdir): - self.benchmarks = {} - - def wrap(self, func): - self.benchmarks[func] = Utils.BenchmarkTimer() - def wrapped(*args): - self.benchmarks[func].startRecord() - ret = func(*args) - self.benchmarks[func].stopRecord() - return ret - return wrapped - - def bench(self, tick): - if tick % 100 == 0: - for k, benchmark in self.benchmarks.items(): - bench = benchmark.benchmark() - print(k.__func__.__name__, 'Tick: ', tick, - ', Benchmark: ', bench, ', FPS: ', 1/bench) - + def log(self, key, val): + if not isinstance(val, (int, float)): + raise RuntimeError(f'{val} must be int or float') + + self.stats[key].append(val) + return True +class MilestoneLogger(Logger): + def __init__(self, log_file): + super().__init__() + logging.basicConfig(format='%(levelname)s:%(message)s', + level=logging.INFO, filename=log_file, filemode='w') + + def log_min(self, key, val): + if key in self.stats and val >= self.stats[key][-1]: + return False + + self.log(key, val) + return True + + def log_max(self, key, val): + if key in self.stats and val <= self.stats[key][-1]: + return False + + self.log(key, val) + return True + + +# CHECK ME: Is this a good place to put here? +# EventCode is used in many places, and I(kywch)'m putting it here +# to avoid a circular import, which happened a few times with event_log.py +class EventCode: + # Move + EAT_FOOD = 1 + DRINK_WATER = 2 + GO_FARTHEST = 3 # record when breaking the previous record + + # Attack + SCORE_HIT = 11 + PLAYER_KILL = 12 + + # Item + CONSUME_ITEM = 21 + GIVE_ITEM = 22 + DESTROY_ITEM = 23 + HARVEST_ITEM = 24 + EQUIP_ITEM = 25 + + # Exchange + GIVE_GOLD = 31 + LIST_ITEM = 32 + EARN_GOLD = 33 + BUY_ITEM = 34 + #SPEND_GOLD = 35 # BUY_ITEM, price has the same info + + # Level up + LEVEL_UP = 41 diff --git a/nmmo/lib/material.py b/nmmo/lib/material.py index 97afda716..eadee2f65 100644 --- a/nmmo/lib/material.py +++ b/nmmo/lib/material.py @@ -1,200 +1,205 @@ -from pdb import set_trace as T from nmmo.systems import item, droptable class Material: - capacity = 0 - tool = None - table = None + capacity = 0 + tool = None + table = None + index = None + respawn = 0 - def __init__(self, config): - pass + def __init__(self, config): + pass - def __eq__(self, mtl): - return self.index == mtl.index + def __eq__(self, mtl): + return self.index == mtl.index - def __equals__(self, mtl): - return self == mtl + def __equals__(self, mtl): + return self == mtl - def harvest(self): - return self.__class__.table + def harvest(self): + return self.__class__.table class Lava(Material): - tex = 'lava' - index = 0 + tex = 'lava' + index = 0 class Water(Material): - tex = 'water' - index = 1 + tex = 'water' + index = 1 - table = droptable.Empty() + table = droptable.Empty() - def __init__(self, config): - self.deplete = __class__ - self.respawn = 1.0 + def __init__(self, config): + self.deplete = __class__ + self.respawn = 1.0 class Grass(Material): - tex = 'grass' - index = 2 + tex = 'grass' + index = 2 class Scrub(Material): - tex = 'scrub' - index = 3 + tex = 'scrub' + index = 3 class Forest(Material): - tex = 'forest' - index = 4 + tex = 'forest' + index = 4 - deplete = Scrub - table = droptable.Empty() + deplete = Scrub + table = droptable.Empty() - def __init__(self, config): - if config.RESOURCE_SYSTEM_ENABLED: - self.capacity = config.RESOURCE_FOREST_CAPACITY - self.respawn = config.RESOURCE_FOREST_RESPAWN + def __init__(self, config): + if config.RESOURCE_SYSTEM_ENABLED: + self.capacity = config.RESOURCE_FOREST_CAPACITY + self.respawn = config.RESOURCE_FOREST_RESPAWN class Stone(Material): - tex = 'stone' - index = 5 + tex = 'stone' + index = 5 class Slag(Material): - tex = 'slag' - index = 6 + tex = 'slag' + index = 6 class Ore(Material): - tex = 'ore' - index = 7 + tex = 'ore' + index = 7 - deplete = Slag - tool = item.Pickaxe + deplete = Slag + tool = item.Pickaxe - def __init__(self, config): - cls = self.__class__ - if cls.table is None: - cls.table = droptable.Standard() - cls.table.add(item.Scrap) + def __init__(self, config): + cls = self.__class__ + if cls.table is None: + cls.table = droptable.Standard() + cls.table.add(item.Scrap) - if config.EQUIPMENT_SYSTEM_ENABLED: - cls.table.add(item.Wand, prob=config.WEAPON_DROP_PROB) + if config.EQUIPMENT_SYSTEM_ENABLED: + cls.table.add(item.Wand, prob=config.WEAPON_DROP_PROB) - self.capacity = config.PROFESSION_ORE_CAPACITY - self.respawn = config.PROFESSION_ORE_RESPAWN + if config.PROFESSION_SYSTEM_ENABLED: + self.capacity = config.PROFESSION_ORE_CAPACITY + self.respawn = config.PROFESSION_ORE_CAPACITY - tool = item.Pickaxe - deplete = Slag + tool = item.Pickaxe + deplete = Slag class Stump(Material): - tex = 'stump' - index = 8 + tex = 'stump' + index = 8 class Tree(Material): - tex = 'tree' - index = 9 + tex = 'tree' + index = 9 - deplete = Stump - tool = item.Chisel + deplete = Stump + tool = item.Chisel - def __init__(self, config): - cls = self.__class__ - if cls.table is None: - cls.table = droptable.Standard() - cls.table.add(item.Shaving) - if config.EQUIPMENT_SYSTEM_ENABLED: - cls.table.add(item.Sword, prob=config.WEAPON_DROP_PROB) + def __init__(self, config): + cls = self.__class__ + if cls.table is None: + cls.table = droptable.Standard() + cls.table.add(item.Shaving) + if config.EQUIPMENT_SYSTEM_ENABLED: + cls.table.add(item.Sword, prob=config.WEAPON_DROP_PROB) + if config.PROFESSION_SYSTEM_ENABLED: self.capacity = config.PROFESSION_TREE_CAPACITY self.respawn = config.PROFESSION_TREE_RESPAWN class Fragment(Material): - tex = 'fragment' - index = 10 + tex = 'fragment' + index = 10 class Crystal(Material): - tex = 'crystal' - index = 11 + tex = 'crystal' + index = 11 - deplete = Fragment - tool = item.Arcane + deplete = Fragment + tool = item.Arcane - def __init__(self, config): - cls = self.__class__ - if cls.table is None: - cls.table = droptable.Standard() - cls.table.add(item.Shard) - if config.EQUIPMENT_SYSTEM_ENABLED: - cls.table.add(item.Bow, prob=config.WEAPON_DROP_PROB) + def __init__(self, config): + cls = self.__class__ + if cls.table is None: + cls.table = droptable.Standard() + cls.table.add(item.Shard) + if config.EQUIPMENT_SYSTEM_ENABLED: + cls.table.add(item.Bow, prob=config.WEAPON_DROP_PROB) - if config.RESOURCE_SYSTEM_ENABLED: - self.capacity = config.PROFESSION_CRYSTAL_CAPACITY - self.respawn = config.PROFESSION_CRYSTAL_RESPAWN + if config.PROFESSION_SYSTEM_ENABLED: + self.capacity = config.PROFESSION_CRYSTAL_CAPACITY + self.respawn = config.PROFESSION_CRYSTAL_RESPAWN class Weeds(Material): - tex = 'weeds' - index = 12 + tex = 'weeds' + index = 12 class Herb(Material): - tex = 'herb' - index = 13 + tex = 'herb' + index = 13 - deplete = Weeds - tool = item.Gloves + deplete = Weeds + tool = item.Gloves - table = droptable.Standard() - table.add(item.Poultice) + table = droptable.Standard() + table.add(item.Poultice) - def __init__(self, config): - if config.RESOURCE_SYSTEM_ENABLED: - self.capacity = config.PROFESSION_HERB_CAPACITY - self.respawn = config.PROFESSION_HERB_RESPAWN + def __init__(self, config): + if config.PROFESSION_SYSTEM_ENABLED: + self.capacity = config.PROFESSION_HERB_CAPACITY + self.respawn = config.PROFESSION_HERB_RESPAWN class Ocean(Material): - tex = 'ocean' - index = 14 + tex = 'ocean' + index = 14 class Fish(Material): - tex = 'fish' - index = 15 + tex = 'fish' + index = 15 - deplete = Ocean - tool = item.Rod + deplete = Ocean + tool = item.Rod - table = droptable.Standard() - table.add(item.Ration) + table = droptable.Standard() + table.add(item.Ration) - def __init__(self, config): - if config.RESOURCE_SYSTEM_ENABLED: - self.capacity = config.PROFESSION_FISH_CAPACITY - self.respawn = config.PROFESSION_FISH_RESPAWN + def __init__(self, config): + if config.PROFESSION_SYSTEM_ENABLED: + self.capacity = config.PROFESSION_FISH_CAPACITY + self.respawn = config.PROFESSION_FISH_RESPAWN +# TODO: Fix lint errors +# pylint: disable=all class Meta(type): - def __init__(self, name, bases, dict): - self.indices = {mtl.index for mtl in self.materials} + def __init__(self, name, bases, dict): + self.indices = {mtl.index for mtl in self.materials} - def __iter__(self): - yield from self.materials + def __iter__(self): + yield from self.materials - def __contains__(self, mtl): - if isinstance(mtl, Material): - mtl = type(mtl) - if isinstance(mtl, type): - return mtl in self.materials - return mtl in self.indices + def __contains__(self, mtl): + if isinstance(mtl, Material): + mtl = type(mtl) + if isinstance(mtl, type): + return mtl in self.materials + return mtl in self.indices class All(metaclass=Meta): - '''List of all materials''' - materials = { - Lava, Water, Grass, Scrub, Forest, - Stone, Slag, Ore, Stump, Tree, - Fragment, Crystal, Weeds, Herb, Ocean, Fish} + '''List of all materials''' + materials = { + Lava, Water, Grass, Scrub, Forest, + Stone, Slag, Ore, Stump, Tree, + Fragment, Crystal, Weeds, Herb, Ocean, Fish} class Impassible(metaclass=Meta): - '''Materials that agents cannot walk through''' - materials = {Lava, Water, Stone, Ocean, Fish} + '''Materials that agents cannot walk through''' + materials = {Lava, Water, Stone, Ocean, Fish} class Habitable(metaclass=Meta): - '''Materials that agents cannot walk on''' - materials = {Grass, Scrub, Forest, Ore, Slag, Tree, Stump, Crystal, Fragment, Herb, Weeds} + '''Materials that agents cannot walk on''' + materials = {Grass, Scrub, Forest, Ore, Slag, Tree, Stump, Crystal, Fragment, Herb, Weeds} class Harvestable(metaclass=Meta): - '''Materials that agents can harvest''' - materials = {Water, Forest, Ore, Tree, Crystal, Herb, Fish} + '''Materials that agents can harvest''' + materials = {Water, Forest, Ore, Tree, Crystal, Herb, Fish} diff --git a/nmmo/lib/overlay.py b/nmmo/lib/overlay.py deleted file mode 100644 index eefbf5b54..000000000 --- a/nmmo/lib/overlay.py +++ /dev/null @@ -1,58 +0,0 @@ -from pdb import set_trace as T -import numpy as np -from scipy import signal - -def norm(ary, nStd=2): - assert type(ary) == np.ndarray, 'ary must be of type np.ndarray' - R, C = ary.shape - preprocessed = np.zeros_like(ary) - nonzero = ary[ary!= 0] - mean = np.mean(nonzero) - std = np.std(nonzero) - if std == 0: - std = 1 - for r in range(R): - for c in range(C): - val = ary[r, c] - if val != 0: - val = (val - mean) / (nStd * std) - val = np.clip(val+1, 0, 2)/2 - preprocessed[r, c] = val - return preprocessed - -def clip(ary): - assert type(ary) == np.ndarray, 'ary must be of type np.ndarray' - R, C = ary.shape - preprocessed = np.zeros_like(ary) - nonzero = ary[ary!= 0] - mmin = np.min(nonzero) - mmag = np.max(nonzero) - mmin - for r in range(R): - for c in range(C): - val = ary[r, c] - val = (val - mmin) / mmag - preprocessed[r, c] = val - return preprocessed - -def twoTone(ary, nStd=2, preprocess='norm', invert=False, periods=1): - assert preprocess in 'norm clip none'.split() - if preprocess == 'norm': - ary = norm(ary, nStd) - elif preprocess == 'clip': - ary = clip(ary) - - R, C = ary.shape - - colorized = np.zeros((R, C, 3)) - if periods != 1: - ary = np.abs(signal.sawtooth(periods*3.14159*ary)) - if invert: - colorized[:, :, 0] = ary - colorized[:, :, 1] = 1-ary - else: - colorized[:, :, 0] = 1-ary - colorized[:, :, 1] = ary - - colorized *= (ary != 0)[:, :, None] - - return colorized diff --git a/nmmo/lib/priorityqueue.py b/nmmo/lib/priorityqueue.py index 35e86a637..7d3d0e3be 100644 --- a/nmmo/lib/priorityqueue.py +++ b/nmmo/lib/priorityqueue.py @@ -1,3 +1,5 @@ +# pylint: disable=all + import heapq, itertools import itertools diff --git a/nmmo/lib/rating.py b/nmmo/lib/rating.py index 15032c155..438fb92d9 100644 --- a/nmmo/lib/rating.py +++ b/nmmo/lib/rating.py @@ -1,4 +1,4 @@ -from pdb import set_trace as T +# pylint: disable=all from collections import defaultdict import numpy as np @@ -14,7 +14,7 @@ def rank(policy_ids, scores): # Double argsort returns ranks return np.argsort(np.argsort( - [-np.mean(vals) + 1e-8 * np.random.normal() for policy, vals in + [-np.mean(vals) + 1e-8 * np.random.normal() for policy, vals in sorted(agents.items())])).tolist() diff --git a/nmmo/lib/spawn.py b/nmmo/lib/spawn.py index 2e257eb92..e834e8910 100644 --- a/nmmo/lib/spawn.py +++ b/nmmo/lib/spawn.py @@ -1,156 +1,101 @@ -from pdb import set_trace as T import numpy as np - class SequentialLoader: - '''config.PLAYER_LOADER that spreads out agent populations''' - def __init__(self, config): - items = config.PLAYERS - for idx, itm in enumerate(items): - itm.policyID = idx - - self.items = items - self.idx = -1 - - def __iter__(self): - return self - - def __next__(self): - self.idx = (self.idx + 1) % len(self.items) - return self.idx, self.items[self.idx] - -class TeamLoader: - '''config.PLAYER_LOADER that loads agent populations adjacent''' - def __init__(self, config): - items = config.PLAYERS - self.team_size = config.PLAYER_N // len(items) - - for idx, itm in enumerate(items): - itm.policyID = idx - - self.items = items - self.idx = -1 + '''config.PLAYER_LOADER that spreads out agent populations''' + def __init__(self, config): + items = config.PLAYERS - def __iter__(self): - return self + self.items = items + self.idx = -1 - def __next__(self): - self.idx += 1 - team_idx = self.idx // self.team_size - return team_idx, self.items[team_idx] + def __iter__(self): + return self + def __next__(self): + self.idx = (self.idx + 1) % len(self.items) + return self.items[self.idx] def spawn_continuous(config): - '''Generates spawn positions for new agents - - Randomly selects spawn positions around - the borders of the square game map - - Returns: - tuple(int, int): - - position: - The position (row, col) to spawn the given agent - ''' - #Spawn at edges - mmax = config.MAP_CENTER + config.MAP_BORDER - mmin = config.MAP_BORDER - - var = np.random.randint(mmin, mmax) - fixed = np.random.choice([mmin, mmax]) - r, c = int(var), int(fixed) - if np.random.rand() > 0.5: - r, c = c, r - return (r, c) - -def old_spawn_concurrent(config): - '''Generates spawn positions for new agents - - Evenly spaces agents around the borders - of the square game map + '''Generates spawn positions for new agents - Returns: - tuple(int, int): + Randomly selects spawn positions around + the borders of the square game map - position: - The position (row, col) to spawn the given agent - ''' - - left = config.MAP_BORDER - right = config.MAP_CENTER + config.MAP_BORDER - rrange = np.arange(left+2, right, 4).tolist() + Returns: + tuple(int, int): - assert not config.MAP_CENTER % 4 - per_side = config.MAP_CENTER // 4 - - lows = (left+np.zeros(per_side, dtype=np.int)).tolist() - highs = (right+np.zeros(per_side, dtype=np.int)).tolist() + position: + The position (row, col) to spawn the given agent + ''' + #Spawn at edges + mmax = config.MAP_CENTER + config.MAP_BORDER + mmin = config.MAP_BORDER - s1 = list(zip(rrange, lows)) - s2 = list(zip(lows, rrange)) - s3 = list(zip(rrange, highs)) - s4 = list(zip(highs, rrange)) + var = np.random.randint(mmin, mmax) + fixed = np.random.choice([mmin, mmax]) + r, c = int(var), int(fixed) + if np.random.rand() > 0.5: + r, c = c, r + return (r, c) - ret = s1 + s2 + s3 + s4 - - # Shuffle needs porting to competition version - np.random.shuffle(ret) - - return ret def spawn_concurrent(config): - '''Generates spawn positions for new agents - - Evenly spaces agents around the borders - of the square game map - - Returns: - tuple(int, int): - - position: - The position (row, col) to spawn the given agent - ''' - team_size = config.PLAYER_TEAM_SIZE - team_n = len(config.PLAYERS) - teammate_sep = config.PLAYER_SPAWN_TEAMMATE_DISTANCE - - # Number of total border tiles - total_tiles = 4 * config.MAP_CENTER - - # Number of tiles, including within-team sep, occupied by each team - tiles_per_team = teammate_sep*(team_size-1) + team_size - - # Number of total tiles dedicated to separating teams - buffer_tiles = 0 - if team_n > 1: - buffer_tiles = total_tiles - tiles_per_team*team_n - - # Number of tiles between teams - team_sep = buffer_tiles // team_n - - # Accounts for lava borders in coord calcs - left = config.MAP_BORDER - right = config.MAP_CENTER + config.MAP_BORDER - lows = config.MAP_CENTER * [left] - highs = config.MAP_CENTER * [right] - inc = list(range(config.MAP_BORDER, config.MAP_CENTER+config.MAP_BORDER)) - - # All edge tiles in order - sides = [] - sides += list(zip(lows, inc)) - sides += list(zip(inc, highs)) - sides += list(zip(highs, inc[::-1])) - sides += list(zip(inc[::-1], lows)) - + '''Generates spawn positions for new agents + + Evenly spaces agents around the borders + of the square game map + + Returns: + tuple(int, int): + + position: + The position (row, col) to spawn the given agent + ''' + team_size = config.PLAYER_TEAM_SIZE + team_n = len(config.PLAYERS) + teammate_sep = config.PLAYER_SPAWN_TEAMMATE_DISTANCE + + # Number of total border tiles + total_tiles = 4 * config.MAP_CENTER + + # Number of tiles, including within-team sep, occupied by each team + tiles_per_team = teammate_sep*(team_size-1) + team_size + + # Number of total tiles dedicated to separating teams + buffer_tiles = 0 + if team_n > 1: + buffer_tiles = total_tiles - tiles_per_team*team_n + + # Number of tiles between teams + team_sep = buffer_tiles // team_n + + # Accounts for lava borders in coord calcs + left = config.MAP_BORDER + right = config.MAP_CENTER + config.MAP_BORDER + lows = config.MAP_CENTER * [left] + highs = config.MAP_CENTER * [right] + inc = list(range(config.MAP_BORDER, config.MAP_CENTER+config.MAP_BORDER)) + + # All edge tiles in order + sides = [] + sides += list(zip(lows, inc)) + sides += list(zip(inc, highs)) + sides += list(zip(highs, inc[::-1])) + sides += list(zip(inc[::-1], lows)) + np.random.shuffle(sides) + + if team_n > 1: # Space across and within teams spawn_positions = [] for idx in range(team_sep//2, len(sides), tiles_per_team+team_sep): - for offset in list(range(0, tiles_per_team, teammate_sep+1)): - if len(spawn_positions) >= config.PLAYER_N: - continue - - pos = sides[idx + offset] - spawn_positions.append(pos) + for offset in list(range(0, tiles_per_team, teammate_sep+1)): + if len(spawn_positions) >= config.PLAYER_N: + continue - return spawn_positions + pos = sides[idx + offset] + spawn_positions.append(pos) + else: + # team_n = 1: to fit 128 agents in a small map, ignore spacing and spawn randomly + spawn_positions = sides[:config.PLAYER_N] + return spawn_positions diff --git a/nmmo/lib/task.py b/nmmo/lib/task.py new file mode 100644 index 000000000..e25c64518 --- /dev/null +++ b/nmmo/lib/task.py @@ -0,0 +1,263 @@ +import json +import random +from typing import List + + +# pylint: disable=abstract-method, super-init-not-called + +class Task(): + def completed(self, realm, entity) -> bool: + raise NotImplementedError + + def description(self) -> List: + return self.__class__.__name__ + + def to_string(self) -> str: + return json.dumps(self.description()) + +############################################################### + +class TaskTarget(): + def __init__(self, name: str, agents: List[str]) -> None: + self._name = name + self._agents = agents + + def agents(self) -> List[int]: + return self._agents + + def description(self) -> List: + return self._name + + def member(self, member): + assert member < len(self._agents) + return TaskTarget(f"{self.description()}.{member}", [self._agents[member]]) + +class TargetTask(Task): + def __init__(self, target: TaskTarget) -> None: + self._target = target + + def description(self) -> List: + return [super().description(), self._target.description()] + + def completed(self, realm, entity) -> bool: + raise NotImplementedError + +############################################################### + +class TeamHelper(): + def __init__(self, agents: List[int], num_teams: int) -> None: + assert len(agents) % num_teams == 0 + self.team_size = len(agents) // num_teams + self._teams = [ + list(agents[i * self.team_size : (i+1) * self.team_size]) + for i in range(num_teams) + ] + self._agent_to_team = {a: tid for tid, t in enumerate(self._teams) for a in t} + + def own_team(self, agent_id: int) -> TaskTarget: + return TaskTarget("Team.Self", self._teams[self._agent_to_team[agent_id]]) + + def left_team(self, agent_id: int) -> TaskTarget: + return TaskTarget("Team.Left", self._teams[ + (self._agent_to_team[agent_id] -1) % len(self._teams) + ]) + + def right_team(self, agent_id: int) -> TaskTarget: + return TaskTarget("Team.Right", self._teams[ + (self._agent_to_team[agent_id] + 1) % len(self._teams) + ]) + + def all(self) -> TaskTarget: + return TaskTarget("All", list(self._agent_to_team.keys())) + +############################################################### + +class AND(Task): + def __init__(self, *tasks: Task) -> None: + super().__init__() + assert len(tasks) > 0 + self._tasks = tasks + + def completed(self, realm, entity) -> bool: + return all(t.completed(realm, entity) for t in self._tasks) + + def description(self) -> List: + return ["AND"] + [t.description() for t in self._tasks] + +class OR(Task): + def __init__(self, *tasks: Task) -> None: + super().__init__() + assert len(tasks) > 0 + self._tasks = tasks + + def completed(self, realm, entity) -> bool: + return any(t.completed(realm, entity) for t in self._tasks) + + def description(self) -> List: + return ["OR"] + [t.description() for t in self._tasks] + +class NOT(Task): + def __init__(self, task: Task) -> None: + super().__init__() + self._task = task + + def completed(self, realm, entity) -> bool: + return not self._task.completed(realm, entity) + + def description(self) -> List: + return ["NOT", self._task.description()] + +############################################################### + +class InflictDamage(TargetTask): + def __init__(self, target: TaskTarget, damage_type: int, quantity: int): + super().__init__(target) + self._damage_type = damage_type + self._quantity = quantity + + def completed(self, realm, entity) -> bool: + # TODO(daveey) damage_type is ignored, needs to be added to entity.history + return sum( + realm.players[a].history.damage_inflicted for a in self._target.agents() + ) >= self._quantity + + def description(self) -> List: + return super().description() + [self._damage_type, self._quantity] + +class Defend(TargetTask): + def __init__(self, target, num_steps) -> None: + super().__init__(target) + self._num_steps = num_steps + + def completed(self, realm, entity) -> bool: + # TODO(daveey) need a way to specify time horizon + return realm.tick >= self._num_steps and all( + realm.players[a].alive for a in self._target.agents() + ) + + def description(self) -> List: + return super().description() + [self._num_steps] + +class Inflict(TargetTask): + def __init__(self, target: TaskTarget, damage_type, quantity: int): + ''' + target: The team that is completing the task. Any agent may complete + damage_type: Can use skills.Melee/Range/Mage + quantity: Minimum damage to inflict in a single hit + ''' + +class Defeat(TargetTask): + def __init__(self, target: TaskTarget, entity_type, level: int): + ''' + target: The team that is completing the task. Any agent may complete + entity type: entity.Player or entity.NPC + level: minimum target level to defeat + ''' + +class Achieve(TargetTask): + def __init__(self, target: TaskTarget, skill, level: int): + ''' + target: The team that is completing the task. Any agent may complete. + skill: systems.skill to advance + level: level to reach + ''' + +class Harvest(TargetTask): + def __init__(self, target: TaskTarget, resource, level: int): + ''' + target: The team that is completing the task. Any agent may complete + resource: lib.material to harvest + level: minimum material level to harvest + ''' + +class Equip(Task): + def __init__(self, target: TaskTarget, item, level: int): + ''' + target: The team that is completing the task. Any agent may complete. + item: systems.item to equip + level: Minimum level of that item + ''' + +class Hoard(Task): + def __init__(self, target: TaskTarget, gold): + ''' + target: The team that is completing the task. Completed across the team + gold: reach this amount of gold held at one time (inventory.gold sum over team) + ''' + +class Group(Task): + def __init__(self, target: TaskTarget, num_teammates: int, distance: int): + ''' + target: The team that is completing the task. Completed across the team + num_teammates: Number of teammates to group together + distance: Max distance to nearest teammate + ''' + +class Spread(Task): + def __init__(self, target: TaskTarget, num_teammates: int, distance: int): + ''' + target: The team that is completing the task. Completed across the team + num_teammates: Number of teammates to group together + distance: Min distance to nearest teammate + ''' + +class Eliminate(Task): + def __init__(self, target: TaskTarget, opponent_team): + ''' + target: The team that is completing the task. Completed across the team + opponent_team: left/right/any team to be eliminated (all agents defeated) + ''' + +############################################################### + +class TaskSampler(): + def __init__(self) -> None: + self._task_specs = [] + self._task_spec_weights = [] + + def add_task_spec(self, task_class, param_space = None, weight: float = 1): + self._task_specs.append((task_class, param_space or [])) + self._task_spec_weights.append(weight) + + def sample(self, + min_clauses: int = 1, + max_clauses: int = 1, + min_clause_size: int = 1, + max_clause_size: int = 1, + not_p: float = 0.0) -> Task: + + clauses = [] + for _ in range(0, random.randint(min_clauses, max_clauses)): + task_specs = random.choices( + self._task_specs, + weights = self._task_spec_weights, + k = random.randint(min_clause_size, max_clause_size) + ) + tasks = [] + for task_class, task_param_space in task_specs: + task = task_class(*[random.choice(tp) for tp in task_param_space]) + if random.random() < not_p: + task = NOT(task) + tasks.append(task) + + if len(tasks) == 1: + clauses.append(tasks[0]) + else: + clauses.append(AND(*tasks)) + + if len(clauses) == 1: + return clauses[0] + + return OR(*clauses) + + @staticmethod + def create_default_task_sampler(team_helper: TeamHelper, agent_id: int): + neighbors = [team_helper.left_team(agent_id), team_helper.right_team(agent_id)] + own_team = team_helper.own_team(agent_id) + team_mates = [own_team.member(m) for m in range(team_helper.team_size)] + sampler = TaskSampler() + + sampler.add_task_spec(InflictDamage, [neighbors + [own_team], [0, 1, 2], [0, 100, 1000]]) + sampler.add_task_spec(Defend, [team_mates, [512, 1024]]) + + return sampler diff --git a/nmmo/lib/utils.py b/nmmo/lib/utils.py index b3c5943e0..e4ebd33cd 100644 --- a/nmmo/lib/utils.py +++ b/nmmo/lib/utils.py @@ -1,87 +1,86 @@ -from pdb import set_trace as T -import numpy as np +# pylint: disable=all -from collections import defaultdict, deque import inspect +from collections import deque + +import numpy as np + class staticproperty(property): - def __get__(self, cls, owner): - return self.fget.__get__(None, owner)() + def __get__(self, cls, owner): + return self.fget.__get__(None, owner)() class classproperty(object): - def __init__(self, f): - self.f = f - def __get__(self, obj, owner): - return self.f(owner) + def __init__(self, f): + self.f = f + def __get__(self, obj, owner): + return self.f(owner) class Iterable(type): - def __iter__(cls): - queue = deque(cls.__dict__.items()) - while len(queue) > 0: - name, attr = queue.popleft() - if type(name) != tuple: - name = tuple([name]) - if not inspect.isclass(attr): - continue - yield name, attr - - def values(cls): - return [e[1] for e in cls] + def __iter__(cls): + queue = deque(cls.__dict__.items()) + while len(queue) > 0: + name, attr = queue.popleft() + if type(name) != tuple: + name = tuple([name]) + if not inspect.isclass(attr): + continue + yield name, attr + + def values(cls): + return [e[1] for e in cls] class StaticIterable(type): - def __iter__(cls): - stack = list(cls.__dict__.items()) - stack.reverse() - for name, attr in stack: - if name == '__module__': - continue - if name.startswith('__'): - break - yield name, attr + def __iter__(cls): + stack = list(cls.__dict__.items()) + stack.reverse() + for name, attr in stack: + if name == '__module__': + continue + if name.startswith('__'): + break + yield name, attr class NameComparable(type): - def __hash__(self): - return hash(self.__name__) + def __hash__(self): + return hash(self.__name__) - def __eq__(self, other): - try: - return self.__name__ == other.__name__ - except: - print('Some sphinx bug makes this block doc calls. You should not see this in normal NMMO usage') + def __eq__(self, other): + return self.__name__ == other.__name__ - def __ne__(self, other): - return self.__name__ != other.__name__ + def __ne__(self, other): + return self.__name__ != other.__name__ - def __lt__(self, other): - return self.__name__ < other.__name__ + def __lt__(self, other): + return self.__name__ < other.__name__ - def __le__(self, other): - return self.__name__ <= other.__name__ + def __le__(self, other): + return self.__name__ <= other.__name__ - def __gt__(self, other): - return self.__name__ > other.__name__ + def __gt__(self, other): + return self.__name__ > other.__name__ - def __ge__(self, other): - return self.__name__ >= other.__name__ + def __ge__(self, other): + return self.__name__ >= other.__name__ class IterableNameComparable(Iterable, NameComparable): - pass + pass def seed(): - return int(np.random.randint(0, 2**32)) + return int(np.random.randint(0, 2**32)) def linf(pos1, pos2): - r1, c1 = pos1 - r2, c2 = pos2 - return max(abs(r1 - r2), abs(c1 - c2)) + # pos could be a single (r,c) or a vector of (r,c)s + diff = np.abs(np.array(pos1) - np.array(pos2)) + return np.max(diff, axis=len(diff.shape)-1) #Bounds checker -def inBounds(r, c, shape, border=0): - R, C = shape - return ( - r > border and - c > border and - r < R - border and - c < C - border - ) +def in_bounds(r, c, shape, border=0): + R, C = shape + return ( + r > border and + c > border and + r < R - border and + c < C - border + ) diff --git a/nmmo/overlay.py b/nmmo/overlay.py deleted file mode 100644 index e269dc9d1..000000000 --- a/nmmo/overlay.py +++ /dev/null @@ -1,171 +0,0 @@ -from pdb import set_trace as T -import numpy as np - -from nmmo.lib import overlay -from nmmo.lib.colors import Neon -from nmmo.systems import combat - - -class OverlayRegistry: - def __init__(self, config, realm): - '''Manager class for overlays - - Args: - config: A Config object - realm: An environment - ''' - self.initialized = False - - self.config = config - self.realm = realm - - self.overlays = { - 'counts': Counts, - 'skills': Skills, - 'wilderness': Wilderness} - - - def init(self, *args): - self.initialized = True - for cmd, overlay in self.overlays.items(): - self.overlays[cmd] = overlay(self.config, self.realm, *args) - return self - - def step(self, obs, pos, cmd): - '''Per-tick overlay updates - - Args: - obs: Observation returned by the environment - pos: Client camera focus position - cmd: User command returned by the client - ''' - if not self.initialized: - self.init() - - self.realm.overlayPos = pos - for overlay in self.overlays.values(): - overlay.update(obs) - - if cmd in self.overlays: - self.overlays[cmd].register(obs) - -class Overlay: - '''Define a overlay for visualization in the client - - Overlays are color images of the same size as the game map. - They are rendered over the environment with transparency and - can be used to gain insight about agent behaviors.''' - def __init__(self, config, realm, *args): - ''' - Args: - config: A Config object - realm: An environment - ''' - self.config = config - self.realm = realm - - self.size = config.MAP_SIZE - self.values = np.zeros((self.size, self.size)) - - def update(self, obs): - '''Compute per-tick updates to this overlay. Override per overlay. - - Args: - obs: Observation returned by the environment - ''' - pass - - def register(self): - '''Compute the overlay and register it within realm. Override per overlay.''' - pass - -class Skills(Overlay): - def __init__(self, config, realm, *args): - '''Indicates whether agents specialize in foraging or combat''' - super().__init__(config, realm) - self.nSkills = 2 - - self.values = np.zeros((self.size, self.size, self.nSkills)) - - def update(self, obs): - '''Computes a count-based exploration map by painting - tiles as agents walk over them''' - for entID, agent in self.realm.realm.players.items(): - r, c = agent.base.pos - - skillLvl = (agent.skills.food.level.val + agent.skills.water.level.val)/2.0 - combatLvl = combat.level(agent.skills) - - if skillLvl == 10 and combatLvl == 3: - continue - - self.values[r, c, 0] = skillLvl - self.values[r, c, 1] = combatLvl - - def register(self, obs): - values = np.zeros((self.size, self.size, self.nSkills)) - for idx in range(self.nSkills): - ary = self.values[:, :, idx] - vals = ary[ary != 0] - mean = np.mean(vals) - std = np.std(vals) - if std == 0: - std = 1 - - values[:, :, idx] = (ary - mean) / std - values[ary == 0] = 0 - - colors = np.array([Neon.BLUE.rgb, Neon.BLOOD.rgb]) - colorized = np.zeros((self.size, self.size, 3)) - amax = np.argmax(values, -1) - - for idx in range(self.nSkills): - colorized[amax == idx] = colors[idx] / 255 - colorized[values[:, :, idx] == 0] = 0 - - self.realm.register(colorized) - -class Counts(Overlay): - def __init__(self, config, realm, *args): - super().__init__(config, realm) - self.values = np.zeros((self.size, self.size, config.PLAYER_POLICIES)) - - def update(self, obs): - '''Computes a count-based exploration map by painting - tiles as agents walk over them''' - for entID, agent in self.realm.realm.players.items(): - pop = agent.base.population.val - r, c = agent.base.pos - self.values[r, c][pop] += 1 - - def register(self, obs): - colors = self.realm.realm.players.palette.colors - colors = np.array([colors[pop].rgb - for pop in range(self.config.PLAYER_POLICIES)]) - - colorized = self.values[:, :, :, None] * colors / 255 - colorized = np.sum(colorized, -2) - countSum = np.sum(self.values[:, :], -1) - data = overlay.norm(countSum)[..., None] - - countSum[countSum==0] = 1 - colorized = colorized * data / countSum[..., None] - - self.realm.register(colorized) - -class Wilderness(Overlay): - def init(self): - '''Computes the local wilderness level''' - data = np.zeros((self.size, self.size)) - for r in range(self.size): - for c in range(self.size): - data[r, c] = combat.wilderness(self.config, (r, c)) - - self.wildy = overlay.twoTone(data, preprocess='clip', invert=True, periods=5) - - def register(self, obs): - if not hasattr(self, 'wildy'): - print('Initializing Wilderness') - self.init() - - self.realm.register(self.wildy) diff --git a/nmmo/render/__init__.py b/nmmo/render/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nmmo/render/overlay.py b/nmmo/render/overlay.py new file mode 100644 index 000000000..3b92a21f8 --- /dev/null +++ b/nmmo/render/overlay.py @@ -0,0 +1,153 @@ +import numpy as np + +from nmmo.lib.colors import Neon +from nmmo.systems import combat + +from .render_utils import normalize + +# pylint: disable=unused-argument +class OverlayRegistry: + def __init__(self, realm, renderer): + '''Manager class for overlays + + Args: + config: A Config object + realm: An environment + ''' + self.initialized = False + + self.realm = realm + self.config = realm.config + self.renderer = renderer + + self.overlays = { + #'counts': Counts, # TODO: change population to team + 'skills': Skills} + + def init(self, *args): + self.initialized = True + for cmd, overlay in self.overlays.items(): + self.overlays[cmd] = overlay(self.config, self.realm, self.renderer, *args) + return self + + def step(self, cmd): + '''Per-tick overlay updates + + Args: + cmd: User command returned by the client + ''' + if not self.initialized: + self.init() + + for overlay in self.overlays.values(): + overlay.update() + + if cmd in self.overlays: + self.overlays[cmd].register() + + +class Overlay: + '''Define a overlay for visualization in the client + + Overlays are color images of the same size as the game map. + They are rendered over the environment with transparency and + can be used to gain insight about agent behaviors.''' + def __init__(self, config, realm, renderer, *args): + ''' + Args: + config: A Config object + realm: An environment + ''' + self.config = config + self.realm = realm + self.renderer = renderer + + self.size = config.MAP_SIZE + self.values = np.zeros((self.size, self.size)) + + def update(self): + '''Compute per-tick updates to this overlay. Override per overlay. + + Args: + obs: Observation returned by the environment + ''' + + def register(self): + '''Compute the overlay and register it within realm. Override per overlay.''' + + +class Skills(Overlay): + def __init__(self, config, realm, renderer, *args): + '''Indicates whether agents specialize in foraging or combat''' + super().__init__(config, realm, renderer) + self.num_skill = 2 + + self.values = np.zeros((self.size, self.size, self.num_skill)) + + def update(self): + '''Computes a count-based exploration map by painting + tiles as agents walk over them''' + for agent in self.realm.players.values(): + r, c = agent.pos + + skill_lvl = (agent.skills.food.level.val + agent.skills.water.level.val)/2.0 + combat_lvl = combat.level(agent.skills) + + if skill_lvl == 10 and combat_lvl == 3: + continue + + self.values[r, c, 0] = skill_lvl + self.values[r, c, 1] = combat_lvl + + def register(self): + values = np.zeros((self.size, self.size, self.num_skill)) + for idx in range(self.num_skill): + ary = self.values[:, :, idx] + vals = ary[ary != 0] + mean = np.mean(vals) + std = np.std(vals) + if std == 0: + std = 1 + + values[:, :, idx] = (ary - mean) / std + values[ary == 0] = 0 + + colors = np.array([Neon.BLUE.rgb, Neon.BLOOD.rgb]) + colorized = np.zeros((self.size, self.size, 3)) + amax = np.argmax(values, -1) + + for idx in range(self.num_skill): + colorized[amax == idx] = colors[idx] / 255 + colorized[values[:, :, idx] == 0] = 0 + + self.renderer.register(colorized) + + +# CHECK ME: this was based on population, so disabling it for now +# We may want this back for the team-level analysis +class Counts(Overlay): + def __init__(self, config, realm, renderer, *args): + super().__init__(config, realm, renderer) + self.values = np.zeros((self.size, self.size, config.PLAYER_POLICIES)) + + def update(self): + '''Computes a count-based exploration map by painting + tiles as agents walk over them''' + for ent_id, agent in self.realm.players.items(): + r, c = agent.pos + self.values[r, c][ent_id] += 1 + + def register(self): + colors = self.realm.players.palette.colors + colors = np.array([colors[pop].rgb + for pop in range(self.config.PLAYER_POLICIES)]) + + colorized = self.values[:, :, :, None] * colors / 255 + colorized = np.sum(colorized, -2) + count_sum = np.sum(self.values[:, :], -1) + data = normalize(count_sum)[..., None] + + count_sum[count_sum==0] = 1 + colorized = colorized * data / count_sum[..., None] + + self.renderer.register(colorized) diff --git a/nmmo/render/render_client.py b/nmmo/render/render_client.py new file mode 100644 index 000000000..e61d88083 --- /dev/null +++ b/nmmo/render/render_client.py @@ -0,0 +1,68 @@ +from __future__ import annotations +import numpy as np + +from nmmo.render import websocket +from nmmo.render.overlay import OverlayRegistry +from nmmo.render.render_utils import patch_packet + + +# Render is external to the game +class WebsocketRenderer: + def __init__(self, realm=None) -> None: + self._client = websocket.Application(realm) + self.overlay_pos = [256, 256] + + self._realm = realm + + self.overlay = None + self.registry = OverlayRegistry(realm, renderer=self) if realm else None + + self.packet = None + + def render_packet(self, packet) -> None: + packet = { + 'pos': self.overlay_pos, + 'wilderness': 0, # obsolete, but maintained for compatibility + **packet } + + self.overlay_pos, _ = self._client.update(packet) + + def render_realm(self) -> None: + assert self._realm is not None, 'This function requires a realm' + assert self._realm.tick is not None, 'render before reset' + + packet = { + 'config': self._realm.config, + 'pos': self.overlay_pos, + 'wilderness': 0, + **self._realm.packet() + } + + # TODO: a hack to make the client work + packet = patch_packet(packet, self._realm) + + if self.overlay is not None: + packet['overlay'] = self.overlay + self.overlay = None + + # save the packet for investigation + self.packet = packet + + # pass the packet to renderer + pos, cmd = self._client.update(self.packet) + + self.overlay_pos = pos + self.registry.step(cmd) + + def register(self, overlay: np.ndarray) -> None: + '''Register an overlay to be sent to the client + + The intended use of this function is: User types overlay -> + client sends cmd to server -> server computes overlay update -> + register(overlay) -> overlay is sent to client -> overlay rendered + + Args: + overlay: A map-sized (self.size) array of floating point values + overlay must be a numpy array of dimension (*(env.size), 3) + ''' + self.overlay = overlay.tolist() diff --git a/nmmo/render/render_utils.py b/nmmo/render/render_utils.py new file mode 100644 index 000000000..47dad7f23 --- /dev/null +++ b/nmmo/render/render_utils.py @@ -0,0 +1,86 @@ +import numpy as np +from scipy import signal + +from nmmo.lib.colors import Neon + +# NOTE: added to fix json.dumps() cannot serialize numpy objects +# pylint: disable=inconsistent-return-statements +def np_encoder(obj): + if isinstance(obj, np.generic): + return obj.item() + +def normalize(ary: np.ndarray, norm_std=2): + R, C = ary.shape + preprocessed = np.zeros_like(ary) + nonzero = ary[ary!= 0] + mean = np.mean(nonzero) + std = np.std(nonzero) + if std == 0: + std = 1 + for r in range(R): + for c in range(C): + val = ary[r, c] + if val != 0: + val = (val - mean) / (norm_std * std) + val = np.clip(val+1, 0, 2)/2 + preprocessed[r, c] = val + return preprocessed + +def clip(ary: np.ndarray): + R, C = ary.shape + preprocessed = np.zeros_like(ary) + nonzero = ary[ary!= 0] + mmin = np.min(nonzero) + mmag = np.max(nonzero) - mmin + for r in range(R): + for c in range(C): + val = ary[r, c] + val = (val - mmin) / mmag + preprocessed[r, c] = val + return preprocessed + +def make_two_tone(ary, norm_std=2, preprocess='norm', invert=False, periods=1): + if preprocess == 'norm': + ary = normalize(ary, norm_std) + elif preprocess == 'clip': + ary = clip(ary) + + # if preprocess not in ['norm', 'clip'], assume no preprocessing + R, C = ary.shape + + colorized = np.zeros((R, C, 3)) + if periods != 1: + ary = np.abs(signal.sawtooth(periods*3.14159*ary)) + if invert: + colorized[:, :, 0] = ary + colorized[:, :, 1] = 1-ary + else: + colorized[:, :, 0] = 1-ary + colorized[:, :, 1] = ary + + colorized *= (ary != 0)[:, :, None] + + return colorized + +# TODO: this is a hack to make the client work +# by adding color, population, self to the packet +# integrating with team helper could make this neat +def patch_packet(packet, realm): + for ent_id in packet['player']: + packet['player'][ent_id]['base']['color'] = Neon.GREEN.packet() + # EntityAttr: population was changed to npc_type + packet['player'][ent_id]['base']['population'] = 0 + # old code: nmmo.Serialized.Entity.Self, no longer being used + packet['player'][ent_id]['base']['self'] = 1 + + npc_colors = { + 1: Neon.YELLOW.packet(), # passive npcs + 2: Neon.MAGENTA.packet(), # neutral npcs + 3: Neon.BLOOD.packet() } # aggressive npcs + for ent_id in packet['npc']: + npc = realm.npcs.corporeal[ent_id] + packet['npc'][ent_id]['base']['color'] = npc_colors[int(npc.npc_type.val)] + packet['npc'][ent_id]['base']['population'] = -int(npc.npc_type.val) # note negative + packet['npc'][ent_id]['base']['self'] = 1 + + return packet diff --git a/nmmo/render/replay_helper.py b/nmmo/render/replay_helper.py new file mode 100644 index 000000000..50858ed9c --- /dev/null +++ b/nmmo/render/replay_helper.py @@ -0,0 +1,100 @@ +import json +import lzma +import logging + +from .render_utils import np_encoder, patch_packet + + +class ReplayHelper: + @staticmethod + def create(realm): + if realm.config.SAVE_REPLAY: + return ReplayFileHelper(realm) + + return DummyReplayHelper() + + +class DummyReplayHelper(ReplayHelper): + def reset(self): + pass + + def update(self): + pass + + def save(self, save_path, compress): + pass + + +class ReplayFileHelper(ReplayHelper): + def __init__(self, realm=None): + self._realm = realm + self.packets = None + self.map = None + self._i = 0 + + def reset(self): + self.packets = [] + self.map = None + self._i = 0 + + def __len__(self): + return len(self.packets) + + def __iter__(self): + self._i = 0 + return self + + def __next__(self): + if self._i >= len(self.packets): + raise StopIteration + packet = self.packets[self._i] + packet['environment'] = self.map + self._i += 1 + return packet + + def update(self, packet=None): + if packet is None: + if self._realm is None: + return + # TODO: patch_packet is a hack. best to remove, if possible + packet = patch_packet(self._realm.packet(), self._realm) + + data = {} + for key, val in packet.items(): + if key == 'environment': + self.map = val + continue + if key == 'config': + continue + data[key] = val + + self.packets.append(data) + + def save(self, save_file, compress=True): + logging.info('Saving replay to %s ...', save_file) + + data = { + 'map': self.map, + 'packets': self.packets } + + data = json.dumps(data, default=np_encoder).encode('utf8') + if compress: + data = lzma.compress(data, format=lzma.FORMAT_ALONE) + + with open(save_file, 'wb') as out: + out.write(data) + + @classmethod + def load(cls, replay_file, decompress=True): + with open(replay_file, 'rb') as fp: + data = fp.read() + + if decompress: + data = lzma.decompress(data, format=lzma.FORMAT_ALONE) + data = json.loads(data.decode('utf-8')) + + replay_helper = ReplayFileHelper() + replay_helper.map = data['map'] + replay_helper.packets = data['packets'] + + return replay_helper diff --git a/nmmo/websocket.py b/nmmo/render/websocket.py similarity index 94% rename from nmmo/websocket.py rename to nmmo/render/websocket.py index 2690980ef..3647f51e1 100644 --- a/nmmo/websocket.py +++ b/nmmo/render/websocket.py @@ -1,12 +1,15 @@ -from pdb import set_trace as T +# pylint: disable=all + import numpy as np from signal import signal, SIGINT -import sys, os, json, pickle, time +import json +import os +import sys +import time import threading from twisted.internet import reactor -from twisted.internet.task import LoopingCall from twisted.python import log from twisted.web.server import Site from twisted.web.static import File @@ -15,6 +18,8 @@ WebSocketServerProtocol from autobahn.twisted.resource import WebSocketResource +from .render_utils import np_encoder + class GodswordServerProtocol(WebSocketServerProtocol): def __init__(self): super().__init__() @@ -76,7 +81,7 @@ def sendUpdate(self, data): packet['pos'] = data['pos'] packet['wilderness'] = data['wilderness'] packet['market'] = data['market'] - + print('Is Connected? : {}'.format(self.isConnected)) if not self.sent_environment: packet['map'] = data['environment'] @@ -88,9 +93,10 @@ def sendUpdate(self, data): packet['overlay'] = data['overlay'] print('SENDING OVERLAY: ', len(packet['overlay'])) - packet = json.dumps(packet).encode('utf8') + packet = json.dumps(packet, default=np_encoder).encode('utf8') self.sendMessage(packet, False) + class WSServerFactory(WebSocketServerFactory): def __init__(self, ip, realm): super().__init__(ip) @@ -108,7 +114,7 @@ def update(self, packet): uptime = np.round(self.tickRate*self.tick, 1) delta = time.time() - self.time print('Wall Clock: ', str(delta)[:5], 'Uptime: ', uptime, ', Tick: ', self.tick) - delta = self.tickRate - delta + delta = self.tickRate - delta if delta > 0: time.sleep(delta) self.time = time.time() @@ -134,7 +140,7 @@ def __init__(self, realm): port = 8080 self.factory = WSServerFactory(u'ws://localhost:{}'.format(port), realm) - self.factory.protocol = GodswordServerProtocol + self.factory.protocol = GodswordServerProtocol resource = WebSocketResource(self.factory) root = File(".") diff --git a/nmmo/scripting.py b/nmmo/scripting.py deleted file mode 100644 index a234e00a9..000000000 --- a/nmmo/scripting.py +++ /dev/null @@ -1,53 +0,0 @@ -from pdb import set_trace as T - -class Observation: - '''Unwraps observation tensors for use with scripted agents''' - def __init__(self, config, obs): - ''' - Args: - config: A forge.blade.core.Config object or subclass object - obs: An observation object from the environment - ''' - self.config = config - self.obs = obs - self.delta = config.PLAYER_VISION_RADIUS - self.tiles = self.obs['Tile']['Continuous'] - - n = int(self.obs['Entity']['N']) - self.agents = self.obs['Entity']['Continuous'][:n] - self.n = n - - if config.ITEM_SYSTEM_ENABLED: - n = int(self.obs['Item']['N']) - self.items = self.obs['Item']['Continuous'][:n] - - if config.EXCHANGE_SYSTEM_ENABLED: - n = int(self.obs['Market']['N']) - self.market = self.obs['Market']['Continuous'][:n] - - def tile(self, rDelta, cDelta): - '''Return the array object corresponding to a nearby tile - - Args: - rDelta: row offset from current agent - cDelta: col offset from current agent - - Returns: - Vector corresponding to the specified tile - ''' - return self.tiles[self.config.PLAYER_VISION_DIAMETER * (self.delta + rDelta) + self.delta + cDelta] - - @property - def agent(self): - '''Return the array object corresponding to the current agent''' - return self.agents[0] - - @staticmethod - def attribute(ary, attr): - '''Return an attribute of a game object - - Args: - ary: The array corresponding to a game object - attr: A forge.blade.io.stimulus.static stimulus class - ''' - return float(ary[attr.index]) diff --git a/nmmo/systems/achievement.py b/nmmo/systems/achievement.py index 15f972bf0..08eddfc4c 100644 --- a/nmmo/systems/achievement.py +++ b/nmmo/systems/achievement.py @@ -1,54 +1,41 @@ -from pdb import set_trace as T -from typing import Callable -from dataclasses import dataclass +from typing import List +from nmmo.lib.task import Task -@dataclass -class Task: - condition: Callable - target: float = None - reward: float = 0 - - -class Diary: - def __init__(self, tasks): - self.achievements = [] - for task in tasks: - self.achievements.append(Achievement(task.condition, task.target, task.reward)) - - @property - def completed(self): - return sum(a.completed for a in self.achievements) - - @property - def cumulative_reward(self, aggregate=True): - return sum(a.reward * a.completed for a in self.achievements) - - def update(self, realm, entity): - return {a.name: a.update(realm, entity) for a in self.achievements} +class Achievement: + def __init__(self, task: Task, reward: float): + self.completed = False + self.task = task + self.reward = reward + @property + def name(self): + return self.task.to_string() -class Achievement: - def __init__(self, condition, target, reward): - self.completed = False + def update(self, realm, entity): + if self.completed: + return 0 - self.condition = condition - self.target = target - self.reward = reward + if self.task.completed(realm, entity): + self.completed = True + return self.reward - @property - def name(self): - return '{}_{}'.format(self.condition.__name__, self.target) + return 0 - def update(self, realm, entity): - if self.completed: - return 0 +class Diary: + def __init__(self, agent, achievements: List[Achievement]): + self.agent = agent + self.achievements = achievements + self.rewards = {} - metric = self.condition(realm, entity) + @property + def completed(self): + return sum(a.completed for a in self.achievements) - if metric >= self.target: - self.completed = True - return self.reward + @property + def cumulative_reward(self): + return sum(a.reward * a.completed for a in self.achievements) - return 0 + def update(self, realm): + self.rewards = { a.name: a.update(realm, self.agent) for a in self.achievements } diff --git a/nmmo/systems/ai/__init__.py b/nmmo/systems/ai/__init__.py index 8f5b0d1aa..5c46b3697 100644 --- a/nmmo/systems/ai/__init__.py +++ b/nmmo/systems/ai/__init__.py @@ -1 +1,2 @@ -from . import utils, move, attack, behavior, policy +# pylint: disable=import-self +from . import utils, move, behavior, policy diff --git a/nmmo/systems/ai/attack.py b/nmmo/systems/ai/attack.py deleted file mode 100644 index 696d9b4c3..000000000 --- a/nmmo/systems/ai/attack.py +++ /dev/null @@ -1,3 +0,0 @@ -from pdb import set_trace as T -from nmmo.systems.ai import utils - diff --git a/nmmo/systems/ai/behavior.py b/nmmo/systems/ai/behavior.py index 31df58e81..85cbf1c26 100644 --- a/nmmo/systems/ai/behavior.py +++ b/nmmo/systems/ai/behavior.py @@ -1,8 +1,9 @@ -from pdb import set_trace as T +# pylint: disable=all + import numpy as np import nmmo -from nmmo.systems.ai import move, attack, utils +from nmmo.systems.ai import move, utils def update(entity): '''Update validity of tracked entities''' @@ -20,7 +21,7 @@ def update(entity): entity.food = None if not utils.validResource(entity, entity.water, entity.vision): entity.water = None - + def pathfind(realm, actions, entity, target): actions[nmmo.action.Move] = {nmmo.action.Direction: move.pathfind(realm.map.tiles, entity, target)} @@ -40,26 +41,6 @@ def explore(realm, actions, entity): tile = realm.map.tiles[rr, cc] pathfind(realm, actions, entity, tile) -def explore(config, ob, actions, spawnR, spawnC): - vision = config.NSTIM - sz = config.TERRAIN_SIZE - Entity = nmmo.Serialized.Entity - Tile = nmmo.Serialized.Tile - - agent = ob.agent - r = utils.Observation.attribute(agent, Entity.R) - c = utils.Observation.attribute(agent, Entity.C) - - centR, centC = sz//2, sz//2 - - vR, vC = centR-spawnR, centC-spawnC - - mmag = max(abs(vR), abs(vC)) - rr = int(np.round(vision*vR/mmag)) - cc = int(np.round(vision*vC/mmag)) - - pathfind(config, ob, actions, rr, cc) - def meander(realm, actions, entity): actions[nmmo.action.Move] = {nmmo.action.Direction: move.habitable(realm.map.tiles, entity)} @@ -72,7 +53,7 @@ def hunt(realm, actions, entity): direction = None if distance == 0: - direction = move.random() + direction = move.random_direction() elif distance > 1: direction = move.pathfind(realm.map.tiles, entity, entity.target) @@ -83,7 +64,7 @@ def hunt(realm, actions, entity): def attack(realm, actions, entity): distance = utils.lInfty(entity.pos, entity.target.pos) - if distance > entity.skills.style.attackRange(realm.config): + if distance > entity.skills.style.attack_range(realm.config): return actions[nmmo.action.Attack] = { diff --git a/nmmo/systems/ai/dynamic_programming.py b/nmmo/systems/ai/dynamic_programming.py deleted file mode 100644 index facc059f3..000000000 --- a/nmmo/systems/ai/dynamic_programming.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import List - -#from forge.blade.core import material -from nmmo.systems import ai - -import math - -import numpy as np - - -def map_to_rewards(tiles, entity) -> List[List[float]]: - lava_reward = stone_reward = water_reward = float('-inf') - forest_reward = 1.0 + math.pow( - (1 - entity.resources.food.val / entity.resources.food.max) * 15.0, - 1.25) - scrub_reward = 1.0 - around_water_reward = 1.0 + math.pow( - (1 - entity.resources.water.val / entity.resources.water.max) * 15.0, - 1.25) - - reward_matrix = np.full((len(tiles), len(tiles[0])), 0.0) - - for line in range(len(tiles)): - tile_line = tiles[line] - for column in range(len(tile_line)): - tile_val = tile_line[column].state.tex - if tile_val == 'lava': - reward_matrix[line][column] += lava_reward - - if tile_val == 'stone': - reward_matrix[line][column] += stone_reward - - if tile_val == 'forest': - reward_matrix[line][column] += forest_reward - - if tile_val == 'water': - reward_matrix[line][column] += water_reward - - #TODO: Make these comparisons work off of the water Enum type - #instead of string compare - if 'water' in ai.utils.adjacentMats(tiles, (line, column)): - reward_matrix[line][column] += around_water_reward - - if tile_val == 'scrub': - reward_matrix[line][column] += scrub_reward - - return reward_matrix - - -def compute_values(reward_matrix: List[List[float]]) -> List[List[float]]: - gamma_factor = 0.8 # look ahead ∈ [0, 1] - max_delta = 0.01 # maximum allowed approximation - - value_matrix = np.full((len(reward_matrix), len(reward_matrix[0])), 0.0) - - delta = float('inf') - while delta > max_delta: - old_value_matrix = np.copy(value_matrix) - for line in range(len(reward_matrix)): - for column in range(len(reward_matrix[0])): - reward = reward_matrix[line][column] - value_matrix[line][ - column] = reward + gamma_factor * max_value_around( - (line, column), value_matrix) - - delta = np.amax( - np.abs(np.subtract(old_value_matrix, value_matrix))) - return value_matrix - - -def values_around(position: (int, int), value_matrix: List[List[float]]) -> ( - float, float, float, float): - line, column = position - - if line - 1 >= 0: - top_value = value_matrix[line - 1][column] - else: - top_value = float('-inf') - - if line + 1 < len(value_matrix): - bottom_value = value_matrix[line + 1][column] - else: - bottom_value = float('-inf') - - if column - 1 >= 0: - left_value = value_matrix[line][column - 1] - else: - left_value = float('-inf') - - if column + 1 < len(value_matrix[0]): - right_value = value_matrix[line][column + 1] - else: - right_value = float('-inf') - - return top_value, bottom_value, left_value, right_value - - -def max_value_around(position: (int, int), - value_matrix: List[List[float]]) -> float: - return max(values_around(position, value_matrix)) - - -def max_value_position_around(position: (int, int), - value_matrix: List[List[float]]) -> (int, int): - line, column = position - top_value, bottom_value, left_value, right_value = values_around(position, - value_matrix) - - max_value = max(top_value, bottom_value, left_value, right_value) - - if max_value == top_value: - return line - 1, column - elif max_value == bottom_value: - return line + 1, column - elif max_value == left_value: - return line, column - 1 - elif max_value == right_value: - return line, column + 1 - - -def max_value_direction_around(position: (int, int), - value_matrix: List[List[float]]) -> (int, int): - top_value, bottom_value, left_value, right_value = values_around(position, - value_matrix) - - max_value = max(top_value, bottom_value, left_value, right_value) - - if max_value == top_value: - return -1, 0 - elif max_value == bottom_value: - return 1, 0 - elif max_value == left_value: - return 0, -1 - elif max_value == right_value: - return 0, 1 diff --git a/nmmo/systems/ai/move.py b/nmmo/systems/ai/move.py index a8240349e..8b1968bd8 100644 --- a/nmmo/systems/ai/move.py +++ b/nmmo/systems/ai/move.py @@ -1,71 +1,68 @@ -from pdb import set_trace as T -import numpy as np -import random# as rand +# pylint: disable=R0401 -import nmmo +import random + +from nmmo.io import action from nmmo.systems.ai import utils -def rand(): - return random.choice(nmmo.action.Direction.edges) -def randomSafe(tiles, ent): - r, c = ent.base.pos - cands = [] - if not tiles[r-1, c].lava: - cands.append(nmmo.action.North) - if not tiles[r+1, c].lava: - cands.append(nmmo.action.South) - if not tiles[r, c-1].lava: - cands.append(nmmo.action.West) - if not tiles[r, c+1].lava: - cands.append(nmmo.action.East) - - return rand.choice(cands) +def random_direction(): + return random.choice(action.Direction.edges) + +def random_safe(tiles, ent): + r, c = ent.pos + cands = [] + if not tiles[r-1, c].lava: + cands.append(action.North) + if not tiles[r+1, c].lava: + cands.append(action.South) + if not tiles[r, c-1].lava: + cands.append(action.West) + if not tiles[r, c+1].lava: + cands.append(action.East) + + return random.choice(cands) def habitable(tiles, ent): - r, c = ent.base.pos - cands = [] - if tiles[r-1, c].vacant: - cands.append(nmmo.action.North) - if tiles[r+1, c].vacant: - cands.append(nmmo.action.South) - if tiles[r, c-1].vacant: - cands.append(nmmo.action.West) - if tiles[r, c+1].vacant: - cands.append(nmmo.action.East) - - if len(cands) == 0: - return nmmo.action.North + r, c = ent.pos + cands = [] + if tiles[r-1, c].habitable: + cands.append(action.North) + if tiles[r+1, c].habitable: + cands.append(action.South) + if tiles[r, c-1].habitable: + cands.append(action.West) + if tiles[r, c+1].habitable: + cands.append(action.East) + + if len(cands) == 0: + return action.North - return random.choice(cands) + return random.choice(cands) def towards(direction): - if direction == (-1, 0): - return nmmo.action.North - elif direction == (1, 0): - return nmmo.action.South - elif direction == (0, -1): - return nmmo.action.West - elif direction == (0, 1): - return nmmo.action.East - else: - return rand() + if direction == (-1, 0): + return action.North + if direction == (1, 0): + return action.South + if direction == (0, -1): + return action.West + if direction == (0, 1): + return action.East + + return random.choice(action.Direction.edges) def bullrush(ent, targ): - direction = utils.directionTowards(ent, targ) - return towards(direction) + direction = utils.directionTowards(ent, targ) + return towards(direction) def pathfind(tiles, ent, targ): - direction = utils.aStar(tiles, ent.pos, targ.pos) - return towards(direction) + direction = utils.aStar(tiles, ent.pos, targ.pos) + return towards(direction) def antipathfind(tiles, ent, targ): - er, ec = ent.pos - tr, tc = targ.pos - goal = (2*er - tr , 2*ec-tc) - direction = utils.aStar(tiles, ent.pos, goal) - return towards(direction) - - - - + er, ec = ent.pos + tr, tc = targ.pos + goal = (2*er - tr , 2*ec-tc) + direction = utils.aStar(tiles, ent.pos, goal) + return towards(direction) diff --git a/nmmo/systems/ai/policy.py b/nmmo/systems/ai/policy.py index 4c57ce5f1..ff5b6642e 100644 --- a/nmmo/systems/ai/policy.py +++ b/nmmo/systems/ai/policy.py @@ -1,39 +1,38 @@ -from pdb import set_trace as T from nmmo.systems.ai import behavior, utils def passive(realm, entity): - behavior.update(entity) - actions = {} + behavior.update(entity) + actions = {} - behavior.meander(realm, actions, entity) + behavior.meander(realm, actions, entity) - return actions + return actions def neutral(realm, entity): - behavior.update(entity) - actions = {} + behavior.update(entity) + actions = {} - if not entity.attacker: - behavior.meander(realm, actions, entity) - else: - entity.target = entity.attacker - behavior.hunt(realm, actions, entity) + if not entity.attacker: + behavior.meander(realm, actions, entity) + else: + entity.target = entity.attacker + behavior.hunt(realm, actions, entity) - return actions + return actions def hostile(realm, entity): - behavior.update(entity) - actions = {} + behavior.update(entity) + actions = {} - # This is probably slow - if not entity.target: - entity.target = utils.closestTarget(entity, realm.map.tiles, - rng=entity.vision) + # This is probably slow + if not entity.target: + entity.target = utils.closestTarget(entity, realm.map.tiles, + rng=entity.vision) - if not entity.target: - behavior.meander(realm, actions, entity) - else: - behavior.hunt(realm, actions, entity) + if not entity.target: + behavior.meander(realm, actions, entity) + else: + behavior.hunt(realm, actions, entity) - return actions + return actions diff --git a/nmmo/systems/ai/utils.py b/nmmo/systems/ai/utils.py index 8a51a949a..406476c93 100644 --- a/nmmo/systems/ai/utils.py +++ b/nmmo/systems/ai/utils.py @@ -1,14 +1,13 @@ -from pdb import set_trace as T +# pylint: disable=all + + +import heapq +from typing import Tuple + import numpy as np -import random -from nmmo.lib.utils import inBounds -from nmmo.systems import combat -from nmmo.lib import material -from queue import PriorityQueue, Queue +from nmmo.lib.utils import in_bounds -from nmmo.systems.ai.dynamic_programming import map_to_rewards, \ - compute_values, max_value_direction_around def validTarget(ent, targ, rng): if targ is None or not targ.alive: @@ -24,8 +23,8 @@ def validResource(ent, tile, rng): def directionTowards(ent, targ): - sr, sc = ent.base.pos - tr, tc = targ.base.pos + sr, sc = ent.pos + tr, tc = targ.pos if abs(sc - tc) > abs(sr - tr): direction = (0, np.sign(tc - sc)) @@ -36,23 +35,24 @@ def directionTowards(ent, targ): def closestTarget(ent, tiles, rng=1): - sr, sc = ent.base.pos + sr, sc = ent.pos for d in range(rng+1): for r in range(-d, d+1): - for e in tiles[sr+r, sc-d].ents.values(): + for e in tiles[sr+r, sc-d].entities.values(): if e is not ent and validTarget(ent, e, rng): return e - for e in tiles[sr + r, sc + d].ents.values(): + for e in tiles[sr + r, sc + d].entities.values(): if e is not ent and validTarget(ent, e, rng): return e - for e in tiles[sr - d, sc + r].ents.values(): + for e in tiles[sr - d, sc + r].entities.values(): if e is not ent and validTarget(ent, e, rng): return e - for e in tiles[sr + d, sc + r].ents.values(): + for e in tiles[sr + d, sc + r].entities.values(): if e is not ent and validTarget(ent, e, rng): return e def distance(ent, targ): - return l1(ent.pos, targ.pos) + # used in scripted/behavior.py, attack() to determine attack range + return lInfty(ent.pos, targ.pos) def lInf(ent, targ): sr, sc = ent.pos @@ -65,59 +65,12 @@ def adjacentPos(pos): return [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)] -def cropTilesAround(position: (int, int), horizon: int, tiles): +def cropTilesAround(position: Tuple[int, int], horizon: int, tiles): line, column = position return tiles[max(line - horizon, 0): min(line + horizon + 1, len(tiles)), max(column - horizon, 0): min(column + horizon + 1, len(tiles[0]))] - -def inSight(dr, dc, vision): - return ( - dr >= -vision and - dc >= -vision and - dr <= vision and - dc <= vision) - -def vacant(tile): - from nmmo.io.stimulus.static import Stimulus - Tile = Stimulus.Tile - occupied = Observation.attribute(tile, Tile.NEnts) - matl = Observation.attribute(tile, Tile.Index) - - lava = material.Lava.index - water = material.Water.index - grass = material.Grass.index - scrub = material.Scrub.index - forest = material.Forest.index - stone = material.Stone.index - orerock = material.Orerock.index - - return matl in (grass, scrub, forest) and not occupied - -def meander(obs): - from nmmo.io.stimulus.static import Stimulus - - agent = obs.agent - Entity = Stimulus.Entity - Tile = Stimulus.Tile - - r = Observation.attribute(agent, Entity.R) - c = Observation.attribute(agent, Entity.C) - - cands = [] - if vacant(obs.tile(-1, 0)): - cands.append((-1, 0)) - if vacant(obs.tile(1, 0)): - cands.append((1, 0)) - if vacant(obs.tile(0, -1)): - cands.append((0, -1)) - if vacant(obs.tile(0, 1)): - cands.append((0, 1)) - if not cands: - return (-1, 0) - return random.choice(cands) - # A* Search def l1(start, goal): sr, sc = start @@ -139,8 +92,7 @@ def aStar(tiles, start, goal, cutoff=100): if start == goal: return (0, 0) - pq = PriorityQueue() - pq.put((0, start)) + pq = [(0, start)] backtrace = {} cost = {start: 0} @@ -149,7 +101,7 @@ def aStar(tiles, start, goal, cutoff=100): closestHeuristic = l1(start, goal) closestCost = closestHeuristic - while not pq.empty(): + while pq: # Use approximate solution if budget exhausted cutoff -= 1 if cutoff <= 0: @@ -157,15 +109,13 @@ def aStar(tiles, start, goal, cutoff=100): goal = closestPos break - priority, cur = pq.get() + priority, cur = heapq.heappop(pq) if cur == goal: break for nxt in adjacentPos(cur): - if not inBounds(*nxt, tiles.shape): - continue - if tiles[nxt].occupied: + if not in_bounds(*nxt, tiles.shape): continue newCost = cost[cur] + 1 @@ -181,7 +131,7 @@ def aStar(tiles, start, goal, cutoff=100): closestHeuristic = heuristic closestCost = priority - pq.put((priority, nxt)) + heapq.heappush(pq, (priority, nxt)) backtrace[nxt] = cur while goal in backtrace and backtrace[goal] != start: @@ -195,7 +145,7 @@ def aStar(tiles, start, goal, cutoff=100): # Adjacency functions def adjacentTiles(tiles, ent): - r, c = ent.base.pos + r, c = ent.pos def adjacentDeltas(): @@ -216,17 +166,17 @@ def posSum(pos1, pos2): def adjacentEmptyPos(env, pos): return [p for p in adjacentPos(pos) - if inBounds(*p, env.size)] + if in_bounds(*p, env.size)] def adjacentTiles(env, pos): return [env.tiles[p] for p in adjacentPos(pos) - if inBounds(*p, env.size)] + if in_bounds(*p, env.size)] def adjacentMats(tiles, pos): return [type(tiles[p].state) for p in adjacentPos(pos) - if inBounds(*p, tiles.shape)] + if in_bounds(*p, tiles.shape)] def adjacencyDelMatPairs(env, pos): diff --git a/nmmo/systems/combat.py b/nmmo/systems/combat.py index 275cf3bf3..1666feed1 100644 --- a/nmmo/systems/combat.py +++ b/nmmo/systems/combat.py @@ -1,146 +1,162 @@ #Various utilities for managing combat, including hit/damage -from pdb import set_trace as T - import numpy as np -import logging from nmmo.systems import skill as Skill -from nmmo.systems import item as Item +from nmmo.lib.log import EventCode def level(skills): - return max(e.level.val for e in skills.skills) + return max(e.level.val for e in skills.skills) def damage_multiplier(config, skill, targ): - skills = [targ.skills.melee, targ.skills.range, targ.skills.mage] - exp = [s.exp for s in skills] + skills = [targ.skills.melee, targ.skills.range, targ.skills.mage] + exp = [s.exp for s in skills] - if max(exp) == min(exp): - return 1.0 + if max(exp) == min(exp): + return 1.0 - idx = np.argmax([exp]) - targ = skills[idx] + idx = np.argmax([exp]) + targ = skills[idx] - if type(skill) == targ.weakness: - return config.COMBAT_WEAKNESS_MULTIPLIER + if isinstance(skill, targ.weakness): + return config.COMBAT_WEAKNESS_MULTIPLIER - return 1.0 + return 1.0 -def attack(realm, player, target, skillFn): - config = player.config - skill = skillFn(player) - skill_type = type(skill) - skill_name = skill_type.__name__ +# pylint: disable=unnecessary-lambda-assignment +def attack(realm, player, target, skill_fn): + config = player.config + skill = skill_fn(player) + skill_type = type(skill) + skill_name = skill_type.__name__ - # Attacker and target levels - player_level = skill.level.val - target_level = level(target.skills) + # Per-style offense/defense + level_damage = 0 + if skill_type == Skill.Melee: + base_damage = config.COMBAT_MELEE_DAMAGE - # Ammunition usage - ammunition = player.equipment.ammunition - if ammunition is not None: - ammunition.fire(player) - - # Per-style offense/defense - level_damage = 0 - if skill_type == Skill.Melee: - base_damage = config.COMBAT_MELEE_DAMAGE - - if config.PROGRESSION_SYSTEM_ENABLED: - base_damage = config.PROGRESSION_MELEE_BASE_DAMAGE - level_damage = config.PROGRESSION_MELEE_LEVEL_DAMAGE - - offense_fn = lambda e: e.melee_attack - defense_fn = lambda e: e.melee_defense - elif skill_type == Skill.Range: - base_damage = config.COMBAT_RANGE_DAMAGE - - if config.PROGRESSION_SYSTEM_ENABLED: - base_damage = config.PROGRESSION_RANGE_BASE_DAMAGE - level_damage = config.PROGRESSION_RANGE_LEVEL_DAMAGE - - offense_fn = lambda e: e.range_attack - defense_fn = lambda e: e.range_defense - elif skill_type == Skill.Mage: - base_damage = config.COMBAT_MAGE_DAMAGE - - if config.PROGRESSION_SYSTEM_ENABLED: - base_damage = config.PROGRESSION_MAGE_BASE_DAMAGE - level_damage = config.PROGRESSION_MAGE_LEVEL_DAMAGE - - offense_fn = lambda e: e.mage_attack - defense_fn = lambda e: e.mage_defense - elif __debug__: - assert False, 'Attack skill must be Melee, Range, or Mage' - - # Compute modifiers - multiplier = damage_multiplier(config, skill, target) - skill_offense = base_damage + level_damage * skill.level.val - skill_defense = config.PROGRESSION_BASE_DEFENSE + config.PROGRESSION_LEVEL_DEFENSE*level(target.skills) + if config.PROGRESSION_SYSTEM_ENABLED: + base_damage = config.PROGRESSION_MELEE_BASE_DAMAGE + level_damage = config.PROGRESSION_MELEE_LEVEL_DAMAGE + + offense_fn = lambda e: e.melee_attack + defense_fn = lambda e: e.melee_defense + + elif skill_type == Skill.Range: + base_damage = config.COMBAT_RANGE_DAMAGE + + if config.PROGRESSION_SYSTEM_ENABLED: + base_damage = config.PROGRESSION_RANGE_BASE_DAMAGE + level_damage = config.PROGRESSION_RANGE_LEVEL_DAMAGE + + offense_fn = lambda e: e.range_attack + defense_fn = lambda e: e.range_defense + + elif skill_type == Skill.Mage: + base_damage = config.COMBAT_MAGE_DAMAGE + + if config.PROGRESSION_SYSTEM_ENABLED: + base_damage = config.PROGRESSION_MAGE_BASE_DAMAGE + level_damage = config.PROGRESSION_MAGE_LEVEL_DAMAGE + + offense_fn = lambda e: e.mage_attack + defense_fn = lambda e: e.mage_defense + + elif __debug__: + assert False, 'Attack skill must be Melee, Range, or Mage' + + # Compute modifiers + multiplier = damage_multiplier(config, skill, target) + skill_offense = base_damage + level_damage * skill.level.val + + if config.PROGRESSION_SYSTEM_ENABLED: + skill_defense = config.PROGRESSION_BASE_DEFENSE + \ + config.PROGRESSION_LEVEL_DEFENSE*level(target.skills) + else: + skill_defense = 0 + + if config.EQUIPMENT_SYSTEM_ENABLED: equipment_offense = player.equipment.total(offense_fn) equipment_defense = target.equipment.total(defense_fn) - # Total damage calculation - offense = skill_offense + equipment_offense - defense = skill_defense + equipment_defense - damage = config.COMBAT_DAMAGE_FORMULA(offense, defense, multiplier) - #damage = multiplier * (offense - defense) - damage = max(int(damage), 0) + # after tallying ammo damage, consume ammo (i.e., fire) + ammunition = player.equipment.ammunition.item + if ammunition is not None: + ammunition.fire(player) + + else: + equipment_offense = 0 + equipment_defense = 0 + + # Total damage calculation + offense = skill_offense + equipment_offense + defense = skill_defense + equipment_defense + damage = config.COMBAT_DAMAGE_FORMULA(offense, defense, multiplier) + #damage = multiplier * (offense - defense) + damage = max(int(damage), 0) + + if player.is_player: + equipment_level_offense = 0 + equipment_level_defense = 0 + if config.EQUIPMENT_SYSTEM_ENABLED: + equipment_level_offense = player.equipment.total(lambda e: e.level) + equipment_level_defense = target.equipment.total(lambda e: e.level) + + realm.event_log.record(EventCode.SCORE_HIT, player, + combat_style=skill_type, damage=damage) - if config.LOG_MILESTONES and player.isPlayer and realm.quill.milestone.log_max(f'Damage_{skill_name}', damage) and config.LOG_VERBOSE: - player_ilvl = player.equipment.total(lambda e: e.level) - target_ilvl = target.equipment.total(lambda e: e.level) + realm.log_milestone(f'Damage_{skill_name}', damage, + f'COMBAT: Inflicted {damage} {skill_name} damage ' + + f'(attack equip lvl {equipment_level_offense} vs ' + + f'defense equip lvl {equipment_level_defense})', + tags={"player_id": player.ent_id}) - logging.info(f'COMBAT: Inflicted {damage} {skill_name} damage (lvl {player_level} i{player_ilvl} vs lvl {target_level} i{target_ilvl})') + player.apply_damage(damage, skill.__class__.__name__.lower()) + target.receive_damage(player, damage) - player.applyDamage(damage, skill.__class__.__name__.lower()) - target.receiveDamage(player, damage) + return damage - return damage -def danger(config, pos, full=False): - border = config.MAP_BORDER - center = config.MAP_CENTER - r, c = pos - - #Distance from border - rDist = min(r - border, center + border - r - 1) - cDist = min(c - border, center + border - c - 1) - dist = min(rDist, cDist) - norm = 2 * dist / center +def danger(config, pos): + border = config.MAP_BORDER + center = config.MAP_CENTER + r, c = pos - if full: - return norm, mag + #Distance from border + r_dist = min(r - border, center + border - r - 1) + c_dist = min(c - border, center + border - c - 1) + dist = min(r_dist, c_dist) + norm = 2 * dist / center - return norm + return norm def spawn(config, dnger): - border = config.MAP_BORDER - center = config.MAP_CENTER - mid = center // 2 - - dist = dnger * center / 2 - max_offset = mid - dist - offset = mid + border + np.random.randint(-max_offset, max_offset) - - rng = np.random.rand() - if rng < 0.25: - r = border + dist - c = offset - elif rng < 0.5: - r = border + center - dist - 1 - c = offset - elif rng < 0.75: - c = border + dist - r = offset - else: - c = border + center - dist - 1 - r = offset - - if __debug__: - assert dnger == danger(config, (r,c)), 'Agent spawned at incorrect radius' - - r = int(r) - c = int(c) - - return r, c + border = config.MAP_BORDER + center = config.MAP_CENTER + mid = center // 2 + + dist = dnger * center / 2 + max_offset = mid - dist + offset = mid + border + np.random.randint(-max_offset, max_offset) + + rng = np.random.rand() + if rng < 0.25: + r = border + dist + c = offset + elif rng < 0.5: + r = border + center - dist - 1 + c = offset + elif rng < 0.75: + c = border + dist + r = offset + else: + c = border + center - dist - 1 + r = offset + + if __debug__: + assert dnger == danger(config, (r,c)), 'Agent spawned at incorrect radius' + + r = int(r) + c = int(c) + + return r, c diff --git a/nmmo/systems/droptable.py b/nmmo/systems/droptable.py index e41f7cad1..6110d79f2 100644 --- a/nmmo/systems/droptable.py +++ b/nmmo/systems/droptable.py @@ -1,56 +1,54 @@ import numpy as np -class Empty(): - def roll(self, realm, level): - return [] - class Fixed(): - def __init__(self, item, amount=1): - self.item = item - self.amount = amount + def __init__(self, item): + self.item = item - def roll(self, realm, level): - return [self.item(realm, level, amount=amount)] + def roll(self, realm, level): + return [self.item(realm, level)] class Drop: - def __init__(self, item, amount, prob): - self.item = item - self.amount = amount - self.prob = prob + def __init__(self, item, prob): + self.item = item + self.prob = prob + + def roll(self, realm, level): + if np.random.rand() < self.prob: + return self.item(realm, level) - def roll(self, realm, level): - if np.random.rand() < self.prob: - return self.item(realm, level, quantity=self.amount) + return None class Standard: - def __init__(self): - self.drops = [] + def __init__(self): + self.drops = [] - def add(self, item, quant=1, prob=1.0): - self.drops += [Drop(item, quant, prob)] + def add(self, item, prob=1.0): + self.drops += [Drop(item, prob)] - def roll(self, realm, level): - ret = [] - for e in self.drops: - drop = e.roll(realm, level) - if drop is not None: - ret += [drop] - return ret + def roll(self, realm, level): + ret = [] + for e in self.drops: + drop = e.roll(realm, level) + if drop is not None: + ret += [drop] + return ret class Empty(Standard): - def roll(self, realm, level): - return [] + def roll(self, realm, level): + return [] class Ammunition(Standard): - def __init__(self, item): - self.item = item + def __init__(self, item): + super().__init__() + self.item = item - def roll(self, realm, level): - return [self.item(realm, level)] + def roll(self, realm, level): + return [self.item(realm, level)] class Consumable(Standard): - def __init__(self, item): - self.item = item + def __init__(self, item): + super().__init__() + self.item = item - def roll(self, realm, level): - return [self.item(realm, level)] + def roll(self, realm, level): + return [self.item(realm, level)] diff --git a/nmmo/systems/equipment.py b/nmmo/systems/equipment.py deleted file mode 100644 index 8decb6fe4..000000000 --- a/nmmo/systems/equipment.py +++ /dev/null @@ -1,56 +0,0 @@ -from pdb import set_trace as T -from nmmo.lib.colors import Tier - -class Loadout: - def __init__(self, chest=0, legs=0): - self.chestplate = Chestplate(chest) - self.platelegs = Platelegs(legs) - - @property - def defense(self): - return (self.chestplate.level + self.platelegs.level) // 2 - - def packet(self): - packet = {} - - packet['chestplate'] = self.chestplate.packet() - packet['platelegs'] = self.platelegs.packet() - - return packet - -class Armor: - def __init__(self, level): - self.level = level - - def packet(self): - packet = {} - - packet['level'] = self.level - packet['color'] = self.color.packet() - - return packet - - @property - def color(self): - if self.level == 0: - return Tier.BLACK - if self.level < 10: - return Tier.WOOD - elif self.level < 20: - return Tier.BRONZE - elif self.level < 40: - return Tier.SILVER - elif self.level < 60: - return Tier.GOLD - elif self.level < 80: - return Tier.PLATINUM - else: - return Tier.DIAMOND - - -class Chestplate(Armor): - pass - -class Platelegs(Armor): - pass - diff --git a/nmmo/systems/exchange.py b/nmmo/systems/exchange.py index 360c3e706..96d04c170 100644 --- a/nmmo/systems/exchange.py +++ b/nmmo/systems/exchange.py @@ -1,218 +1,159 @@ -from pdb import set_trace as T - -from collections import defaultdict, deque -from queue import PriorityQueue - -import inspect -import logging -import inspect - +from __future__ import annotations +from collections import deque import math -class Offer: - def __init__(self, seller, item): - self.seller = seller - self.item = item - - ''' - def __lt__(self, offer): - return self.price < offer.price - - def __le__(self, offer): - return self.price <= offer.price - - def __eq__(self, offer): - return self.price == offer.price - - def __ne__(self, offer): - return self.price != offer.price - - def __gt__(self, offer): - return self.price > offer.price - - def __ge__(self, offer): - return self.price >= offer.price - ''' - -#Why is the api so weird... -class Queue(deque): - def __init__(self): - super().__init__() - self.price = None - - def push(self, x): - self.appendleft(x) - - def peek(self): - if len(self) > 0: - return self[-1] - return None - -class ItemListings: - def __init__(self): - self.listings = PriorityQueue() - self.placeholder = None - self.item_number = 0 - self.alpha = 0.01 - self.volume = 0 - - self.step() - - def step(self): - #self.volume = 0 - pass - - @property - def price(self): - if not self.supply: - return - - price, item_number, seller = self.listings.get() - self.listings.put((price, item_number, seller)) - return price +from typing import Dict - @property - def supply(self): - return self.listings.qsize() +from nmmo.systems.item import Item, Stack +from nmmo.lib.log import EventCode - @property - def empty(self): - return self.listings.empty() +""" +The Exchange class is a simulation of an in-game item exchange. +It has several methods that allow players to list items for sale, +buy items, and remove expired listings. - def buy(self, buyer, max_price): - if not self.supply: - return +The _list_item() method is used to add a new item to the +exchange, and the unlist_item() method is used to remove +an item from the exchange. The step() method is used to +regularly check and remove expired listings. - price, item_number, seller = self.listings.get() +The sell() method allows a player to sell an item, and the buy() method +allows a player to purchase an item. The packet property returns a +dictionary that contains information about the items currently being +sold on the exchange, such as the maximum and minimum price, +the average price, and the total supply of the items. - if price > max_price or price > buyer.inventory.gold.quantity.val: - self.listings.put((price, item_number, seller)) - return - - seller.inventory.gold.quantity += price - buyer.inventory.gold.quantity -= price - - buyer.buys += 1 - seller.sells += 1 - self.volume += 1 - return price - - def sell(self, seller, price): - if price == 1 and not self.empty: - seller.inventory.gold.quantity += 1 - else: - self.listings.put((price, self.item_number, seller)) - self.item_number += 1 - - #print('Sell {}: {}'.format(item.__class__.__name__, price)) +""" +class ItemListing: + def __init__(self, item: Item, seller, price: int, tick: int): + self.item = item + self.seller = seller + self.price = price + self.tick = tick class Exchange: - def __init__(self): - self.item_listings = defaultdict(ItemListings) - - @property - def dataframeKeys(self): - keys = [] - for listings in self.item_listings.values(): - if listings.placeholder: - keys.append(listings.placeholder.instanceID) - - return keys - - @property - def dataframeVals(self): - vals = [] - for listings in self.item_listings.values(): - if listings.placeholder: - vals.append(listings.placeholder) - - return vals - - @property - def packet(self): - packet = {} - for (item_cls, level), listings in self.item_listings.items(): - key = f'{item_cls.__name__}_{level}' - - item = listings.placeholder - if item is None: - continue - - packet[key] = { - 'price': listings.price, - 'supply': listings.supply} - - return packet - - def step(self): - for item, listings in self.item_listings.items(): - listings.step() - - def available(self, item): - return self.item_listings[item].available() - - def buy(self, realm, buyer, item): - assert isinstance(item, object), f'{item} purchase is not an Item instance' - assert item.quantity.val == 1, f'{item} purchase has quantity {item.quantity.val}' - - #TODO: Handle ammo stacks - if not buyer.inventory.space: - return - - config = realm.config - level = item.level.val - - #Agents may try to buy an item at the same time - #Therefore the price has to be semi-variable - price = item.price.val - max_price = 1.1 * price - - item = type(item) - listings_key = (item, level) - listings = self.item_listings[listings_key] - - price = listings.buy(buyer, max_price) - if price is not None: - buyer.inventory.receive(listings.placeholder) - - if ((config.LOG_MILESTONES and realm.quill.milestone.log_max(f'Buy_{item.__name__}', level)) or - (config.LOG_EVENTS and realm.quill.event.log(f'Buy_{item.__name__}', level))) and config.LOG_VERBOSE: - logging.info(f'EXCHANGE: Bought level {level} {item.__name__} for {price} gold') - if ((config.LOG_MILESTONES and realm.quill.milestone.log_max(f'Transaction_Amount', price)) or - (config.LOG_EVENTS and realm.quill.event.log(f'Transaction_Amount', price))) and config.LOG_VERBOSE: - logging.info(f'EXCHANGE: Transaction of {price} gold (level {level} {item.__name__})') - - #Update placeholder - listings.placeholder = None - if listings.supply: - listings.placeholder = item(realm, level, price=listings.price) - - def sell(self, realm, seller, item, price): - assert isinstance(item, object), f'{item} for sale is not an Item instance' - assert item in seller.inventory, f'{item} for sale is not in {seller} inventory' - assert item.quantity.val > 0, f'{item} for sale has quantity {item.quantity.val}' - - if not item.tradable.val: - return - - config = realm.config - level = item.level.val - - #Remove from seller - seller.inventory.remove(item, quantity=1) - item = type(item) - - - if ((config.LOG_MILESTONES and realm.quill.milestone.log_max(f'Sell_{item.__name__}', level)) or (config.LOG_EVENTS and realm.quill.event.log(f'Sell_{item.__name__}', level))) and config.LOG_VERBOSE: - logging.info(f'EXCHANGE: Offered level {level} {item.__name__} for {price} gold') - - listings_key = (item, level) - listings = self.item_listings[listings_key] - current_price = listings.price - - #Update obs placeholder item - if listings.placeholder is None or (current_price is not None and price < current_price): - listings.placeholder = item(realm, level, price=price) - - #print('{} Sold {} x {} for {} ea.'.format(seller.base.name, quantity, item.__name__, price)) - listings.sell(seller, price) + def __init__(self, realm): + self._listings_queue: deque[(int, int)] = deque() # (item_id, tick) + self._item_listings: Dict[int, ItemListing] = {} + self._realm = realm + self._config = realm.config + + def _list_item(self, item: Item, seller, price: int, tick: int): + item.listed_price.update(price) + self._item_listings[item.id.val] = ItemListing(item, seller, price, tick) + self._listings_queue.append((item.id.val, tick)) + + def unlist_item(self, item: Item): + if item.id.val in self._item_listings: + self._unlist_item(item.id.val) + + def _unlist_item(self, item_id: int): + item = self._item_listings.pop(item_id).item + item.listed_price.update(0) + + def step(self, current_tick: int): + """ + Remove expired listings from the exchange's listings queue + and item listings dictionary. It takes in one parameter, + current_tick, which is the current time in the game. + + The method starts by checking the oldest listing in the listings + queue using a while loop. If the current tick minus the + listing tick is less than or equal to the EXCHANGE_LISTING_DURATION + in the realm's configuration, the method breaks out of + the loop as the oldest listing has not expired. + If the oldest listing has expired, the method removes it from the + listings queue and the item listings dictionary. + + It then checks if the actual listing still exists and that + it is indeed expired. If it does exist and is expired, + it calls the _unlist_item method to remove the listing and update + the item's listed price. The process repeats until all expired listings + are removed from the queue and dictionary. + """ + + # Remove expired listings + while self._listings_queue: + (item_id, listing_tick) = self._listings_queue[0] + if current_tick - listing_tick <= self._config.EXCHANGE_LISTING_DURATION: + # Oldest listing has not expired + break + + # Remove expired listing from queue + self._listings_queue.popleft() + + # The actual listing might have been refreshed and is newer than the queue record. + # Or it might have already been removed. + listing = self._item_listings.get(item_id) + if listing is not None and \ + current_tick - listing.tick > self._config.EXCHANGE_LISTING_DURATION: + self._unlist_item(item_id) + + def sell(self, seller, item: Item, price: int, tick: int): + assert isinstance( + item, object), f'{item} for sale is not an Item instance' + assert item in seller.inventory, f'{item} for sale is not in {seller} inventory' + assert item.quantity.val > 0, f'{item} for sale has quantity {item.quantity.val}' + assert item.listed_price.val == 0, 'Item is already listed' + assert item.equipped.val == 0, 'Item has been equiped so cannot be listed' + assert price > 0, 'Price must be larger than 0' + + self._list_item(item, seller, price, tick) + + self._realm.event_log.record(EventCode.LIST_ITEM, seller, item=item, price=price) + + self._realm.log_milestone(f'Sell_{item.__class__.__name__}', item.level.val, + f'EXCHANGE: Offered level {item.level.val} {item.__class__.__name__} for {price} gold', + tags={"player_id": seller.ent_id}) + + def buy(self, buyer, item: Item): + assert item.quantity.val > 0, f'{item} purchase has quantity {item.quantity.val}' + assert item.equipped.val == 0, 'Listed item must not be equipped' + assert buyer.gold.val >= item.listed_price.val, 'Buyer does not have enough gold' + assert buyer.ent_id != item.owner_id.val, 'One cannot buy their own items' + + if not buyer.inventory.space: + if isinstance(item, Stack): + if not buyer.inventory.has_stack(item.signature): + # no ammo stack with the same signature, so cannot buy + return + else: # no space, and item is not ammo stack, so cannot buy + return + + # item is not in the listing (perhaps bought by other) + if item.id.val not in self._item_listings: + return + + listing = self._item_listings[item.id.val] + price = item.listed_price.val + + self.unlist_item(item) + listing.seller.inventory.remove(item) + buyer.inventory.receive(item) + buyer.gold.decrement(price) + listing.seller.gold.increment(price) + + # TODO(kywch): tidy up the logs - milestone, event, etc ... + #self._realm.log_milestone(f'Buy_{item.__name__}', item.level.val) + #self._realm.log_milestone('Transaction_Amount', item.listed_price.val) + self._realm.event_log.record(EventCode.BUY_ITEM, buyer, item=item, price=price) + self._realm.event_log.record(EventCode.EARN_GOLD, listing.seller, amount=price) + + @property + def packet(self): + packet = {} + for listing in self._item_listings.values(): + item = listing.item + key = f'{item.__class__.__name__}_{item.level.val}' + max_price = max(packet.get(key, {}).get('max_price', -math.inf), listing.price) + min_price = min(packet.get(key, {}).get('min_price', math.inf), listing.price) + supply = packet.get(key, {}).get('supply', 0) + item.quantity.val + + packet[key] = { + 'max_price': max_price, + 'min_price': min_price, + 'price': (max_price + min_price) / 2, + 'supply': supply + } + + return packet diff --git a/nmmo/systems/experience.py b/nmmo/systems/experience.py index 3a1113a91..25be029fb 100644 --- a/nmmo/systems/experience.py +++ b/nmmo/systems/experience.py @@ -1,14 +1,13 @@ -from pdb import set_trace as T import numpy as np class ExperienceCalculator: - def __init__(self, num_levels=15): - self.exp = np.array([0] + [10*2**i for i in range(num_levels)]) + def __init__(self, num_levels=15): + self.exp = np.array([0] + [10*2**i for i in range(num_levels)]) - def expAtLevel(self, level): - return int(self.exp[level - 1]) + def exp_at_level(self, level): + return int(self.exp[level - 1]) - def levelAtExp(self, exp): - if exp >= self.exp[-1]: - return len(self.exp) - return np.argmin(exp >= self.exp) + def level_at_exp(self, exp): + if exp >= self.exp[-1]: + return len(self.exp) + return np.argmin(exp >= self.exp) diff --git a/nmmo/systems/inventory.py b/nmmo/systems/inventory.py index 5fcefa9a0..0b21226ae 100644 --- a/nmmo/systems/inventory.py +++ b/nmmo/systems/inventory.py @@ -1,183 +1,193 @@ -from pdb import set_trace as T -import numpy as np +from typing import Dict, Tuple from ordered_set import OrderedSet -import inspect -import logging from nmmo.systems import item as Item -from nmmo.systems import skill as Skill +class EquipmentSlot: + def __init__(self) -> None: + self.item = None -class Equipment: - def __init__(self, realm): - self.hat = None - self.top = None - self.bottom = None - - self.held = None - self.ammunition = None - - def total(self, lambda_getter): - items = [lambda_getter(e).val for e in self] - if not items: - return 0 - return sum(items) - - def __iter__(self): - for item in [self.hat, self.top, self.bottom, self.held, self.ammunition]: - if item is not None: - yield item - - def conditional_packet(self, packet, item_name, item): - if item: - packet[item_name] = item.packet - - @property - def item_level(self): - return self.total(lambda e: e.level) - - @property - def melee_attack(self): - return self.total(lambda e: e.melee_attack) - - @property - def range_attack(self): - return self.total(lambda e: e.range_attack) - - @property - def mage_attack(self): - return self.total(lambda e: e.mage_attack) - - @property - def melee_defense(self): - return self.total(lambda e: e.melee_defense) - - @property - def range_defense(self): - return self.total(lambda e: e.range_defense) - - @property - def mage_defense(self): - return self.total(lambda e: e.mage_defense) - - @property - def packet(self): - packet = {} - - self.conditional_packet(packet, 'hat', self.hat) - self.conditional_packet(packet, 'top', self.top) - self.conditional_packet(packet, 'bottom', self.bottom) - self.conditional_packet(packet, 'held', self.held) - self.conditional_packet(packet, 'ammunition', self.ammunition) + def equip(self, item: Item.Item) -> None: + self.item = item - packet['item_level'] = self.item_level + def unequip(self) -> None: + if self.item: + self.item.equipped.update(0) + self.item = None - packet['melee_attack'] = self.melee_attack - packet['range_attack'] = self.range_attack - packet['mage_attack'] = self.mage_attack - packet['melee_defense'] = self.melee_defense - packet['range_defense'] = self.range_defense - packet['mage_defense'] = self.mage_defense - - return packet +class Equipment: + def __init__(self): + self.hat = EquipmentSlot() + self.top = EquipmentSlot() + self.bottom = EquipmentSlot() + self.held = EquipmentSlot() + self.ammunition = EquipmentSlot() + + def total(self, lambda_getter): + items = [lambda_getter(e).val for e in self] + if not items: + return 0 + return sum(items) + + def __iter__(self): + for slot in [self.hat, self.top, self.bottom, self.held, self.ammunition]: + if slot.item is not None: + yield slot.item + + def conditional_packet(self, packet, slot_name: str, slot: EquipmentSlot): + if slot.item: + packet[slot_name] = slot.item.packet + + @property + def item_level(self): + return self.total(lambda e: e.level) + + @property + def melee_attack(self): + return self.total(lambda e: e.melee_attack) + + @property + def range_attack(self): + return self.total(lambda e: e.range_attack) + + @property + def mage_attack(self): + return self.total(lambda e: e.mage_attack) + + @property + def melee_defense(self): + return self.total(lambda e: e.melee_defense) + + @property + def range_defense(self): + return self.total(lambda e: e.range_defense) + + @property + def mage_defense(self): + return self.total(lambda e: e.mage_defense) + + @property + def packet(self): + packet = {} + + self.conditional_packet(packet, 'hat', self.hat) + self.conditional_packet(packet, 'top', self.top) + self.conditional_packet(packet, 'bottom', self.bottom) + self.conditional_packet(packet, 'held', self.held) + self.conditional_packet(packet, 'ammunition', self.ammunition) + + # pylint: disable=R0801 + # Similar lines here and in npc.py + packet['item_level'] = self.item_level + packet['melee_attack'] = self.melee_attack + packet['range_attack'] = self.range_attack + packet['mage_attack'] = self.mage_attack + packet['melee_defense'] = self.melee_defense + packet['range_defense'] = self.range_defense + packet['mage_defense'] = self.mage_defense + + return packet class Inventory: - def __init__(self, realm, entity): - config = realm.config - self.realm = realm - self.entity = entity - self.config = config + def __init__(self, realm, entity): + config = realm.config + self.realm = realm + self.entity = entity + self.config = config - self.equipment = Equipment(realm) - - if not config.ITEM_SYSTEM_ENABLED: - return + self.equipment = Equipment() + self.capacity = 0 + if config.ITEM_SYSTEM_ENABLED: self.capacity = config.ITEM_INVENTORY_CAPACITY - self.gold = Item.Gold(realm) - - self._item_stacks = {self.gold.signature: self.gold} - self._item_references = OrderedSet([self.gold]) - - @property - def space(self): - return self.capacity - len(self._item_references) - - @property - def dataframeKeys(self): - return [e.instanceID for e in self._item_references] - - def packet(self): - item_packet = [] - if self.config.ITEM_SYSTEM_ENABLED: - item_packet = [e.packet for e in self._item_references] - - return { - 'items': item_packet, - 'equipment': self.equipment.packet} - - def __iter__(self): - for item in self._item_references: - yield item - - def receive(self, item): - assert isinstance(item, Item.Item), f'{item} received is not an Item instance' - assert item not in self._item_references, f'{item} object received already in inventory' - assert not item.equipped.val, f'Received equipped item {item}' - #assert self.space, f'Out of space for {item}' - assert item.quantity.val, f'Received empty item {item}' - - config = self.config - - if isinstance(item, Item.Stack): - signature = item.signature - if signature in self._item_stacks: - stack = self._item_stacks[signature] - assert item.level.val == stack.level.val, f'{item} stack level mismatch' - stack.quantity += item.quantity.val - - if config.LOG_MILESTONES and isinstance(item, Item.Gold) and self.realm.quill.milestone.log_max(f'Wealth', self.gold.quantity.val) and config.LOG_VERBOSE: - logging.info(f'EXCHANGE: Total wealth {self.gold.quantity.val} gold') - - return - elif not self.space: - return - self._item_stacks[signature] = item + self._item_stacks: Dict[Tuple, Item.Stack] = {} + self.items: OrderedSet[Item.Item] = OrderedSet([]) + + @property + def space(self): + return self.capacity - len(self.items) + + def has_stack(self, signature: Tuple) -> bool: + return signature in self._item_stacks + + def packet(self): + item_packet = [] + if self.config.ITEM_SYSTEM_ENABLED: + item_packet = [e.packet for e in self.items] + + return { + 'items': item_packet, + 'equipment': self.equipment.packet} + + def __iter__(self): + for item in self.items: + yield item + + def receive(self, item: Item.Item): + assert isinstance(item, Item.Item), f'{item} received is not an Item instance' + assert item not in self.items, f'{item} object received already in inventory' + assert not item.equipped.val, f'Received equipped item {item}' + assert not item.listed_price.val, f'Received listed item {item}' + assert item.quantity.val, f'Received empty item {item}' + + if isinstance(item, Item.Stack): + signature = item.signature + if self.has_stack(signature): + stack = self._item_stacks[signature] + assert item.level.val == stack.level.val, f'{item} stack level mismatch' + stack.quantity.increment(item.quantity.val) + # destroy the original item instance after the transfer is complete + item.destroy() + return if not self.space: - return + # if no space thus cannot receive, just destroy the item + item.destroy() + return - if config.LOG_MILESTONES and self.realm.quill.milestone.log_max(f'Receive_{item.__class__.__name__}', item.level.val) and config.LOG_VERBOSE: - logging.info(f'INVENTORY: Received level {item.level.val} {item.__class__.__name__}') + self._item_stacks[signature] = item + if not self.space: + # if no space thus cannot receive, just destroy the item + item.destroy() + return - self._item_references.add(item) + self.realm.log_milestone(f'Receive_{item.__class__.__name__}', item.level.val, + f'INVENTORY: Received level {item.level.val} {item.__class__.__name__}', + tags={"player_id": self.entity.ent_id}) - def remove(self, item, quantity=None): - assert isinstance(item, Item.Item), f'{item} received is not an Item instance' - assert item in self._item_references, f'No item {item} to remove' + item.owner_id.update(self.entity.id.val) + self.items.add(item) - if item.equipped.val: - item.use(self.entity) + # pylint: disable=protected-access + def remove(self, item, quantity=None): + assert isinstance(item, Item.Item), f'{item} removing item is not an Item instance' + assert item in self.items, f'No item {item} to remove' - assert not item.equipped.val, f'Removing {item} while equipped' + if isinstance(item, Item.Equipment) and item.equipped.val: + item.unequip(item._slot(self.entity)) - if isinstance(item, Item.Stack): - signature = item.signature + if isinstance(item, Item.Stack): + signature = item.signature - assert item.signature in self._item_stacks, f'{item} stack to remove not in inventory' - stack = self._item_stacks[signature] + assert self.has_stack(item.signature), f'{item} stack to remove not in inventory' + stack = self._item_stacks[signature] - if quantity is None or stack.quantity.val == quantity: - self._item_references.remove(stack) - del self._item_stacks[signature] - return + if quantity is None or stack.quantity.val == quantity: + self._remove(stack) + del self._item_stacks[signature] + return - assert 0 < quantity <= stack.quantity.val, f'Invalid remove {quantity} x {item} ({stack.quantity.val} available)' - stack.quantity.val -= quantity + assert 0 < quantity <= stack.quantity.val, \ + f'Invalid remove {quantity} x {item} ({stack.quantity.val} available)' + stack.quantity.val -= quantity + return - return + self._remove(item) - self._item_references.remove(item) + def _remove(self, item): + self.realm.exchange.unlist_item(item) + item.owner_id.update(0) + self.items.remove(item) diff --git a/nmmo/systems/item.py b/nmmo/systems/item.py index d3df7bfac..68079e082 100644 --- a/nmmo/systems/item.py +++ b/nmmo/systems/item.py @@ -1,434 +1,426 @@ -from pdb import set_trace as T +from __future__ import annotations -import logging -import random +import math +from abc import ABC +from types import SimpleNamespace +from typing import Dict -from nmmo.io.stimulus import Serialized -from nmmo.lib.colors import Tier -from nmmo.systems import combat - -class ItemID: - item_ids = {} - id_items = {} - - def register(cls, item_id): - if __debug__: - if cls in ItemID.item_ids: - assert ItemID.item_ids[cls] == item_id, f'Missmatched item_id assignment for class {cls}' - if item_id in ItemID.id_items: - assert ItemID.id_items[item_id] == cls, f'Missmatched class assignment for item_id {item_id}' - - ItemID.item_ids[cls] = item_id - ItemID.id_items[item_id] = cls - - def get(cls_or_id): - if type(cls_or_id) == int: - return ItemID.id_items[cls_or_id] - return ItemID.item_ids[cls_or_id] - -class Item: - ITEM_ID = None - INSTANCE_ID = 0 - def __init__(self, realm, level, - capacity=0, quantity=1, tradable=True, - melee_attack=0, range_attack=0, mage_attack=0, - melee_defense=0, range_defense=0, mage_defense=0, - health_restore=0, resource_restore=0, price=0): - - self.config = realm.config - self.realm = realm - - self.instanceID = Item.INSTANCE_ID - realm.items[self.instanceID] = self - - self.instance = Serialized.Item.ID(realm.dataframe, self.instanceID, Item.INSTANCE_ID) - self.index = Serialized.Item.Index(realm.dataframe, self.instanceID, self.ITEM_ID) - self.level = Serialized.Item.Level(realm.dataframe, self.instanceID, level) - self.capacity = Serialized.Item.Capacity(realm.dataframe, self.instanceID, capacity) - self.quantity = Serialized.Item.Quantity(realm.dataframe, self.instanceID, quantity) - self.tradable = Serialized.Item.Tradable(realm.dataframe, self.instanceID, tradable) - self.melee_attack = Serialized.Item.MeleeAttack(realm.dataframe, self.instanceID, melee_attack) - self.range_attack = Serialized.Item.RangeAttack(realm.dataframe, self.instanceID, range_attack) - self.mage_attack = Serialized.Item.MageAttack(realm.dataframe, self.instanceID, mage_attack) - self.melee_defense = Serialized.Item.MeleeDefense(realm.dataframe, self.instanceID, melee_defense) - self.range_defense = Serialized.Item.RangeDefense(realm.dataframe, self.instanceID, range_defense) - self.mage_defense = Serialized.Item.MageDefense(realm.dataframe, self.instanceID, mage_defense) - self.health_restore = Serialized.Item.HealthRestore(realm.dataframe, self.instanceID, health_restore) - self.resource_restore = Serialized.Item.ResourceRestore(realm.dataframe, self.instanceID, resource_restore) - self.price = Serialized.Item.Price(realm.dataframe, self.instanceID, price) - self.equipped = Serialized.Item.Equipped(realm.dataframe, self.instanceID, 0) - - realm.dataframe.init(Serialized.Item, self.instanceID, None) - - Item.INSTANCE_ID += 1 - if self.ITEM_ID is not None: - ItemID.register(self.__class__, item_id=self.ITEM_ID) - - @property - def signature(self): - return (self.index.val, self.level.val) - - @property - def packet(self): - return {'item': self.__class__.__name__, - 'level': self.level.val, - 'capacity': self.capacity.val, - 'quantity': self.quantity.val, - 'melee_attack': self.melee_attack.val, - 'range_attack': self.range_attack.val, - 'mage_attack': self.mage_attack.val, - 'melee_defense': self.melee_defense.val, - 'range_defense': self.range_defense.val, - 'mage_defense': self.mage_defense.val, - 'health_restore': self.health_restore.val, - 'resource_restore': self.resource_restore.val, - 'price': self.price.val} - - def use(self, entity): - return - #TODO: Warning? - #assert False, f'Use {type(self)} not defined' +from numpy import real -class Stack(): - pass +from nmmo.lib.colors import Tier -class Gold(Item, Stack): - ITEM_ID = 1 - def __init__(self, realm, **kwargs): - super().__init__(realm, level=0, tradable=False, **kwargs) +from nmmo.datastore.serialized import SerializedState +from nmmo.lib.colors import Tier +from nmmo.lib.log import EventCode + +# pylint: disable=no-member +ItemState = SerializedState.subclass("Item", [ + "id", + "type_id", + "owner_id", + + "level", + "capacity", + "quantity", + "melee_attack", + "range_attack", + "mage_attack", + "melee_defense", + "range_defense", + "mage_defense", + "health_restore", + "resource_restore", + "equipped", + + # Market + "listed_price", +]) + +# TODO: These limits should be defined in the config. +ItemState.Limits = lambda config: { + "id": (0, math.inf), + "type_id": (0, (config.ITEM_N + 1) if config.ITEM_SYSTEM_ENABLED else 0), + "owner_id": (-math.inf, math.inf), + "level": (0, 99), + "capacity": (0, 99), + "quantity": (0, math.inf), # NOTE: Ammunitions can be stacked infinitely + "melee_attack": (0, 100), + "range_attack": (0, 100), + "mage_attack": (0, 100), + "melee_defense": (0, 100), + "range_defense": (0, 100), + "mage_defense": (0, 100), + "health_restore": (0, 100), + "resource_restore": (0, 100), + "equipped": (0, 1), + "listed_price": (0, math.inf), +} + +ItemState.Query = SimpleNamespace( + table=lambda ds: ds.table("Item").where_neq( + ItemState.State.attr_name_to_col["id"], 0), + + by_id=lambda ds, id: ds.table("Item").where_eq( + ItemState.State.attr_name_to_col["id"], id), + + owned_by = lambda ds, id: ds.table("Item").where_eq( + ItemState.State.attr_name_to_col["owner_id"], id), + + for_sale = lambda ds: ds.table("Item").where_neq( + ItemState.State.attr_name_to_col["listed_price"], 0), +) + +class Item(ItemState): + ITEM_TYPE_ID = None + _item_type_id_to_class: Dict[int, type] = {} + + @staticmethod + def register(item_type): + assert item_type.ITEM_TYPE_ID is not None + if item_type.ITEM_TYPE_ID not in Item._item_type_id_to_class: + Item._item_type_id_to_class[item_type.ITEM_TYPE_ID] = item_type + + @staticmethod + def item_class(type_id: int): + return Item._item_type_id_to_class[type_id] + + def __init__(self, realm, level, + capacity=0, + melee_attack=0, range_attack=0, mage_attack=0, + melee_defense=0, range_defense=0, mage_defense=0, + health_restore=0, resource_restore=0): + + super().__init__(realm.datastore, ItemState.Limits(realm.config)) + self.realm = realm + self.config = realm.config + + Item.register(self.__class__) + + self.id.update(self.datastore_record.id) + self.type_id.update(self.ITEM_TYPE_ID) + self.level.update(level) + self.capacity.update(capacity) + # every item instance is created individually, i.e., quantity=1 + self.quantity.update(1) + self.melee_attack.update(melee_attack) + self.range_attack.update(range_attack) + self.mage_attack.update(mage_attack) + self.melee_defense.update(melee_defense) + self.range_defense.update(range_defense) + self.mage_defense.update(mage_defense) + self.health_restore.update(health_restore) + self.resource_restore.update(resource_restore) + realm.items[self.id.val] = self + + def destroy(self): + if self.owner_id.val in self.realm.players: + self.realm.players[self.owner_id.val].inventory.remove(self) + self.realm.items.pop(self.id.val, None) + self.datastore_record.delete() + + @property + def packet(self): + return {'item': self.__class__.__name__, + 'level': self.level.val, + 'capacity': self.capacity.val, + 'quantity': self.quantity.val, + 'melee_attack': self.melee_attack.val, + 'range_attack': self.range_attack.val, + 'mage_attack': self.mage_attack.val, + 'melee_defense': self.melee_defense.val, + 'range_defense': self.range_defense.val, + 'mage_defense': self.mage_defense.val, + 'health_restore': self.health_restore.val, + 'resource_restore': self.resource_restore.val, + } + + def _level(self, entity): + # this is for armors, ration, and poultice + # weapons and tools must override this with specific skills + return entity.level + + def level_gt(self, entity): + return self.level.val > self._level(entity) + + def use(self, entity) -> bool: + raise NotImplementedError + +class Stack: + @property + def signature(self): + return (self.type_id.val, self.level.val) class Equipment(Item): - @property - def packet(self): - packet = {'color': self.color.packet()} - return {**packet, **super().packet} - - @property - def color(self): - if self.level == 0: - return Tier.BLACK - if self.level < 10: - return Tier.WOOD - elif self.level < 20: - return Tier.BRONZE - elif self.level < 40: - return Tier.SILVER - elif self.level < 60: - return Tier.GOLD - elif self.level < 80: - return Tier.PLATINUM - else: - return Tier.DIAMOND - - def use(self, entity): - if self.equipped.val: - self.equipped.update(0) - equip = self.unequip(entity) - else: - self.equipped.update(1) - equip = self.equip(entity) - - config = self.config - if not config.LOG_MILESTONES or not entity.isPlayer: - return equip - - realm = self.realm - equipment = entity.equipment - item_name = self.__class__.__name__ - - if realm.quill.milestone.log_max(f'{item_name}_level', self.level.val) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Equipped level {self.level.val} {item_name}') - if realm.quill.milestone.log_max(f'Item_Level', equipment.item_level) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Item level {equipment.item_level}') - if realm.quill.milestone.log_max(f'Mage_Attack', equipment.mage_attack) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Mage attack {equipment.mage_attack}') - if realm.quill.milestone.log_max(f'Mage_Defense', equipment.mage_defense) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Mage defense {equipment.mage_defense}') - if realm.quill.milestone.log_max(f'Range_Attack', equipment.range_attack) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Range attack {equipment.range_attack}') - if realm.quill.milestone.log_max(f'Range_Defense', equipment.range_defense) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Range defense {equipment.range_defense}') - if realm.quill.milestone.log_max(f'Melee_Attack', equipment.melee_attack) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Melee attack {equipment.melee_attack}') - if realm.quill.milestone.log_max(f'Melee_Defense', equipment.melee_defense) and config.LOG_VERBOSE: - logging.info(f'EQUIPMENT: Melee defense {equipment.melee_defense}') - - return equip - -class Armor(Equipment): - def __init__(self, realm, level, **kwargs): - defense = realm.config.EQUIPMENT_ARMOR_BASE_DEFENSE + level*realm.config.EQUIPMENT_ARMOR_LEVEL_DEFENSE - super().__init__(realm, level, - melee_defense=defense, - range_defense=defense, - mage_defense=defense, - **kwargs) + @property + def packet(self): + packet = {'color': self.color.packet()} + return {**packet, **super().packet} + + @property + def color(self): + if self.level == 0: + return Tier.BLACK + if self.level < 10: + return Tier.WOOD + if self.level < 20: + return Tier.BRONZE + if self.level < 40: + return Tier.SILVER + if self.level < 60: + return Tier.GOLD + if self.level < 80: + return Tier.PLATINUM + return Tier.DIAMOND + + def unequip(self, equip_slot): + assert self.equipped.val == 1 + self.equipped.update(0) + equip_slot.unequip() + + def equip(self, entity, equip_slot): + assert self.equipped.val == 0 + if self._level(entity) < self.level.val: + return + self.equipped.update(1) + equip_slot.equip(self) + + if self.config.LOG_MILESTONES and entity.is_player and self.config.LOG_VERBOSE: + for (label, level) in [ + (f"{self.__class__.__name__}_Level", self.level.val), + ("Item_Level", entity.equipment.item_level), + ("Melee_Attack", entity.equipment.melee_attack), + ("Range_Attack", entity.equipment.range_attack), + ("Mage_Attack", entity.equipment.mage_attack), + ("Melee_Defense", entity.equipment.melee_defense), + ("Range_Defense", entity.equipment.range_defense), + ("Mage_Defense", entity.equipment.mage_defense)]: + + self.realm.log_milestone(label, level, f'EQUIPMENT: {label} {level}') + + def _slot(self, entity): + raise NotImplementedError + + def use(self, entity): + assert self in entity.inventory, "Item is not in entity's inventory" + assert self.listed_price == 0, "Listed item cannot be used" + assert self._level(entity) >= self.level.val, "Entity's level is not sufficient to use the item" + + if self.equipped.val: + self.unequip(self._slot(entity)) + else: + # always empty the slot first + self._slot(entity).unequip() + self.equip(entity, self._slot(entity)) + self.realm.event_log.record(EventCode.EQUIP_ITEM, entity, item=self) + +class Armor(Equipment, ABC): + def __init__(self, realm, level, **kwargs): + defense = realm.config.EQUIPMENT_ARMOR_BASE_DEFENSE + \ + level*realm.config.EQUIPMENT_ARMOR_LEVEL_DEFENSE + super().__init__(realm, level, + melee_defense=defense, + range_defense=defense, + mage_defense=defense, + **kwargs) class Hat(Armor): - ITEM_ID = 2 + ITEM_TYPE_ID = 2 + def _slot(self, entity): + return entity.inventory.equipment.hat +class Top(Armor): + ITEM_TYPE_ID = 3 + def _slot(self, entity): + return entity.inventory.equipment.top +class Bottom(Armor): + ITEM_TYPE_ID = 4 + def _slot(self, entity): + return entity.inventory.equipment.bottom - def equip(self, entity): - if entity.level < self.level.val: - return - if entity.inventory.equipment.hat: - entity.inventory.equipment.hat.use(entity) - entity.inventory.equipment.hat = self - def unequip(self, entity): - entity.inventory.equipment.hat = None +class Weapon(Equipment): + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.attack = ( + realm.config.EQUIPMENT_WEAPON_BASE_DAMAGE + + level*realm.config.EQUIPMENT_WEAPON_LEVEL_DAMAGE) -class Top(Armor): - ITEM_ID = 3 + def _slot(self, entity): + return entity.inventory.equipment.held - def equip(self, entity): - if entity.level < self.level.val: - return - if entity.inventory.equipment.top: - entity.inventory.equipment.top.use(entity) - entity.inventory.equipment.top = self +class Sword(Weapon): + ITEM_TYPE_ID = 5 - def unequip(self, entity): - entity.inventory.equipment.top = None + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.melee_attack.update(self.attack) -class Bottom(Armor): - ITEM_ID = 4 + def _level(self, entity): + return entity.skills.melee.level.val +class Bow(Weapon): + ITEM_TYPE_ID = 6 - def equip(self, entity): - if entity.level < self.level.val: - return - if entity.inventory.equipment.bottom: - entity.inventory.equipment.bottom.use(entity) - entity.inventory.equipment.bottom = self + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.range_attack.update(self.attack) - def unequip(self, entity): - entity.inventory.equipment.bottom = None + def _level(self, entity): + return entity.skills.range.level.val +class Wand(Weapon): + ITEM_TYPE_ID = 7 -class Weapon(Equipment): - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.attack = realm.config.EQUIPMENT_WEAPON_BASE_DAMAGE + level*realm.config.EQUIPMENT_WEAPON_LEVEL_DAMAGE + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.mage_attack.update(self.attack) - def equip(self, entity): - if entity.inventory.equipment.held: - entity.inventory.equipment.held.use(entity) - entity.inventory.equipment.held = self + def _level(self, entity): + return entity.skills.mage.level.val - def unequip(self, entity): - entity.inventory.equipment.held = None -class Sword(Weapon): - ITEM_ID = 5 - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.melee_attack.update(self.attack) - - def equip(self, entity): - if entity.skills.melee.level.val >= self.level.val: - super().equip(entity) - -class Bow(Weapon): - ITEM_ID = 6 - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.range_attack.update(self.attack) - - def equip(self, entity): - if entity.skills.range.level.val >= self.level.val: - super().equip(entity) - -class Wand(Weapon): - ITEM_ID = 7 - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.mage_attack.update(self.attack) - - def equip(self, entity): - if entity.skills.mage.level.val >= self.level.val: - super().equip(entity) - class Tool(Equipment): - def __init__(self, realm, level, **kwargs): - defense = realm.config.EQUIPMENT_TOOL_BASE_DEFENSE + level*realm.config.EQUIPMENT_TOOL_LEVEL_DEFENSE - super().__init__(realm, level, - melee_defense=defense, - range_defense=defense, - mage_defense=defense, - **kwargs) - - def equip(self, entity): - if entity.inventory.equipment.held: - entity.inventory.equipment.held.use(entity) - entity.inventory.equipment.held = self - - def unequip(self, entity): - entity.inventory.equipment.held = None - + def __init__(self, realm, level, **kwargs): + defense = realm.config.EQUIPMENT_TOOL_BASE_DEFENSE + \ + level*realm.config.EQUIPMENT_TOOL_LEVEL_DEFENSE + super().__init__(realm, level, + melee_defense=defense, + range_defense=defense, + mage_defense=defense, + **kwargs) + + def _slot(self, entity): + return entity.inventory.equipment.held class Rod(Tool): - ITEM_ID = 8 - def equip(self, entity): - if entity.skills.fishing.level >= self.level.val: - super().equip(entity) - return True - - return False - + ITEM_TYPE_ID = 8 + def _level(self, entity): + return entity.skills.fishing.level.val class Gloves(Tool): - ITEM_ID = 9 - def equip(self, entity): - if entity.skills.herbalism.level >= self.level.val: - super().equip(entity) - return True - - return False - + ITEM_TYPE_ID = 9 + def _level(self, entity): + return entity.skills.herbalism.level.val class Pickaxe(Tool): - ITEM_ID = 10 - def equip(self, entity): - if entity.skills.prospecting.level >= self.level.val: - super().equip(entity) - return True - - return False - + ITEM_TYPE_ID = 10 + def _level(self, entity): + return entity.skills.prospecting.level.val class Chisel(Tool): - ITEM_ID = 11 - def equip(self, entity): - if entity.skills.carving.level >= self.level.val: - super().equip(entity) - return True - - return False - + ITEM_TYPE_ID = 11 + def _level(self, entity): + return entity.skills.carving.level.val class Arcane(Tool): - ITEM_ID = 12 - def equip(self, entity): - if entity.skills.alchemy.level >= self.level.val: - super().equip(entity) - return True + ITEM_TYPE_ID = 12 + def _level(self, entity): + return entity.skills.alchemy.level.val - return False class Ammunition(Equipment, Stack): - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.attack = realm.config.EQUIPMENT_AMMUNITION_BASE_DAMAGE + level*realm.config.EQUIPMENT_AMMUNITION_LEVEL_DAMAGE + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.attack = ( + realm.config.EQUIPMENT_AMMUNITION_BASE_DAMAGE + + level*realm.config.EQUIPMENT_AMMUNITION_LEVEL_DAMAGE) - def equip(self, entity): - if entity.inventory.equipment.ammunition: - entity.inventory.equipment.ammunition.use(entity) - entity.inventory.equipment.ammunition = self + def _slot(self, entity): + return entity.inventory.equipment.ammunition - def unequip(self, entity): - entity.inventory.equipment.ammunition = None + def fire(self, entity) -> int: + assert self.equipped.val > 0, 'Ammunition not equipped' + assert self.quantity.val > 0, 'Used ammunition with 0 quantity' - def fire(self, entity): - if __debug__: - err = 'Used ammunition with 0 quantity' - assert self.quantity.val > 0, err + self.quantity.decrement() - self.quantity.decrement() + if self.quantity.val == 0: + entity.inventory.remove(self) + # delete this empty item instance from the datastore + self.destroy() - if self.quantity.val == 0: - entity.inventory.remove(self) + return self.damage - return self.damage - class Scrap(Ammunition): - ITEM_ID = 13 - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.melee_attack.update(self.attack) + ITEM_TYPE_ID = 13 - def equip(self, entity): - if entity.skills.melee.level >= self.level.val: - super().equip(entity) - return True + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.melee_attack.update(self.attack) - return False + def _level(self, entity): + return entity.skills.melee.level.val - @property - def damage(self): - return self.melee_attack.val + @property + def damage(self): + return self.melee_attack.val class Shaving(Ammunition): - ITEM_ID = 14 - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.range_attack.update(self.attack) + ITEM_TYPE_ID = 14 - def equip(self, entity): - if entity.skills.range.level >= self.level.val: - super().equip(entity) - return True + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.range_attack.update(self.attack) - return False + def _level(self, entity): + return entity.skills.range.level.val - @property - def damage(self): - return self.range_attack.val + @property + def damage(self): + return self.range_attack.val class Shard(Ammunition): - ITEM_ID = 15 - def __init__(self, realm, level, **kwargs): - super().__init__(realm, level, **kwargs) - self.mage_attack.update(self.attack) + ITEM_TYPE_ID = 15 - def equip(self, entity): - if entity.skills.mage.level >= self.level.val: - super().equip(entity) - return True + def __init__(self, realm, level, **kwargs): + super().__init__(realm, level, **kwargs) + self.mage_attack.update(self.attack) - return False + def _level(self, entity): + return entity.skills.mage.level.val - @property - def damage(self): - return self.mage_attack.val + @property + def damage(self): + return self.mage_attack.val -class Consumable(Item): - pass -class Ration(Consumable): - ITEM_ID = 16 - def __init__(self, realm, level, **kwargs): - restore = realm.config.PROFESSION_CONSUMABLE_RESTORE(level) - super().__init__(realm, level, resource_restore=restore, **kwargs) +# NOTE: Each consumable item (ration, poultice) cannot be stacked, +# so each item takes 1 inventory space +class Consumable(Item): + def use(self, entity) -> bool: + assert self in entity.inventory, "Item is not in entity's inventory" + assert self.listed_price == 0, "Listed item cannot be used" + assert self._level(entity) >= self.level.val, "Entity's level is not sufficient to use the item" - def use(self, entity): - if entity.level < self.level.val: - return False + self.realm.log_milestone( + f'Consumed_{self.__class__.__name__}', self.level.val, + f"PROF: Consumed {self.level.val} {self.__class__.__name__} " + f"by Entity level {entity.attack_level}", + tags={"player_id": entity.ent_id}) - if self.config.LOG_MILESTONES and self.realm.quill.milestone.log_max(f'Consumed_Ration', self.level.val) and self.config.LOG_VERBOSE: - logging.info(f'PROFESSION: Consumed level {self.level.val} ration') + self.realm.event_log.record(EventCode.CONSUME_ITEM, entity, item=self) - entity.resources.food.increment(self.resource_restore.val) - entity.resources.water.increment(self.resource_restore.val) + self._apply_effects(entity) + entity.inventory.remove(self) + self.destroy() + return True - entity.ration_level_consumed = max(entity.ration_level_consumed, self.level.val) - entity.ration_consumed += 1 +class Ration(Consumable): + ITEM_TYPE_ID = 16 - entity.inventory.remove(self) + def __init__(self, realm, level, **kwargs): + restore = 0 + if realm.config.PROFESSION_SYSTEM_ENABLED: + restore = realm.config.PROFESSION_CONSUMABLE_RESTORE(level) + super().__init__(realm, level, resource_restore=restore, **kwargs) - return True + def _apply_effects(self, entity): + entity.resources.food.increment(self.resource_restore.val) + entity.resources.water.increment(self.resource_restore.val) class Poultice(Consumable): - ITEM_ID = 17 + ITEM_TYPE_ID = 17 - def __init__(self, realm, level, **kwargs): + def __init__(self, realm, level, **kwargs): + restore = 0 + if realm.config.PROFESSION_SYSTEM_ENABLED: restore = realm.config.PROFESSION_CONSUMABLE_RESTORE(level) - super().__init__(realm, level, health_restore=restore, **kwargs) - - def use(self, entity): - if entity.level < self.level.val: - return False - - if self.config.LOG_MILESTONES and self.realm.quill.milestone.log_max(f'Consumed_Poultice', self.level.val) and self.config.LOG_VERBOSE: - logging.info(f'PROFESSION: Consumed level {self.level.val} poultice') - - entity.resources.health.increment(self.health_restore.val) - - entity.poultice_level_consumed = max(entity.poultice_level_consumed, self.level.val) - entity.poultice_consumed += 1 - - entity.inventory.remove(self) + super().__init__(realm, level, health_restore=restore, **kwargs) - return True + def _apply_effects(self, entity): + entity.resources.health.increment(self.health_restore.val) + entity.poultice_consumed += 1 + entity.poultice_level_consumed = max( + entity.poultice_level_consumed, self.level.val) diff --git a/nmmo/systems/skill.py b/nmmo/systems/skill.py index d887db1dd..447fe574a 100644 --- a/nmmo/systems/skill.py +++ b/nmmo/systems/skill.py @@ -1,214 +1,248 @@ -from pdb import set_trace as T -import numpy as np +from __future__ import annotations -from ordered_set import OrderedSet -import logging import abc -from nmmo.io.stimulus import Serialized -from nmmo.systems import experience, combat, ai +import numpy as np +from ordered_set import OrderedSet + from nmmo.lib import material +from nmmo.systems import combat, experience +from nmmo.lib.log import EventCode ### Infrastructure ### class SkillGroup: - def __init__(self, realm, entity): - self.config = realm.config - self.realm = realm + def __init__(self, realm, entity): + self.config = realm.config + self.realm = realm + self.entity = entity - self.expCalc = experience.ExperienceCalculator() - self.skills = OrderedSet() + self.experience_calculator = experience.ExperienceCalculator() + self.skills = OrderedSet() - def update(self, realm, entity): - for skill in self.skills: - skill.update(realm, entity) + def update(self): + for skill in self.skills: + skill.update() - def packet(self): - data = {} - for skill in self.skills: - data[skill.__class__.__name__.lower()] = skill.packet() + def packet(self): + data = {} + for skill in self.skills: + data[skill.__class__.__name__.lower()] = skill.packet() - return data + return data -class Skill: - skillItems = abc.ABCMeta +class Skill(abc.ABC): + def __init__(self, skill_group: SkillGroup): + self.realm = skill_group.realm + self.config = skill_group.config + self.entity = skill_group.entity - def __init__(self, realm, entity, skillGroup): - self.config = realm.config - self.realm = realm + self.experience_calculator = skill_group.experience_calculator + self.skill_group = skill_group + self.exp = 0 - self.expCalc = skillGroup.expCalc - self.exp = 0 + skill_group.skills.add(self) - skillGroup.skills.add(self) + def packet(self): + data = {} - def packet(self): - data = {} + data['exp'] = self.exp + data['level'] = self.level.val - data['exp'] = self.exp - data['level'] = self.level.val + return data - return data + def add_xp(self, xp): + self.exp += xp * self.config.PROGRESSION_BASE_XP_SCALE + new_level = int(self.experience_calculator.level_at_exp(self.exp)) - def add_xp(self, xp): - level = self.expCalc.levelAtExp(self.exp) - self.exp += xp * self.config.PROGRESSION_BASE_XP_SCALE + if new_level > self.level.val: + self.level.update(new_level) + self.realm.event_log.record(EventCode.LEVEL_UP, self.entity, + skill=self, level=new_level) - level = self.expCalc.levelAtExp(self.exp) - self.level.update(int(level)) + self.realm.log_milestone(f'Level_{self.__class__.__name__}', new_level, + f"PROGRESSION: Reached level {new_level} {self.__class__.__name__}", + tags={"player_id": self.entity.ent_id}) - if self.config.LOG_MILESTONES and self.realm.quill.milestone.log_max(f'Level_{self.__class__.__name__}', level) and self.config.LOG_VERBOSE: - logging.info(f'PROGRESSION: Reached level {level} {self.__class__.__name__}') + def set_experience_by_level(self, level): + self.exp = self.experience_calculator.level_at_exp(level) + self.level.update(int(level)) - def setExpByLevel(self, level): - self.exp = self.expCalc.expAtLevel(level) - self.level.update(int(level)) + @property + def level(self): + raise NotImplementedError(f"Skill {self.__class__.__name__} "\ + "does not implement 'level' property") ### Skill Bases ### class CombatSkill(Skill): - def update(self, realm, entity): - pass + def update(self): + pass -class NonCombatSkill(Skill): pass +class NonCombatSkill(Skill): + def __init__(self, skill_group: SkillGroup): + super().__init__(skill_group) + self._level = Lvl(1) + @property + def level(self): + return self._level class HarvestSkill(NonCombatSkill): - def processDrops(self, realm, entity, matl, dropTable): - level = 1 - tool = entity.equipment.held - if type(tool) == matl.tool: - level = tool.level.val - - #TODO: double-check drop table quantity - for drop in dropTable.roll(realm, level): - assert drop.level.val == level, 'Drop level does not match roll specification' - - if self.config.LOG_MILESTONES and realm.quill.milestone.log_max(f'Gather_{drop.__class__.__name__}', level) and self.config.LOG_VERBOSE: - logging.info(f'PROFESSION: Gathered level {level} {drop.__class__.__name__} (level {self.level.val} {self.__class__.__name__})') - - if entity.inventory.space: - entity.inventory.receive(drop) - - def harvest(self, realm, entity, matl, deplete=True): - r, c = entity.pos - if realm.map.tiles[r, c].state != matl: - return - - dropTable = realm.map.harvest(r, c, deplete) - if dropTable: - self.processDrops(realm, entity, matl, dropTable) - return True - - def harvestAdjacent(self, realm, entity, matl, deplete=True): - r, c = entity.pos - dropTable = None - - if realm.map.tiles[r-1, c].state == matl: - dropTable = realm.map.harvest(r-1, c, deplete) - if realm.map.tiles[r+1, c].state == matl: - dropTable = realm.map.harvest(r+1, c, deplete) - if realm.map.tiles[r, c-1].state == matl: - dropTable = realm.map.harvest(r, c-1, deplete) - if realm.map.tiles[r, c+1].state == matl: - dropTable = realm.map.harvest(r, c+1, deplete) - - if dropTable: - self.processDrops(realm, entity, matl, dropTable) - return True + def process_drops(self, matl, drop_table): + if not self.config.ITEM_SYSTEM_ENABLED: + return -class AmmunitionSkill(HarvestSkill): - def processDrops(self, realm, entity, matl, dropTable): - super().processDrops(realm, entity, matl, dropTable) + entity = self.entity - self.add_xp(self.config.PROGRESSION_AMMUNITION_XP_SCALE) + level = 1 + tool = entity.equipment.held + if matl.tool is not None and isinstance(tool, matl.tool): + level = tool.level.val + #TODO: double-check drop table quantity + for drop in drop_table.roll(self.realm, level): + assert drop.level.val == level, 'Drop level does not match roll specification' + + self.realm.log_milestone(f'Gather_{drop.__class__.__name__}', + level, f"PROFESSION: Gathered level {level} {drop.__class__.__name__} " + f"(level {self.level.val} {self.__class__.__name__})", + tags={"player_id": entity.ent_id}) + + if entity.inventory.space: + entity.inventory.receive(drop) + self.realm.event_log.record(EventCode.HARVEST_ITEM, entity, item=drop) + + def harvest(self, matl, deplete=True): + entity = self.entity + realm = self.realm + + r, c = entity.pos + if realm.map.tiles[r, c].state != matl: + return False + + drop_table = realm.map.harvest(r, c, deplete) + if drop_table: + self.process_drops(matl, drop_table) + + return drop_table + + def harvest_adjacent(self, matl, deplete=True): + entity = self.entity + realm = self.realm + + r, c = entity.pos + drop_table = None + + if realm.map.tiles[r-1, c].state == matl: + drop_table = realm.map.harvest(r-1, c, deplete) + if realm.map.tiles[r+1, c].state == matl: + drop_table = realm.map.harvest(r+1, c, deplete) + if realm.map.tiles[r, c-1].state == matl: + drop_table = realm.map.harvest(r, c-1, deplete) + if realm.map.tiles[r, c+1].state == matl: + drop_table = realm.map.harvest(r, c+1, deplete) + + if drop_table: + self.process_drops(matl, drop_table) + + return drop_table + +class AmmunitionSkill(HarvestSkill): + def process_drops(self, matl, drop_table): + super().process_drops(matl, drop_table) + if self.config.PROGRESSION_SYSTEM_ENABLED: + self.add_xp(self.config.PROGRESSION_AMMUNITION_XP_SCALE) -class ConsumableSkill(HarvestSkill): - def processDrops(self, realm, entity, matl, dropTable): - super().processDrops(realm, entity, matl, dropTable) - self.add_xp(self.config.PROGRESSION_CONSUMABLE_XP_SCALE) +class ConsumableSkill(HarvestSkill): + def process_drops(self, matl, drop_table): + super().process_drops(matl, drop_table) + if self.config.PROGRESSION_SYSTEM_ENABLED: + self.add_xp(self.config.PROGRESSION_CONSUMABLE_XP_SCALE) ### Skill groups ### class Basic(SkillGroup): - def __init__(self, realm, entity): - super().__init__(realm, entity) + def __init__(self, realm, entity): + super().__init__(realm, entity) - self.water = Water(realm, entity, self) - self.food = Food(realm, entity, self) + self.water = Water(self) + self.food = Food(self) - @property - def basicLevel(self): - return 0.5 * (self.water.level - + self.food.level) + @property + def basic_level(self): + return 0.5 * (self.water.level + + self.food.level) class Harvest(SkillGroup): - def __init__(self, realm, entity): - super().__init__(realm, entity) - - self.fishing = Fishing(realm, entity, self) - self.herbalism = Herbalism(realm, entity, self) - self.prospecting = Prospecting(realm, entity, self) - self.carving = Carving(realm, entity, self) - self.alchemy = Alchemy(realm, entity, self) - - @property - def harvestLevel(self): - return max(self.fishing.level, - self.herbalism.level, - self.prospecting.level, - self.carving.level, - self.alchemy.level) + def __init__(self, realm, entity): + super().__init__(realm, entity) + + self.fishing = Fishing(self) + self.herbalism = Herbalism(self) + self.prospecting = Prospecting(self) + self.carving = Carving(self) + self.alchemy = Alchemy(self) + + @property + def harvest_level(self): + return max(self.fishing.level, + self.herbalism.level, + self.prospecting.level, + self.carving.level, + self.alchemy.level) class Combat(SkillGroup): - def __init__(self, realm, entity): - super().__init__(realm, entity) + def __init__(self, realm, entity): + super().__init__(realm, entity) - self.melee = Melee(realm, entity, self) - self.range = Range(realm, entity, self) - self.mage = Mage(realm, entity, self) + self.melee = Melee(self) + self.range = Range(self) + self.mage = Mage(self) - def packet(self): - data = super().packet() - data['level'] = combat.level(self) + def packet(self): + data = super().packet() + data['level'] = combat.level(self) - return data + return data - @property - def combatLevel(self): - return max(self.melee.level, - self.range.level, - self.mage.level) + @property + def combat_level(self): + return max(self.melee.level, + self.range.level, + self.mage.level) - def applyDamage(self, dmg, style): - if not self.config.PROGRESSION_SYSTEM_ENABLED: - return - - config = self.config + def apply_damage(self, style): + if self.config.PROGRESSION_SYSTEM_ENABLED: skill = self.__dict__[style] - skill.add_xp(config.PROGRESSION_COMBAT_XP_SCALE) + skill.add_xp(self.config.PROGRESSION_COMBAT_XP_SCALE) - def receiveDamage(self, dmg): - pass + def receive_damage(self, dmg): + pass class Skills(Basic, Harvest, Combat): - pass + pass ### Skills ### class Melee(CombatSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Melee(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 1 + + @property + def level(self): + return self.entity.melee_level class Range(CombatSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Range(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 2 + + @property + def level(self): + return self.entity.range_level class Mage(CombatSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Mage(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 3 + + @property + def level(self): + return self.entity.mage_level Melee.weakness = Mage Range.weakness = Melee @@ -216,94 +250,106 @@ def __init__(self, realm, ent, skillGroup): ### Individual Skills ### -class CombatSkill(Skill): pass - class Lvl: - def __init__(self, val): - self.val = val + def __init__(self, val): + self.val = val - def update(self, val): - self.val = val + def update(self, val): + self.val = val class Water(HarvestSkill): - def __init__(self, realm, entity, skillGroup): - self.level = Lvl(1) - super().__init__(realm, entity, skillGroup) + def update(self): + config = self.config + if not config.RESOURCE_SYSTEM_ENABLED: + return - def update(self, realm, entity): - config = self.config - if not config.RESOURCE_SYSTEM_ENABLED: - return + depletion = config.RESOURCE_DEPLETION_RATE + water = self.entity.resources.water + water.decrement(depletion) - depletion = config.RESOURCE_DEPLETION_RATE - water = entity.resources.water - water.decrement(depletion) + if self.config.IMMORTAL: + return - tiles = realm.map.tiles - if not self.harvestAdjacent(realm, entity, material.Water, deplete=False): - return + if not self.harvest_adjacent(material.Water, deplete=False): + return + + restore = np.floor(config.RESOURCE_BASE + * config.RESOURCE_HARVEST_RESTORE_FRACTION) + water.increment(restore) + + self.realm.event_log.record(EventCode.DRINK_WATER, self.entity) - restore = np.floor(config.RESOURCE_BASE - * config.RESOURCE_HARVEST_RESTORE_FRACTION) - water.increment(restore) class Food(HarvestSkill): - def __init__(self, realm, entity, skillGroup): - self.level = Lvl(1) - super().__init__(realm, entity, skillGroup) + def __init__(self, skill_group): + self._level = Lvl(1) + super().__init__(skill_group) + + def update(self): + config = self.config + if not config.RESOURCE_SYSTEM_ENABLED: + return - def update(self, realm, entity): - config = self.config - if not config.RESOURCE_SYSTEM_ENABLED: - return + depletion = config.RESOURCE_DEPLETION_RATE + food = self.entity.resources.food + food.decrement(depletion) - depletion = config.RESOURCE_DEPLETION_RATE - food = entity.resources.food - food.decrement(depletion) + if not self.harvest(material.Forest): + return - if not self.harvest(realm, entity, material.Forest): - return + restore = np.floor(config.RESOURCE_BASE + * config.RESOURCE_HARVEST_RESTORE_FRACTION) + food.increment(restore) + + self.realm.event_log.record(EventCode.EAT_FOOD, self.entity) - restore = np.floor(config.RESOURCE_BASE - * config.RESOURCE_HARVEST_RESTORE_FRACTION) - food.increment(restore) class Fishing(ConsumableSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Fishing(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 4 + + @property + def level(self): + return self.entity.fishing_level - def update(self, realm, entity): - self.harvestAdjacent(realm, entity, material.Fish) + def update(self): + self.harvest_adjacent(material.Fish) class Herbalism(ConsumableSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Herbalism(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 5 - def update(self, realm, entity): - self.harvest(realm, entity, material.Herb) + @property + def level(self): + return self.entity.herbalism_level + + def update(self): + self.harvest(material.Herb) class Prospecting(AmmunitionSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Prospecting(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 6 + + @property + def level(self): + return self.entity.prospecting_level - def update(self, realm, entity): - self.harvest(realm, entity, material.Ore) + def update(self): + self.harvest(material.Ore) class Carving(AmmunitionSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Carving(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 7 - def update(self, realm, entity): - self.harvest(realm, entity, material.Tree) + @property + def level(self): + return self.entity.carving_level + + def update(self,): + self.harvest(material.Tree) class Alchemy(AmmunitionSkill): - def __init__(self, realm, ent, skillGroup): - self.level = Serialized.Entity.Alchemy(ent.dataframe, ent.entID) - super().__init__(realm, ent, skillGroup) + SKILL_ID = 8 + + @property + def level(self): + return self.entity.alchemy_level - def update(self, realm, entity): - self.harvest(realm, entity, material.Crystal) + def update(self): + self.harvest(material.Crystal) diff --git a/nmmo/version.py b/nmmo/version.py index 1e522274f..afced1472 100644 --- a/nmmo/version.py +++ b/nmmo/version.py @@ -1 +1 @@ -__version__ = '1.6.0.7' +__version__ = '2.0.0' diff --git a/offline_dataset.py b/offline_dataset.py index 9dfb25544..38a7a6fb1 100644 --- a/offline_dataset.py +++ b/offline_dataset.py @@ -1,4 +1,5 @@ -from pdb import set_trace as T +# pylint: disable=all + import numpy as np import h5py @@ -42,13 +43,13 @@ def write(self, t, episode, obs=None, atn=None, rewards=None, dones=None): def write_vectorized(self, t, episode, obs=None, atn=None, rewards=None, dones=None): if obs is not None: - self.obs[t, episode_list] = obs + self.obs[t, episode] = obs if atn is not None: - self.atn[t, episode_list] = atn + self.atn[t, episode] = atn if rewards is not None: - self.rewards[t, episode_list] = rewards + self.rewards[t, episode] = rewards if dones is not None: - self.dones[t, episode_list] = dones + self.dones[t, episode] = dones EPISODES = 5 HORIZON = 16 diff --git a/quicktest.py b/quicktest.py new file mode 100644 index 000000000..6b8e507f1 --- /dev/null +++ b/quicktest.py @@ -0,0 +1,29 @@ +import nmmo +from nmmo.core.config import Medium + + +def create_config(base, *systems): + systems = (base, *systems) + name = '_'.join(cls.__name__ for cls in systems) + conf = type(name, systems, {})() + + conf.TERRAIN_TRAIN_MAPS = 1 + conf.TERRAIN_EVAL_MAPS = 1 + conf.IMMORTAL = True + conf.RENDER = True + + return conf + +def benchmark_config(base, nent, *systems): + conf = create_config(base, *systems) + conf.PLAYER_N = nent + env = nmmo.Env(conf) + env.reset() + + env.render() + while True: + env.step(actions={}) + env.render() + +benchmark_config(Medium, 100) + diff --git a/run_tests.sh b/run_tests.sh index c33234a01..27305a3e6 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1 +1,2 @@ -pytest --benchmark-columns=ops,mean,stddev,min,max,iterations,rounds --benchmark-max-time=5 --benchmark-min-rounds=1 +pytest -rP --benchmark-columns=ops,mean,stddev,min,max,iterations,rounds --benchmark-max-time=5 --benchmark-min-rounds=1 --benchmark-sort=name tests/test_performance.py +#pytest --benchmark-columns=ops,mean,stddev,min,max,iterations,rounds --benchmark-max-time=5 --benchmark-min-rounds=1 --benchmark-sort=name diff --git a/save_atns.py b/save_atns.py index 657417ebc..fd0f69ddc 100644 --- a/save_atns.py +++ b/save_atns.py @@ -5,4 +5,4 @@ config = nmmo.config.Default() env = nmmo.integrations.CleanRLEnv(config, seed=42) actions = [{e: env.action_space(1).sample() for e in range(1, config.PLAYER_N+1)} for _ in range(HORIZON)] -np.save('actions.npy', actions) \ No newline at end of file +np.save('actions.npy', actions) diff --git a/scripted/__init__.py b/scripted/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/scripted/attack.py b/scripted/attack.py new file mode 100644 index 000000000..0f2c916c0 --- /dev/null +++ b/scripted/attack.py @@ -0,0 +1,53 @@ +# pylint: disable=all + +import numpy as np + +import nmmo +from nmmo.core.observation import Observation +from nmmo.entity.entity import EntityState + +from scripted import utils + +def closestTarget(config, ob: Observation): + shortestDist = np.inf + closestAgent = None + + agent = ob.agent() + + start = (agent.row, agent.col) + + for target in ob.entities.values: + target = EntityState.parse_array(target) + if target.id == agent.id: + continue + + dist = utils.l1(start, (target.row, target.col)) + + if dist < shortestDist and dist != 0: + shortestDist = dist + closestAgent = target + + if closestAgent is None: + return None, None + + return closestAgent, shortestDist + +def attacker(config, ob: Observation): + agent = ob.agent() + + attacker_id = agent.attacker_id + + if attacker_id == 0: + return None, None + + target = ob.entity(attacker_id) + if target == None: + return None, None + + return target, utils.l1((agent.row, agent.col), (target.row, target.col)) + +def target(config, actions, style, targetID): + actions[nmmo.action.Attack] = { + nmmo.action.Style: style, + nmmo.action.Target: targetID} + diff --git a/scripted/baselines.py b/scripted/baselines.py new file mode 100644 index 000000000..376ef21c1 --- /dev/null +++ b/scripted/baselines.py @@ -0,0 +1,517 @@ +# pylint: disable=all + +from typing import Dict + +from collections import defaultdict +import random + +import nmmo +from nmmo import material +from nmmo.systems import skill +import nmmo.systems.item as item_system +from nmmo.lib import colors +from nmmo.io import action +from nmmo.core.observation import Observation + +from scripted import attack, move + +class Scripted(nmmo.Agent): + '''Template class for scripted models. + + You may either subclass directly or mirror the __call__ function''' + scripted = True + color = colors.Neon.SKY + def __init__(self, config, idx): + ''' + Args: + config : A forge.blade.core.Config object or subclass object + ''' + super().__init__(config, idx) + self.health_max = config.PLAYER_BASE_HEALTH + + if config.RESOURCE_SYSTEM_ENABLED: + self.food_max = config.RESOURCE_BASE + self.water_max = config.RESOURCE_BASE + + self.spawnR = None + self.spawnC = None + + @property + def policy(self): + return self.__class__.__name__ + + @property + def forage_criterion(self) -> bool: + '''Return true if low on food or water''' + min_level = 7 * self.config.RESOURCE_DEPLETION_RATE + return self.me.food <= min_level or self.me.water <= min_level + + def forage(self): + '''Min/max food and water using Dijkstra's algorithm''' + move.forageDijkstra(self.config, self.ob, self.actions, self.food_max, self.water_max) + + def gather(self, resource): + '''BFS search for a particular resource''' + return move.gatherBFS(self.config, self.ob, self.actions, resource) + + def explore(self): + '''Route away from spawn''' + move.explore(self.config, self.ob, self.actions, self.me.row, self.me.col) + + @property + def downtime(self): + '''Return true if agent is not occupied with a high-priority action''' + return not self.forage_criterion and self.attacker is None + + def evade(self): + '''Target and path away from an attacker''' + move.evade(self.config, self.ob, self.actions, self.attacker) + self.target = self.attacker + self.targetID = self.attackerID + self.targetDist = self.attackerDist + + def attack(self): + '''Attack the current target''' + if self.target is not None: + assert self.targetID is not None + style = random.choice(self.style) + attack.target(self.config, self.actions, style, self.targetID) + + def target_weak(self): + '''Target the nearest agent if it is weak''' + if self.closest is None: + return False + + selfLevel = self.me.level + targLevel = max(self.closest.melee_level, self.closest.range_level, self.closest.mage_level) + + if self.closest.npc_type == 1 or \ + targLevel <= selfLevel <= 5 or \ + selfLevel >= targLevel + 3: + self.target = self.closest + self.targetID = self.closestID + self.targetDist = self.closestDist + + def scan_agents(self): + '''Scan the nearby area for agents''' + self.closest, self.closestDist = attack.closestTarget(self.config, self.ob) + self.attacker, self.attackerDist = attack.attacker(self.config, self.ob) + + self.closestID = None + if self.closest is not None: + self.closestID = self.closest.id + + self.attackerID = None + if self.attacker is not None: + self.attackerID = self.attacker.id + + self.target = None + self.targetID = None + self.targetDist = None + + def adaptive_control_and_targeting(self, explore=True): + '''Balanced foraging, evasion, and exploration''' + self.scan_agents() + + if self.attacker is not None: + self.evade() + return + + if self.fog_criterion: + self.explore() + elif self.forage_criterion or not explore: + self.forage() + else: + self.explore() + + self.target_weak() + + def process_inventory(self): + if not self.config.ITEM_SYSTEM_ENABLED: + return + + self.inventory = {} + self.best_items: Dict = {} + self.item_counts = defaultdict(int) + + self.item_levels = { + item_system.Hat.ITEM_TYPE_ID: self.level, + item_system.Top.ITEM_TYPE_ID: self.level, + item_system.Bottom.ITEM_TYPE_ID: self.level, + item_system.Sword.ITEM_TYPE_ID: self.me.melee_level, + item_system.Bow.ITEM_TYPE_ID: self.me.range_level, + item_system.Wand.ITEM_TYPE_ID: self.me.mage_level, + item_system.Rod.ITEM_TYPE_ID: self.me.fishing_level, + item_system.Gloves.ITEM_TYPE_ID: self.me.herbalism_level, + item_system.Pickaxe.ITEM_TYPE_ID: self.me.prospecting_level, + item_system.Chisel.ITEM_TYPE_ID: self.me.carving_level, + item_system.Arcane.ITEM_TYPE_ID: self.me.alchemy_level, + item_system.Scrap.ITEM_TYPE_ID: self.me.melee_level, + item_system.Shaving.ITEM_TYPE_ID: self.me.range_level, + item_system.Shard.ITEM_TYPE_ID: self.me.mage_level, + item_system.Ration.ITEM_TYPE_ID: self.level, + item_system.Poultice.ITEM_TYPE_ID: self.level + } + + for item_ary in self.ob.inventory.values: + itm = item_system.ItemState.parse_array(item_ary) + assert itm.quantity != 0 + + # Too high level to equip or use + if itm.type_id in self.item_levels and itm.level > self.item_levels[itm.type_id]: + continue + + # cannot use listed item + if itm.listed_price: + continue + + self.item_counts[itm.type_id] += itm.quantity + self.inventory[itm.id] = itm + + # Best by default + if itm.type_id not in self.best_items: + self.best_items[itm.type_id] = itm + + best_itm = self.best_items[itm.type_id] + + if itm.level > best_itm.level: + self.best_items[itm.type_id] = itm + + def upgrade_heuristic(self, current_level, upgrade_level, price): + return (upgrade_level - current_level) / max(price, 1) + + def process_market(self): + if not self.config.EXCHANGE_SYSTEM_ENABLED: + return + + self.market = {} + self.best_heuristic = {} + + for item_ary in self.ob.market.values: + itm = item_system.ItemState.parse_array(item_ary) + + self.market[itm.id] = itm + + # Prune Unaffordable + if itm.listed_price > self.me.gold: + continue + + # Too high level to equip + if itm.type_id in self.item_levels and itm.level > self.item_levels[itm.type_id] : + continue + + #Current best item level + current_level = 0 + if itm.type_id in self.best_items: + current_level = self.best_items[itm.type_id].level + + itm.heuristic = self.upgrade_heuristic(current_level, itm.level, itm.listed_price) + + #Always count first item + if itm.type_id not in self.best_heuristic: + self.best_heuristic[itm.type_id] = itm + continue + + #Better heuristic value + if itm.heuristic > self.best_heuristic[itm.type_id].heuristic: + self.best_heuristic[itm.type_id] = itm + + def equip(self, items: set): + for type_id, itm in self.best_items.items(): + if type_id not in items: + continue + + if itm.equipped or itm.listed_price: + continue + + # InventoryItem needs where the item is (index) in the inventory + self.actions[action.Use] = { + action.InventoryItem: self.ob.inventory.index(itm.id)} # list(self.ob.inventory.ids).index(itm.id) + + return True + + def consume(self): + if self.me.health <= self.health_max // 2 and item_system.Poultice.ITEM_TYPE_ID in self.best_items: + itm = self.best_items[item_system.Poultice.ITEM_TYPE_ID] + elif (self.me.food == 0 or self.me.water == 0) and item_system.Ration.ITEM_TYPE_ID in self.best_items: + itm = self.best_items[item_system.Ration.ITEM_TYPE_ID] + else: + return + + if itm.listed_price: + return + + # InventoryItem needs where the item is (index) in the inventory + self.actions[action.Use] = { + action.InventoryItem: self.ob.inventory.index(itm.id)} # list(self.ob.inventory.ids).index(itm.id) + + def sell(self, keep_k: dict, keep_best: set): + for itm in self.inventory.values(): + price = int(max(itm.level, 1)) + assert itm.quantity > 0 + + if itm.equipped or itm.listed_price: + continue + + if itm.type_id in keep_k: + owned = self.item_counts[itm.type_id] + k = keep_k[itm.type_id] + if owned <= k: + continue + + #Exists an equippable of the current class, best needs to be kept, and this is the best item + if itm.type_id in self.best_items and \ + itm.type_id in keep_best and \ + itm.id == self.best_items[itm.type_id].id: + continue + + self.actions[action.Sell] = { + action.InventoryItem: self.ob.inventory.index(itm.id), # list(self.ob.inventory.ids).index(itm.id) + action.Price: action.Price.edges[price-1] } # Price starts from 1 + + return itm + + def buy(self, buy_k: dict, buy_upgrade: set): + if len(self.inventory) >= self.config.ITEM_INVENTORY_CAPACITY: + return + + purchase = None + best = list(self.best_heuristic.items()) + random.shuffle(best) + for type_id, itm in best: + # Buy top k + if type_id in buy_k: + owned = self.item_counts[type_id] + k = buy_k[type_id] + if owned < k: + purchase = itm + + # Check if item desired and upgrade + elif type_id in buy_upgrade and itm.heuristic > 0: + purchase = itm + + # Buy best heuristic upgrade + if purchase: + self.actions[action.Buy] = { + action.MarketItem: self.ob.market.index(purchase.id)} #list(self.ob.market.ids).index(purchase.id)} + return + + def exchange(self): + if not self.config.EXCHANGE_SYSTEM_ENABLED: + return + + self.process_market() + self.sell(keep_k=self.supplies, keep_best=self.wishlist) + self.buy(buy_k=self.supplies, buy_upgrade=self.wishlist) + + def use(self): + self.process_inventory() + if self.config.EQUIPMENT_SYSTEM_ENABLED and not self.consume(): + self.equip(items=self.wishlist) + + def __call__(self, observation: Observation): + '''Process observations and return actions''' + self.actions = {} + + self.ob = observation + self.me = observation.agent() + + # combat level + self.me.level = max(self.me.melee_level, self.me.range_level, self.me.mage_level) + + self.skills = { + skill.Melee: self.me.melee_level, + skill.Range: self.me.range_level, + skill.Mage: self.me.mage_level, + skill.Fishing: self.me.fishing_level, + skill.Herbalism: self.me.herbalism_level, + skill.Prospecting: self.me.prospecting_level, + skill.Carving: self.me.carving_level, + skill.Alchemy: self.me.alchemy_level + } + + # TODO(kywch): need a consistent level variables + # level for using armor, rations, and poultice + self.level = min(1, max(self.skills.values())) + + if self.spawnR is None: + self.spawnR = self.me.row + if self.spawnC is None: + self.spawnC = self.me.col + + # When to run from death fog in BR configs + self.fog_criterion = None + if self.config.PLAYER_DEATH_FOG is not None: + start_running = self.time_alive > self.config.PLAYER_DEATH_FOG - 64 + run_now = self.time_alive % max(1, int(1 / self.config.PLAYER_DEATH_FOG_SPEED)) + self.fog_criterion = start_running and run_now + + +class Sleeper(Scripted): + '''Do Nothing''' + def __call__(self, obs): + super().__call__(obs) + return {} +class Random(Scripted): + '''Moves randomly''' + def __call__(self, obs): + super().__call__(obs) + + move.rand(self.config, self.ob, self.actions) + return self.actions + +class Meander(Scripted): + '''Moves randomly on safe terrain''' + def __call__(self, obs): + super().__call__(obs) + + move.meander(self.config, self.ob, self.actions) + return self.actions + +class Explore(Scripted): + '''Actively explores towards the center''' + def __call__(self, obs): + super().__call__(obs) + + self.explore() + + return self.actions + +class Forage(Scripted): + '''Forages using Dijkstra's algorithm and actively explores''' + def __call__(self, obs): + super().__call__(obs) + + if self.forage_criterion: + self.forage() + else: + self.explore() + + return self.actions + +class Combat(Scripted): + '''Forages, fights, and explores''' + def __init__(self, config, idx): + super().__init__(config, idx) + self.style = [action.Melee, action.Range, action.Mage] + + @property + def supplies(self): + return { + item_system.Ration.ITEM_TYPE_ID: 2, + item_system.Poultice.ITEM_TYPE_ID: 2, + self.ammo.ITEM_TYPE_ID: 10 + } + + @property + def wishlist(self): + return { + item_system.Hat.ITEM_TYPE_ID, + item_system.Top.ITEM_TYPE_ID, + item_system.Bottom.ITEM_TYPE_ID, + self.weapon.ITEM_TYPE_ID, + self.ammo.ITEM_TYPE_ID + } + + def __call__(self, obs): + super().__call__(obs) + self.use() + self.exchange() + + self.adaptive_control_and_targeting() + self.attack() + + return self.actions + +class Gather(Scripted): + '''Forages, fights, and explores''' + def __init__(self, config, idx): + super().__init__(config, idx) + self.resource = [material.Fish, material.Herb, material.Ore, material.Tree, material.Crystal] + + @property + def supplies(self): + return { + item_system.Ration.ITEM_TYPE_ID: 1, + item_system.Poultice.ITEM_TYPE_ID: 1 + } + + @property + def wishlist(self): + return { + item_system.Hat.ITEM_TYPE_ID, + item_system.Top.ITEM_TYPE_ID, + item_system.Bottom.ITEM_TYPE_ID, + self.tool.ITEM_TYPE_ID + } + + def __call__(self, obs): + super().__call__(obs) + self.use() + self.exchange() + + if self.forage_criterion: + self.forage() + elif self.fog_criterion or not self.gather(self.resource): + self.explore() + + return self.actions + +class Fisher(Gather): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.resource = [material.Fish] + self.tool = item_system.Rod + +class Herbalist(Gather): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.resource = [material.Herb] + self.tool = item_system.Gloves + +class Prospector(Gather): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.resource = [material.Ore] + self.tool = item_system.Pickaxe + +class Carver(Gather): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.resource = [material.Tree] + self.tool = item_system.Chisel + +class Alchemist(Gather): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.resource = [material.Crystal] + self.tool = item_system.Arcane + +class Melee(Combat): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.style = [action.Melee] + self.weapon = item_system.Sword + self.ammo = item_system.Scrap + +class Range(Combat): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.style = [action.Range] + self.weapon = item_system.Bow + self.ammo = item_system.Shaving + +class Mage(Combat): + def __init__(self, config, idx): + super().__init__(config, idx) + if config.SPECIALIZE: + self.style = [action.Mage] + self.weapon = item_system.Wand + self.ammo = item_system.Shard diff --git a/scripted/behavior.py b/scripted/behavior.py new file mode 100644 index 000000000..c2d8753c2 --- /dev/null +++ b/scripted/behavior.py @@ -0,0 +1,62 @@ +# pylint: disable=all + +import nmmo +from nmmo.systems.ai import move, attack, utils + +def update(entity): + '''Update validity of tracked entities''' + if not utils.validTarget(entity, entity.attacker, entity.vision): + entity.attacker = None + if not utils.validTarget(entity, entity.target, entity.vision): + entity.target = None + if not utils.validTarget(entity, entity.closest, entity.vision): + entity.closest = None + + if entity.__class__.__name__ != 'Player': + return + + if not utils.validResource(entity, entity.food, entity.vision): + entity.food = None + if not utils.validResource(entity, entity.water, entity.vision): + entity.water = None + +def pathfind(config, ob, actions, rr, cc): + actions[nmmo.action.Move] = {nmmo.action.Direction: move.pathfind(config, ob, actions, rr, cc)} + +def meander(realm, actions, entity): + actions[nmmo.action.Move] = {nmmo.action.Direction: move.habitable(realm.map.tiles, entity)} + +def evade(realm, actions, entity): + actions[nmmo.action.Move] = {nmmo.action.Direction: move.antipathfind(realm.map.tiles, entity, entity.attacker)} + +def hunt(realm, actions, entity): + #Move args + distance = utils.distance(entity, entity.target) + + direction = None + if distance == 0: + direction = move.random_direction() + elif distance > 1: + direction = move.pathfind(realm.map.tiles, entity, entity.target) + + if direction is not None: + actions[nmmo.action.Move] = {nmmo.action.Direction: direction} + + attack(realm, actions, entity) + +def attack(realm, actions, entity): + distance = utils.distance(entity, entity.target) + if distance > entity.skills.style.attack_range(realm.config): + return + + actions[nmmo.action.Attack] = {nmmo.action.Style: entity.skills.style, + nmmo.action.Target: entity.target} + +def forageDP(realm, actions, entity): + direction = utils.forageDP(realm.map.tiles, entity) + actions[nmmo.action.Move] = {nmmo.action.Direction: move.towards(direction)} + +#def forageDijkstra(realm, actions, entity): +def forageDijkstra(config, ob, actions, food_max, water_max): + direction = utils.forageDijkstra(config, ob, food_max, water_max) + actions[nmmo.action.Move] = {nmmo.action.Direction: move.towards(direction)} diff --git a/scripted/move.py b/scripted/move.py new file mode 100644 index 000000000..3fbec0e10 --- /dev/null +++ b/scripted/move.py @@ -0,0 +1,320 @@ +# pylint: disable=all + +import numpy as np +import random + +import heapq + +from nmmo.io import action +from nmmo.core.observation import Observation +from nmmo.lib import material + +from scripted import utils + +def adjacentPos(pos): + r, c = pos + return [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)] + +def inSight(dr, dc, vision): + return ( + dr >= -vision and + dc >= -vision and + dr <= vision and + dc <= vision) + +def rand(config, ob, actions): + direction = random.choice(action.Direction.edges) + actions[action.Move] = {action.Direction: direction} + +def towards(direction): + if direction == (-1, 0): + return action.North + elif direction == (1, 0): + return action.South + elif direction == (0, -1): + return action.West + elif direction == (0, 1): + return action.East + else: + return random.choice(action.Direction.edges) + +def pathfind(config, ob, actions, rr, cc): + direction = aStar(config, ob, actions, rr, cc) + direction = towards(direction) + actions[action.Move] = {action.Direction: direction} + +def meander(config, ob, actions): + cands = [] + if ob.tile(-1, 0).material_id in material.Habitable: + cands.append((-1, 0)) + if ob.tile(1, 0).material_id in material.Habitable: + cands.append((1, 0)) + if ob.tile(0, -1).material_id in material.Habitable: + cands.append((0, -1)) + if ob.tile(0, 1).material_id in material.Habitable: + cands.append((0, 1)) + if not cands: + return (-1, 0) + + direction = random.choices(cands)[0] + direction = towards(direction) + actions[action.Move] = {action.Direction: direction} + +def explore(config, ob, actions, r, c): + vision = config.PLAYER_VISION_RADIUS + sz = config.MAP_SIZE + + centR, centC = sz//2, sz//2 + + vR, vC = centR-r, centC-c + + mmag = max(1, abs(vR), abs(vC)) + rr = int(np.round(vision*vR/mmag)) + cc = int(np.round(vision*vC/mmag)) + pathfind(config, ob, actions, rr, cc) + +def evade(config, ob: Observation, actions, attacker): + agent = ob.agent() + + rr, cc = (2*agent.row - attacker.row, 2*agent.col - attacker.col) + + pathfind(config, ob, actions, rr, cc) + +def forageDijkstra(config, ob: Observation, actions, food_max, water_max, cutoff=100): + vision = config.PLAYER_VISION_RADIUS + + agent = ob.agent() + food = agent.food + water = agent.water + + best = -1000 + start = (0, 0) + goal = (0, 0) + + reward = {start: (food, water)} + backtrace = {start: None} + + queue = [start] + + while queue: + cutoff -= 1 + if cutoff <= 0: + break + + cur = queue.pop(0) + for nxt in adjacentPos(cur): + if nxt in backtrace: + continue + + if not inSight(*nxt, vision): + continue + + tile = ob.tile(*nxt) + matl = tile.material_id + + if not matl in material.Habitable: + continue + + food, water = reward[cur] + food = max(0, food - 1) + water = max(0, water - 1) + + if matl == material.Forest.index: + food = min(food+food_max//2, food_max) + for pos in adjacentPos(nxt): + if not inSight(*pos, vision): + continue + + tile = ob.tile(*pos) + matl = tile.material_id + + if matl == material.Water.index: + water = min(water+water_max//2, water_max) + break + + reward[nxt] = (food, water) + + total = min(food, water) + if total > best or ( + total == best and max(food, water) > max(reward[goal])): + best = total + goal = nxt + + queue.append(nxt) + backtrace[nxt] = cur + + while goal in backtrace and backtrace[goal] != start: + goal = backtrace[goal] + direction = towards(goal) + actions[action.Move] = {action.Direction: direction} + +def findResource(config, ob: Observation, resource): + vision = config.PLAYER_VISION_RADIUS + + resource_index = resource.index + + for r in range(-vision, vision+1): + for c in range(-vision, vision+1): + tile = ob.tile(r, c) + material_id = tile.material_id + + if material_id == resource_index: + return (r, c) + + return False + +def gatherAStar(config, ob, actions, resource, cutoff=100): + resource_pos = findResource(config, ob, resource) + if not resource_pos: + return + + rr, cc = resource_pos + next_pos = aStar(config, ob, actions, rr, cc, cutoff=cutoff) + if not next_pos or next_pos == (0, 0): + return + + direction = towards(next_pos) + actions[action.Move] = {action.Direction: direction} + return True + +def gatherBFS(config, ob: Observation, actions, resource, cutoff=100): + vision = config.PLAYER_VISION_RADIUS + + start = (0, 0) + + backtrace = {start: None} + + queue = [start] + + found = False + while queue: + cutoff -= 1 + if cutoff <= 0: + return False + + cur = queue.pop(0) + for nxt in adjacentPos(cur): + if found: + break + + if nxt in backtrace: + continue + + if not inSight(*nxt, vision): + continue + + tile = ob.tile(*nxt) + matl = tile.material_id + + if material.Fish in resource and material.Fish.index == matl: + found = nxt + backtrace[nxt] = cur + break + + if not tile.material_id in material.Habitable: + continue + + if matl in (e.index for e in resource): + found = nxt + backtrace[nxt] = cur + break + + for pos in adjacentPos(nxt): + if not inSight(*pos, vision): + continue + + tile = ob.tile(*pos) + matl = tile.material_id + + if matl == material.Fish.index: + backtrace[nxt] = cur + break + + queue.append(nxt) + backtrace[nxt] = cur + + #Ran out of tiles + if not found: + return False + + found_orig = found + while found in backtrace and backtrace[found] != start: + found = backtrace[found] + + direction = towards(found) + actions[action.Move] = {action.Direction: direction} + + return True + + +def aStar(config, ob: Observation, actions, rr, cc, cutoff=100): + vision = config.PLAYER_VISION_RADIUS + + start = (0, 0) + goal = (rr, cc) + + if start == goal: + return (0, 0) + + pq = [(0, start)] + + backtrace = {} + cost = {start: 0} + + closestPos = start + closestHeuristic = utils.l1(start, goal) + closestCost = closestHeuristic + + while pq: + # Use approximate solution if budget exhausted + cutoff -= 1 + if cutoff <= 0: + if goal not in backtrace: + goal = closestPos + break + + priority, cur = heapq.heappop(pq) + + if cur == goal: + break + + for nxt in adjacentPos(cur): + if not inSight(*nxt, vision): + continue + + tile = ob.tile(*nxt) + matl = tile.material_id + + if not matl in material.Habitable: + continue + + #Omitted water from the original implementation. Seems key + if matl in material.Impassible: + continue + + newCost = cost[cur] + 1 + if nxt not in cost or newCost < cost[nxt]: + cost[nxt] = newCost + heuristic = utils.lInfty(goal, nxt) + priority = newCost + heuristic + + # Compute approximate solution + if heuristic < closestHeuristic or ( + heuristic == closestHeuristic and priority < closestCost): + closestPos = nxt + closestHeuristic = heuristic + closestCost = priority + + heapq.heappush(pq, (priority, nxt)) + backtrace[nxt] = cur + + #Not needed with scuffed material list above + #if goal not in backtrace: + # goal = closestPos + + goal = closestPos + while goal in backtrace and backtrace[goal] != start: + goal = backtrace[goal] + + return goal + diff --git a/scripted/utils.py b/scripted/utils.py new file mode 100644 index 000000000..0c7f2af85 --- /dev/null +++ b/scripted/utils.py @@ -0,0 +1,30 @@ + + +def l1(start, goal): + sr, sc = start + gr, gc = goal + return abs(gr - sr) + abs(gc - sc) + +def l2(start, goal): + sr, sc = start + gr, gc = goal + return 0.5*((gr - sr)**2 + (gc - sc)**2)**0.5 + +def lInfty(start, goal): + sr, sc = start + gr, gc = goal + return max(abs(gr - sr), abs(gc - sc)) + +def adjacentPos(pos): + r, c = pos + return [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)] + +def adjacentDeltas(): + return [(-1, 0), (1, 0), (0, 1), (0, -1)] + +def inSight(dr, dc, vision): + return ( + dr >= -vision and + dc >= -vision and + dr <= vision and + dc <= vision) \ No newline at end of file diff --git a/scripts/git-pr.sh b/scripts/git-pr.sh new file mode 100755 index 000000000..aa7584c5f --- /dev/null +++ b/scripts/git-pr.sh @@ -0,0 +1,63 @@ +#!/bin/bash +MASTER_BRANCH="v1.6" + +# check if in master branch +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == MASTER_BRANCH ]; then + echo "Please run 'git pr' from a topic branch." + exit 1 +fi + +# check if there are any uncommitted changes +git_status=$(git status --porcelain) + +if [ -n "$git_status" ]; then + read -p "Uncommitted changes found. Commit before running 'git pr'? (y/n) " ans + if [ "$ans" = "y" ]; then + git commit -m -a "Automatic commit for git-pr" + else + echo "Please commit or stash changes before running 'git pr'." + exit 1 + fi +fi + +# Merging master +echo "Merging master..." +git merge origin/$MASTER_BRANCH + +# Checking pylint, xcxc, pytest without touching git +PRE_GIT_CHECK=$(find . -name pre-git-check.sh) +if test -f "$PRE_GIT_CHECK"; then + $PRE_GIT_CHECK + if [ $? -ne 0 ]; then + echo "pre-git-check.sh failed. Exiting." + exit 1 + fi +else + echo "Missing pre-git-check.sh. Exiting." + exit 1 +fi + +# create a new branch from current branch and reset to master +echo "Creating and switching to new topic branch..." +git_user=$(git config user.email | cut -d'@' -f1) +branch_name="${git_user}-git-pr-$RANDOM-$RANDOM" +git checkout -b $branch_name +git reset --soft origin/$MASTER_BRANCH + +# Verify that a commit message was added +echo "Verifying commit message..." +if ! git commit -a ; then + echo "Commit message is empty. Exiting." + exit 1 +fi + +# Push the topic branch to origin +echo "Pushing topic branch to origin..." +git push -u origin $branch_name + +# Generate a Github pull request (just the url, not actually making a PR) +echo "Generating Github pull request..." +pull_request_url="https://github.com/CarperAI/nmmo-environment/compare/$MASTER_BRANCH...CarperAI:nmmo-environment:$branch_name?expand=1" + +echo "Pull request URL: $pull_request_url" diff --git a/scripts/pre-git-check.sh b/scripts/pre-git-check.sh new file mode 100755 index 000000000..29376bea5 --- /dev/null +++ b/scripts/pre-git-check.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +echo +echo "Checking pylint, xcxc, pytest without touching git" +echo + +# Run linter +echo "--------------------------------------------------------------------" +echo "Running linter..." +files=$(git ls-files -m -o --exclude-standard '*.py') +for file in $files; do + if test -e $file; then + echo $file + if ! pylint --score=no --fail-under=10 $file; then + echo "Lint failed. Exiting." + exit 1 + fi + fi +done + +if ! pylint --recursive=y nmmo tests; then + echo "Lint failed. Exiting." + exit 1 +fi + +# Check if there are any "xcxc" strings in the code +echo "--------------------------------------------------------------------" +echo "Looking for xcxc..." +files=$(find . -name '*.py') +for file in $files; do + if grep -q 'xcxc' $file; then + echo "Found xcxc in $file!" >&2 + read -p "Do you like to stop here? (y/n) " ans + if [ "$ans" = "y" ]; then + exit 1 + fi + fi +done + +# Run unit tests +echo +echo "--------------------------------------------------------------------" +echo "Running unit tests..." +if ! pytest; then + echo "Unit tests failed. Exiting." + exit 1 +fi + +echo +echo "Pre-git checks look good!" +echo \ No newline at end of file diff --git a/scripts/requirements.txt b/scripts/requirements.txt index 3d754363f..2e6d30742 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -4,6 +4,7 @@ fire==0.4.0 gym==0.17.2 imageio==2.8.0 lovely-tensors==0.1.8 +lovely-numpy==0.1.8 matplotlib==3.1.3 numpy==1.21.1 pettingzoo==1.13.1 diff --git a/setup.py b/setup.py index e28a33196..9bdcaf56a 100644 --- a/setup.py +++ b/setup.py @@ -1,26 +1,18 @@ -from pdb import set_trace as T from itertools import chain from setuptools import find_packages, setup - REPO_URL = "https://github.com/neuralmmo/environment" extra = { 'docs': [ 'sphinx-rtd-theme==0.5.1', 'sphinxcontrib-youtube==1.0.1', - ], - 'cleanrl': [ - 'wandb==0.12.9', - 'supersuit==3.3.5', - 'pettingzoo==1.15.0', - 'gym==0.23.0', - 'tensorboard', - 'torch', - 'openskill', - ], - } + 'myst-parser==1.0.0', + 'sphinx-rtd-theme==0.5.1', + 'sphinx_design==0.4.1', + ], +} extra['all'] = list(set(chain.from_iterable(extra.values()))) @@ -33,19 +25,21 @@ packages=find_packages(), include_package_data=True, install_requires=[ + 'numpy==1.23.3', + 'scipy==1.10.0', + 'pytest==7.3.0', 'pytest-benchmark==3.4.1', 'fire==0.4.0', - 'setproctitle==1.1.10', - 'service-identity==21.1.0', 'autobahn==19.3.3', 'Twisted==19.2.0', 'vec-noise==1.1.4', - 'imageio==2.8.0', + 'imageio==2.23.0', 'tqdm==4.61.1', - 'lz4==4.0.0', 'h5py==3.7.0', - 'pettingzoo', - 'ordered-set', + 'pettingzoo==1.19.0', + 'gym==0.23.0', + 'pylint==2.16.0', + 'py==1.11.0' ], extras_require=extra, python_requires=">=3.7", @@ -60,7 +54,11 @@ "Intended Audience :: Developers", "Environment :: Console", "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/action/test_ammo_use.py b/tests/action/test_ammo_use.py new file mode 100644 index 000000000..8d699dcf6 --- /dev/null +++ b/tests/action/test_ammo_use.py @@ -0,0 +1,361 @@ +import unittest +import logging +import numpy as np + +from tests.testhelpers import ScriptedTestTemplate, provide_item + +from nmmo.io import action +from nmmo.systems import item as Item +from nmmo.systems.item import ItemState + +RANDOM_SEED = 284 + +LOGFILE = 'tests/action/test_ammo_use.log' + +class TestAmmoUse(ScriptedTestTemplate): + # pylint: disable=protected-access,multiple-statements + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # config specific to the tests here + cls.config.LOG_VERBOSE = False + if cls.config.LOG_VERBOSE: + logging.basicConfig(filename=LOGFILE, level=logging.INFO) + + def _assert_action_targets_zero(self, gym_obs): + mask = np.sum(gym_obs['ActionTargets'][action.GiveGold][action.Price]) \ + + np.sum(gym_obs['ActionTargets'][action.Buy][action.MarketItem]) + for atn in [action.Use, action.Give, action.Destroy, action.Sell]: + mask += np.sum(gym_obs['ActionTargets'][atn][action.InventoryItem]) + self.assertEqual(mask, 0) + + def test_ammo_fire_all(self): + env = self._setup_env(random_seed=RANDOM_SEED) + + # First tick actions: USE (equip) level-0 ammo + env.step({ ent_id: { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(ent_ammo, 0) } + } for ent_id, ent_ammo in self.ammo.items() }) + + # check if the agents have equipped the ammo + for ent_id, ent_ammo in self.ammo.items(): + gym_obs = env.obs[ent_id].to_gym() + inventory = env.obs[ent_id].inventory + inv_idx = inventory.sig(ent_ammo, 0) + self.assertEqual(1, # True + ItemState.parse_array(inventory.values[inv_idx]).equipped) + + # check SELL InventoryItem mask -- one cannot sell equipped item + mask = gym_obs['ActionTargets'][action.Sell][action.InventoryItem][:inventory.len] > 0 + self.assertTrue(inventory.id(inv_idx) not in inventory.ids[mask]) + + # the agents must not be in combat status + self.assertFalse(env.realm.players[ent_id].in_combat) + + # Second tick actions: ATTACK other agents using ammo + # NOTE that agents 1 & 3's attack are invalid due to out-of-range + env.step({ ent_id: { action.Attack: + { action.Style: env.realm.players[ent_id].agent.style[0], + action.Target: (ent_id+1)%3+1 } } + for ent_id in self.ammo }) + + # check combat status: agents 2 (attacker) and 1 (target) are in combat + self.assertTrue(env.realm.players[2].in_combat) + self.assertTrue(env.realm.players[1].in_combat) + self.assertFalse(env.realm.players[3].in_combat) + + # check the action masks are all 0 during combat + for ent_id in [1, 2]: + self._assert_action_targets_zero(env.obs[ent_id].to_gym()) + + # check if the ammos were consumed + ammo_ids = [] + for ent_id, ent_ammo in self.ammo.items(): + inventory = env.obs[ent_id].inventory + inv_idx = inventory.sig(ent_ammo, 0) + item_info = ItemState.parse_array(inventory.values[inv_idx]) + if ent_id == 2: + # only agent 2's attack is valid and consume ammo + self.assertEqual(self.ammo_quantity - 1, item_info.quantity) + ammo_ids.append(inventory.id(inv_idx)) + else: + self.assertEqual(self.ammo_quantity, item_info.quantity) + + # Third tick actions: ATTACK again to use up all the ammo, except agent 3 + # NOTE that agent 3's attack command is invalid due to out-of-range + env.step({ ent_id: { action.Attack: + { action.Style: env.realm.players[ent_id].agent.style[0], + action.Target: (ent_id+1)%3+1 } } + for ent_id in self.ammo }) + + # agents 1 and 2's latest_combat_tick should be updated + self.assertEqual(env.realm.tick, env.realm.players[1].latest_combat_tick.val) + self.assertEqual(env.realm.tick, env.realm.players[2].latest_combat_tick.val) + self.assertEqual(0, env.realm.players[3].latest_combat_tick.val) + + # check if the ammos are depleted and the ammo slot is empty + ent_id = 2 + self.assertTrue(env.obs[ent_id].inventory.len == len(self.item_sig[ent_id]) - 1) + self.assertTrue(env.realm.players[ent_id].inventory.equipment.ammunition.item is None) + + for item_id in ammo_ids: + self.assertTrue(len(ItemState.Query.by_id(env.realm.datastore, item_id)) == 0) + self.assertTrue(item_id not in env.realm.items) + + # invalid attacks + for ent_id in [1, 3]: + # agent 3 gathered shaving, so the item count increased + #self.assertTrue(env.obs[ent_id].inventory.len == len(self.item_sig[ent_id])) + self.assertTrue(env.realm.players[ent_id].inventory.equipment.ammunition.item is not None) + + # after 3 ticks, combat status should be cleared + for _ in range(3): + env.step({ 0:0 }) # put dummy actions to prevent generating scripted actions + + for ent_id in [1, 2, 3]: + self.assertFalse(env.realm.players[ent_id].in_combat) + + # DONE + + def test_cannot_use_listed_items(self): + env = self._setup_env(random_seed=RANDOM_SEED) + + sell_price = 1 + + # provide extra scrap to range to make its inventory full + # but level-0 scrap overlaps with the listed item + ent_id = 2 + provide_item(env.realm, ent_id, Item.Scrap, level=0, quantity=3) + provide_item(env.realm, ent_id, Item.Scrap, level=1, quantity=3) + + # provide extra scrap to mage to make its inventory full + # there will be no overlapping item + ent_id = 3 + provide_item(env.realm, ent_id, Item.Scrap, level=5, quantity=3) + provide_item(env.realm, ent_id, Item.Scrap, level=7, quantity=3) + env.obs = env._compute_observations() + + # First tick actions: SELL level-0 ammo + env.step({ ent_id: { action.Sell: + { action.InventoryItem: env.obs[ent_id].inventory.sig(ent_ammo, 0), + action.Price: sell_price } } + for ent_id, ent_ammo in self.ammo.items() }) + + # check if the ammos were listed + for ent_id, ent_ammo in self.ammo.items(): + gym_obs = env.obs[ent_id].to_gym() + inventory = env.obs[ent_id].inventory + inv_idx = inventory.sig(ent_ammo, 0) + item_info = ItemState.parse_array(inventory.values[inv_idx]) + # ItemState data + self.assertEqual(sell_price, item_info.listed_price) + # Exchange listing + self.assertTrue(item_info.id in env.realm.exchange._item_listings) + self.assertTrue(item_info.id in env.obs[ent_id].market.ids) + + # check SELL InventoryItem mask -- one cannot sell listed item + mask = gym_obs['ActionTargets'][action.Sell][action.InventoryItem][:inventory.len] > 0 + self.assertTrue(inventory.id(inv_idx) not in inventory.ids[mask]) + + # check USE InventoryItem mask -- one cannot use listed item + mask = gym_obs['ActionTargets'][action.Use][action.InventoryItem][:inventory.len] > 0 + self.assertTrue(inventory.id(inv_idx) not in inventory.ids[mask]) + + # check BUY MarketItem mask -- there should be two ammo items in the market + mask = gym_obs['ActionTargets'][action.Buy][action.MarketItem][:inventory.len] > 0 + # agent 1 has inventory space + if ent_id == 1: self.assertTrue(sum(mask) == 2) + # agent 2's inventory is full but can buy level-0 scrap (existing ammo) + if ent_id == 2: self.assertTrue(sum(mask) == 1) + # agent 3's inventory is full without overlapping ammo + if ent_id == 3: self.assertTrue(sum(mask) == 0) + + # Second tick actions: USE ammo, which should NOT happen + env.step({ ent_id: { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(ent_ammo, 0) } + } for ent_id, ent_ammo in self.ammo.items() }) + + # check if the agents have equipped the ammo + for ent_id, ent_ammo in self.ammo.items(): + inventory = env.obs[ent_id].inventory + inv_idx = inventory.sig(ent_ammo, 0) + self.assertEqual(0, # False + ItemState.parse_array(inventory.values[inv_idx]).equipped) + + # DONE + + def test_receive_extra_ammo_swap(self): + env = self._setup_env(random_seed=RANDOM_SEED) + + extra_ammo = 500 + scrap_lvl0 = (Item.Scrap, 0) + scrap_lvl1 = (Item.Scrap, 1) + scrap_lvl3 = (Item.Scrap, 3) + + def sig_int_tuple(sig): + return (sig[0].ITEM_TYPE_ID, sig[1]) + + for ent_id in self.policy: + # provide extra scrap + provide_item(env.realm, ent_id, Item.Scrap, level=0, quantity=extra_ammo) + provide_item(env.realm, ent_id, Item.Scrap, level=1, quantity=extra_ammo) + + # level up the agent 1 (Melee) to 2 + env.realm.players[1].skills.melee.level.update(2) + env.obs = env._compute_observations() + + # check inventory + for ent_id in self.ammo: + # realm data + inv_realm = { item.signature: item.quantity.val + for item in env.realm.players[ent_id].inventory.items + if isinstance(item, Item.Stack) } + self.assertTrue( sig_int_tuple(scrap_lvl0) in inv_realm ) + self.assertTrue( sig_int_tuple(scrap_lvl1) in inv_realm ) + self.assertEqual( inv_realm[sig_int_tuple(scrap_lvl1)], extra_ammo ) + + # item datastore + inv_obs = env.obs[ent_id].inventory + self.assertTrue(inv_obs.sig(*scrap_lvl0) is not None) + self.assertTrue(inv_obs.sig(*scrap_lvl1) is not None) + self.assertEqual( extra_ammo, + ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl1)]).quantity) + if ent_id == 1: + # if the ammo has the same signature, the quantity is added to the existing stack + self.assertEqual( inv_realm[sig_int_tuple(scrap_lvl0)], extra_ammo + self.ammo_quantity ) + self.assertEqual( extra_ammo + self.ammo_quantity, + ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl0)]).quantity) + # so there should be 1 more space + self.assertEqual( inv_obs.len, self.config.ITEM_INVENTORY_CAPACITY - 1) + + else: + # if the signature is different, it occupies a new inventory space + self.assertEqual( inv_realm[sig_int_tuple(scrap_lvl0)], extra_ammo ) + self.assertEqual( extra_ammo, + ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl0)]).quantity) + # thus the inventory is full + self.assertEqual( inv_obs.len, self.config.ITEM_INVENTORY_CAPACITY) + + if ent_id == 1: + gym_obs = env.obs[ent_id].to_gym() + # check USE InventoryItem mask + mask = gym_obs['ActionTargets'][action.Use][action.InventoryItem][:inv_obs.len] > 0 + # level-2 melee should be able to use level-0, level-1 scrap but not level-3 + self.assertTrue(inv_obs.id(inv_obs.sig(*scrap_lvl0)) in inv_obs.ids[mask]) + self.assertTrue(inv_obs.id(inv_obs.sig(*scrap_lvl1)) in inv_obs.ids[mask]) + self.assertTrue(inv_obs.id(inv_obs.sig(*scrap_lvl3)) not in inv_obs.ids[mask]) + + # First tick actions: USE (equip) level-0 ammo + # execute only the agent 1's action + ent_id = 1 + env.step({ ent_id: { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*scrap_lvl0) } }}) + + # check if the agents have equipped the ammo 0 + inv_obs = env.obs[ent_id].inventory + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl0)]).equipped == 1) + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl1)]).equipped == 0) + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl3)]).equipped == 0) + + # Second tick actions: USE (equip) level-1 ammo + # this should unequip level-0 then equip level-1 ammo + env.step({ ent_id: { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*scrap_lvl1) } }}) + + # check if the agents have equipped the ammo 1 + inv_obs = env.obs[ent_id].inventory + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl0)]).equipped == 0) + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl1)]).equipped == 1) + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl3)]).equipped == 0) + + # Third tick actions: USE (equip) level-3 ammo + # this should ignore USE action and leave level-1 ammo equipped + env.step({ ent_id: { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*scrap_lvl3) } }}) + + # check if the agents have equipped the ammo 1 + inv_obs = env.obs[ent_id].inventory + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl0)]).equipped == 0) + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl1)]).equipped == 1) + self.assertTrue(ItemState.parse_array(inv_obs.values[inv_obs.sig(*scrap_lvl3)]).equipped == 0) + + # DONE + + def test_use_ration_poultice(self): + # cannot use level-3 ration & poultice due to low level + # can use level-0 ration & poultice to increase food/water/health + env = self._setup_env(random_seed=RANDOM_SEED) + + # make food/water/health 20 + res_dec_tick = env.config.RESOURCE_DEPLETION_RATE + init_res = 20 + for ent_id in self.policy: + env.realm.players[ent_id].resources.food.update(init_res) + env.realm.players[ent_id].resources.water.update(init_res) + env.realm.players[ent_id].resources.health.update(init_res) + env.obs = env._compute_observations() + + """First tick: try to use level-3 ration & poultice""" + ration_lvl3 = (Item.Ration, 3) + poultice_lvl3 = (Item.Poultice, 3) + + actions = {} + ent_id = 1; actions[ent_id] = { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*ration_lvl3) } } + ent_id = 2; actions[ent_id] = { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*ration_lvl3) } } + ent_id = 3; actions[ent_id] = { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*poultice_lvl3) } } + + env.step(actions) + + # check if the agents have used the ration & poultice + for ent_id in [1, 2]: + # cannot use due to low level, so still in the inventory + self.assertFalse( env.obs[ent_id].inventory.sig(*ration_lvl3) is None) + + # failed to restore food/water, so no change + resources = env.realm.players[ent_id].resources + self.assertEqual( resources.food.val, init_res - res_dec_tick) + self.assertEqual( resources.water.val, init_res - res_dec_tick) + + ent_id = 3 # failed to use the item + self.assertFalse( env.obs[ent_id].inventory.sig(*poultice_lvl3) is None) + self.assertEqual( env.realm.players[ent_id].resources.health.val, init_res) + + """Second tick: try to use level-0 ration & poultice""" + ration_lvl0 = (Item.Ration, 0) + poultice_lvl0 = (Item.Poultice, 0) + + actions = {} + ent_id = 1; actions[ent_id] = { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*ration_lvl0) } } + ent_id = 2; actions[ent_id] = { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*ration_lvl0) } } + ent_id = 3; actions[ent_id] = { action.Use: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*poultice_lvl0) } } + + env.step(actions) + + # check if the agents have successfully used the ration & poultice + restore = env.config.PROFESSION_CONSUMABLE_RESTORE(0) + for ent_id in [1, 2]: + # items should be gone + self.assertTrue( env.obs[ent_id].inventory.sig(*ration_lvl0) is None) + + # successfully restored food/water + resources = env.realm.players[ent_id].resources + self.assertEqual( resources.food.val, init_res + restore - 2*res_dec_tick) + self.assertEqual( resources.water.val, init_res + restore - 2*res_dec_tick) + + ent_id = 3 # successfully restored health + self.assertTrue( env.obs[ent_id].inventory.sig(*poultice_lvl0) is None) # item gone + self.assertEqual( env.realm.players[ent_id].resources.health.val, init_res + restore) + + # DONE + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/action/test_destroy_give_gold.py b/tests/action/test_destroy_give_gold.py new file mode 100644 index 000000000..b21da3f97 --- /dev/null +++ b/tests/action/test_destroy_give_gold.py @@ -0,0 +1,289 @@ +import unittest +import logging + +from tests.testhelpers import ScriptedTestTemplate, change_spawn_pos, provide_item + +from nmmo.io import action +from nmmo.systems import item as Item +from nmmo.systems.item import ItemState +from scripted import baselines + +RANDOM_SEED = 985 + +LOGFILE = 'tests/action/test_destroy_give_gold.log' + +class TestDestroyGiveGold(ScriptedTestTemplate): + # pylint: disable=protected-access,multiple-statements + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # config specific to the tests here + cls.config.PLAYERS = [baselines.Melee, baselines.Range] + cls.config.PLAYER_N = 6 + + cls.policy = { 1:'Melee', 2:'Range', 3:'Melee', 4:'Range', 5:'Melee', 6:'Range' } + cls.spawn_locs = { 1:(17,17), 2:(21,21), 3:(17,17), 4:(21,21), 5:(21,21), 6:(17,17) } + cls.ammo = { 1:Item.Scrap, 2:Item.Shaving, 3:Item.Scrap, + 4:Item.Shaving, 5:Item.Scrap, 6:Item.Shaving } + + cls.config.LOG_VERBOSE = False + if cls.config.LOG_VERBOSE: + logging.basicConfig(filename=LOGFILE, level=logging.INFO) + + def test_destroy(self): + env = self._setup_env(random_seed=RANDOM_SEED) + + # check if level-0 and level-3 ammo are in the correct place + for ent_id in self.policy: + for idx, lvl in enumerate(self.item_level): + assert self.item_sig[ent_id][idx] == (self.ammo[ent_id], lvl) + + # equipped items cannot be destroyed, i.e. that action will be ignored + # this should be marked in the mask too + + """ First tick """ # First tick actions: USE (equip) level-0 ammo + env.step({ ent_id: { action.Use: { action.InventoryItem: + env.obs[ent_id].inventory.sig(*self.item_sig[ent_id][0]) } # level-0 ammo + } for ent_id in self.policy }) + + # check if the agents have equipped the ammo + for ent_id in self.policy: + ent_obs = env.obs[ent_id] + inv_idx = ent_obs.inventory.sig(*self.item_sig[ent_id][0]) # level-0 ammo + self.assertEqual(1, # True + ItemState.parse_array(ent_obs.inventory.values[inv_idx]).equipped) + + # check Destroy InventoryItem mask -- one cannot destroy equipped item + for item_sig in self.item_sig[ent_id]: + if item_sig == (self.ammo[ent_id], 0): # level-0 ammo + self.assertFalse(self._check_inv_mask(ent_obs, action.Destroy, item_sig)) + else: + # other items can be destroyed + self.assertTrue(self._check_inv_mask(ent_obs, action.Destroy, item_sig)) + + """ Second tick """ # Second tick actions: DESTROY ammo + actions = {} + + for ent_id in self.policy: + if ent_id in [1, 2]: + # agent 1 & 2, destroy the level-3 ammos, which are valid + actions[ent_id] = { action.Destroy: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*self.item_sig[ent_id][1]) } } + else: + # other agents: destroy the equipped level-0 ammos, which are invalid + actions[ent_id] = { action.Destroy: + { action.InventoryItem: env.obs[ent_id].inventory.sig(*self.item_sig[ent_id][0]) } } + env.step(actions) + + # check if the ammos were destroyed + for ent_id in self.policy: + if ent_id in [1, 2]: + inv_idx = env.obs[ent_id].inventory.sig(*self.item_sig[ent_id][1]) + self.assertTrue(inv_idx is None) # valid actions, thus destroyed + else: + inv_idx = env.obs[ent_id].inventory.sig(*self.item_sig[ent_id][0]) + self.assertTrue(inv_idx is not None) # invalid actions, thus not destroyed + + # DONE + + def test_give_tile_npc(self): + # cannot give to self (should be masked) + # cannot give if not on the same tile (should be masked) + # cannot give to the other team member (should be masked) + # cannot give to npc (should be masked) + env = self._setup_env(random_seed=RANDOM_SEED) + + # teleport the npc -1 to agent 5's location + change_spawn_pos(env.realm, -1, self.spawn_locs[5]) + env.obs = env._compute_observations() + + """ First tick actions """ + actions = {} + test_cond = {} + + # agent 1: give ammo to agent 3 (valid: the same team, same tile) + test_cond[1] = { 'tgt_id': 3, 'item_sig': self.item_sig[1][0], + 'ent_mask': True, 'inv_mask': True, 'valid': True } + # agent 2: give ammo to agent 2 (invalid: cannot give to self) + test_cond[2] = { 'tgt_id': 2, 'item_sig': self.item_sig[2][0], + 'ent_mask': False, 'inv_mask': True, 'valid': False } + # agent 4: give ammo to agent 5 (invalid: other tile) + test_cond[4] = { 'tgt_id': 6, 'item_sig': self.item_sig[4][0], + 'ent_mask': False, 'inv_mask': True, 'valid': False } + # agent 5: give ammo to npc -1 (invalid, should be masked) + test_cond[5] = { 'tgt_id': -1, 'item_sig': self.item_sig[5][0], + 'ent_mask': False, 'inv_mask': True, 'valid': False } + + actions = self._check_assert_make_action(env, action.Give, test_cond) + env.step(actions) + + # check the results + for ent_id, cond in test_cond.items(): + self.assertEqual( cond['valid'], + env.obs[ent_id].inventory.sig(*cond['item_sig']) is None) + + if ent_id == 1: # agent 1 gave ammo stack to agent 3 + tgt_inv = env.obs[cond['tgt_id']].inventory + inv_idx = tgt_inv.sig(*cond['item_sig']) + self.assertEqual(2 * self.ammo_quantity, + ItemState.parse_array(tgt_inv.values[inv_idx]).quantity) + + # DONE + + def test_give_equipped_listed(self): + # cannot give equipped items (should be masked) + # cannot give listed items (should be masked) + env = self._setup_env(random_seed=RANDOM_SEED) + + """ First tick actions """ + actions = {} + + # agent 1: equip the ammo + ent_id = 1; item_sig = self.item_sig[ent_id][0] + self.assertTrue( + self._check_inv_mask(env.obs[ent_id], action.Use, item_sig)) + actions[ent_id] = { action.Use: { action.InventoryItem: + env.obs[ent_id].inventory.sig(*item_sig) } } + + # agent 2: list the ammo for sale + ent_id = 2; price = 5; item_sig = self.item_sig[ent_id][0] + self.assertTrue( + self._check_inv_mask(env.obs[ent_id], action.Sell, item_sig)) + actions[ent_id] = { action.Sell: { + action.InventoryItem: env.obs[ent_id].inventory.sig(*item_sig), + action.Price: price } } + + env.step(actions) + + # Check the first tick actions + # agent 1: equip the ammo + ent_id = 1; item_sig = self.item_sig[ent_id][0] + inv_idx = env.obs[ent_id].inventory.sig(*item_sig) + self.assertEqual(1, + ItemState.parse_array(env.obs[ent_id].inventory.values[inv_idx]).equipped) + + # agent 2: list the ammo for sale + ent_id = 2; price = 5; item_sig = self.item_sig[ent_id][0] + inv_idx = env.obs[ent_id].inventory.sig(*item_sig) + self.assertEqual(price, + ItemState.parse_array(env.obs[ent_id].inventory.values[inv_idx]).listed_price) + self.assertTrue(env.obs[ent_id].inventory.id(inv_idx) in env.obs[ent_id].market.ids) + + """ Second tick actions """ + actions = {} + test_cond = {} + + # agent 1: give equipped ammo to agent 3 (invalid: should be masked) + test_cond[1] = { 'tgt_id': 3, 'item_sig': self.item_sig[1][0], + 'ent_mask': True, 'inv_mask': False, 'valid': False } + # agent 2: give listed ammo to agent 4 (invalid: should be masked) + test_cond[2] = { 'tgt_id': 4, 'item_sig': self.item_sig[2][0], + 'ent_mask': True, 'inv_mask': False, 'valid': False } + + actions = self._check_assert_make_action(env, action.Give, test_cond) + env.step(actions) + + # Check the second tick actions + # check the results + for ent_id, cond in test_cond.items(): + self.assertEqual( cond['valid'], + env.obs[ent_id].inventory.sig(*cond['item_sig']) is None) + + # DONE + + def test_give_full_inventory(self): + # cannot give to an agent with the full inventory, + # but it's possible if the agent has the same ammo stack + env = self._setup_env(random_seed=RANDOM_SEED) + + # make the inventory full for agents 1, 2 + extra_items = { (Item.Bottom, 0), (Item.Bottom, 3) } + for ent_id in [1, 2]: + for item_sig in extra_items: + self.item_sig[ent_id].append(item_sig) + provide_item(env.realm, ent_id, item_sig[0], item_sig[1], 1) + + env.obs = env._compute_observations() + + # check if the inventory is full + for ent_id in [1, 2]: + self.assertEqual(env.obs[ent_id].inventory.len, env.config.ITEM_INVENTORY_CAPACITY) + self.assertTrue(env.realm.players[ent_id].inventory.space == 0) + + """ First tick actions """ + actions = {} + test_cond = {} + + # agent 3: give ammo to agent 1 (the same ammo stack, so valid) + test_cond[3] = { 'tgt_id': 1, 'item_sig': self.item_sig[3][0], + 'ent_mask': True, 'inv_mask': True, 'valid': True } + # agent 4: give gloves to agent 2 (not the stack, so invalid) + test_cond[4] = { 'tgt_id': 2, 'item_sig': self.item_sig[4][4], + 'ent_mask': True, 'inv_mask': True, 'valid': False } + + actions = self._check_assert_make_action(env, action.Give, test_cond) + env.step(actions) + + # Check the first tick actions + # check the results + for ent_id, cond in test_cond.items(): + self.assertEqual( cond['valid'], + env.obs[ent_id].inventory.sig(*cond['item_sig']) is None) + + if ent_id == 3: # successfully gave the ammo stack to agent 1 + tgt_inv = env.obs[cond['tgt_id']].inventory + inv_idx = tgt_inv.sig(*cond['item_sig']) + self.assertEqual(2 * self.ammo_quantity, + ItemState.parse_array(tgt_inv.values[inv_idx]).quantity) + + # DONE + + def test_give_gold(self): + # cannot give to an npc (should be masked) + # cannot give to self (should be masked) + # cannot give if not on the same tile (should be masked) + env = self._setup_env(random_seed=RANDOM_SEED) + + # teleport the npc -1 to agent 3's location + change_spawn_pos(env.realm, -1, self.spawn_locs[3]) + env.obs = env._compute_observations() + + test_cond = {} + + # NOTE: the below tests rely on the static execution order from 1 to N + # agent 1: give gold to agent 3 (valid: same tile) + test_cond[1] = { 'tgt_id': 3, 'gold': 1, 'ent_mask': True, + 'ent_gold': self.init_gold-1, 'tgt_gold': self.init_gold+1 } + # agent 2: give gold to agent 4 (valid: same tile) + test_cond[2] = { 'tgt_id': 4, 'gold': 100, 'ent_mask': True, + 'ent_gold': 0, 'tgt_gold': 2*self.init_gold } + # agent 3: give gold to npc -1 (invalid: cannot give to npc) + # ent_gold is self.init_gold+1 because (3) got 1 gold from (1) + test_cond[3] = { 'tgt_id': -1, 'gold': 1, 'ent_mask': False, + 'ent_gold': self.init_gold+1, 'tgt_gold': self.init_gold } + # agent 4: give -1 gold to 2 (invalid: cannot give minus gold) + # ent_gold is 2*self.init_gold because (4) got 5 gold from (2) + # tgt_gold is 0 because (2) gave all gold to (4) + test_cond[4] = { 'tgt_id': 2, 'gold': -1, 'ent_mask': True, + 'ent_gold': 2*self.init_gold, 'tgt_gold': 0 } + # agent 6: give gold to agent 4 (invalid: the other tile) + # tgt_gold is 2*self.init_gold because (4) got 5 gold from (2) + test_cond[6] = { 'tgt_id': 4, 'gold': 1, 'ent_mask': False, + 'ent_gold': self.init_gold, 'tgt_gold': 2*self.init_gold } + + actions = self._check_assert_make_action(env, action.GiveGold, test_cond) + env.step(actions) + + # check the results + for ent_id, cond in test_cond.items(): + self.assertEqual(cond['ent_gold'], env.realm.players[ent_id].gold.val) + if cond['tgt_id'] > 0: + self.assertEqual(cond['tgt_gold'], env.realm.players[cond['tgt_id']].gold.val) + + # DONE + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/action/test_monkey_action.py b/tests/action/test_monkey_action.py new file mode 100644 index 000000000..9b5d2e2c3 --- /dev/null +++ b/tests/action/test_monkey_action.py @@ -0,0 +1,86 @@ +import unittest +import random +from tqdm import tqdm + +import numpy as np + +from tests.testhelpers import ScriptedAgentTestConfig, ScriptedAgentTestEnv + +import nmmo + +# 30 seems to be enough to test variety of agent actions +TEST_HORIZON = 30 +RANDOM_SEED = random.randint(0, 1000000) + + +def make_random_actions(config, ent_obs): + assert 'ActionTargets' in ent_obs, 'ActionTargets is not provided in the obs' + actions = {} + + # atn, arg, val + for atn in sorted(nmmo.Action.edges(config)): + actions[atn] = {} + for arg in sorted(atn.edges, reverse=True): # intentionally doing wrong + mask = ent_obs['ActionTargets'][atn][arg] + actions[atn][arg] = 0 + if np.any(mask): + actions[atn][arg] += int(np.random.choice(np.where(mask)[0])) + + return actions + +# CHECK ME: this would be nice to include in the env._validate_actions() +def filter_item_actions(actions): + # when there are multiple actions on the same item, select one + flt_atns = {} + inventory_atn = {} # key: inventory idx, val: action + for atn in actions: + if atn in [nmmo.action.Use, nmmo.action.Sell, nmmo.action.Give, nmmo.action.Destroy]: + for arg, val in actions[atn].items(): + if arg == nmmo.action.InventoryItem: + if val not in inventory_atn: + inventory_atn[val] = [( atn, actions[atn] )] + else: + inventory_atn[val].append(( atn, actions[atn] )) + else: + flt_atns[atn] = actions[atn] + + # randomly select one action for each inventory item + for atns in inventory_atn.values(): + if len(atns) > 1: + picked = random.choice(atns) + flt_atns[picked[0]] = picked[1] + else: + flt_atns[atns[0][0]] = atns[0][1] + + return flt_atns + + +class TestMonkeyAction(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.config = ScriptedAgentTestConfig() + cls.config.PROVIDE_ACTION_TARGETS = True + + @staticmethod + # NOTE: this can also be used for sweeping random seeds + def rollout_with_seed(config, seed): + env = ScriptedAgentTestEnv(config) + obs = env.reset(seed=seed) + + for _ in tqdm(range(TEST_HORIZON)): + # sample random actions for each player + actions = {} + for ent_id in env.realm.players: + ent_atns = make_random_actions(config, obs[ent_id]) + actions[ent_id] = filter_item_actions(ent_atns) + obs, _, _, _ = env.step(actions) + + def test_monkey_action(self): + try: + self.rollout_with_seed(self.config, RANDOM_SEED) + except: # pylint: disable=bare-except + assert False, f"Monkey action failed. seed: {RANDOM_SEED}" + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/action/test_sell_buy.py b/tests/action/test_sell_buy.py new file mode 100644 index 000000000..d883c19ca --- /dev/null +++ b/tests/action/test_sell_buy.py @@ -0,0 +1,179 @@ +import unittest +import logging + +from tests.testhelpers import ScriptedTestTemplate, provide_item + +from nmmo.io import action +from nmmo.systems import item as Item +from nmmo.systems.item import ItemState +from scripted import baselines + +RANDOM_SEED = 985 + +LOGFILE = 'tests/action/test_sell_buy.log' + +class TestSellBuy(ScriptedTestTemplate): + # pylint: disable=protected-access,multiple-statements,unsubscriptable-object + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # config specific to the tests here + cls.config.PLAYERS = [baselines.Melee, baselines.Range] + cls.config.PLAYER_N = 6 + + cls.policy = { 1:'Melee', 2:'Range', 3:'Melee', 4:'Range', 5:'Melee', 6:'Range' } + cls.ammo = { 1:Item.Scrap, 2:Item.Shaving, 3:Item.Scrap, + 4:Item.Shaving, 5:Item.Scrap, 6:Item.Shaving } + + cls.config.LOG_VERBOSE = False + if cls.config.LOG_VERBOSE: + logging.basicConfig(filename=LOGFILE, level=logging.INFO) + + + def test_sell_buy(self): + # cannot list an item with 0 price --> impossible to do this + # cannot list an equipped item for sale (should be masked) + # cannot buy an item with the full inventory, + # but it's possible if the agent has the same ammo stack + # cannot buy its own item (should be masked) + # cannot buy an item if gold is not enough (should be masked) + # cannot list an already listed item for sale (should be masked) + env = self._setup_env(random_seed=RANDOM_SEED) + + # make the inventory full for agents 1, 2 + extra_items = { (Item.Bottom, 0), (Item.Bottom, 3) } + for ent_id in [1, 2]: + for item_sig in extra_items: + self.item_sig[ent_id].append(item_sig) + provide_item(env.realm, ent_id, item_sig[0], item_sig[1], 1) + + env.obs = env._compute_observations() + + # check if the inventory is full + for ent_id in [1, 2]: + self.assertEqual(env.obs[ent_id].inventory.len, env.config.ITEM_INVENTORY_CAPACITY) + self.assertTrue(env.realm.players[ent_id].inventory.space == 0) + + """ First tick actions """ + # cannot list an item with 0 price + actions = {} + + # agent 1-2: equip the ammo + for ent_id in [1, 2]: + item_sig = self.item_sig[ent_id][0] + self.assertTrue( + self._check_inv_mask(env.obs[ent_id], action.Use, item_sig)) + actions[ent_id] = { action.Use: { action.InventoryItem: + env.obs[ent_id].inventory.sig(*item_sig) } } + + # agent 4: list the ammo for sale with price 0 + # the zero in action.Price is deserialized into Discrete_1, so it's valid + ent_id = 4; price = 0; item_sig = self.item_sig[ent_id][0] + actions[ent_id] = { action.Sell: { + action.InventoryItem: env.obs[ent_id].inventory.sig(*item_sig), + action.Price: action.Price.edges[price] } } + + env.step(actions) + + # Check the first tick actions + # agent 1-2: the ammo equipped, thus should be masked for sale + for ent_id in [1, 2]: + item_sig = self.item_sig[ent_id][0] + inv_idx = env.obs[ent_id].inventory.sig(*item_sig) + self.assertEqual(1, # equipped = true + ItemState.parse_array(env.obs[ent_id].inventory.values[inv_idx]).equipped) + self.assertFalse( # not allowed to list + self._check_inv_mask(env.obs[ent_id], action.Sell, item_sig)) + + """ Second tick actions """ + # listing the level-0 ammo with different prices + # cannot list an equipped item for sale (should be masked) + + listing_price = { 1:1, 2:5, 3:15, 5:2 } # gold + for ent_id, price in listing_price.items(): + item_sig = self.item_sig[ent_id][0] + actions[ent_id] = { action.Sell: { + action.InventoryItem: env.obs[ent_id].inventory.sig(*item_sig), + action.Price: action.Price.edges[price-1] } } + + env.step(actions) + + # Check the second tick actions + # agent 1-2: the ammo equipped, thus not listed for sale + # agent 3-5's ammos listed for sale + for ent_id, price in listing_price.items(): + item_id = env.obs[ent_id].inventory.id(0) + + if ent_id in [1, 2]: # failed to list for sale + self.assertFalse(item_id in env.obs[ent_id].market.ids) # not listed + self.assertEqual(0, + ItemState.parse_array(env.obs[ent_id].inventory.values[0]).listed_price) + + else: # should succeed to list for sale + self.assertTrue(item_id in env.obs[ent_id].market.ids) # listed + self.assertEqual(price, # sale price set + ItemState.parse_array(env.obs[ent_id].inventory.values[0]).listed_price) + + # should not buy mine + self.assertFalse( self._check_mkt_mask(env.obs[ent_id], item_id)) + + # should not list the same item twice + self.assertFalse( + self._check_inv_mask(env.obs[ent_id], action.Sell, self.item_sig[ent_id][0])) + + """ Third tick actions """ + # cannot buy an item with the full inventory, + # but it's possible if the agent has the same ammo stack + # cannot buy its own item (should be masked) + # cannot buy an item if gold is not enough (should be masked) + # cannot list an already listed item for sale (should be masked) + + test_cond = {} + + # agent 1: buy agent 5's ammo (valid: 1 has the same ammo stack) + # although 1's inventory is full, this action is valid + agent5_ammo = env.obs[5].inventory.id(0) + test_cond[1] = { 'item_id': agent5_ammo, 'mkt_mask': True } + + # agent 2: buy agent 5's ammo (invalid: full space and no same stack) + test_cond[2] = { 'item_id': agent5_ammo, 'mkt_mask': False } + + # agent 4: cannot buy its own item (invalid) + test_cond[4] = { 'item_id': env.obs[4].inventory.id(0), 'mkt_mask': False } + + # agent 5: cannot buy agent 3's ammo (invalid: not enought gold) + test_cond[5] = { 'item_id': env.obs[3].inventory.id(0), 'mkt_mask': False } + + actions = self._check_assert_make_action(env, action.Buy, test_cond) + + # agent 3: list an already listed item for sale (try different price) + ent_id = 3; item_sig = self.item_sig[ent_id][0] + actions[ent_id] = { action.Sell: { + action.InventoryItem: env.obs[ent_id].inventory.sig(*item_sig), + action.Price: action.Price.edges[7] } } # try to set different price + + env.step(actions) + + # Check the third tick actions + # agent 1: buy agent 5's ammo (valid: 1 has the same ammo stack) + # agent 5's ammo should be gone + seller_id = 5; buyer_id = 1 + self.assertFalse( agent5_ammo in env.obs[seller_id].inventory.ids) + self.assertEqual( env.realm.players[seller_id].gold.val, # gold transfer + self.init_gold + listing_price[seller_id]) + self.assertEqual(2 * self.ammo_quantity, # ammo transfer + ItemState.parse_array(env.obs[buyer_id].inventory.values[0]).quantity) + self.assertEqual( env.realm.players[buyer_id].gold.val, # gold transfer + self.init_gold - listing_price[seller_id]) + + # agent 2-4: invalid buy, no exchange, thus the same money + for ent_id in [2, 3, 4]: + self.assertEqual( env.realm.players[ent_id].gold.val, self.init_gold) + + # DONE + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/conftest.py b/tests/conftest.py index aebd5d4fb..e47ec08ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,16 @@ + +#pylint: disable=unused-argument + +import logging +logging.basicConfig(level=logging.INFO, stream=None) + def pytest_benchmark_scale_unit(config, unit, benchmarks, best, worst, sort): - if unit == 'seconds': - prefix = 'millisec' - scale = 1000 - elif unit == 'operations': - prefix = '' - scale = 1 - else: - raise RuntimeError("Unexpected measurement unit %r" % unit) - return prefix, scale + if unit == 'seconds': + prefix = 'millisec' + scale = 1000 + elif unit == 'operations': + prefix = '' + scale = 1 + else: + raise RuntimeError(f"Unexpected measurement unit {unit}") + return prefix, scale diff --git a/tests/core/test_env.py b/tests/core/test_env.py new file mode 100644 index 000000000..ba1829051 --- /dev/null +++ b/tests/core/test_env.py @@ -0,0 +1,154 @@ + +import unittest +from typing import List + +import random +from tqdm import tqdm + +import nmmo +from nmmo.core.realm import Realm +from nmmo.core.tile import TileState +from nmmo.entity.entity import Entity, EntityState +from nmmo.systems.item import ItemState +from scripted import baselines + +# Allow private access for testing +# pylint: disable=protected-access + +# 30 seems to be enough to test variety of agent actions +TEST_HORIZON = 30 +RANDOM_SEED = random.randint(0, 10000) +# TODO: We should check that milestones have been reached, to make +# sure that the agents aren't just dying +class Config(nmmo.config.Small, nmmo.config.AllGameSystems): + RENDER = False + SPECIALIZE = True + PLAYERS = [ + baselines.Fisher, baselines.Herbalist, baselines.Prospector, + baselines.Carver, baselines.Alchemist, + baselines.Melee, baselines.Range, baselines.Mage] + +class TestEnv(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.config = Config() + cls.env = nmmo.Env(cls.config, RANDOM_SEED) + + def test_action_space(self): + action_space = self.env.action_space(0) + self.assertSetEqual( + set(action_space.keys()), + set(nmmo.Action.edges(self.config))) + + def test_observations(self): + obs = self.env.reset() + + self.assertEqual(obs.keys(), self.env.realm.players.keys()) + + for _ in tqdm(range(TEST_HORIZON)): + entity_locations = [ + [ev.row.val, ev.col.val, e] for e, ev in self.env.realm.players.entities.items() + ] + [ + [ev.row.val, ev.col.val, e] for e, ev in self.env.realm.npcs.entities.items() + ] + + for player_id, player_obs in obs.items(): + self._validate_tiles(player_obs, self.env.realm) + self._validate_entitites( + player_id, player_obs, self.env.realm, entity_locations) + self._validate_inventory(player_id, player_obs, self.env.realm) + self._validate_market(player_obs, self.env.realm) + obs, _, _, _ = self.env.step({}) + + def _validate_tiles(self, obs, realm: Realm): + for tile_obs in obs["Tile"]: + tile_obs = TileState.parse_array(tile_obs) + tile = realm.map.tiles[(int(tile_obs.row), int(tile_obs.col))] + for key, val in tile_obs.__dict__.items(): + if val != getattr(tile, key).val: + self.assertEqual(val, getattr(tile, key).val, + f"Mismatch for {key} in tile {tile_obs.row}, {tile_obs.col}") + + def _validate_entitites(self, player_id, obs, realm: Realm, entity_locations: List[List[int]]): + observed_entities = set() + + for entity_obs in obs["Entity"]: + entity_obs = EntityState.parse_array(entity_obs) + + if entity_obs.id == 0: + continue + + entity: Entity = realm.entity(entity_obs.id) + + observed_entities.add(entity.ent_id) + + for key, val in entity_obs.__dict__.items(): + if getattr(entity, key) is None: + raise ValueError(f"Entity {entity} has no attribute {key}") + self.assertEqual(val, getattr(entity, key).val, + f"Mismatch for {key} in entity {entity_obs.id}") + + # Make sure that we see entities IFF they are in our vision radius + row = realm.players.entities[player_id].row.val + col = realm.players.entities[player_id].col.val + visible_entities = { + e for r, c, e in entity_locations + if r >= row - realm.config.PLAYER_VISION_RADIUS + and c >= col - realm.config.PLAYER_VISION_RADIUS + and r <= row + realm.config.PLAYER_VISION_RADIUS + and c <= col + realm.config.PLAYER_VISION_RADIUS + } + self.assertSetEqual(visible_entities, observed_entities, + f"Mismatch between observed: {observed_entities} " \ + f"and visible {visible_entities} for player {player_id}, "\ + f" step {self.env.realm.tick}") + + def _validate_inventory(self, player_id, obs, realm: Realm): + self._validate_items( + {i.id.val: i for i in realm.players[player_id].inventory.items}, + obs["Inventory"] + ) + + def _validate_market(self, obs, realm: Realm): + self._validate_items( + {i.item.id.val: i.item for i in realm.exchange._item_listings.values()}, + obs["Market"] + ) + + def _validate_items(self, items_dict, item_obs): + item_obs = item_obs[item_obs[:,0] != 0] + if len(items_dict) != len(item_obs): + assert len(items_dict) == len(item_obs) + for item_ob in item_obs: + item_ob = ItemState.parse_array(item_ob) + item = items_dict[item_ob.id] + for key, val in item_ob.__dict__.items(): + self.assertEqual(val, getattr(item, key).val, + f"Mismatch for {key} in item {item_ob.id}: {val} != {getattr(item, key).val}") + + def test_clean_item_after_reset(self): + # use the separate env + new_env = nmmo.Env(self.config, RANDOM_SEED) + + # reset the environment after running + new_env.reset() + for _ in tqdm(range(TEST_HORIZON)): + new_env.step({}) + new_env.reset() + + # items are referenced in the realm.items, which must be empty + self.assertTrue(len(new_env.realm.items) == 0) + + # items are referenced in the exchange + self.assertTrue(len(new_env.realm.exchange._item_listings) == 0) + self.assertTrue(len(new_env.realm.exchange._listings_queue) == 0) + + # TODO(kywch): ItemState table is not empty after players/npcs.reset() + # but should be. Will fix this while debugging the item system. + # So for now, ItemState table is cleared manually here, just to pass this test + ItemState.State.table(new_env.realm.datastore).reset() + + self.assertTrue(ItemState.State.table(new_env.realm.datastore).is_empty()) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core/test_map_generation.py b/tests/core/test_map_generation.py new file mode 100644 index 000000000..8b699d537 --- /dev/null +++ b/tests/core/test_map_generation.py @@ -0,0 +1,28 @@ +import unittest +import os +import shutil + +import nmmo + +class TestMapGeneration(unittest.TestCase): + def test_insufficient_maps(self): + config = nmmo.config.Small() + config.MAP_N = 20 + + path_maps = os.path.join(config.PATH_CWD, config.PATH_MAPS) + shutil.rmtree(path_maps) + + # this generates 20 maps + nmmo.Env(config) + + # test if MAP_FORCE_GENERATION can be overriden + config.MAP_N = 30 + config.MAP_FORCE_GENERATION = False + + test_env = nmmo.Env(config) + test_env.reset(map_id = 25) + + # this should finish without error + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core/test_tile.py b/tests/core/test_tile.py new file mode 100644 index 000000000..593ddad9c --- /dev/null +++ b/tests/core/test_tile.py @@ -0,0 +1,39 @@ +import unittest +import nmmo +from nmmo.core.tile import Tile, TileState +from nmmo.datastore.numpy_datastore import NumpyDatastore +from nmmo.lib import material + +class MockRealm: + def __init__(self): + self.datastore = NumpyDatastore() + self.datastore.register_object_type("Tile", TileState.State.num_attributes) + self.config = nmmo.config.Small() + +class MockEntity(): + def __init__(self, id): + self.ent_id = id + +class TestTile(unittest.TestCase): + def test_tile(self): + mock_realm = MockRealm() + tile = Tile(mock_realm, 10, 20) + + tile.reset(material.Forest, nmmo.config.Small()) + + self.assertEqual(tile.row.val, 10) + self.assertEqual(tile.col.val, 20) + self.assertEqual(tile.material_id.val, material.Forest.index) + + tile.add_entity(MockEntity(1)) + tile.add_entity(MockEntity(2)) + self.assertCountEqual(tile.entities.keys(), [1, 2]) + tile.remove_entity(1) + self.assertCountEqual(tile.entities.keys(), [2]) + + tile.harvest(True) + self.assertEqual(tile.depleted, True) + self.assertEqual(tile.material_id.val, material.Scrub.index) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/datastore/__init__.py b/tests/datastore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/datastore/test_datastore.py b/tests/datastore/test_datastore.py new file mode 100644 index 000000000..d9bf7c6a0 --- /dev/null +++ b/tests/datastore/test_datastore.py @@ -0,0 +1,43 @@ +import unittest + +import numpy as np + +from nmmo.datastore.numpy_datastore import NumpyDatastore + + +class TestDatastore(unittest.TestCase): + + def testdatastore_record(self): + datastore = NumpyDatastore() + datastore.register_object_type("TestObject", 2) + c1 = 0 + c2 = 1 + + o = datastore.create_record("TestObject") + self.assertEqual([o.get(c1), o.get(c2)], [0, 0]) + + o.update(c1, 1) + o.update(c2, 2) + self.assertEqual([o.get(c1), o.get(c2)], [1, 2]) + + np.testing.assert_array_equal( + datastore.table("TestObject").get([o.id]), + np.array([[1, 2]])) + + o2 = datastore.create_record("TestObject") + o2.update(c2, 2) + np.testing.assert_array_equal( + datastore.table("TestObject").get([o.id, o2.id]), + np.array([[1, 2], [0, 2]])) + + np.testing.assert_array_equal( + datastore.table("TestObject").where_eq(c2, 2), + np.array([[1, 2], [0, 2]])) + + o.delete() + np.testing.assert_array_equal( + datastore.table("TestObject").where_eq(c2, 2), + np.array([[0, 2]])) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/datastore/test_id_allocator.py b/tests/datastore/test_id_allocator.py new file mode 100644 index 000000000..dd8310a6e --- /dev/null +++ b/tests/datastore/test_id_allocator.py @@ -0,0 +1,64 @@ +import unittest + +from nmmo.datastore.id_allocator import IdAllocator + +class TestIdAllocator(unittest.TestCase): + def test_id_allocator(self): + id_allocator = IdAllocator(10) + + for i in range(1, 10): + row_id = id_allocator.allocate() + self.assertEqual(i, row_id) + self.assertTrue(id_allocator.full()) + + id_allocator.remove(5) + id_allocator.remove(6) + id_allocator.remove(1) + self.assertFalse(id_allocator.full()) + + self.assertSetEqual( + set(id_allocator.allocate() for i in range(3)), + set([5, 6, 1]) + ) + self.assertTrue(id_allocator.full()) + + id_allocator.expand(11) + self.assertFalse(id_allocator.full()) + + self.assertEqual(id_allocator.allocate(), 10) + + with self.assertRaises(KeyError): + id_allocator.allocate() + + def test_id_reuse(self): + id_allocator = IdAllocator(10) + + for i in range(1, 10): + row_id = id_allocator.allocate() + self.assertEqual(i, row_id) + self.assertTrue(id_allocator.full()) + + id_allocator.remove(5) + id_allocator.remove(6) + id_allocator.remove(1) + self.assertFalse(id_allocator.full()) + + self.assertSetEqual( + set(id_allocator.allocate() for i in range(3)), + set([5, 6, 1]) + ) + self.assertTrue(id_allocator.full()) + + id_allocator.expand(11) + self.assertFalse(id_allocator.full()) + + self.assertEqual(id_allocator.allocate(), 10) + + with self.assertRaises(KeyError): + id_allocator.allocate() + + id_allocator.remove(10) + self.assertEqual(id_allocator.allocate(), 10) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/datastore/test_numpy_datastore.py b/tests/datastore/test_numpy_datastore.py new file mode 100644 index 000000000..2a4dca5a3 --- /dev/null +++ b/tests/datastore/test_numpy_datastore.py @@ -0,0 +1,47 @@ +import unittest + +import numpy as np + +from nmmo.datastore.numpy_datastore import NumpyTable + +# pylint: disable=protected-access +class TestNumpyTable(unittest.TestCase): + def test_continous_table(self): + table = NumpyTable(3, 10, np.float32) + table.update(2, 0, 2.1) + table.update(2, 1, 2.2) + table.update(5, 0, 5.1) + table.update(5, 2, 5.3) + np.testing.assert_array_equal( + table.get([1,2,5]), + np.array([[0, 0, 0], [2.1, 2.2, 0], [5.1, 0, 5.3]], dtype=np.float32) + ) + + def test_discrete_table(self): + table = NumpyTable(3, 10, np.int32) + table.update(2, 0, 11) + table.update(2, 1, 12) + table.update(5, 0, 51) + table.update(5, 2, 53) + np.testing.assert_array_equal( + table.get([1,2,5]), + np.array([[0, 0, 0], [11, 12, 0], [51, 0, 53]], dtype=np.int32) + ) + + def test_expand(self): + table = NumpyTable(3, 10, np.float32) + + table.update(2, 0, 2.1) + with self.assertRaises(IndexError): + table.update(10, 0, 10.1) + + table._expand(11) + table.update(10, 0, 10.1) + + np.testing.assert_array_equal( + table.get([10, 2]), + np.array([[10.1, 0, 0], [2.1, 0, 0]], dtype=np.float32) + ) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/entity/__init__.py b/tests/entity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/entity/test_entity.py b/tests/entity/test_entity.py new file mode 100644 index 000000000..848bb7bb1 --- /dev/null +++ b/tests/entity/test_entity.py @@ -0,0 +1,59 @@ +import unittest +import nmmo +from nmmo.entity.entity import Entity, EntityState +from nmmo.datastore.numpy_datastore import NumpyDatastore + +class MockRealm: + def __init__(self): + self.config = nmmo.config.Default() + self.config.PLAYERS = range(100) + self.datastore = NumpyDatastore() + self.datastore.register_object_type("Entity", EntityState.State.num_attributes) + +# pylint: disable=no-member +class TestEntity(unittest.TestCase): + def test_entity(self): + realm = MockRealm() + entity_id = 123 + entity = Entity(realm, (10,20), entity_id, "name") + + self.assertEqual(entity.id.val, entity_id) + self.assertEqual(entity.row.val, 10) + self.assertEqual(entity.col.val, 20) + self.assertEqual(entity.damage.val, 0) + self.assertEqual(entity.time_alive.val, 0) + self.assertEqual(entity.freeze.val, 0) + self.assertEqual(entity.item_level.val, 0) + self.assertEqual(entity.attacker_id.val, 0) + self.assertEqual(entity.message.val, 0) + self.assertEqual(entity.gold.val, 0) + self.assertEqual(entity.health.val, realm.config.PLAYER_BASE_HEALTH) + self.assertEqual(entity.food.val, realm.config.RESOURCE_BASE) + self.assertEqual(entity.water.val, realm.config.RESOURCE_BASE) + self.assertEqual(entity.melee_level.val, 0) + self.assertEqual(entity.range_level.val, 0) + self.assertEqual(entity.mage_level.val, 0) + self.assertEqual(entity.fishing_level.val, 0) + self.assertEqual(entity.herbalism_level.val, 0) + self.assertEqual(entity.prospecting_level.val, 0) + self.assertEqual(entity.carving_level.val, 0) + self.assertEqual(entity.alchemy_level.val, 0) + + def test_query_by_ids(self): + realm = MockRealm() + entity_id = 123 + entity = Entity(realm, (10,20), entity_id, "name") + + entities = EntityState.Query.by_ids(realm.datastore, [entity_id]) + self.assertEqual(len(entities), 1) + self.assertEqual(entities[0][Entity.State.attr_name_to_col["id"]], entity_id) + self.assertEqual(entities[0][Entity.State.attr_name_to_col["row"]], 10) + self.assertEqual(entities[0][Entity.State.attr_name_to_col["col"]], 20) + + entity.food.update(11) + e_row = EntityState.Query.by_id(realm.datastore, entity_id) + self.assertEqual(e_row[Entity.State.attr_name_to_col["food"]], 11) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lib/__init__.py b/tests/lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lib/test_serialized.py b/tests/lib/test_serialized.py new file mode 100644 index 000000000..1c181567e --- /dev/null +++ b/tests/lib/test_serialized.py @@ -0,0 +1,56 @@ +from collections import defaultdict +import unittest + +from nmmo.datastore.serialized import SerializedState + +# pylint: disable=no-member,unused-argument,unsubscriptable-object + +FooState = SerializedState.subclass("FooState", [ + "a", "b", "col" +]) + +FooState.Limits = { + "a": (-10, 10), +} + +class MockDatastoreRecord(): + def __init__(self): + self._data = defaultdict(lambda: 0) + + def get(self, name): + return self._data[name] + + def update(self, name, value): + self._data[name] = value + +class MockDatastore(): + def create_record(self, name): + return MockDatastoreRecord() + + def register_object_type(self, name, attributes): + assert name == "FooState" + assert attributes == ["a", "b", "col"] + +class TestSerialized(unittest.TestCase): + + def test_serialized(self): + state = FooState(MockDatastore(), FooState.Limits) + + # initial value = 0 + self.assertEqual(state.a.val, 0) + + # if given value is within the range, set to the value + state.a.update(1) + self.assertEqual(state.a.val, 1) + + # if given a lower value than the min, set to min + a_min, a_max = FooState.Limits["a"] + state.a.update(a_min - 100) + self.assertEqual(state.a.val, a_min) + + # if given a higher value than the max, set to max + state.a.update(a_max + 100) + self.assertEqual(state.a.val, a_max) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/render/test_load_replay.py b/tests/render/test_load_replay.py new file mode 100644 index 000000000..5f3fe8203 --- /dev/null +++ b/tests/render/test_load_replay.py @@ -0,0 +1,20 @@ +'''Manual test for rendering replay''' + +if __name__ == '__main__': + import time + + # pylint: disable=import-error + from nmmo.render.render_client import WebsocketRenderer + from nmmo.render.replay_helper import ReplayFileHelper + + # open a client + renderer = WebsocketRenderer() + time.sleep(3) + + # load a replay + replay = ReplayFileHelper.load('replay_dev.json', decompress=False) + + # run the replay + for packet in replay: + renderer.render_packet(packet) + time.sleep(1) diff --git a/tests/render/test_render_save.py b/tests/render/test_render_save.py new file mode 100644 index 000000000..f1ce47971 --- /dev/null +++ b/tests/render/test_render_save.py @@ -0,0 +1,31 @@ +'''Manual test for render client connectivity''' + +if __name__ == '__main__': + import time + import random + import nmmo + + # pylint: disable=import-error + from nmmo.render.render_client import WebsocketRenderer + from tests.testhelpers import ScriptedAgentTestConfig + + TEST_HORIZON = 100 + RANDOM_SEED = random.randint(0, 9999) + + # config.RENDER option is gone, + # RENDER can be done without setting any config + config = ScriptedAgentTestConfig() + env = nmmo.Env(config) + + env.reset(seed=RANDOM_SEED) + + # the renderer is external to the env, so need to manually initiate it + renderer = WebsocketRenderer(env.realm) + + for tick in range(TEST_HORIZON): + env.step({}) + renderer.render_realm() + time.sleep(1) + + # save the packet: this is possible because config.SAVE_REPLAY = True + env.realm.save_replay(f'replay_seed_{RANDOM_SEED:04d}.json', compress=False) diff --git a/tests/systems/test_exchange.py b/tests/systems/test_exchange.py new file mode 100644 index 000000000..599be59b4 --- /dev/null +++ b/tests/systems/test_exchange.py @@ -0,0 +1,90 @@ +from types import SimpleNamespace +import unittest +import nmmo +from nmmo.datastore.numpy_datastore import NumpyDatastore +from nmmo.systems.exchange import Exchange +from nmmo.systems.item import ItemState +import nmmo.systems.item as item +import numpy as np + +class MockRealm: + def __init__(self): + self.config = nmmo.config.Default() + self.config.EXCHANGE_LISTING_DURATION = 3 + self.datastore = NumpyDatastore() + self.items = {} + self.datastore.register_object_type("Item", ItemState.State.num_attributes) + +class MockEntity: + def __init__(self) -> None: + self.items = [] + self.inventory = SimpleNamespace( + receive = lambda item: self.items.append(item), + remove = lambda item: self.items.remove(item) + ) + +class TestExchange(unittest.TestCase): + def test_listings(self): + realm = MockRealm() + exchange = Exchange(realm) + + entity_1 = MockEntity() + + hat_1 = item.Hat(realm, 1) + hat_2 = item.Hat(realm, 10) + entity_1.inventory.receive(hat_1) + entity_1.inventory.receive(hat_2) + self.assertEqual(len(entity_1.items), 2) + + tick = 0 + exchange._list_item(hat_1, entity_1, 10, tick) + self.assertEqual(len(exchange._item_listings), 1) + self.assertEqual(exchange._listings_queue[0], (hat_1.id.val, 0)) + + tick = 1 + exchange._list_item(hat_2, entity_1, 20, tick) + self.assertEqual(len(exchange._item_listings), 2) + self.assertEqual(exchange._listings_queue[0], (hat_1.id.val, 0)) + + tick = 4 + exchange.step(tick) + # hat_1 should expire and not be listed + self.assertEqual(len(exchange._item_listings), 1) + self.assertEqual(exchange._listings_queue[0], (hat_2.id.val, 1)) + + tick = 5 + exchange._list_item(hat_2, entity_1, 10, tick) + exchange.step(tick) + # hat_2 got re-listed, so should still be listed + self.assertEqual(len(exchange._item_listings), 1) + self.assertEqual(exchange._listings_queue[0], (hat_2.id.val, 5)) + + tick = 10 + exchange.step(tick) + self.assertEqual(len(exchange._item_listings), 0) + + def test_for_sale_items(self): + realm = MockRealm() + exchange = Exchange(realm) + entity_1 = MockEntity() + + hat_1 = item.Hat(realm, 1) + hat_2 = item.Hat(realm, 10) + exchange._list_item(hat_1, entity_1, 10, 0) + exchange._list_item(hat_2, entity_1, 20, 10) + + np.testing.assert_array_equal( + item.Item.Query.for_sale(realm.datastore)[:,0], [hat_1.id.val, hat_2.id.val]) + + # first listing should expire + exchange.step(10) + np.testing.assert_array_equal( + item.Item.Query.for_sale(realm.datastore)[:,0], [hat_2.id.val]) + + # second listing should expire + exchange.step(100) + np.testing.assert_array_equal( + item.Item.Query.for_sale(realm.datastore)[:,0], []) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/systems/test_item.py b/tests/systems/test_item.py new file mode 100644 index 000000000..bf86d323c --- /dev/null +++ b/tests/systems/test_item.py @@ -0,0 +1,69 @@ +import unittest +import numpy as np + +import nmmo +from nmmo.datastore.numpy_datastore import NumpyDatastore +from nmmo.systems.item import Hat, Top, ItemState + +class MockRealm: + def __init__(self): + self.config = nmmo.config.Default() + self.datastore = NumpyDatastore() + self.items = {} + self.datastore.register_object_type("Item", ItemState.State.num_attributes) + self.players = {} + +# pylint: disable=no-member +class TestItem(unittest.TestCase): + def test_item(self): + realm = MockRealm() + + hat_1 = Hat(realm, 1) + self.assertTrue(ItemState.Query.by_id(realm.datastore, hat_1.id.val) is not None) + self.assertEqual(hat_1.type_id.val, Hat.ITEM_TYPE_ID) + self.assertEqual(hat_1.level.val, 1) + self.assertEqual(hat_1.mage_defense.val, 10) + + hat_2 = Hat(realm, 10) + self.assertTrue(ItemState.Query.by_id(realm.datastore, hat_2.id.val) is not None) + self.assertEqual(hat_2.level.val, 10) + self.assertEqual(hat_2.melee_defense.val, 100) + + self.assertDictEqual(realm.items, {hat_1.id.val: hat_1, hat_2.id.val: hat_2}) + + # also test destroy + ids = [hat_1.id.val, hat_2.id.val] + hat_1.destroy() + hat_2.destroy() + # after destroy(), the datastore entry is gone, but the class still exsits + # make sure that after destroy the owner_id is 0, at least + self.assertTrue(hat_1.owner_id.val == 0) + self.assertTrue(hat_2.owner_id.val == 0) + for item_id in ids: + self.assertTrue(len(ItemState.Query.by_id(realm.datastore, item_id)) == 0) + self.assertDictEqual(realm.items, {}) + + # create a new item with the hat_1's id, but it must still be void + new_top = Top(realm, 3) + new_top.id.update(ids[0]) # hat_1's id + new_top.owner_id.update(100) + # make sure that the hat_1 is not linked to the new_top + self.assertTrue(hat_1.owner_id.val == 0) + + def test_owned_by(self): + realm = MockRealm() + + hat_1 = Hat(realm, 1) + hat_2 = Hat(realm, 10) + + hat_1.owner_id.update(1) + hat_2.owner_id.update(1) + + np.testing.assert_array_equal( + ItemState.Query.owned_by(realm.datastore, 1)[:,0], + [hat_1.id.val, hat_2.id.val]) + + self.assertEqual(Hat.Query.owned_by(realm.datastore, 2).size, 0) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index d0a97e0ca..000000000 --- a/tests/test_api.py +++ /dev/null @@ -1,20 +0,0 @@ -from pdb import set_trace as T - - -def test_import(): - import nmmo - -def test_env_creation(): - import nmmo - env = nmmo.Env() - env.reset() - env.step({}) - -def test_io(): - import nmmo - env = nmmo.Env() - env.observation_space(0) - env.action_space(0) - -if __name__ == '__main__': - test_io() diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index c0c33a5b0..000000000 --- a/tests/test_client.py +++ /dev/null @@ -1,15 +0,0 @@ -'''Manual test for client connectivity''' - -from pdb import set_trace as T -import pytest - -import nmmo - -if __name__ == '__main__': - env = nmmo.Env() - env.config.RENDER = True - - env.reset() - while True: - env.render() - env.step({}) diff --git a/tests/test_determinism.py b/tests/test_determinism.py index 60e2a0eef..9d4ca733f 100644 --- a/tests/test_determinism.py +++ b/tests/test_determinism.py @@ -1,47 +1,76 @@ -from pdb import set_trace as T -import numpy as np +#from pdb import set_trace as T +import unittest + +import logging import random +from tqdm import tqdm + +from tests.testhelpers import ScriptedAgentTestConfig, ScriptedAgentTestEnv +from tests.testhelpers import observations_are_equal, actions_are_equal + +# 30 seems to be enough to test variety of agent actions +TEST_HORIZON = 30 +RANDOM_SEED = random.randint(0, 10000) + + +class TestDeterminism(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.horizon = TEST_HORIZON + cls.rand_seed = RANDOM_SEED + cls.config = ScriptedAgentTestConfig() + env = ScriptedAgentTestEnv(cls.config) + + logging.info('TestDeterminism: Setting up the reference env with seed %s', str(cls.rand_seed)) + cls.init_obs_src = env.reset(seed=cls.rand_seed) + cls.actions_src = [] + logging.info('TestDeterminism: Running %s ticks', str(cls.horizon)) + for _ in tqdm(range(cls.horizon)): + nxt_obs_src, _, _, _ = env.step({}) + cls.actions_src.append(env.actions) + cls.final_obs_src = nxt_obs_src + npcs_src = {} + for nid, npc in list(env.realm.npcs.items()): + npcs_src[nid] = npc.packet() + cls.final_npcs_src = npcs_src + + logging.info('TestDeterminism: Setting up the replication env with seed %s', str(cls.rand_seed)) + cls.init_obs_rep = env.reset(seed=cls.rand_seed) + cls.actions_rep = [] + logging.info('TestDeterminism: Running %s ticks', str(cls.horizon)) + for _ in tqdm(range(cls.horizon)): + nxt_obs_rep, _, _, _ = env.step({}) + cls.actions_rep.append(env.actions) + cls.final_obs_rep = nxt_obs_rep + npcs_rep = {} + for nid, npc in list(env.realm.npcs.items()): + npcs_rep[nid] = npc.packet() + cls.final_npcs_rep = npcs_rep + + def test_func_are_observations_equal(self): + self.assertTrue(observations_are_equal(self.init_obs_src, self.init_obs_src)) + self.assertTrue(observations_are_equal(self.final_obs_src, self.final_obs_src)) + self.assertTrue(actions_are_equal(self.actions_src[0], self.actions_src[0])) + self.assertDictEqual(self.final_npcs_src, self.final_npcs_src) + + def test_compare_initial_observations(self): + # assertDictEqual CANNOT replace are_observations_equal + self.assertTrue(observations_are_equal(self.init_obs_src, self.init_obs_rep)) + #self.assertDictEqual(self.init_obs_src, self.init_obs_rep) + + def test_compare_actions(self): + self.assertEqual(len(self.actions_src), len(self.actions_rep)) + for t, action_src in enumerate(self.actions_src): + self.assertTrue(actions_are_equal(action_src, self.actions_rep[t])) + + def test_compare_final_observations(self): + # assertDictEqual CANNOT replace are_observations_equal + self.assertTrue(observations_are_equal(self.final_obs_src, self.final_obs_rep)) + #self.assertDictEqual(self.final_obs_src, self.final_obs_rep) + + def test_compare_final_npcs(self) : + self.assertDictEqual(self.final_npcs_src, self.final_npcs_rep) + -import nmmo - - -def test_determinism(): - config = nmmo.config.Default() - env1 = nmmo.Env(config, seed=42) - env1.reset() - for i in range(2): - obs1, _, _, _ = env1.step({}) - - config = nmmo.config.Default() - env2 = nmmo.Env(config, seed=42) - env2.reset() - for i in range(2): - obs2, _, _, _ = env2.step({}) - - npc1 = env1.realm.npcs.values() - npc2 = env2.realm.npcs.values() - - for n1, n2 in zip(npc1, npc2): - assert n1.pos == n2.pos - - assert list(obs1.keys()) == list(obs2.keys()) - keys = list(obs1.keys()) - for k in keys: - ent1 = obs1[k] - ent2 = obs2[k] - - obj = ent1.keys() - for o in obj: - obj1 = ent1[o] - obj2 = ent2[o] - - attrs = list(obj1) - for a in attrs: - attr1 = obj1[a] - attr2 = obj2[a] - - if np.sum(attr1 != attr2) > 0: - T() - assert np.sum(attr1 != attr2) == 0 - -test_determinism() \ No newline at end of file +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_deterministic_replay.py b/tests/test_deterministic_replay.py new file mode 100644 index 000000000..59d46644c --- /dev/null +++ b/tests/test_deterministic_replay.py @@ -0,0 +1,169 @@ +#from pdb import set_trace as T +import unittest + +import os +import glob +import pickle +import logging +import random +from typing import Any, Dict + +import numpy as np +from tqdm import tqdm + +from tests.testhelpers import ScriptedAgentTestConfig, ScriptedAgentTestEnv +from tests.testhelpers import observations_are_equal + +import nmmo + +TEST_HORIZON = 50 +LOCAL_REPLAY = 'tests/replay_local.pickle' + +def load_replay_file(replay_file): + # load the pickle file + with open(replay_file, 'rb') as handle: + ref_data = pickle.load(handle) + + logging.info('TestDetReplay: Loading the existing replay file with seed %s', + str(ref_data['seed'])) + + seed = ref_data['seed'] + config = ref_data['config'] + map_src = ref_data['map'] + init_obs = ref_data['init_obs'] + init_npcs = ref_data['init_npcs'] + med_obs = ref_data['med_obs'] + actions = ref_data['actions'] + final_obs = ref_data['final_obs'] + final_npcs = ref_data['final_npcs'] + + return seed, config, map_src, init_obs, init_npcs, med_obs, actions, final_obs, final_npcs + + +def make_actions_picklable(actions: Dict[int, Dict[str, Dict[str, Any]]]): + for eid in actions: + for atn, args in actions[eid].items(): + for arg, val in args.items(): + if arg == nmmo.io.action.Price and not isinstance(val, int): + # : + # convert Discrete_1 to 1 + actions[eid][atn][arg] = val.val + return actions + + +def generate_replay_file(replay_file, test_horizon): + # generate the new data with a new env + seed = random.randint(0, 10000) + logging.info('TestDetReplay: Creating a new replay file with seed %s', str(seed)) + config = ScriptedAgentTestConfig() + env_src = ScriptedAgentTestEnv(config, seed=seed) + init_obs = env_src.reset() + init_npcs = env_src.realm.npcs.packet + + # extract the map + map_src = np.zeros((config.MAP_SIZE, config.MAP_SIZE)) + for r in range(config.MAP_SIZE): + for c in range(config.MAP_SIZE): + map_src[r,c] = env_src.realm.map.tiles[r,c].material_id.val + + med_obs, actions = [], [] + logging.info('TestDetReplay: Running %s ticks', str(test_horizon)) + for _ in tqdm(range(test_horizon)): + nxt_obs, _, _, _ = env_src.step({}) + med_obs.append(nxt_obs) + actions.append(make_actions_picklable(env_src.actions)) + final_obs = nxt_obs + final_npcs = env_src.realm.npcs.packet + + # save to the file + with open(replay_file, 'wb') as handle: + ref_data = {} + ref_data['version'] = nmmo.__version__ # just in case + ref_data['seed'] = seed + ref_data['config'] = config + ref_data['map'] = map_src + ref_data['init_obs'] = init_obs + ref_data['init_npcs'] = init_npcs + ref_data['med_obs'] = med_obs + ref_data['actions'] = actions + ref_data['final_obs'] = final_obs + ref_data['final_npcs'] = final_npcs + + pickle.dump(ref_data, handle) + + return seed, config, map_src, init_obs, init_npcs, med_obs, actions, final_obs, final_npcs + + +class TestDeterministicReplay(unittest.TestCase): + + # CHECK ME: pausing the deterministic replay test while debugging actions/items + # because changes there would most likely to change the game play and make the test fail + __test__ = False + + @classmethod + def setUpClass(cls): + """ + First, check if there is a replay file on the repo that starts with 'replay_repo_' + If there is one, use it. + + Second, check if there a local replay file, which should be named 'replay_local.pickle' + If there is one, use it. If not create one. + + TODO: allow passing a different replay file + """ + # first, look for the repo replay file + replay_files = glob.glob(os.path.join('tests', 'replay_repo_*.pickle')) + if replay_files: + # there may be several, but we only take the first one [0] + cls.seed, cls.config, cls.map_src, cls.init_obs_src, cls.init_npcs_src, \ + cls.med_obs_src,cls.actions, cls.final_obs_src, cls.final_npcs_src = \ + load_replay_file(replay_files[0]) + else: + # if there is no repo replay file, then go with the default local file + if os.path.exists(LOCAL_REPLAY): + cls.seed, cls.config, cls.map_src, cls.init_obs_src, cls.init_npcs_src, \ + cls.med_obs_src, cls.actions, cls.final_obs_src, cls.final_npcs_src = \ + load_replay_file(LOCAL_REPLAY) + else: + cls.seed, cls.config, cls.map_src, cls.init_obs_src, cls.init_npcs_src, \ + cls.med_obs_src, cls.actions, cls.final_obs_src, cls.final_npcs_src = \ + generate_replay_file(LOCAL_REPLAY, TEST_HORIZON) + cls.horizon = len(cls.actions) + + logging.info('TestDetReplay: Setting up the replication env with seed %s', str(cls.seed)) + env_rep = ScriptedAgentTestEnv(cls.config, seed=cls.seed) + cls.init_obs_rep = env_rep.reset() + cls.init_npcs_rep = env_rep.realm.npcs.packet + + # extract the map + cls.map_rep = np.zeros((cls.config.MAP_SIZE, cls.config.MAP_SIZE)) + for r in range(cls.config.MAP_SIZE): + for c in range(cls.config.MAP_SIZE): + cls.map_rep[r,c] = env_rep.realm.map.tiles[r,c].material_id.val + + cls.med_obs_rep, cls.actions_rep = [], [] + logging.info('TestDetReplay: Running %s ticks', str(cls.horizon)) + for t in tqdm(range(cls.horizon)): + nxt_obs_rep, _, _, _ = env_rep.step(cls.actions[t]) + cls.med_obs_rep.append(nxt_obs_rep) + cls.final_obs_rep = nxt_obs_rep + cls.final_npcs_rep = env_rep.realm.npcs.packet + + def test_compare_maps(self): + self.assertEqual(np.sum(self.map_src != self.map_rep), 0) + + def test_compare_init_obs(self): + self.assertTrue(observations_are_equal(self.init_obs_src, self.init_obs_rep)) + + def test_compare_init_npcs(self): + self.assertTrue(observations_are_equal(self.init_npcs_src, self.init_npcs_rep)) + + def test_compare_final_obs(self): + self.assertTrue(observations_are_equal(self.final_obs_src, self.final_obs_rep)) + + def test_compare_final_npcs(self): + self.assertTrue(observations_are_equal(self.final_npcs_src, self.final_npcs_rep)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_emulation.py b/tests/test_emulation.py deleted file mode 100644 index eb9787c1e..000000000 --- a/tests/test_emulation.py +++ /dev/null @@ -1,85 +0,0 @@ -from pdb import set_trace as T -import numpy as np - -import nmmo - -def init_env(config_cls=nmmo.config.Small): - env = nmmo.Env(config_cls()) - obs = env.reset() - return env, obs - -def test_emulate_flat_obs(): - class Config(nmmo.config.Small): - EMULATE_FLAT_OBS = True - - init_env(Config) - -def test_emulate_flat_atn(): - class Config(nmmo.config.Small): - EMULATE_FLAT_ATN = True - - init_env(Config) - -def test_emulate_const_nent(): - class Config(nmmo.config.Small): - EMULATE_CONST_NENT = True - - init_env(Config) - -def test_all_emulation(): - class Config(nmmo.config.Small): - EMULATE_FLAT_OBS = True - EMULATE_FLAT_ATN = True - EMULATE_CONST_POP = True - - init_env(Config) - -def test_emulate_single_agent(): - class Config(nmmo.config.Small): - EMULATE_CONST_NENT = True - - config = Config() - envs = nmmo.emulation.multiagent_to_singleagent(config) - - for e in envs: - ob = e.reset() - for i in range(32): - ob, reward, done, info = e.step({}) - -def equals(config, batch1, batch2): - assert list(batch1.keys()) == list(batch2.keys()) - for (entity_name,), entity in nmmo.io.stimulus.Serialized: - if not entity.enabled(config): - continue - - batch1_attrs = batch1[entity_name] - batch2_attrs = batch2[entity_name] - - attr_keys = 'Continuous Discrete N'.split() - assert list(batch1_attrs.keys()) == list(batch2_attrs.keys()) == attr_keys - - for key in attr_keys: - assert np.array_equal(batch1_attrs[key], batch2_attrs[key]) - -def test_pack_unpack_obs(): - env, obs = init_env() - packed = nmmo.emulation.pack_obs(obs) - packed = np.vstack(list(packed.values())) - unpacked = nmmo.emulation.unpack_obs(env.config, packed) - batched = nmmo.emulation.batch_obs(env.config, obs) - - equals(env.config, unpacked, batched) - -def test_obs_pack_speed(benchmark): - env, obs = init_env() - benchmark(lambda: nmmo.emulation.pack_obs(obs)) - -def test_obs_unpack_speed(benchmark): - env, obs = init_env() - packed = nmmo.emulation.pack_obs(obs) - packed = np.vstack(list(packed.values())) - - benchmark(lambda: nmmo.emulation.unpack_obs(env.config, packed)) - -if __name__ == '__main__': - test_pack_unpack_obs() diff --git a/tests/test_eventlog.py b/tests/test_eventlog.py new file mode 100644 index 000000000..57d53647b --- /dev/null +++ b/tests/test_eventlog.py @@ -0,0 +1,123 @@ +import unittest + +import nmmo +from nmmo.datastore.numpy_datastore import NumpyDatastore +from nmmo.lib.event_log import EventState, EventLogger +from nmmo.lib.log import EventCode +from nmmo.entity.entity import Entity +from nmmo.systems.item import ItemState +from nmmo.systems.item import Scrap, Ration, Hat +from nmmo.systems import skill as Skill + + +class MockRealm: + def __init__(self): + self.config = nmmo.config.Default() + self.datastore = NumpyDatastore() + self.items = {} + self.datastore.register_object_type("Event", EventState.State.num_attributes) + self.datastore.register_object_type("Item", ItemState.State.num_attributes) + self.tick = 0 + + +class MockEntity(Entity): + # pylint: disable=super-init-not-called + def __init__(self, ent_id, **kwargs): + self.id = ent_id + self.level = kwargs.pop('attack_level', 0) + + @property + def ent_id(self): + return self.id + + @property + def attack_level(self): + return self.level + + +class TestEventLog(unittest.TestCase): + + def test_event_logging(self): + mock_realm = MockRealm() + event_log = EventLogger(mock_realm) + + mock_realm.tick = 0 # tick increase to 1 after all actions are processed + event_log.record(EventCode.EAT_FOOD, MockEntity(1)) + event_log.record(EventCode.DRINK_WATER, MockEntity(2)) + event_log.record(EventCode.SCORE_HIT, MockEntity(2), + combat_style=Skill.Melee, damage=50) + event_log.record(EventCode.PLAYER_KILL, MockEntity(3), + target=MockEntity(5, attack_level=5)) + + mock_realm.tick = 1 + event_log.record(EventCode.CONSUME_ITEM, MockEntity(4), + item=Ration(mock_realm, 8)) + event_log.record(EventCode.GIVE_ITEM, MockEntity(4)) + event_log.record(EventCode.DESTROY_ITEM, MockEntity(5)) + event_log.record(EventCode.HARVEST_ITEM, MockEntity(6), + item=Scrap(mock_realm, 3)) + + mock_realm.tick = 2 + event_log.record(EventCode.GIVE_GOLD, MockEntity(7)) + event_log.record(EventCode.LIST_ITEM, MockEntity(8), + item=Ration(mock_realm, 5), price=11) + event_log.record(EventCode.EARN_GOLD, MockEntity(9), amount=15) + event_log.record(EventCode.BUY_ITEM, MockEntity(10), + item=Scrap(mock_realm, 7), price=21) + #event_log.record(EventCode.SPEND_GOLD, env.realm.players[11], amount=25) + + mock_realm.tick = 3 + event_log.record(EventCode.LEVEL_UP, MockEntity(12), + skill=Skill.Fishing, level=3) + + mock_realm.tick = 4 + event_log.record(EventCode.GO_FARTHEST, MockEntity(12), distance=6) + event_log.record(EventCode.EQUIP_ITEM, MockEntity(12), + item=Hat(mock_realm, 4)) + + log_data = [list(row) for row in event_log.get_data()] + + self.assertListEqual(log_data, [ + [ 1, 1, 1, EventCode.EAT_FOOD, 0, 0, 0, 0, 0], + [ 2, 2, 1, EventCode.DRINK_WATER, 0, 0, 0, 0, 0], + [ 3, 2, 1, EventCode.SCORE_HIT, 1, 0, 50, 0, 0], + [ 4, 3, 1, EventCode.PLAYER_KILL, 0, 5, 0, 0, 5], + [ 5, 4, 2, EventCode.CONSUME_ITEM, 16, 8, 1, 0, 0], + [ 6, 4, 2, EventCode.GIVE_ITEM, 0, 0, 0, 0, 0], + [ 7, 5, 2, EventCode.DESTROY_ITEM, 0, 0, 0, 0, 0], + [ 8, 6, 2, EventCode.HARVEST_ITEM, 13, 3, 1, 0, 0], + [ 9, 7, 3, EventCode.GIVE_GOLD, 0, 0, 0, 0, 0], + [10, 8, 3, EventCode.LIST_ITEM, 16, 5, 1, 11, 0], + [11, 9, 3, EventCode.EARN_GOLD, 0, 0, 0, 15, 0], + [12, 10, 3, EventCode.BUY_ITEM, 13, 7, 1, 21, 0], + [13, 12, 4, EventCode.LEVEL_UP, 4, 3, 0, 0, 0], + [14, 12, 5, EventCode.GO_FARTHEST, 0, 0, 6, 0, 0], + [15, 12, 5, EventCode.EQUIP_ITEM, 2, 4, 1, 0, 0]]) + +if __name__ == '__main__': + unittest.main() + + """ + TEST_HORIZON = 50 + RANDOM_SEED = 338 + + from tests.testhelpers import ScriptedAgentTestConfig, ScriptedAgentTestEnv + + config = ScriptedAgentTestConfig() + env = ScriptedAgentTestEnv(config) + + env.reset(seed=RANDOM_SEED) + + from tqdm import tqdm + for tick in tqdm(range(TEST_HORIZON)): + env.step({}) + + # events to check + log = env.realm.event_log.get_data() + idx = (log[:,2] == tick+1) & (log[:,3] == EventCode.EQUIP_ITEM) + if sum(idx): + print(log[idx]) + print() + + print('done') + """ diff --git a/tests/test_performance.py b/tests/test_performance.py index 9c264a2d6..f27519e6c 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -1,100 +1,139 @@ -from pdb import set_trace as T -import pytest import nmmo -from nmmo.core.config import Config, Small, Large, Resource, Combat, Progression, NPC, AllGameSystems +from nmmo.core.config import (NPC, AllGameSystems, Combat, Communication, + Equipment, Exchange, Item, Medium, Profession, + Progression, Resource, Small, Terrain) +from scripted import baselines + # Test utils def create_and_reset(conf): - env = nmmo.Env(conf()) - env.reset(idx=1) + env = nmmo.Env(conf()) + env.reset(map_id=1) def create_config(base, *systems): - systems = (base, *systems) - name = '_'.join(cls.__name__ for cls in systems) + systems = (base, *systems) + name = '_'.join(cls.__name__ for cls in systems) - conf = type(name, systems, {})() + conf = type(name, systems, {})() - conf.TERRAIN_TRAIN_MAPS = 1 - conf.TERRAIN_EVAL_MAPS = 1 + conf.TERRAIN_TRAIN_MAPS = 1 + conf.TERRAIN_EVAL_MAPS = 1 + conf.IMMORTAL = True - return conf + return conf def benchmark_config(benchmark, base, nent, *systems): - conf = create_config(base, *systems) - conf.NENT = nent - - env = nmmo.Env(conf) - env.reset() + conf = create_config(base, *systems) + conf.PLAYER_N = nent + conf.PLAYERS = [baselines.Random] - benchmark(env.step, actions={}) - -def benchmark_env(benchmark, env, nent): - env.config.NENT = nent - env.reset() + env = nmmo.Env(conf) + env.reset() - benchmark(env.step, actions={}) + benchmark(env.step, actions={}) # Small map tests -- fast with greater coverage for individual game systems def test_small_env_creation(benchmark): - benchmark(lambda: nmmo.Env(Small())) + benchmark(lambda: nmmo.Env(Small())) def test_small_env_reset(benchmark): - env = nmmo.Env(Small()) - benchmark(lambda: env.reset(idx=1)) + config = Small() + config.PLAYERS = [baselines.Random] + env = nmmo.Env(config) + benchmark(lambda: env.reset(map_id=1)) + +def test_fps_base_small_1_pop(benchmark): + benchmark_config(benchmark, Small, 1) + +def test_fps_minimal_small_1_pop(benchmark): + benchmark_config(benchmark, Small, 1, Terrain, Resource, Combat, Progression) + +def test_fps_npc_small_1_pop(benchmark): + benchmark_config(benchmark, Small, 1, Terrain, Resource, Combat, Progression, NPC) + +def test_fps_test_small_1_pop(benchmark): + benchmark_config(benchmark, Small, 1, Terrain, Resource, Combat, Progression, Item, Exchange) -def test_fps_small_base_1_pop(benchmark): - benchmark_config(benchmark, Small, 1) +def test_fps_no_npc_small_1_pop(benchmark): + benchmark_config(benchmark, Small, 1, Terrain, Resource, + Combat, Progression, Item, Equipment, Profession, Exchange, Communication) -def test_fps_small_resource_1_pop(benchmark): - benchmark_config(benchmark, Small, 1, Resource) +def test_fps_all_small_1_pop(benchmark): + benchmark_config(benchmark, Small, 1, AllGameSystems) -def test_fps_small_combat_1_pop(benchmark): - benchmark_config(benchmark, Small, 1, Combat) +def test_fps_base_med_1_pop(benchmark): + benchmark_config(benchmark, Medium, 1) -def test_fps_small_progression_1_pop(benchmark): - benchmark_config(benchmark, Small, 1, Progression) +def test_fps_minimal_med_1_pop(benchmark): + benchmark_config(benchmark, Medium, 1, Terrain, Resource, Combat) -def test_fps_small_rcp_1_pop(benchmark): - benchmark_config(benchmark, Small, 1, Resource, Combat, Progression) +def test_fps_npc_med_1_pop(benchmark): + benchmark_config(benchmark, Medium, 1, Terrain, Resource, Combat, NPC) -def test_fps_small_npc_1_pop(benchmark): - benchmark_config(benchmark, Small, 1, NPC) +def test_fps_test_med_1_pop(benchmark): + benchmark_config(benchmark, Medium, 1, Terrain, Resource, Combat, Progression, Item, Exchange) -def test_fps_small_all_1_pop(benchmark): - benchmark_config(benchmark, Small, 1, AllGameSystems) +def test_fps_no_npc_med_1_pop(benchmark): + benchmark_config(benchmark, Medium, 1, Terrain, Resource, + Combat, Progression, Item, Equipment, Profession, Exchange, Communication) -def test_fps_small_rcp_100_pop(benchmark): - benchmark_config(benchmark, Small, 100, Resource, Combat, Progression) +def test_fps_all_med_1_pop(benchmark): + benchmark_config(benchmark, Medium, 1, AllGameSystems) -def test_fps_small_all_100_pop(benchmark): - benchmark_config(benchmark, Small, 100, AllGameSystems) +def test_fps_base_med_100_pop(benchmark): + benchmark_config(benchmark, Medium, 100) + +def test_fps_minimal_med_100_pop(benchmark): + benchmark_config(benchmark, Medium, 100, Terrain, Resource, Combat) + +def test_fps_npc_med_100_pop(benchmark): + benchmark_config(benchmark, Medium, 100, Terrain, Resource, Combat, NPC) + +def test_fps_test_med_100_pop(benchmark): + benchmark_config(benchmark, Medium, 100, Terrain, Resource, Combat, Progression, Item, Exchange) + +def test_fps_no_npc_med_100_pop(benchmark): + benchmark_config(benchmark, Medium, 100, Terrain, Resource, Combat, + Progression, Item, Equipment, Profession, Exchange, Communication) + +def test_fps_all_med_100_pop(benchmark): + benchmark_config(benchmark, Medium, 100, AllGameSystems) + + +''' +def benchmark_env(benchmark, env, nent): + env.config.PLAYER_N = nent + env.config.PLAYERS = [nmmo.agent.Random] + env.reset() + benchmark(env.step, actions={}) # Reuse large maps since we aren't benchmarking the reset function def test_large_env_creation(benchmark): - benchmark(lambda: nmmo.Env(Large())) + benchmark(lambda: nmmo.Env(Large())) def test_large_env_reset(benchmark): - env = nmmo.Env(Large()) - benchmark(lambda: env.reset(idx=1)) + env = nmmo.Env(Large()) + benchmark(lambda: env.reset(idx=1)) -LargeMapsRCP = nmmo.Env(create_config(Large, Resource, Combat, Progression)) +LargeMapsRCP = nmmo.Env(create_config(Large, Resource, Terrain, Combat, Progression)) LargeMapsAll = nmmo.Env(create_config(Large, AllGameSystems)) def test_fps_large_rcp_1_pop(benchmark): - benchmark_env(benchmark, LargeMapsRCP, 1) + benchmark_env(benchmark, LargeMapsRCP, 1) def test_fps_large_rcp_100_pop(benchmark): - benchmark_env(benchmark, LargeMapsRCP, 100) + benchmark_env(benchmark, LargeMapsRCP, 100) def test_fps_large_rcp_1000_pop(benchmark): - benchmark_env(benchmark, LargeMapsRCP, 1000) + benchmark_env(benchmark, LargeMapsRCP, 1000) def test_fps_large_all_1_pop(benchmark): - benchmark_env(benchmark, LargeMapsAll, 1) + benchmark_env(benchmark, LargeMapsAll, 1) def test_fps_large_all_100_pop(benchmark): - benchmark_env(benchmark, LargeMapsAll, 100) + benchmark_env(benchmark, LargeMapsAll, 100) def test_fps_large_all_1000_pop(benchmark): - benchmark_env(benchmark, LargeMapsAll, 1000) + benchmark_env(benchmark, LargeMapsAll, 1000) +''' diff --git a/tests/test_pettingzoo.py b/tests/test_pettingzoo.py index 600d1e0b9..3a8d78f30 100644 --- a/tests/test_pettingzoo.py +++ b/tests/test_pettingzoo.py @@ -1,9 +1,15 @@ -from pdb import set_trace as T -from pettingzoo.test import parallel_api_test import nmmo +from scripted import baselines def test_pettingzoo_api(): - env = nmmo.Env() - parallel_api_test(env, num_cycles=1000) + config = nmmo.config.Default() + config.PLAYERS = [baselines.Random] + # ensv = nmmo.Env(config) + # TODO: disabled due to Env not implementing the correct PettinZoo step() API + # parallel_api_test(env, num_cycles=1000) + + +if __name__ == '__main__': + test_pettingzoo_api() diff --git a/tests/test_rollout.py b/tests/test_rollout.py index e9209f166..35f3591f6 100644 --- a/tests/test_rollout.py +++ b/tests/test_rollout.py @@ -1,15 +1,14 @@ -from pdb import set_trace as T import nmmo - +from scripted.baselines import Random def test_rollout(): - config = nmmo.config.Default() - config.AGENTS = [nmmo.core.agent.Random] + config = nmmo.config.Default() + config.PLAYERS = [Random] - env = nmmo.Env() - env.reset() - for i in range(128): - env.step({}) + env = nmmo.Env(config) + env.reset() + for _ in range(128): + env.step({}) if __name__ == '__main__': - test_rollout() + test_rollout() diff --git a/tests/test_task.py b/tests/test_task.py new file mode 100644 index 000000000..f8c61040c --- /dev/null +++ b/tests/test_task.py @@ -0,0 +1,121 @@ +# pylint: disable=redefined-outer-name,super-init-not-called + +import logging +import unittest + +import nmmo +from nmmo.lib import task +from nmmo.systems import achievement +from nmmo.core.realm import Realm +from nmmo.entity.entity import Entity +from scripted.baselines import Sleeper + + +class Success(task.Task): + def completed(self, realm: Realm, entity: Entity) -> bool: + return True + +class Failure(task.Task): + def completed(self, realm: Realm, entity: Entity) -> bool: + return False + +class FakeTask(task.TargetTask): + def __init__(self, target: task.TaskTarget, param1: int, param2: float) -> None: + super().__init__(target) + self._param1 = param1 + self._param2 = param2 + + def completed(self, realm: Realm, entity: Entity) -> bool: + return False + + def description(self): + return [super().description(), self._param1, self._param2] + +# pylint: disable +class MockRealm(Realm): + def __init__(self): + pass + +class MockEntity(Entity): + def __init__(self): + pass + +realm = MockRealm() +entity = MockEntity() + +class TestTasks(unittest.TestCase): + + def test_operators(self): + self.assertFalse(task.AND(Success(), Failure(), Success()).completed(realm, entity)) + self.assertTrue(task.OR(Success(), Failure(), Success()).completed(realm, entity)) + self.assertTrue(task.AND(Success(), task.NOT(Failure()), Success()).completed(realm, entity)) + + def test_descriptions(self): + self.assertEqual( + task.AND(Success(), + task.NOT(task.OR(Success(), + FakeTask(task.TaskTarget("t1", []), 123, 3.45)))).description(), + ['AND', 'Success', ['NOT', ['OR', 'Success', [['FakeTask', 't1'], 123, 3.45]]]] + ) + + def test_team_helper(self): + team_helper = task.TeamHelper(range(1, 101), 5) + + self.assertSequenceEqual(team_helper.own_team(17).agents(), range(1, 21)) + self.assertSequenceEqual(team_helper.own_team(84).agents(), range(81, 101)) + + self.assertSequenceEqual(team_helper.left_team(84).agents(), range(61, 81)) + self.assertSequenceEqual(team_helper.right_team(84).agents(), range(1, 21)) + + self.assertSequenceEqual(team_helper.left_team(17).agents(), range(81, 101)) + self.assertSequenceEqual(team_helper.right_team(17).agents(), range(21, 41)) + + self.assertSequenceEqual(team_helper.all().agents(), range(1, 101)) + + def test_task_target(self): + task_target = task.TaskTarget("Foo", [1, 2, 8, 9]) + + self.assertEqual(task_target.member(2).description(), "Foo.2") + self.assertEqual(task_target.member(2).agents(), [8]) + + def test_sample(self): + sampler = task.TaskSampler() + + sampler.add_task_spec(Success) + sampler.add_task_spec(Failure) + sampler.add_task_spec(FakeTask, [ + [task.TaskTarget("t1", []), task.TaskTarget("t2", [])], + [1, 5, 10], + [0.1, 0.2, 0.3, 0.4] + ]) + + sampler.sample(max_clauses=5, max_clause_size=5, not_p=0.5) + + def test_default_sampler(self): + team_helper = task.TeamHelper(range(1, 101), 5) + sampler = task.TaskSampler.create_default_task_sampler(team_helper, 10) + + sampler.sample(max_clauses=5, max_clause_size=5, not_p=0.5) + + def test_completed_tasks_in_info(self): + config = nmmo.config.Default() + config.PLAYERS = [Sleeper] + config.TASKS = [ + achievement.Achievement(Success(), 10), + achievement.Achievement(Failure(), 100) + ] + + env = nmmo.Env(config) + + env.reset() + _, _, _, infos = env.step({}) + logging.info(infos) + self.assertEqual(infos[1][Success().to_string()], 10) + self.assertEqual(infos[1][Failure().to_string()], 0) + + _, _, _, infos = env.step({}) + self.assertEqual(infos[1][Success().to_string()], 0) + self.assertEqual(infos[1][Failure().to_string()], 0) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/testhelpers.py b/tests/testhelpers.py new file mode 100644 index 000000000..274828c68 --- /dev/null +++ b/tests/testhelpers.py @@ -0,0 +1,368 @@ +import logging +import unittest + +from copy import deepcopy +import numpy as np + +import nmmo + +from scripted import baselines +from nmmo.entity.entity import EntityState +from nmmo.io import action +from nmmo.systems import item as Item +from nmmo.core.realm import Realm + +# this function can be replaced by assertDictEqual +# but might be still useful for debugging +def actions_are_equal(source_atn, target_atn, debug=True): + + # compare the numbers and player ids + player_src = list(source_atn.keys()) + player_tgt = list(target_atn.keys()) + if player_src != player_tgt: + if debug: + logging.error("players don't match") + return False + + # for each player, compare the actions + for ent_id in player_src: + atn1 = source_atn[ent_id] + atn2 = target_atn[ent_id] + + if list(atn1.keys()) != list(atn2.keys()): + if debug: + logging.error("action keys don't match. player: %s", str(ent_id)) + return False + + for atn, args in atn1.items(): + if atn2[atn] != args: + if debug: + logging.error("action args don't match. player: %s, action: %s", str(ent_id), str(atn)) + return False + + return True + + +# this function CANNOT be replaced by assertDictEqual +def observations_are_equal(source_obs, target_obs, debug=True): + + keys_src = list(source_obs.keys()) + keys_obs = list(target_obs.keys()) + if keys_src != keys_obs: + if debug: + logging.error("entities don't match") + return False + + for k in keys_src: + ent_src = source_obs[k] + ent_tgt = target_obs[k] + if list(ent_src.keys()) != list(ent_tgt.keys()): + if debug: + logging.error("entries don't match. key: %s", str(k)) + return False + + obj = ent_src.keys() + for o in obj: + + # ActionTargets causes a problem here, so skip it + if o == "ActionTargets": + continue + + obj_src = ent_src[o] + obj_tgt = ent_tgt[o] + if np.sum(obj_src != obj_tgt) > 0: + if debug: + logging.error("objects don't match. key: %s, obj: %s", str(k), str(o)) + return False + + return True + + +def player_total(env): + return sum(ent.gold.val for ent in env.realm.players.values()) + + +def count_actions(tick, actions): + cnt_action = {} + for atn in (action.Move, action.Attack, action.Sell, action.Use, action.Give, action.Buy): + cnt_action[atn] = 0 + + for ent_id in actions: + for atn, _ in actions[ent_id].items(): + if atn in cnt_action: + cnt_action[atn] += 1 + else: + cnt_action[atn] = 1 + + info_str = f"Tick: {tick}, acting agents: {len(actions)}, action counts " + \ + f"move: {cnt_action[action.Move]}, attack: {cnt_action[action.Attack]}, " + \ + f"sell: {cnt_action[action.Sell]}, use: {cnt_action[action.Move]}, " + \ + f"give: {cnt_action[action.Give]}, buy: {cnt_action[action.Buy]}" + logging.info(info_str) + + return cnt_action + + +class ScriptedAgentTestConfig(nmmo.config.Small, nmmo.config.AllGameSystems): + + __test__ = False + + LOG_ENV = True + + LOG_MILESTONES = True + LOG_EVENTS = False + LOG_VERBOSE = False + + SAVE_REPLAY = True + + SPECIALIZE = True + PLAYERS = [ + baselines.Fisher, baselines.Herbalist, + baselines.Prospector,baselines.Carver, baselines.Alchemist, + baselines.Melee, baselines.Range, baselines.Mage] + + +# pylint: disable=abstract-method,duplicate-code +class ScriptedAgentTestEnv(nmmo.Env): + ''' + EnvTest step() bypasses some differential treatments for scripted agents + To do so, actions of scripted must be serialized using the serialize_actions function above + ''' + __test__ = False + + def __init__(self, config: nmmo.config.Config, seed=None): + super().__init__(config=config, seed=seed) + + # all agent must be scripted agents when using ScriptedAgentTestEnv + for ent in self.realm.players.values(): + assert isinstance(ent.agent, baselines.Scripted), 'All agent must be scripted.' + + # this is to cache the actions generated by scripted policies + self.actions = {} + + def reset(self, map_id=None, seed=None, options=None): + self.actions = {} + # manually resetting the EntityState, ItemState datastore tables + EntityState.State.table(self.realm.datastore).reset() + Item.ItemState.State.table(self.realm.datastore).reset() + return super().reset(map_id=map_id, seed=seed, options=options) + + def _compute_scripted_agent_actions(self, actions): + assert actions is not None, "actions must be provided, even it's {}" + # if actions are not provided, generate actions using the scripted policy + if actions == {}: + for eid, ent in self.realm.players.items(): + actions[eid] = ent.agent(self.obs[eid]) + + # cache the actions for replay before deserialization + self.actions = deepcopy(actions) + + # if actions are provided, just run ent.agent() to set the RNG to the same state + else: + # NOTE: This is a hack to set the random number generator to the same state + # since scripted agents also use RNG. Without this, the RNG is in different state, + # and the env.step() does not give the same results in the deterministic replay. + for eid, ent in self.realm.players.items(): + ent.agent(self.obs[eid]) + + return actions + + +def change_spawn_pos(realm: Realm, ent_id: int, new_pos): + # check if the position is valid + assert realm.map.tiles[new_pos].habitable, "Given pos is not habitable." + assert realm.entity(ent_id), "No such entity in the realm" + + entity = realm.entity(ent_id) + old_pos = entity.pos + realm.map.tiles[old_pos].remove_entity(ent_id) + + # set to new pos + entity.row.update(new_pos[0]) + entity.col.update(new_pos[1]) + entity.spawn_pos = new_pos + realm.map.tiles[new_pos].add_entity(entity) + +def provide_item(realm: Realm, ent_id: int, + item: Item.Item, level: int, quantity: int): + for _ in range(quantity): + realm.players[ent_id].inventory.receive( + item(realm, level=level)) + + +# pylint: disable=invalid-name,protected-access +class ScriptedTestTemplate(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # only use Combat agents + cls.config = ScriptedAgentTestConfig() + cls.config.PROVIDE_ACTION_TARGETS = True + + cls.config.PLAYERS = [baselines.Melee, baselines.Range, baselines.Mage] + cls.config.PLAYER_N = 3 + #cls.config.IMMORTAL = True + + # set up agents to test ammo use + cls.policy = { 1:'Melee', 2:'Range', 3:'Mage' } + # 1 cannot hit 3, 2 can hit 1, 3 cannot hit 2 + cls.spawn_locs = { 1:(17, 17), 2:(17, 19), 3:(21, 21) } + cls.ammo = { 1:Item.Scrap, 2:Item.Shaving, 3:Item.Shard } + cls.ammo_quantity = 2 + + # items to provide + cls.init_gold = 5 + cls.item_level = [0, 3] # 0 can be used, 3 cannot be used + cls.item_sig = {} + + def _make_item_sig(self): + item_sig = {} + for ent_id, ammo in self.ammo.items(): + item_sig[ent_id] = [] + for item in [ammo, Item.Top, Item.Gloves, Item.Ration, Item.Poultice]: + for lvl in self.item_level: + item_sig[ent_id].append((item, lvl)) + + return item_sig + + def _setup_env(self, random_seed, check_assert=True): + """ set up a new env and perform initial checks """ + env = ScriptedAgentTestEnv(self.config, seed=random_seed) + env.reset() + + # provide money for all + for ent_id in env.realm.players: + env.realm.players[ent_id].gold.update(self.init_gold) + + # provide items that are in item_sig + self.item_sig = self._make_item_sig() + for ent_id, items in self.item_sig.items(): + for item_sig in items: + if item_sig[0] == self.ammo[ent_id]: + provide_item(env.realm, ent_id, item_sig[0], item_sig[1], self.ammo_quantity) + else: + provide_item(env.realm, ent_id, item_sig[0], item_sig[1], 1) + + # teleport the players, if provided with specific locations + for ent_id, pos in self.spawn_locs.items(): + change_spawn_pos(env.realm, ent_id, pos) + + env.obs = env._compute_observations() + + if check_assert: + self._check_default_asserts(env) + + return env + + def _check_ent_mask(self, ent_obs, atn, target_id): + assert atn in [action.Give, action.GiveGold], "Invalid action" + gym_obs = ent_obs.to_gym() + mask = gym_obs['ActionTargets'][atn][action.Target][:ent_obs.entities.len] > 0 + + return target_id in ent_obs.entities.ids[mask] + + def _check_inv_mask(self, ent_obs, atn, item_sig): + assert atn in [action.Destroy, action.Give, action.Sell, action.Use], "Invalid action" + gym_obs = ent_obs.to_gym() + mask = gym_obs['ActionTargets'][atn][action.InventoryItem][:ent_obs.inventory.len] > 0 + inv_idx = ent_obs.inventory.sig(*item_sig) + + return ent_obs.inventory.id(inv_idx) in ent_obs.inventory.ids[mask] + + def _check_mkt_mask(self, ent_obs, item_id): + gym_obs = ent_obs.to_gym() + mask = gym_obs['ActionTargets'][action.Buy][action.MarketItem][:ent_obs.market.len] > 0 + + return item_id in ent_obs.market.ids[mask] + + def _check_default_asserts(self, env): + """ The below asserts are based on the hardcoded values in setUpClass() + This should not run when different values were used + """ + # check if the agents are in specified positions + for ent_id, pos in self.spawn_locs.items(): + self.assertEqual(env.realm.players[ent_id].pos, pos) + + for ent_id, sig_list in self.item_sig.items(): + # ammo instances are in the datastore and global item registry (realm) + inventory = env.obs[ent_id].inventory + self.assertTrue(inventory.len == len(sig_list)) + for inv_idx in range(inventory.len): + item_id = inventory.id(inv_idx) + self.assertTrue(Item.ItemState.Query.by_id(env.realm.datastore, item_id) is not None) + self.assertTrue(item_id in env.realm.items) + + for lvl in self.item_level: + inv_idx = inventory.sig(self.ammo[ent_id], lvl) + self.assertTrue(inv_idx is not None) + self.assertEqual(self.ammo_quantity, # provided 2 ammos + Item.ItemState.parse_array(inventory.values[inv_idx]).quantity) + + # check ActionTargets + ent_obs = env.obs[ent_id] + + if env.config.ITEM_SYSTEM_ENABLED: + # USE InventoryItem mask + for item_sig in sig_list: + if item_sig[1] == 0: + # items that can be used + self.assertTrue(self._check_inv_mask(ent_obs, action.Use, item_sig)) + else: + # items that are too high to use + self.assertFalse(self._check_inv_mask(ent_obs, action.Use, item_sig)) + + if env.config.EXCHANGE_SYSTEM_ENABLED: + # SELL InventoryItem mask + for item_sig in sig_list: + # the agent can sell anything now + self.assertTrue(self._check_inv_mask(ent_obs, action.Sell, item_sig)) + + # BUY MarketItem mask -- there is nothing on the market, so mask should be all 0 + self.assertTrue(len(env.obs[ent_id].market.ids) == 0) + + def _check_assert_make_action(self, env, atn, test_cond): + assert atn in [action.Give, action.GiveGold, action.Buy], "Invalid action" + actions = {} + for ent_id, cond in test_cond.items(): + ent_obs = env.obs[ent_id] + + if atn in [action.Give, action.GiveGold]: + # self should be always masked + self.assertFalse(self._check_ent_mask(ent_obs, atn, ent_id)) + + # check if the target is masked as expected + self.assertEqual( + cond['ent_mask'], + self._check_ent_mask(ent_obs, atn, cond['tgt_id']), + f"ent_id: {ent_id}, atn: {ent_id}, tgt_id: {cond['tgt_id']}" + ) + + if atn in [action.Give]: + self.assertEqual( + cond['inv_mask'], + self._check_inv_mask(ent_obs, atn, cond['item_sig']), + f"ent_id: {ent_id}, atn: {ent_id}, tgt_id: {cond['item_sig']}" + ) + + if atn in [action.Buy]: + self.assertEqual( + cond['mkt_mask'], + self._check_mkt_mask(ent_obs, cond['item_id']), + f"ent_id: {ent_id}, atn: {ent_id}, tgt_id: {cond['item_id']}" + ) + + # append the actions + if atn == action.Give: + actions[ent_id] = { action.Give: { + action.InventoryItem: env.obs[ent_id].inventory.sig(*cond['item_sig']), + action.Target: cond['tgt_id'] } } + + elif atn == action.GiveGold: + actions[ent_id] = { action.GiveGold: + { action.Target: cond['tgt_id'], action.Price: cond['gold'] } } + + elif atn == action.Buy: + mkt_idx = ent_obs.market.index(cond['item_id']) + actions[ent_id] = { action.Buy: { action.MarketItem: mkt_idx } } + + return actions diff --git a/tools/task_generator.py b/tools/task_generator.py new file mode 100644 index 000000000..505b9c4e2 --- /dev/null +++ b/tools/task_generator.py @@ -0,0 +1,26 @@ +import argparse +import nmmo.lib.task as task + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tasks", type=int, default=10) + parser.add_argument("--num_teams", type=int, default=10) + parser.add_argument("--team_size", type=int, default=1) + parser.add_argument("--min_clauses", type=int, default=1) + parser.add_argument("--max_clauses", type=int, default=1) + parser.add_argument("--min_clause_size", type=int, default=1) + parser.add_argument("--max_clause_size", type=int, default=1) + parser.add_argument("--not_p", type=float, default=0.5) + + flags = parser.parse_args() + + team_helper = task.TeamHelper(range(flags.team_size * flags.num_teams), flags.num_teams) + sampler = task.TaskSampler.create_default_task_sampler(team_helper, 0) + for i in range(flags.tasks): + task = sampler.sample( + flags.min_clauses, flags.max_clauses, + flags.min_clause_size, flags.max_clause_size, flags.not_p) + print(task.to_string()) + + +