# MarlGrid Envs creation

> Acts as module \__init__ for MarlGrid envs creation 

In [None]:
#| default_exp envs.marl_grid.__init__

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import inspect
import numpy.random as random
import sys
from gymnasium.envs.registration import register as gym_register

from mawm.envs.marl_grid.envs.findgoal_env.findgoal import FindGoalMultiGrid
from mawm.envs.marl_grid.envs.redbluedoors_envs.redbluedoors import RedBlueDoorsMultiGrid

from mawm.envs.marl_grid.agents  import GridAgentInterface
from mawm.envs.marl_grid.base import MultiGridEnv

this_module = sys.modules[__name__]

In [None]:
#| export
def get_env_name(env_cfg):
    """
    Automatically generate env name from env configs.
    """
    assert env_cfg.env_type == 'c'
    name = f'MarlGrid-{env_cfg.num_agents}Agent'

    if env_cfg.comm_len > 0:
        name += f'{env_cfg.comm_len}C'
        if not env_cfg.discrete_comm:
            name += 'cont'

    if env_cfg.view_size != 7:
        name += f'{env_cfg.view_size}Vs'

    name += f'{env_cfg.grid_size}x{env_cfg.grid_size}-v0'
    return name

In [None]:
#| export
def make_agents(
        n_agents,
        view_size,
        view_tile_size=8,
        view_offset=0,
        agent_color=None,
        observation_style='image',
        observe_position=False,
        observe_self_position=False,
        observe_done=False,
        observe_self_env_act=False,
        observe_t=False,
        restrict_actions=True,
        comm_dim=0,
        comm_len=1,
        discrete_comm=True,
        n_adversaries=0,
        see_through_walls=True,
        neutral_shape=True,
        can_overlap=True,
        observe_orientation=True
):
    colors = ['red', 'blue', 'purple', 'orange', 'olive', 'pink']
    assert n_agents <= len(colors)

    if isinstance(view_size, list):
        assert len(view_size) == n_agents
        view_size_lst = view_size
    else:
        view_size_lst = [view_size for _ in range(n_agents)]

    # assign roles differently in each environment
    adv_indices = random.choice([i for i in range(n_agents)],
                                n_adversaries,
                                replace=False)

    agents = [GridAgentInterface(
        color=c if agent_color is None else agent_color,
        neutral_shape=neutral_shape,
        can_overlap=can_overlap,
        view_size=view_size_lst[i],
        view_tile_size=view_tile_size,
        view_offset=view_offset,
        observation_style=observation_style,
        observe_position=observe_position,
        observe_self_position=observe_self_position,
        observe_done=observe_done,
        observe_self_env_act=observe_self_env_act,
        observe_t=observe_t,
        restrict_actions=restrict_actions,
        see_through_walls=see_through_walls,
        comm_dim=comm_dim,
        comm_len=comm_len,
        discrete_comm=discrete_comm,
        n_agents=n_agents,
        observe_orientation=observe_orientation,
        is_adversary=1 if i in adv_indices else 0
    ) for i, c in enumerate(colors[:n_agents])]
    return agents




In [None]:
#| export
def get_env_creator(
        env_class,
        n_agents,
        grid_size,
        view_size,
        view_tile_size=8,
        view_offset=0,
        agent_color=None,
        observation_style='image',
        observe_position=False,
        observe_self_position=False,
        observe_done=False,
        observe_self_env_act=False,
        observe_t=False,
        restrict_actions=True,
        comm_dim=0,
        comm_len=1,
        discrete_comm=True,
        n_adversaries=0,
        neutral_shape=True,
        can_overlap=True,
        env_kwargs={}
):
    def env_creator(env_config):
        agents = make_agents(n_agents,
                             view_size,
                             view_tile_size=view_tile_size,
                             view_offset=view_offset,
                             agent_color=agent_color,
                             observation_style=observation_style,
                             observe_position=observe_position,
                             observe_self_position=observe_self_position,
                             observe_done=observe_done,
                             observe_self_env_act=observe_self_env_act,
                             observe_t=observe_t,
                             restrict_actions=restrict_actions,
                             comm_dim=comm_dim,
                             comm_len=comm_len,
                             discrete_comm=discrete_comm,
                             n_adversaries=n_adversaries,
                             neutral_shape=neutral_shape,
                             can_overlap=can_overlap)
        env_config.update(env_kwargs)
        env_config['grid_size'] = grid_size
        env_config['agents'] = agents
        return env_class(env_config)

    return env_creator



In [None]:
#| export
def get_env_class_creator(
        env_class,
        n_agents,
        grid_size,
        view_size,
        view_tile_size=8,
        view_offset=0,
        agent_color=None,
        observation_style='image',
        observe_position=False,
        observe_self_position=False,
        observe_done=False,
        observe_self_env_act=False,
        observe_t=False,
        restrict_actions=True,
        comm_dim=0,
        comm_len=1,
        discrete_comm=True,
        n_adversaries=0,
        neutral_shape=True,
        can_overlap=True,
        env_kwargs={}
):
    class GymEnv(env_class):
        def __new__(cls):
            agents = make_agents(n_agents,
                                 view_size,
                                 view_tile_size=view_tile_size,
                                 view_offset=view_offset,
                                 agent_color=agent_color,
                                 observation_style=observation_style,
                                 observe_position=observe_position,
                                 observe_self_position=observe_self_position,
                                 observe_done=observe_done,
                                 observe_self_env_act=observe_self_env_act,
                                 observe_t=observe_t,
                                 restrict_actions=restrict_actions,
                                 comm_dim=comm_dim,
                                 comm_len=comm_len,
                                 discrete_comm=discrete_comm,
                                 n_adversaries=n_adversaries,
                                 neutral_shape=neutral_shape,
                                 can_overlap=can_overlap)
            instance = super(env_class, GymEnv).__new__(env_class)
            env_config = env_kwargs
            env_config['grid_size'] = grid_size
            env_config['agents'] = agents

            instance.__init__(config=env_config)
            return instance
    return GymEnv



