In [1]:
from typing import Optional, Sequence, Tuple
from jax_party.env import JaxParty, PartyGenerator
from jax_party.env_types import Action

import jax
import jax.numpy as jnp
import chex
import itertools
from jax_party.utils import tree_slice

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Tuple, Union
from mava.types import Observation, ObservationGlobalState, State
from jumanji.env import Environment
from jumanji.wrappers import Wrapper
from jumanji.types import TimeStep
from jumanji import specs


class JumanjiMarlWrapper(Wrapper, ABC):
    def __init__(self, env: Environment, add_global_state: bool):
        self.add_global_state = add_global_state
        super().__init__(env)
        self.num_agents = self._env.num_agents
        self.time_limit = self._env.time_limit

    @abstractmethod
    def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
        """Modify the timestep for `step` and `reset`."""
        pass

    def get_global_state(self, obs: Observation) -> chex.Array:
        """The default way to create a global state for an environment if it has no
        available global state - concatenate all observations.
        """
        global_state = jnp.concatenate(obs.agents_view, axis=0)
        global_state = jnp.tile(global_state, (self._env.num_agents, 1))
        return global_state

    def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
        """Reset the environment."""
        state, timestep = self._env.reset(key)
        timestep = self.modify_timestep(timestep)
        if self.add_global_state:
            global_state = self.get_global_state(timestep.observation)
            observation = ObservationGlobalState(
                global_state=global_state,
                agents_view=timestep.observation.agents_view,
                action_mask=timestep.observation.action_mask,
                step_count=timestep.observation.step_count,
            )
            return state, timestep.replace(observation=observation)

        return state, timestep

    def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
        """Step the environment."""
        state, timestep = self._env.step(state, action)
        timestep = self.modify_timestep(timestep)
        if self.add_global_state:
            global_state = self.get_global_state(timestep.observation)
            observation = ObservationGlobalState(
                global_state=global_state,
                agents_view=timestep.observation.agents_view,
                action_mask=timestep.observation.action_mask,
                step_count=timestep.observation.step_count,
            )
            return state, timestep.replace(observation=observation)

        return state, timestep

    @cached_property
    def observation_spec(
        self,
    ) -> specs.Spec[Union[Observation, ObservationGlobalState]]:
        """Specification of the observation of the environment."""
        step_count = specs.BoundedArray(
            (self.num_agents,),
            int,
            jnp.zeros(self.num_agents, dtype=int),
            jnp.repeat(self.time_limit, self.num_agents),
            "step_count",
        )

        obs_spec = self._env.observation_spec
        obs_data = {
            "agents_view": obs_spec.agents_view,
            "action_mask": obs_spec.action_mask,
            "step_count": step_count,
        }

        if self.add_global_state:
            num_obs_features = obs_spec.agents_view.shape[-1]
            global_state = specs.Array(
                (self._env.num_agents, self._env.num_agents * num_obs_features),
                obs_spec.agents_view.dtype,
                "global_state",
            )
            obs_data["global_state"] = global_state
            return specs.Spec(ObservationGlobalState, "ObservationSpec", **obs_data)

        return specs.Spec(Observation, "ObservationSpec", **obs_data)

    @cached_property
    def action_dim(self) -> chex.Array:
        """Get the actions dim for each agent."""
        return int(self._env.action_spec.num_values[0])


class PartyMARLWrapper(JumanjiMarlWrapper):
    def __init__(self, env: JaxParty, add_global_state: bool = False):
        super().__init__(env, add_global_state)
        self._env: JaxParty

    def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
        """Modify the timestep for `step` and `reset`."""
        return timestep

In [12]:
env = JaxParty(generator=PartyGenerator, time_limit=4000)
env = PartyMARLWrapper(env)

key = jax.random.PRNGKey(0)
state, timestep = env.reset(key)

for actions in jnp.array(list(itertools.permutations([0, 1, 2]))):
    print("active_agents", state.active_agents)
    print("actions", actions)
    actions = env._get_valid_actions(actions, state.action_mask)
    print("corrected actions", actions)
    state, timestep = env.step(state, actions)
    print(state.cumulative_rewards, state.ranking)
    print("-" * 30)

active_agents [0 1 1]
actions [0 1 2]
corrected actions [0 1 2]
[0. 0. 5.] [2 0 1]
------------------------------
active_agents [1 1 0]
actions [0 2 1]
corrected actions [0 2 0]
[-1.  5.  5.] [1 2 0]
------------------------------
active_agents [1 0 1]
actions [1 0 2]
corrected actions [1 0 2]
[-1.  5. 10.] [2 1 0]
------------------------------
active_agents [0 1 1]
actions [1 2 0]
corrected actions [0 2 0]
[-1. 10.  9.] [1 2 0]
------------------------------
active_agents [1 0 1]
actions [2 0 1]
corrected actions [2 0 1]
[ 4. 10.  9.] [1 2 0]
------------------------------
active_agents [1 1 0]
actions [2 1 0]
corrected actions [2 1 0]
[ 9. 10.  9.] [1 0 2]
------------------------------


In [10]:
env.observation_spec.generate_value()

Observation(agents_view=Array([False, False, False], dtype=bool), action_mask=Array([False, False, False], dtype=bool), step_count=Array([0, 0, 0], dtype=int32))