## Initialization

In [None]:
#@title
! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git

In [None]:
#@title
%cd /content/VectorizedMultiAgentSimulator

!pip install -r requirements.txt
!apt-get update
!apt-get install -y x11-utils 
!apt-get install -y xvfb
!apt-get install -y imagemagick
!pip install -e .

In [None]:
#@title
!pip install pyvirtualdisplay
import pyvirtualdisplay
display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
display.start()

## Run


In [None]:
import torch
import os
has_gpu = torch.cuda.is_available()
os.environ["RLLIB_NUM_GPUS"] = "1" if has_gpu else "0"

In [6]:
#  Copyright (c) 2022.
#  ProrokLab (https://www.proroklab.org/)
#  All rights reserved.

import os
from typing import Dict, Optional

import numpy as np
import ray
from ray import tune
from ray.rllib import BaseEnv, Policy, RolloutWorker
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.evaluation import Episode, MultiAgentEpisode
from ray.rllib.utils.typing import PolicyID
from ray.tune import register_env

import wandb
from vmas import make_env, Wrapper

scenario_name = "balance"

# Scenario specific variables.
# When modifying this also modify env_config and env_creator
n_agents = 4

# Common variables
continuous_actions = True
max_steps = 200
num_vectorized_envs = 96
num_workers = 1
vmas_device = "cpu"  # or cuda


def env_creator(config: Dict):
    env = make_env(
        scenario_name=config["scenario_name"],
        num_envs=config["num_envs"],
        device=config["device"],
        continuous_actions=config["continuous_actions"],
        wrapper=Wrapper.RLLIB,
        max_steps=config["max_steps"],
        # Scenario specific variables
        n_agents=config["n_agents"],
    )
    return env


if not ray.is_initialized():
    ray.init()
    print("Ray init!")
register_env(scenario_name, lambda config: env_creator(config))


class EvaluationCallbacks(DefaultCallbacks):
    def on_episode_step(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        episode: MultiAgentEpisode,
        **kwargs,
    ):
        info = episode.last_info_for()
        for a_key in info.keys():
            for b_key in info[a_key]:
                try:
                    episode.user_data[f"{a_key}/{b_key}"].append(info[a_key][b_key])
                except KeyError:
                    episode.user_data[f"{a_key}/{b_key}"] = [info[a_key][b_key]]

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: MultiAgentEpisode,
        **kwargs,
    ):
        info = episode.last_info_for()
        for a_key in info.keys():
            for b_key in info[a_key]:
                metric = np.array(episode.user_data[f"{a_key}/{b_key}"])
                episode.custom_metrics[f"{a_key}/{b_key}"] = np.sum(metric).item()


class RenderingCallbacks(DefaultCallbacks):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.frames = []

    def on_episode_step(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Optional[Dict[PolicyID, Policy]] = None,
        episode: Episode,
        **kwargs,
    ) -> None:
        self.frames.append(base_env.vector_env.try_render_at(mode="rgb_array"))

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        episode: Episode,
        **kwargs,
    ) -> None:
        vid = np.transpose(self.frames, (0, 3, 1, 2))
        episode.media["rendering"] = wandb.Video(vid, fps=1 / base_env.vector_env.env.world.dt, format="mp4")
        self.frames = []


