# Some preliminary checks

In [1]:
import torch
import tensorflow as tf
import os

os.environ["RAY_DEDUP_LOGS"] = "0"

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

PyTorch Version: 2.1.0
CUDA Available: False
CUDA Version: None


In [2]:
import psutil

# print number of gpus / CPUs
print("Number of CPUs: ", psutil.cpu_count())

num_cpus = 12
num_gpus = 0

Number of CPUs:  12


# Training

In [3]:
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import get_trainable_cls

from custom_env import CustomEnvironment
from config import run_config

## The RLlib configuration
class Args:
    def __init__(self):
        self.run = "PPO"
        self.framework = "torch" # "tf2" or "torch"
args = Args()

## Generate the configuration
env = CustomEnvironment(run_config["env"])

config = (
    get_trainable_cls(args.run)
    .get_default_config()
    .environment(CustomEnvironment, env_config=run_config["env"])
    .framework(args.framework)
    .training(_enable_learner_api=True, num_sgd_iter=10, sgd_minibatch_size=256, train_batch_size=20000)
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
    .rl_module(_enable_rl_module_api=True)
    .rollouts(
        rollout_fragment_length="auto",
        batch_mode= 'truncate_episodes',
        num_rollout_workers=num_cpus-1,
        num_envs_per_worker=1,
        #create_env_on_local_worker=False,
    )
    # This as to be specified everytime (don't know how to automatically ajust)
    .resources(
        #num_gpus = num_gpus,
        #num_gpus_per_worker=0,
        #num_cpus_per_worker=2,
        # learner workers
        #num_learner_workers=num_gpus,
        #num_gpus_per_learner_worker=1,
        #num_cpus_per_learner_worker=0,
    )
    .checkpointing(export_native_model_files=True)
)
config.exploration_config = {}



2023-10-08 23:05:01,234	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-10-08 23:05:01,289	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


In [None]:
import ray 
from ray import air, tune
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.rllib.utils.test_utils import check_learning_achieved

ray.init(num_cpus=num_cpus, num_gpus=num_gpus)

print("num CPUS rays sees :", ray.cluster_resources().get('CPU', 0))
print("num GPUS rays sees :", ray.cluster_resources().get('GPU', 0))

opti_config = {
    'stop_iters': 500,
    'stop_timesteps': 10000000,
    'stop_reward': 0.1,
    'as_test': False
}

## Run the experiemnt    
tuner = tune.Tuner(
    args.run,
    param_space=config,
    run_config=air.RunConfig(
        stop={
            "training_iteration": opti_config["stop_iters"],
            "timesteps_total": opti_config["stop_timesteps"],
            "episode_reward_mean": opti_config["stop_reward"],
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(
            project="marl-rllib", 
            group="PPO",
            api_key="90dc2cefddde123eaac0caae90161981ed969abe",
            log_config=True,
        )],
        checkpoint_config=air.CheckpointConfig(
            checkpoint_at_end=True,
            checkpoint_frequency=10
        ),
    ),
)
results = tuner.fit()

if opti_config["as_test"]:
    print("Checking if learning goals were achieved")
    check_learning_achieved(results, opti_config["stop_reward"])
ray.shutdown()


2023-10-08 23:05:03,374	INFO worker.py:1642 -- Started a local Ray instance.
2023-10-08 23:05:03,817	INFO tune.py:645 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


num CPUS rays sees : 12.0
num GPUS rays sees : 0


0,1
Current time:,2023-10-08 23:07:36
Running for:,00:02:32.73
Memory:,21.1/64.0 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_CustomEnvironment_59e38_00000,RUNNING,127.0.0.1:44998,2,122.455,40000,-15.7471,-13.1752,-18.0806,1000


