In [20]:
# Jax imports 

import jax.numpy as jnp
import numpy as np 
from jax import random 
import haiku as hk 
import optax 

In [61]:
# Environment
import gym

env = gym.make('ma_gym:Checkers-v0')

In [59]:
class CentralControllerWrapper: 
    
    def __init__(self, ma_env):
        
        self.env = ma_env 
        self.num_agents = ma_env.n_agents 
        self.action_mapping = self.enumerate_agent_actions()
        self.action_space = len(self.action_mapping)
        self.observation_space = np.sum([len(i) for i in ma_env.reset()])
        
    def reset(self, ):
        
        obs_n = self.env.reset()
        joint_obs = self.create_joint_obs(obs_n)
        
        return joint_obs
    
    def step(self, joint_action): 
        
        action = self.action_mapping(joint_action)
        obs_n, reward_n, done_n, info = self.env.step(action)
        
        joint_obs = self.create_joint_obs(obs_n)
        team_reward = jnp.sum(reward_n)
        team_done = all(done_n)
        
        return joint_obs, team_reward, team_done, info
        
    
    def enumerate_agent_actions(self, ):
        
        agent_actions = [np.arange(self.env.action_space[i].n) for i in range(len(self.env.action_space))]
        enumerated_actions = np.array(np.meshgrid(*agent_actions)).T.reshape(-1,self.num_agents)
        action_mapping = {int(i): list(action) for i, action in enumerate(enumerated_actions)}
        return action_mapping
    
    def create_joint_obs(self, env_obs):
        
        array_obs = np.array(env_obs)
        joint_obs = np.concatenate(array_obs, axis = -1)
        
        return joint_obs
    

In [62]:
env = CentralControllerWrapper(env)