def train():

    RLLIB_NUM_GPUS = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
    num_gpus = 0.001 if RLLIB_NUM_GPUS > 0 else 0  # Driver GPU
    num_gpus_per_worker = (
        (RLLIB_NUM_GPUS - num_gpus) / (num_workers + 1) if vmas_device == "cuda" else 0
    )

    tune.run(
        PPOTrainer,
        stop={"training_iteration": 400},
        checkpoint_freq=1,
        keep_checkpoints_num=2,
        checkpoint_at_end=True,
        checkpoint_score_attr="episode_reward_mean",
        # callbacks=[
        #     WandbLoggerCallback(
        #        project=f"{scenario_name}",
        #        api_key="",
        #    )
        # ],
        config={
            "seed": 0,
            "framework": "torch",
            "env": scenario_name,
            "kl_coeff": 0.01,
            "kl_target": 0.01,
            "lambda": 0.9,
            "clip_param": 0.2,
            "vf_loss_coeff": 1,
            "vf_clip_param": float("inf"),
            "entropy_coeff": 0,
            "train_batch_size": 60000,
            "rollout_fragment_length": 125,
            "sgd_minibatch_size": 4096,
            "num_sgd_iter": 40,
            "num_gpus": num_gpus,
            "num_workers": num_workers,
            "num_gpus_per_worker": num_gpus_per_worker,
            "num_envs_per_worker": num_vectorized_envs,
            "lr": 5e-6,
            "gamma": 0.99,
            "use_gae": True,
            "use_critic": True,
            "batch_mode": "truncate_episodes",
            "env_config": {
                "device": vmas_device,
                "num_envs": num_vectorized_envs,
                "scenario_name": scenario_name,
                "continuous_actions": continuous_actions,
                "max_steps": max_steps,
                # Scenario specific variables
                "n_agents": n_agents,
            },
            "evaluation_interval": 5,
            "evaluation_duration": 1,
            "evaluation_num_workers": 0,
            "evaluation_parallel_to_training": False,
            "evaluation_config": {
                "num_envs_per_worker": 1,
                "env_config": {
                    "num_envs": 1,
                },
                # "callbacks": MultiCallbacks([RenderingCallbacks, EvaluationCallbacks]),
            },
            "callbacks": EvaluationCallbacks,
        },
    )


if __name__ == "__main__":
    train()