[2m[36m(RolloutWorker pid=45001)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45004)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45005)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45010)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45006)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45002)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45008)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45009)[0m   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
[2m[36m(RolloutWorker pid=45007)[0m   gym.logger.warn(f"Box bound pre

Trial name,agent_timesteps_total,checkpoint_dir_name,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_sampled_throughput_per_sec,num_env_steps_trained,num_env_steps_trained_this_iter,num_env_steps_trained_throughput_per_sec,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_total,training_iteration,trial_id
PPO_CustomEnvironment_59e38_00000,606329,,"{'ObsPreprocessorConnector_ms': 0.006701368274110736, 'StateBufferConnector_ms': 0.003612041473388672, 'ViewRequirementAgentConnector_ms': 0.4773656527201335}","{'num_env_steps_sampled': 40000, 'num_env_steps_trained': 0, 'num_agent_steps_sampled': 606329, 'num_agent_steps_trained': 0}",{},2023-10-08_23-07-17,False,1000,{},-13.1752,-15.7471,-18.0806,22,33,MacBook-Pro-de-Tanguy.local,"{'learner': {'__all__': {'num_agent_steps_trained': 512.0, 'num_env_steps_trained': 273597.0, 'total_loss': 0.27068477739164504}, 'prey': {'total_loss': 0.27068477739164504, 'policy_loss': -0.0013974461996663658, 'vf_loss': 0.04968034114312784, 'vf_loss_unclipped': 0.16782034388466477, 'vf_explained_var': -0.8716005432993887, 'entropy': 2.8677008319534583, 'mean_kl_loss': 0.0067088654718796175, 'default_optimizer_lr': 5.000000000000002e-05, 'curr_lr': 5e-05, 'curr_entropy_coeff': 0.0, 'curr_kl_coeff': 0.10000000149011612}, 'predator': {'total_loss': 0.22173099595255882, 'policy_loss': -0.004570574517268409, 'vf_loss': 0.22484674687812167, 'vf_loss_unclipped': 0.7706875287265978, 'vf_explained_var': -0.546073189635432, 'entropy': 2.6696010577063003, 'mean_kl_loss': 0.007274116323971793, 'default_optimizer_lr': 5.000000000000002e-05, 'curr_lr': 5e-05, 'curr_entropy_coeff': 0.0, 'curr_kl_coeff': 0.20000000298023224}}, 'num_env_steps_sampled': 40000, 'num_env_steps_trained': 0, 'num_agent_steps_sampled': 606329, 'num_agent_steps_trained': 0}",2,127.0.0.1,606329,0,40000,20000,321.196,0,0,0,0,11,0,0,0,"{'cpu_util_percent': 25.52934782608696, 'ram_util_percent': 32.62173913043478}",44998,"{'prey': -1.0251783437068376, 'predator': 68.89013974748939}","{'prey': -4.093891080340487, 'predator': 22.83063343309827}","{'prey': -11.080116384082547, 'predator': -1.1080296546613446}","{'mean_raw_obs_processing_ms': 2.1939569612040453, 'mean_inference_ms': 1.2477843403823796, 'mean_action_processing_ms': 0.5297219811578259, 'mean_env_wait_ms': 0.5158033607569157, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -13.175176681382728, 'episode_reward_min': -18.08055482624098, 'episode_reward_mean': -15.747099338910871, 'episode_len_mean': 1000.0, 'episode_media': {}, 'episodes_this_iter': 22, 'policy_reward_min': {'prey': -11.080116384082547, 'predator': -1.1080296546613446}, 'policy_reward_max': {'prey': -1.0251783437068376, 'predator': 68.89013974748939}, 'policy_reward_mean': {'prey': -4.093891080340487, 'predator': 22.83063343309827}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-16.756132346618443, -14.501565454718891, -13.175176681382728, -18.08055482624098, -16.144062496855508, -13.824497821645368, -16.82879556655911, -16.255226058367786, -14.959426502814711, -14.573104679344599, -14.388639133580924, -15.547259474536718, -16.289217920126678, -15.34391587861511, -15.297621451297937, -17.190442560896557, -15.778501798514519, -14.406197921360242, -15.121161600087737, -16.990076999776697, -16.271901995584162, -15.08656900148258, -14.210986728809635, -16.935501941889058, -17.63501686421306, -16.751746859128286, -16.87354837408199, -14.110544765486335, -16.525795587064398, -16.096287366657762, -15.853084279926215, -15.27019188769467, -16.581525358699256], 'episode_lengths': [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000], 'policy_prey_reward': [-1.0331809450425662, -1.0860311765285002, -1.0789160309113015, -10.222585812305075, -1.0885190013877946, -1.0739105231401669, -1.0732489156569536, -1.0921225519214195, -10.368044143944816, -1.0671625879372197, -1.0743417182025696, -1.0650602115897558, -1.0663941795219765, -1.0770590675955622, -1.0607061164649316, -1.0821987554556747, -1.0848810905756954, -10.420628465479055, -1.0681116237797124, -10.425461987005278, -1.0757866610307094, -10.157271320075582, -1.0609468503549733, -1.0731122445915158, -1.0555006932373276, -1.0689417432039938, -10.002679324257274, -10.829902065476357, -10.881944503631733, -10.999749005103649, -10.534504407751927, -11.03015036433386, -1.088265576768338, -1.0952239511781974, -10.267192763511902, -10.490237850936511, -10.327538104389735, -10.946902892854187, -1.0461746198133122, -10.387415904345465, -1.0613581861701888, -10.039393239729876, -10.875296853356485, -1.051760158678675, -10.720347902505472, -1.0756334594062185, -1.0645100181970246, -1.0899318246859064, -10.835528933298074, -1.1052971028622822, -1.0935394734006152, -1.05324173418818, -1.0847921509260492, -1.075193977480347, -1.0813992174008693, -1.0633979132813636, -1.056373391940204, -1.0762870498672583, -1.0649791260724908, -1.0595737924450823, -1.094810772502805, -1.070551043421539, -1.0889233094808686, -1.0894892644213041, -1.0968734273495644, -1.08028321533418, -1.1011121648661713, -1.0893459529983396, -1.0636918780305766, -1.046717367581855, -10.001578636852912, -1.065097808501408, -10.435214924681675, -1.0382654871442476, -10.546708025052403, -1.0546830216455978, -1.0660283451351482, -10.421723676368654, -1.078780986362438, -10.525737337785818, -10.282051633589887, -10.495801967362983, -11.041343584654998, -10.540902876047618, -1.072167454529644, -10.594919937685368, -10.698722929322676, -1.0777409789408703, -10.59864107627411, -1.0704507058293986, -11.069523123008477, -1.0696898476332704, -1.0657603951898384, -1.0592689197045273, -10.246358951575024, -10.520070486708974, -1.084551716538982, -1.0752379471859184, -1.074385990102115, -11.018031971886245, -1.0744173002660773, -1.049011022667481, -1.0580235560426003, -1.105692210906728, -1.050753924840243, -1.0685528394060921, -1.0974548755489446, -1.0709307731568498, -1.067206933308821, -1.0745401224521294, -1.0964873965080473, -1.0688704030264784, -10.605201167113501, -11.038410598668671, -1.0871369979025678, -1.0593136256754514, -10.191100188901954, -11.036677773165614, -1.059840007403642, -10.410413122707675, -1.086456617850566, -1.0848153059093792, -10.238008443105928, -1.0967298865843054, -10.029639907701181, -1.0977946536028251, -1.04952792370286, -10.14885455101178, -1.0738451948389245, -1.096330897602225, -10.491248218191128, -11.038143249770254, -1.0531462287490492, -1.0727804920049868, -1.0915315413079392, -1.0705191992724175, -10.346475414439308, -1.0675691012650572, -1.0885210634597642, -1.0626541854905296, -10.260028689813842, -1.0911778973047912, -1.0955455639095235, -1.0772161096820287, -10.762526730696404, -1.1000030670951262, -10.056710141223254, -10.1136117261094, -1.0797200374329123, -1.0698569554556527, -1.0875865715681536, -1.0747585710234329, -10.708892946959644, -1.0696853480482797, -1.0538281273238694, -1.0798889888222356, -10.077134661005415, -1.056110884205031, -10.546677954866807, -1.0594030584564555, -1.0616993124425214, -10.93242470427781, -10.544810870029686, -10.32867728424329, -10.487663347988416, -11.032988880008473, -1.0532305309962444, -1.032304231928622, -1.0675332959210646, -1.0701143048528192, -1.0829052839801632, -10.125207151610613, -1.0732650076783383, -1.0816046692982013, -10.106750850453206, -10.908353891548918, -1.0762415930680718, -1.0807956563674659, -10.468791558694804, -1.0892709236147555, -1.071031633294851, -10.223798484888022, -10.781359694306623, -10.314183767491823, -10.759060875615104, -1.0904938120830452, -1.101493877525135, -1.0943390850332124, -1.0954513902724794, -1.0874638838981416, -1.0820925768288494, -1.1324423847928926, -1.0855930919846002, -1.0530611151282576, -1.0734353140123216, -1.0824154252071485, -10.156379174043112, -1.0631446653753724, -10.443814870394803, -1.0584384778256806, -1.056463558517698, -1.0932530087144305, -1.078616884204146, -1.0561677939284764, -1.0695846397663573, -10.115308941063704, -1.0656875226190943, -1.0921905952732047, -10.65260864577935, -1.043370558268691, -10.275200712145809, -1.1030842716582419, -1.1237186264107342, -1.1017016166736695, -10.895568645003584, -1.095275205719677, -10.2220516197573, -1.0815797631074242, -10.690317332937543, -10.954565928086895, -1.07430776302436, -10.244988928689487, -1.0620254802046767, -1.0889384907138109, -1.0849331413533418, -1.099542304851256, -1.0775055240924467, -10.816821402818642, -1.1139692262020477, -1.0758072580973295, -1.0859942505342628, -1.0638160618382515, -1.0717224846739772, -1.058588739171481, -1.034445930937675, -1.0438263378370192, -1.0701232140873527, -10.233039964132711, -1.0649232267328612, -1.055990050708078, -10.418955918670664, -1.071139071235715, -1.088018817559707, -11.06552623613043, -10.778600559642424, -1.1126688317547029, -10.105082457500899, -1.1202502688682165, -1.0994461030915788, -1.079490075100534, -10.899697289723072, -1.0942535602906753, -10.782162337073043, -10.79416768407992, -1.0522895344591814, -10.727284288621458, -1.0797150696275517, -1.064747055315219, -10.268011536774102, -1.0769981289728325, -10.07542744137072, -10.734900029493206, -1.084700598854851, -1.0868192608492488, -10.204039863428642, -1.0568488364650064, -1.0973648517565793, -1.0728077711182864, -10.489902778106178, -1.0653807236561825, -1.0893663575192816, -10.081559907742147, -1.097795501862053, -10.24162414226307, -1.1134425345350745, -1.0975926132401295, -1.1018975612727933, -1.103875634928367, -10.47700120138679, -1.117644688299702, -1.1027771728192113, -1.111529853132419, -10.203379354260182, -10.958365719193896, -10.990315680954216, -1.0678439173837893, -10.605509207802507, -1.0710954197335492, -1.0848326036300446, -1.056003209170052, -1.063675972076112, -1.0251783437068376, -1.0601557793960052, -1.067128819325811, -10.557563348548003, -10.832154117090113, -1.082364045906554, -1.089852597606391, -1.0544241197694946, -1.058258686592715, -1.0986305796519686, -11.030445948977993, -10.269730307033916, -10.785521939128925, -10.61987471596054, -1.0781846508987525, -1.0970568297995427, -1.0683428921405729, -1.0747155335060072, -1.0971013906278453, -1.0453802238968126, -10.81460180311684, -1.1220632464763642, -10.886156256261005, -10.954616130474184, -10.437778855647435, -1.0782178645986653, -10.31036740562049, -1.0422205486829084, -10.300392699516927, -1.0970732989710428, -1.0914356243308754, -1.0835491738929817, -10.460555234660893, -1.0583068998970806, -10.643434764014234, -1.0541899988003718, -1.0680066339242633, -1.0712837761645477, -1.0608347260771638, -11.080116384082547, -10.178144175702183, -1.0942055754941147, -1.1097166428732128, -1.0612465318179742, -1.1070964554444882, -1.0784736836104596, -10.518948397713206, -1.105649819577462, -10.194769264054672, -1.0681956155739556, -10.00448676015266, -10.196998033101181, -1.0904792372508165, -1.08460157769112, -1.0522642262963628, -10.250799331155555, -1.0768496918496917, -1.1028430096892006, -1.1124953973289073, -1.0866590115567276, -1.0764072353099479, -1.077892774471119, -1.0631948094227426, -1.0881285177813302, -1.0689045967546262, -10.770643950154495, -1.0900806402824859, -1.073184652126815, -10.725107310351612, -1.0896536809882025, -1.105865653861617, -1.0610572042265085, -1.0573719842025784, -10.261347240616423, -1.0954742319918571, -1.0627651853433815, -1.0642439081008852, -1.1208572130240173, -1.0860202818862466, -1.0919104205410075, -1.074644416493442, -1.074148818718204, -1.1086528095312784, -1.0465381447954814, -1.072173292630277, -1.0957140561838072, -1.0819561797230328, -1.047599800749802, -1.0598242689960942, -1.106086770014188, -1.0693152445555658, -1.049747197663337, -10.018408074352514, -1.0931741488770654, -1.058224307875032, -10.553282968293228, -1.0815606376486249, -1.0736422516293267, -1.0753129752076267, -10.542678271757229, -1.0744645227535807, -1.0760884340381736, -1.0877113431059178, -1.0822539159846, -1.0520669762561417, -11.071470682856416, -10.124150261358897, -1.084819956964208, -1.0805294636689489, -1.0950709786462736, -1.0855809759740778, -1.0626456023869664, -1.078049753131857, -1.0792649714837772, -1.057968325860867, -10.404959376308058, -1.056982699726126, -1.0967754544663486, -10.648941789269553, -10.25461019837555, -1.076753815543661, -1.0906039378574213, -1.0673579935241597, -10.50694669934502, -1.0521762509081305, -10.104995355880456, -10.377745402744035, -1.0478684755599577, -1.0947743051787946, -1.1017748907847629, -1.0644419112718693, -1.1101896617159093, -1.0812190396822912, -1.067605814285048, -10.529119514951821, -1.0825226799334773, -1.070467573519321, -1.085133736241524, -10.564836081843872, -1.0680683499808779, -1.1077525642883637, -10.184100517755843, -1.0787584836300497, -1.0984898781580335, -1.0602201104993965, -1.070217796841316, -1.0838192378267666, -1.1037669669110257, -10.686615023317847, -1.0613742024079371, -10.909715937169691, -1.0712382439537511, -1.1054196747916665, -10.387303128216809, -1.0822213154582763, -10.163430604144821, -1.0995192392664201, -1.0512325209544706, -10.980254566671388, -10.800669105186888, -1.1225541252609037, -10.317138548700388, -1.082434751424946, -1.1152824515216304, -1.0982077658477494, -1.0824508662337542, -1.0974416827478752, -10.088605572152103, -1.0760360847461707, -1.1104979727825082, -10.739215327151687, -1.0612861143050403, -1.1213684820135252, -10.67590401463751, -10.095235243943245, -10.851877676193856, -10.94656307888867, -1.070046876501675, -1.05828544502263, -1.0559387881790852, -1.0723341473376231, -10.387781380216355, -10.550830681193416, -1.0846760573699317, -1.0954202589620736, -10.611353406974729, -1.051444250260534, -1.0681857704205857, -1.0794349667848908, -10.743637259139097, -1.1253872456642624, -1.0934218217845364, -1.0652221285053343, -10.809323540216358, -1.0708240635979969, -1.1057796789764367, -1.115845529813247, -1.1228142050919772, -10.39094324706491, -1.0711494600264087, -10.72732546017665, -1.1135275793684372, -10.703327741956784, -1.10196379357755], 'policy_predator_reward': [-1.1007851384663407, 18.871935773998683, 28.89376661509801, 38.89178426344177, 28.89644634745207, 68.89013974748939, 8.899654699330442, -1.1005303601196093, 8.870033835753986, 18.894566945610507, 48.90514680767333, 38.890051882216625, 18.900563604227504, 18.891418193469878, 28.90572010590034, 18.87119066067836, 28.90157144799706, 18.887855161121767, 8.875849218106913, 38.89318198519854, 48.88512726223081, 18.89547623544951, 18.914888318781962, 28.88721003670337, 8.881436928630475, 28.874646138398266, 8.894550274449255, 28.88897860791711, 28.90632208498825, 28.894313989200477, 8.890647355619231, 8.885026060199605, 28.902797269932588, 38.880449676733825, 38.882233531547854, 18.89651678150211, 8.873765564168652, 48.89324075915313, 8.888858023738333, 18.89710516422296, 28.88927111735901, 38.881249335008214, 28.883584049905906, 18.887494453411584, 28.87119924045364, 28.890942184877005, 8.890537699783946, 18.88941551285859, -1.1080296546613446, 8.873563984769282, 18.88539706005956, -1.1011217447882673, -1.0922346314504443, 28.8955323677355, 38.925011095161366, 18.903904219900564, 8.882504041401841, 18.88618106957687, 28.904738380972702, 18.915322820801283, 28.876424012422714, 18.859584572363943, 28.89585882503507, 28.91335731551968, 38.86727688289295, 8.911690513367786]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 2.1939569612040453, 'mean_inference_ms': 1.2477843403823796, 'mean_action_processing_ms': 0.5297219811578259, 'mean_env_wait_ms': 0.5158033607569157, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.006701368274110736, 'StateBufferConnector_ms': 0.003612041473388672, 'ViewRequirementAgentConnector_ms': 0.4773656527201335}}",122.455,62.2698,122.455,"{'training_iteration_time_ms': 61225.002, 'sample_time_ms': 8838.516, 'synch_weights_time_ms': 7.191}",1696799237,40000,2,59e38_00000




# Render episode 

### Retrieve checkpoint

In [None]:
best_checkpoint = results.get_best_result().checkpoint
best_checkpoint

In [None]:
from ray.rllib.algorithms.algorithm import Algorithm

algo = Algorithm.from_checkpoint(best_checkpoint)

# After loading the algorithm
local_worker = algo.workers.local_worker()
available_policy_ids = list(local_worker.policy_map.keys())
print("Available Policy IDs:", available_policy_ids)

### Run and plot

In [None]:
import numpy as np

def process_observations(observation, agent_ids, truncation=None):
    loc_x = [observation[key][4] if key in observation else 0 for key in agent_ids]
    loc_y = [observation[key][5] if key in observation else 0 for key in agent_ids]
    if truncation:
        still_in_the_game = [1 if not truncation[key] else 0 for key in agent_ids]
    else:
        still_in_the_game = [1 for _ in agent_ids]
    observations["loc_x"].append(np.array(loc_x))
    observations["loc_y"].append(np.array(loc_y))
    observations["still_in_the_game"].append(np.array(still_in_the_game))
    
    return observations

# Use the first available policy ID
policy_id = available_policy_ids[0]

step_count = 0
observations = {"loc_x": [], "loc_y": [], "still_in_the_game": []}

observation, _ = env.reset()
agent_ids = env._agent_ids
loc_x, loc_y, still_in_the_game = process_observations(observation, agent_ids)


while step_count < 500:
    actions = {
        key: algo.compute_single_action(
            value, policy_id="prey" if env.agents[key].agent_type == 0 else "predator"
        ) for key, value in observation.items()
    }
    
    observation, _, termination, truncation, _ = env.step(actions)
    
    observations = process_observations(observation, agent_ids, truncation)
    
    step_count += 1

stage_size = env.stage_size
observations["loc_x"] = np.array(observations["loc_x"]) * stage_size
observations["loc_y"] = np.array(observations["loc_y"]) * stage_size
observations["still_in_the_game"] = np.array(observations["still_in_the_game"])

env.close()

In [None]:
import importlib

import animation

importlib.reload(animation)
from animation import generate_animation

ani = generate_animation(observations, env)

In [None]:
from IPython.display import HTML

HTML(ani.to_html5_video())

# Retrain

In [None]:
from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks

def restore_policy_and_weights(policy_type):
    checkpoint_path = os.path.join(best_checkpoint.to_directory(), f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

restored_policy_predator_weights = restore_policy_and_weights("predator")
restored_policy_prey_weights = restore_policy_and_weights("prey")

print("Starting new tune.Tuner().fit()")

ray.init()

# Start our actual experiment.
stop = {
    "episode_reward_mean": args.stop_reward,
    "timesteps_total": args.stop_timesteps,
    "training_iteration": args.stop_iters,
}

class RestoreWeightsCallback(DefaultCallbacks):
    def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
        algorithm.set_weights({"predator": restored_policy_predator_weights})
        algorithm.set_weights({"prey": restored_policy_prey_weights})

config.callbacks(RestoreWeightsCallback)

results = tune.run(
    "PPO",
    stop=stop,
    config=config.to_dict(),
    verbose=1,
)

if args.as_test:
    check_learning_achieved(results, args.stop_reward)

ray.shutdown()