<a target="_blank" href="https://colab.research.google.com/github/CLAIR-LAB-TECHNION/MAC/blob/master/demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
!pip install "git+https://github.com/CLAIR-LAB-TECHNION/MAC/"
!pip install "git+https://github.com/CLAIR-LAB-TECHNION/multi_taxi@0.4.0"

In [2]:
from mac.coordination.coordinator import CentralizedCoordinator
from mac.env_wrapper import EnvWrapper
from mac.agent import Agent

In [3]:
class TaxiCentralizedCoordinator(CentralizedCoordinator):
    def run_step(self, step_data):
        joint_action = self.central_agent.get_action(step_data)
        step_data = self.env_wrapper.step(joint_action)
        return step_data, joint_action

    def get_initial_data(self):
        obs, infos = self.env_wrapper.reset()
        rewards = {agent_id: 0 for agent_id in obs}
        terms = {agent_id: False for agent_id in obs}
        truncs = {agent_id: False for agent_id in obs}

        return obs, rewards, terms, truncs, infos

    def log_step(self, step_data):
        pass

    def init_log_data(self):
        pass

    def get_ids(self):
        return self.env_wrapper.env.agents

In [4]:
import time
from IPython.display import clear_output

SLEEP_TIME = 0.2

class TaxiWrapper(EnvWrapper):
    def get_agent_step_data(self, step_data, agent_id):
        return [item[agent_id] for item in step_data]

    def reset(self):
        out = self.env.reset(return_info=True)
        time.sleep(SLEEP_TIME)
        clear_output(wait=True)
        self.env.render()
        return out

    def step(self, action):
        out = self.env.step(action)
        time.sleep(SLEEP_TIME)
        clear_output(wait=True)
        self.env.render()
        return out

    def is_done(self, step_data):
        return self.env.env_done()

In [5]:
class TaxiRandomCentralAgent(Agent):
    def __init__(self, action_spaces):
        self.action_spaces = action_spaces

    def get_action(self, step_data):
        return {
            agent_id: self.action_spaces[agent_id].sample()
            for agent_id in self.action_spaces
        }

    def get_observation(self, step_data):
        obs, _, _, _, _ = step_data
        return obs

In [6]:
from multi_taxi import multi_taxi_v0

env = multi_taxi_v0.parallel_env(num_taxis=5, num_passengers=5, render_mode='human')

In [7]:
env_wrapper = TaxiWrapper(env)
central_agent = TaxiRandomCentralAgent({agent: env.action_space(agent) for agent in env.possible_agents})
coordinator = TaxiCentralizedCoordinator(env_wrapper, None, central_agent)

In [8]:
coordinator.run(100)

+-----------------------+
| : |F: | : | : | : |F: |
| : | : : : | : | : |[41m [0m: |
|[37mP[0m: :[33mP[0m: : :[36mP[0m: : : : : : |
| : : : :[31mP[0m: | : :[37mD[0m:[43m[31mD[0m[0m: : |
| : : : : : | : : : :[42m [0m: |
|[32mD[0m:[33mD[0m:[36mD[0m: : : : : : : : : |
| | :G| | | :[46mG[0m|[47m [0m| | : |[32mP[0m|
+-----------------------+
Taxi0-YELLOW: Fuel: inf, Location: (3, 9), Engine: ON, Collided: False, Step: 100, ALIVE
Taxi1-RED: Fuel: inf, Location: (1, 10), Engine: ON, Collided: False, Step: 100, ALIVE
Taxi2-WHITE: Fuel: inf, Location: (6, 7), Engine: ON, Collided: False, Step: 100, ALIVE
Taxi3-GREEN: Fuel: inf, Location: (4, 10), Engine: ON, Collided: False, Step: 100, ALIVE
Taxi4-CYAN: Fuel: inf, Location: (6, 6), Engine: ON, Collided: False, Step: 100, ALIVE
Passenger0-YELLOW: Location: (2, 2), Destination: (5, 1)
Passenger1-RED: Location: (3, 4), Destination: (3, 9)
Passenger2-WHITE: Location: (2, 0), Destination: (3, 8)
Passenger3-GREEN: Locati