[2m[36m(pid=44668)[0m   'nearest': pil_image.NEAREST,
[2m[36m(pid=44668)[0m   'bilinear': pil_image.BILINEAR,
[2m[36m(pid=44668)[0m   'bicubic': pil_image.BICUBIC,
[2m[36m(pid=44668)[0m   if hasattr(pil_image, 'HAMMING'):
[2m[36m(pid=44668)[0m   if hasattr(pil_image, 'BOX'):
[2m[36m(pid=44668)[0m   if hasattr(pil_image, 'LANCZOS'):
[2m[36m(pid=44668)[0m   _nlv = LooseVersion(_np_version)
[2m[36m(pid=44668)[0m   other = LooseVersion(other)
[2m[36m(pid=44668)[0m   if LooseVersion(_np_version) >= LooseVersion("1.17.0"):
[2m[36m(pid=44668)[0m Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
[2m[36m(pid=44668)[0m   (np.object, string),
[2m[36m(pid=44668)[0m Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
[2m[36m(pid=44668)[0m   (np.bool, bool),
[2m[36m(pid=44668)[0m Deprecated in NumPy 1.20; for more deta

Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
PPO_balance_591fb_00000,14160000,"{'num_env_steps_sampled': 14160000, 'num_env_steps_trained': 14160000, 'num_agent_steps_sampled': 14160000, 'num_agent_steps_trained': 14160000}","{'rewards/0_mean': 99.64384879943259, 'rewards/0_min': -6.2609405517578125, 'rewards/0_max': 141.73487854003906, 'rewards/1_mean': 99.64384879943259, 'rewards/1_min': -6.2609405517578125, 'rewards/1_max': 141.73487854003906, 'rewards/2_mean': 99.64384879943259, 'rewards/2_min': -6.2609405517578125, 'rewards/2_max': 141.73487854003906, 'rewards/3_mean': 99.64384879943259, 'rewards/3_min': -6.2609405517578125, 'rewards/3_max': 141.73487854003906, 'agent 0/pos_rew_mean': 99.64384879943259, 'agent 0/pos_rew_min': -6.2609405517578125, 'agent 0/pos_rew_max': 141.73487854003906, 'agent 0/ground_rew_mean': 0.0, 'agent 0/ground_rew_min': 0.0, 'agent 0/ground_rew_max': 0.0, 'agent 1/pos_rew_mean': 99.64384879943259, 'agent 1/pos_rew_min': -6.2609405517578125, 'agent 1/pos_rew_max': 141.73487854003906, 'agent 1/ground_rew_mean': 0.0, 'agent 1/ground_rew_min': 0.0, 'agent 1/ground_rew_max': 0.0, 'agent 2/pos_rew_mean': 99.64384879943259, 'agent 2/pos_rew_min': -6.2609405517578125, 'agent 2/pos_rew_max': 141.73487854003906, 'agent 2/ground_rew_mean': 0.0, 'agent 2/ground_rew_min': 0.0, 'agent 2/ground_rew_max': 0.0, 'agent 3/pos_rew_mean': 99.64384879943259, 'agent 3/pos_rew_min': -6.2609405517578125, 'agent 3/pos_rew_max': 141.73487854003906, 'agent 3/ground_rew_mean': 0.0, 'agent 3/ground_rew_min': 0.0, 'agent 3/ground_rew_max': 0.0}",2023-01-10_17-20-03,False,197.441,{},141.735,99.6438,-6.26094,311,71095,92d3e75facb64bf3a6a5db4ec0551b6e,wosersysydeMacBook-Pro.local,"{'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 20.392005970648356, 'cur_kl_coeff': 6.681911775230489e-54, 'cur_lr': 5e-06, 'total_loss': 2.9767663796033177, 'policy_loss': 0.0017569990937772672, 'vf_loss': 2.9750093873058048, 'vf_explained_var': 0.9874570196228368, 'kl': 0.0038214488556637402, 'entropy': 1.807006427007062, 'entropy_coeff': 0.0}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 4096.0, 'num_grad_updates_lifetime': 131880.5, 'diff_num_grad_updates_vs_sampler_policy': 279.5}}, 'num_env_steps_sampled': 14160000, 'num_env_steps_trained': 14160000, 'num_agent_steps_sampled': 14160000, 'num_agent_steps_trained': 14160000}",236,127.0.0.1,14160000,14160000,14160000,60000,14160000,60000,0,1,0,0,60000,"{'cpu_util_percent': 24.53402777777778, 'ram_util_percent': 64.95416666666668}",44650,{},{},{},"{'mean_raw_obs_processing_ms': 44.55475229622829, 'mean_inference_ms': 3.412697662916898, 'mean_action_processing_ms': 19.299814500179682, 'mean_env_wait_ms': 29.72847054005817, 'mean_env_render_ms': 0.0}","{'episode_reward_max': 141.73487854003906, 'episode_reward_min': -6.2609405517578125, 'episode_reward_mean': 99.64384879943259, 'episode_len_mean': 197.44051446945338, 'episode_media': {}, 'episodes_this_iter': 311, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {'rewards/0_mean': 99.64384879943259, 'rewards/0_min': -6.2609405517578125, 'rewards/0_max': 141.73487854003906, 'rewards/1_mean': 99.64384879943259, 'rewards/1_min': -6.2609405517578125, 'rewards/1_max': 141.73487854003906, 'rewards/2_mean': 99.64384879943259, 'rewards/2_min': -6.2609405517578125, 'rewards/2_max': 141.73487854003906, 'rewards/3_mean': 99.64384879943259, 'rewards/3_min': -6.2609405517578125, 'rewards/3_max': 141.73487854003906, 'agent 0/pos_rew_mean': 99.64384879943259, 'agent 0/pos_rew_min': -6.2609405517578125, 'agent 0/pos_rew_max': 141.73487854003906, 'agent 0/ground_rew_mean': 0.0, 'agent 0/ground_rew_min': 0.0, 'agent 0/ground_rew_max': 0.0, 'agent 1/pos_rew_mean': 99.64384879943259, 'agent 1/pos_rew_min': -6.2609405517578125, 'agent 1/pos_rew_max': 141.73487854003906, 'agent 1/ground_rew_mean': 0.0, 'agent 1/ground_rew_min': 0.0, 'agent 1/ground_rew_max': 0.0, 'agent 2/pos_rew_mean': 99.64384879943259, 'agent 2/pos_rew_min': -6.2609405517578125, 'agent 2/pos_rew_max': 141.73487854003906, 'agent 2/ground_rew_mean': 0.0, 'agent 2/ground_rew_min': 0.0, 'agent 2/ground_rew_max': 0.0, 'agent 3/pos_rew_mean': 99.64384879943259, 'agent 3/pos_rew_min': -6.2609405517578125, 'agent 3/pos_rew_max': 141.73487854003906, 'agent 3/ground_rew_mean': 0.0, 'agent 3/ground_rew_min': 0.0, 'agent 3/ground_rew_max': 0.0}, 'hist_stats': {'episode_reward': [117.07628440856934, 93.00224018096924, 116.0483570098877, 113.12100219726562, 109.36699295043945, 75.40217590332031, 123.17409896850586, 108.72024536132812, 115.05589008331299, 109.92096710205078, 79.85285186767578, 118.79481506347656, 130.62065410614014, 64.2946548461914, 58.69377899169922, 14.982879638671875, 93.89892578125, 98.98825073242188, 112.01651763916016, 119.82271575927734, 128.41402053833008, 79.90019226074219, 117.13457489013672, 117.35860443115234, 100.3802843093872, 95.09027290344238, 126.27787017822266, 126.06783294677734, 125.4280891418457, 135.2369384765625, 116.9592514038086, 106.66745281219482, 114.20149230957031, 105.5972900390625, 92.87841415405273, 63.0732536315918, 111.09952163696289, 115.36222839355469, 97.53299522399902, 112.32647895812988, 113.09326648712158, 5.3201904296875, 93.76148986816406, 116.50341606140137, 117.34215354919434, 125.92927169799805, 80.92390632629395, 57.2237548828125, 69.23227119445801, 87.05670166015625, 109.56949615478516, 105.14801788330078, 127.53852462768555, 117.05529022216797, 126.34617233276367, 70.83329010009766, 105.60393524169922, 99.51790237426758, 117.95056533813477, 110.15217590332031, 108.36830234527588, 74.40993881225586, 106.08880996704102, 111.71199035644531, 129.34356689453125, 117.26565551757812, 115.62464904785156, 60.24730682373047, 113.12211418151855, 85.47795104980469, 104.64946365356445, 84.17298603057861, 87.41825103759766, 101.14694881439209, 83.84781646728516, 117.29898071289062, 114.02300262451172, 60.63333511352539, 56.87834167480469, 89.46420669555664, 82.97163200378418, 102.17825698852539, 87.85149765014648, 93.42079448699951, 115.1454963684082, 132.74742698669434, 109.44936752319336, 117.98036193847656, 119.08479690551758, 117.1807632446289, 115.49753189086914, 76.57540893554688, 108.33441162109375, 117.22947692871094, 132.00692558288574, 98.90068054199219, 118.39411926269531, 14.444847106933594, 108.29718017578125, 111.18363189697266, 116.71433639526367, 115.18730926513672, 101.27680206298828, 109.97918319702148, 57.97298812866211, 88.78361511230469, 108.71320343017578, 9.400802612304688, 91.9937686920166, 132.56836700439453, 120.79468536376953, 20.652801513671875, 141.26381492614746, 87.47630310058594, 120.73787689208984, 99.53800773620605, 117.64168548583984, 115.69616317749023, 121.49663543701172, 116.57160186767578, 82.01677703857422, 84.24993896484375, 98.28810214996338, 111.34930419921875, 63.37464904785156, 112.82472229003906, 87.77971649169922, 111.5920295715332, 82.38740539550781, 112.0631217956543, 102.03842163085938, 110.16379165649414, 120.22867584228516, 76.16581726074219, 85.29930877685547, 100.02140235900879, 122.97039699554443, 95.63889694213867, 97.79639625549316, 100.66753387451172, 113.52615356445312, 116.59344482421875, -6.2609405517578125, 118.02297973632812, 105.95014572143555, 118.93058776855469, 102.16992568969727, 42.452239990234375, 105.828857421875, 126.16203308105469, 104.24900817871094, 75.74565505981445, 138.75174713134766, 121.76358413696289, 113.9814224243164, 97.04019165039062, 56.02977752685547, 64.58610153198242, 1.580230712890625, 105.53238677978516, 81.27451705932617, 95.98115539550781, 24.524131774902344, 10.918853759765625, 131.31895446777344, 113.03656768798828, 111.95329666137695, 111.07004261016846, 117.0172233581543, 125.7429084777832, 87.4390058517456, 117.99787139892578, 101.94406795501709, 111.33692932128906, 114.73672866821289, 104.9920539855957, 55.10760498046875, 116.36240768432617, 99.28965759277344, 123.17544555664062, 118.1320686340332, 118.23887252807617, 122.23132038116455, 115.03525066375732, 106.04404735565186, 114.39415550231934, 126.40606307983398, 1.1255950927734375, 87.20348358154297, 112.58648300170898, 82.70046615600586, 118.55970001220703, 119.95545196533203, 96.17192459106445, 112.03681182861328, 103.47395324707031, 111.6278076171875, 107.08552360534668, 113.71221160888672, 96.38127899169922, 70.23542022705078, 116.43207550048828, 94.46032333374023, 66.73311614990234, 120.61705780029297, 66.55838012695312, 89.37339210510254, 119.23163223266602, 117.31189727783203, 114.76184463500977, 113.97962951660156, 118.67873001098633, 116.92880249023438, 120.33320617675781, 119.79291343688965, -5.6340789794921875, 127.44141006469727, 92.38697814941406, 108.32819366455078, 73.40590286254883, 107.57826232910156, 103.09602355957031, 112.38995361328125, 64.8662338256836, 106.19150733947754, 117.78337478637695, 97.08245468139648, 124.4818115234375, 86.75025177001953, 125.18252182006836, 120.95628356933594, 128.63893699645996, 101.31093978881836, 109.28397369384766, 123.55352783203125, 102.82972717285156, 117.75715637207031, 106.3121109008789, 96.04522705078125, 116.45892333984375, 113.59090900421143, -4.5597686767578125, 110.57996368408203, 114.68016052246094, 89.32744598388672, 88.42433547973633, 137.8973159790039, 59.52796936035156, 41.957275390625, 113.18428421020508, 63.25364685058594, 102.19526863098145, 85.27239227294922, 117.37494659423828, 91.28302764892578, 113.89111328125, 107.20068550109863, 119.61580657958984, 107.76805877685547, 118.10227584838867, 54.0421142578125, 20.241790771484375, 78.71988105773926, 105.80561351776123, 99.95496559143066, 122.45042037963867, 29.008209228515625, 1.2701416015625, 118.01552963256836, 126.80803298950195, 115.02727317810059, 114.57110214233398, 115.73232078552246, 102.80484771728516, 129.0581932067871, 101.76990509033203, 141.73487854003906, 118.62461853027344, 60.67420959472656, 116.21189880371094, 75.29787826538086, 107.99716091156006, 56.375587463378906, 133.87681198120117, 102.17572402954102, 129.68142318725586, 119.16714096069336, 118.52069473266602, 119.1226577758789, 17.066757202148438, 116.18502044677734, 117.27503967285156, 122.46009826660156, 123.42143249511719, 7.4161224365234375, 117.71757507324219, 93.61946105957031, 105.1984748840332, 93.37629890441895, 65.90878677368164, 112.88870239257812, 113.14134979248047, 115.05036544799805, 112.28030300140381, 115.25590133666992, 117.42527770996094, 95.26692962646484, 116.20146560668945, 135.28208923339844, 107.35464477539062, 85.88555145263672], 'episode_lengths': [200, 160, 200, 200, 200, 200, 200, 200, 196, 200, 200, 200, 199, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 170, 200, 200, 192, 200, 200, 200, 189, 200, 200, 200, 200, 200, 200, 200, 200, 176, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 188, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 140, 200, 177, 200, 200, 200, 200, 200, 200, 200, 200, 155, 166, 200, 188, 200, 200, 200, 200, 200, 200, 185, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 170, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 150, 200, 175, 200, 200, 200, 200, 200, 200, 200, 200, 200, 181, 167, 166, 159, 200, 200, 200, 200, 143, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 151, 200, 200, 200, 200, 196, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 181, 200, 165, 200, 200, 188, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 176, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 170, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 185, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 44.55475229622829, 'mean_inference_ms': 3.412697662916898, 'mean_action_processing_ms': 19.299814500179682, 'mean_env_wait_ms': 29.72847054005817, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",28554.7,101.724,28554.7,"{'training_iteration_time_ms': 99196.919, 'load_time_ms': 27.45, 'load_throughput': 2185772.317, 'learn_time_ms': 49205.141, 'learn_throughput': 1219.385, 'synch_weights_time_ms': 2.088}",1673342403,0,14160000,236,591fb_00000,9.7637


2023-01-10 17:20:28,786	ERROR tune.py:758 -- Trials did not complete: [PPO_balance_591fb_00000]
2023-01-10 17:20:28,787	INFO tune.py:763 -- Total run time: 28636.93 seconds (28636.69 seconds for the tuning loop).
