In [10]:
import gym
import numpy as np
import gymnasium.spaces as spaces
from pettingzoo.mpe import simple_speaker_listener_v4
import warnings
warnings.filterwarnings("ignore")

In [11]:
class mpe_wrapper_for_pettingzoo():
    def __init__(self,env,continuous_actions=False):        
        self.continuous_actions = continuous_actions
        self.observation_space = list(env.observation_spaces.values())
        self.action_space = list(env.action_spaces.values())
        assert len(self.observation_space) == len(self.action_space)
        self.n = len(self.observation_space)
        self.agents_name = list(env.observation_spaces.keys())
        self.obs_shape_n = [
            self.get_shape(self.observation_space[i]) for i in range(self.n)
        ]
        self.act_shape_n = [
            self.get_shape(self.action_space[i]) for i in range(self.n)
        ]

    def get_shape(self, input_space):
        if (isinstance(input_space, spaces.Box)):
            if (len(input_space.shape) == 1):
                return input_space.shape[0]
            else:
                return input_space.shape
        elif (isinstance(input_space, spaces.Discrete)):
            return input_space.n
        else:
            print('[Error] shape is {}, not Box or Discrete'.format(input_space.shape))
            raise NotImplementedError

    def reset(self):
        obs = self.env.reset()
        return list(obs.values())

    def step(self, actions):
        actions_dict = dict()
        for i, act in enumerate(actions):
            agent = self.agents_name[i]
            if self.continuous_actions:
                assert np.all(((act<=1.0 + 1e-3), (act>=-1.0 - 1e-3))), 'the action should be in range [-1.0, 1.0], but got {}'.format(act)
                high = self.action_space[i].high
                low = self.action_space[i].low
                mapped_action = low + (act - (-1.0)) * ((high - low) / 2.0)
                mapped_action = np.clip(mapped_action, low, high)
                actions_dict[agent] = mapped_action
            else:
                actions_dict[agent] = np.argmax(act)
        obs, reward, done, info = self.env.step(actions_dict)
        return list(obs.values()), list(reward.values()), list(done.values()), list(info.values())
    

continuous_actions=False
env = simple_speaker_listener_v4.parallel_env(max_cycles=25, continuous_actions=continuous_actions)
env=mpe_wrapper_for_pettingzoo(env)
critic_in_dim = sum(env.obs_shape_n) + sum(env.act_shape_n)