In [None]:
#| export
def register_env(
        env_name,
        n_agents,
        grid_size,
        view_size,
        view_tile_size,
        comm_dim,
        comm_len,
        discrete_comm,
        n_adversaries,
        observation_style,
        observe_position,
        observe_self_position,
        observe_done,
        observe_self_env_act,
        observe_t,
        neutral_shape,
        can_overlap,
        use_gym_env=False,
        env_configs={},
        clutter_density=0.15,
        env_type='c',
):
    if env_type == 'c':
        env_class = FindGoalMultiGrid
        restrict_actions = True
    elif env_type == 'd':
        env_class = RedBlueDoorsMultiGrid
        assert n_agents == 2
        assert n_adversaries == 0
        restrict_actions = False
    else:
        raise ValueError(f'env type {env_type} not supported')

    env_creator = get_env_creator(
        env_class,
        n_agents=n_agents,
        grid_size=grid_size,
        view_size=view_size,
        view_tile_size=view_tile_size,
        comm_dim=comm_dim,
        comm_len=comm_len,
        discrete_comm=discrete_comm,
        n_adversaries=n_adversaries,
        observation_style=observation_style,
        observe_position=observe_position,
        observe_self_position=observe_self_position,
        observe_done=observe_done,
        observe_self_env_act=observe_self_env_act,
        observe_t=observe_t,
        restrict_actions=restrict_actions,
        neutral_shape=neutral_shape,
        env_kwargs={
            'clutter_density': clutter_density,
            'randomize_goal': True,
        }
    )

    # register for Gym
    if use_gym_env:
        gym_env_class = get_env_class_creator(
            env_class,
            n_agents=n_agents,
            grid_size=grid_size,
            view_size=view_size,
            view_tile_size=view_tile_size,
            comm_dim=comm_dim,
            comm_len=comm_len,
            discrete_comm=discrete_comm,
            n_adversaries=n_adversaries,
            observation_style=observation_style,
            observe_position=observe_position,
            observe_self_position=observe_self_position,
            observe_done=observe_done,
            observe_self_env_act=observe_self_env_act,
            observe_t=observe_t,
            restrict_actions=restrict_actions,
            neutral_shape=neutral_shape,
            can_overlap=can_overlap,
            env_kwargs={
                'clutter_density': clutter_density,
                'randomize_goal': True,
            }
        )

        env_class_name = 'env_0'
        setattr(this_module, env_class_name, gym_env_class)
        gym_register(env_name,
                     entry_point=f'marlgrid.envs:{env_class_name}')

    env = env_creator(env_configs)
    return env




In [None]:
#| export
def env_from_config(env_config, randomize_seed=True):
    possible_envs = {k: v for k, v in globals().items() if
                     inspect.isclass(v) and issubclass(v, MultiGridEnv)}

    env_class = possible_envs[env_config['env_class']]

    env_kwargs = {k: v for k, v in env_config.items() if k != 'env_class'}
    if randomize_seed:
        env_kwargs['seed'] = env_kwargs.get('seed', 0
                                            ) + random.randint(0, 1337 * 1337)

    return env_class(**env_kwargs)


In [None]:
#| export
import gymnasium as gym
import numpy as np
from easydict import EasyDict as edict
# from mawm.envs.marl_grid._envs import register_env

def create_grid_world_env(env_cfg):
    """
    Automatically generate env instance from env configs.
    """
    env_name = get_env_name(env_cfg)

    env = register_env(
        env_name=env_name,
        n_agents=env_cfg.num_agents,
        grid_size=env_cfg.grid_size,
        view_size=env_cfg.view_size,
        view_tile_size=env_cfg.view_tile_size,
        comm_dim=2,
        comm_len=env_cfg.comm_len,
        discrete_comm=env_cfg.discrete_comm,
        n_adversaries=0,
        observation_style=env_cfg.observation_style,
        observe_position=env_cfg.observe_position,
        observe_self_position=env_cfg.observe_self_position,
        observe_done=env_cfg.observe_done,
        observe_self_env_act=env_cfg.observe_self_env_act,
        observe_t=env_cfg.observe_t,
        neutral_shape=env_cfg.neutral_shape,
        can_overlap=env_cfg.can_overlap,
        use_gym_env=False,
        env_configs={
            'max_steps': env_cfg.max_steps,
            'team_reward_multiplier': env_cfg.team_reward_multiplier,
            'team_reward_type': env_cfg.team_reward_type,
            'team_reward_freq': env_cfg.team_reward_freq,
            'seed': env_cfg.seed,
            'active_after_done': env_cfg.active_after_done,
            'discrete_position': env_cfg.discrete_position,
            'separate_rew_more': env_cfg.separate_rew_more,
            'info_gain_rew': env_cfg.info_gain_rew,
        },
        clutter_density=env_cfg.clutter_density)

    return env





In [None]:
#| export
import numpy as np
import gymnasium as  gym
from mawm.envs.marl_grid.wrappers import DictObservationNormalizationWrapper,  GridWorldEvaluatorWrapper

def make_env(env_cfg, lock=None):
    """ Use this to make Environments """

    env_name = env_cfg.env_name

    assert env_name.startswith('MarlGrid')
    env = create_grid_world_env(env_cfg)
    env = GridWorldEvaluatorWrapper(env)
    env = DictObservationNormalizationWrapper(env)

    return env


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()