# Some preliminary checks

In [1]:
import torch
import os

#os.environ["RAY_DEDUP_LOGS"] = "0"

import ray 

print("Ray version :", ray.__version__)

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())

print("MPS Available:", torch.backends.mps.is_available())

torch._dynamo.list_backends()

Ray version : 2.39.0
PyTorch Version: 2.5.1
CUDA Available: False
MPS Available: True


['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm']

### Important : Number of CPUs and GPUs available

In [2]:
import psutil

# print number of gpus / CPUs
print("Number of CPUs: ", psutil.cpu_count())

num_cpus = 12
num_gpus = 0
num_learner = 1

assert num_cpus <= psutil.cpu_count()

Number of CPUs:  12


# Environement and algorithm configuration

Some of the commented lines are preparation work to use a futur feature of RLLib

Note: In multi-agent environments, `rollout_fragment_lenght` sets the batch size based on (across-agents) environment steps, not the steps of individual agents, which can result in unexpectedly large batches.


In [3]:
from ray.rllib.policy.policy import PolicySpec
#from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
#from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec

from ray.tune.registry import get_trainable_cls

from particle_2d_env import Particle2dEnvironment
from particle_2d_env import MetricsCallbacks, RenderingCallbacks
from config import run_config

ALGO = "PPO"        
FRAMEWORK= "torch" # always "torch"
env = Particle2dEnvironment(run_config["env"])

config = (
    get_trainable_cls(ALGO).get_default_config()
    .environment(Particle2dEnvironment, env_config=run_config["env"])
    .framework(FRAMEWORK,)
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    #.callbacks(MetricsCallbacks)
    #.callbacks(RenderingCallbacks)
    # Specify the learner's hyperparameters.
    .training(
        num_sgd_iter=5,          
        num_epochs=10,
        train_batch_size_per_learner=512,             # the number of step collected
        model={
            "fcnet_hiddens": [128, 128, 128], 
            "use_attention": True,
            #"use_lstm": False,
            #"max_seq_len": 5,
            #"lstm_cell_size": 16,
        },
        #lr=tune.grid_search([0.01, 0.001, 0.0001])
    )
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.particule_agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"],
        count_steps_by="agent_steps",
    )
    .learners(
        num_learners=4,  # or >2
        num_cpus_per_learner=1,  # <- default
        num_gpus_per_learner=0,  # <- default
    )
    .resources(num_cpus_for_main_process=1)  # default is 1
    .env_runners(
        rollout_fragment_length="auto", #"auto" for PPO explained here : https://docs.ray.io/en/latest/rllib/rllib-sample-collection.html
        batch_mode= 'truncate_episodes',
        num_env_runners=1, # need 2 for IMPALA, 1 for PPO
        num_envs_per_env_runner=1,
    )
    .checkpointing(export_native_model_files=True)
)


  gym.logger.warn(
  gym.logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


# Training

In [4]:
from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks

path_to_checkpoint = None #os.getcwd() + "/ray_results" + "PPO_2024-05-18_00-08-19/PPO_Particle2dEnvironment_bb60c_00000_0_2024-05-18_00-08-19/checkpoint_000001"

def restore_weights(path, policy_type):
    checkpoint_path = os.path.join(path, f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

if path_to_checkpoint is not None: 
    class RestoreWeightsCallback(DefaultCallbacks):
        def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
            algorithm.set_weights({"predator": restore_weights(path_to_checkpoint, "predator")})
            algorithm.set_weights({"prey": restore_weights(path_to_checkpoint, "prey")})

    config.callbacks(RestoreWeightsCallback)



## Launch training

In [None]:
from ray import train, tune
from ray.tune import Tuner
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.rllib.utils.test_utils import check_learning_achieved
import os

ray.init(
    num_cpus=num_cpus, 
    num_gpus=num_gpus
)

# Stop criterium
stop = {  
    "training_iteration": 1500,
    #"timesteps_total": 200000000,
}

# Read the API key from the file to use Wanddb
with open('wandb_api_key.txt', 'r') as file:
    api_key = file.read().strip()
callbacks = [
    WandbLoggerCallback(                   
        project="marl-rllib", 
        group=ALGO,
        api_key=api_key,
        log_config=True,
        upload_checkpoints=True
    ), 
]

# When to save the models 
checkpoint_config = train.CheckpointConfig(         
    checkpoint_at_end=True,
    checkpoint_frequency=10,
)

# Where to save 
# absolute path + ray_results directory
storage_path=os.getcwd() + "/ray_results"

if path_to_checkpoint is None : 
    tuner = tune.Tuner(
        ALGO,                                                 # Defined before
        param_space=config,                                   # Defined before
        run_config=train.RunConfig(    
            storage_path=storage_path,
            stop=stop,
            verbose=3,
            callbacks=callbacks,
            checkpoint_config=checkpoint_config,
        ),
    )
    # Run the experiment 
    results = tuner.fit()

# If we instantiate previously trained neural network
else: 
    callbacks.append(RestoreWeightsCallback)

    results = tune.run(
        ALGO,
        config=config.to_dict(),
        storage_path=storage_path,
        stop=stop,
        verbose=3,
        callbacks=callbacks,
        checkpoint_config=checkpoint_config,
    )


ray.shutdown()


2024-12-01 22:02:25,152	INFO worker.py:1819 -- Started a local Ray instance.
2024-12-01 22:02:25,794	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


0,1
Current time:,2024-12-01 22:06:50
Running for:,00:04:25.13
Memory:,24.1/64.0 GiB

Trial name,status,loc,iter,total time (s),num_env_steps_sample d_lifetime,num_episodes_lifetim e,num_env_steps_traine d_lifetime
PPO_Particle2dEnvironment_91321_00000,RUNNING,127.0.0.1:34265,13,239.373,26624,38,26624


[36m(_WrappedExecutable pid=34278)[0m Setting up process group for: env:// [rank=0, world_size=4]
[36m(PPO pid=34265)[0m Trainable.setup took 10.040 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
[36m(PPO pid=34265)[0m Install gputil for GPU system monitoring.
[36m(_WandbLoggingActor pid=34296)[0m wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[36m(_WandbLoggingActor pid=34296)[0m wandb: Currently logged in as: tanguy-cazalets (tcazalet_airo). Use `wandb login --relogin` to force relogin
[36m(_WandbLoggingActor pid=34296)[0m wandb: Tracking run with wandb version 0.18.7
[36m(_WandbLoggingActor pid=34296)[0m wandb: Run data is saved locally in /private/tmp/ray/session_2024-12-01_22-02-24_528317_34233/artifacts/2024-12-01_22-02-25/PPO_2024-12-01_22-02-25/driver_artifacts/PPO_Particle2dEnvironment_91321_00000_0_2024-12-01_22-02-25/wandb/ru

Trial name,date,done,env_runners,fault_tolerance,hostname,iterations_since_restore,learners,node_ip,num_agent_steps_sampled_lifetime,num_env_steps_sampled_lifetime,num_env_steps_trained_lifetime,num_episodes_lifetime,perf,pid,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,training_iteration,trial_id
PPO_Particle2dEnvironment_91321_00000,2024-12-01_22-06-43,False,"{'agent_episode_returns_mean': {'22': -0.3422, '17': -0.047700000000000006, '11': -0.1582, '12': -0.09910000000000001, '14': -0.2768, '29': -0.1886, '0': -0.09559999999999999, '5': -0.033, '15': -0.12140000000000001, '20': -0.34090000000000004, '7': -0.4571, '23': -0.10880000000000001, '2': -0.011100000000000004, '1': -0.39880000000000004, '4': -0.4071, '28': -0.0917, '25': -0.48700000000000004, '30': 2.733600000000001, '16': -0.1284, '10': -0.25189999999999996, '19': -0.2658, '9': -0.0105, '8': -0.20740000000000003, '24': -0.3827, '18': -0.26289999999999997, '27': -0.1081, '13': -0.012600000000000002, '21': -0.22149999999999997, '31': 2.4890000000000017, '26': 0.0, '6': -0.20299999999999996, '3': -0.251}, 'num_agent_steps_sampled': {'23': 2048, '2': 2048, '1': 2048, '19': 2048, '9': 2048, '24': 2048, '18': 1619, '27': 2048, '16': 2048, '13': 2048, '10': 2048, '31': 2048, '6': 2048, '8': 1360, '3': 2048, '21': 2048, '17': 2048, '11': 2048, '26': 2048, '29': 2048, '0': 2048, '15': 2048, '20': 2048, '7': 2048, '22': 2039, '4': 1381, '12': 2030, '28': 2048, '25': 2048, '14': 2048, '30': 2048, '5': 2048}, 'episode_return_max': 0.0, 'num_agent_steps_sampled_lifetime': {'7': 179482, '23': 179219, '2': 186368, '1': 175889, '4': 177269, '19': 177811, '25': 175960, '18': 183893, '27': 180552, '16': 184816, '13': 186368, '10': 182402, '31': 186368, '6': 179155, '9': 186368, '8': 181040, '24': 178818, '21': 176928, '17': 183508, '11': 179900, '26': 186368, '29': 181506, '15': 178218, '3': 178439, '20': 177645, '22': 179485, '12': 184010, '28': 186332, '14': 179590, '30': 186368, '0': 185060, '5': 185120}, 'episode_return_min': -1.95999999999996, 'agent_steps': {'22': 655.64, '21': 651.82, '17': 682.87, '11': 687.31, '29': 672.61, '0': 686.95, '5': 697.94, '15': 677.5, '20': 652.87, '7': 665.86, '2': 700.0, '1': 662.26, '4': 635.0, '12': 692.84, '28': 699.66, '25': 641.35, '14': 672.8, '30': 700.0, '23': 680.36, '10': 688.54, '19': 664.27, '9': 700.0, '24': 669.4, '18': 679.94, '27': 677.92, '16': 694.21, '13': 700.0, '31': 700.0, '26': 700.0, '6': 680.76, '8': 680.08, '3': 646.81}, 'num_env_steps_sampled_lifetime': 186368, 'module_episode_returns_mean': {'predator': 2.9235000000000015, 'prey': -1.1293}, 'episode_len_max': 700, 'episode_return_mean': -0.7482999999999972, 'num_env_steps_sampled': 2048, 'num_module_steps_sampled': {'prey': 59629, 'predator': 4096}, 'episode_len_min': 700, 'episode_duration_sec_mean': 3.5693788301096356, 'num_module_steps_sampled_lifetime': {'prey': 5437519, 'predator': 372736}, 'num_episodes': 3, 'episode_len_mean': 700.0}","{'num_healthy_workers': 1, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}",MacBook-Pro-de-Tanguy.local,13,"{'predator': {'entropy': 2.8857398629188538, 'curr_kl_coeff': 1.0125000476837158, 'gradients_default_optimizer_global_norm': 2.0316648483276367, 'policy_loss': -0.0591624160297215, 'vf_explained_var': 0.045291975140571594, 'default_optimizer_learning_rate': 5e-05, 'num_non_trainable_parameters': np.float64(0.0), 'vf_loss': 0.09536299761384726, 'total_loss': 0.05085933208465576, 'module_train_batch_size_mean': np.float64(1027.0128524198406), 'num_module_steps_trained': 4110, 'vf_loss_unclipped': 0.22920488473027945, 'mean_kl_loss': 0.01447778451256454, 'diff_num_grad_updates_vs_sampler_policy': 12.0, 'curr_entropy_coeff': 0.0, 'num_trainable_parameters': np.float64(157445.0)}, 'prey': {'default_optimizer_learning_rate': 5e-05, 'num_non_trainable_parameters': np.float64(0.0), 'vf_loss': 0.007650006533367559, 'total_loss': 0.04594599606934935, 'module_train_batch_size_mean': np.float64(15368.419181790803), 'num_module_steps_trained': 59841, 'vf_loss_unclipped': 0.007650006533367559, 'mean_kl_loss': 0.014044933952391148, 'entropy': 2.820167601108551, 'diff_num_grad_updates_vs_sampler_policy': 12.0, 'num_trainable_parameters': np.float64(157445.0), 'curr_entropy_coeff': 0.0, 'vf_explained_var': -0.4392000436782837, 'curr_kl_coeff': 0.15000000223517418, 'gradients_default_optimizer_global_norm': 0.7767826318740845, 'policy_loss': 0.03620670617965516}, '__all_modules__': {'num_trainable_parameters': np.float64(314890.0), 'num_env_steps_trained': 2048, 'learner_connector_timer': np.float64(0.48037207465868903), 'num_non_trainable_parameters': np.float64(0.0), 'num_module_steps_trained': 63951}}",127.0.0.1,"{'0': 26188, '1': 25363, '10': 26240, '11': 26029, '12': 26372, '13': 26624, '14': 25611, '15': 25839, '16': 26430, '17': 26052, '18': 25539, '19': 25431, '2': 26624, '20': 25050, '21': 25015, '22': 25136, '23': 25795, '24': 25600, '25': 24665, '26': 26624, '27': 25886, '28': 26606, '29': 25536, '3': 24847, '30': 26624, '31': 26624, '4': 23955, '5': 26520, '6': 25971, '7': 25273, '8': 25500, '9': 26624}",26624,26624,38,"{'cpu_util_percent': np.float64(27.149999999999995), 'ram_util_percent': np.float64(37.357692307692304)}",34265,239.373,18.291,239.373,"{'env_runner_sampling_timer': 11.129676189961314, 'learner_update_timer': 7.921476422495362, 'synch_weights': 0.003975342283382807, 'synch_env_connectors': 0.0033017806574292914, 'training_iteration_time_sec': 18.33869779109955, 'restore_workers_time_sec': 5.817413330078125e-06, 'training_step_time_sec': 18.33798348903656}",1733087203,13,91321_00000


[36m(PPO pid=34265)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/tanguy/Code/Finebouche/collective_behavior/ray_results/PPO_2024-12-01_22-02-25/PPO_Particle2dEnvironment_91321_00000_0_2024-12-01_22-02-25/checkpoint_000000)
[36m(_WandbLoggingActor pid=34296)[0m wandb: Adding directory to artifact (/Users/tanguy/Code/Finebouche/collective_behavior/ray_results/PPO_2024-12-01_22-02-25/PPO_Particle2dEnvironment_91321_00000_0_2024-12-01_22-02-25/checkpoint_000000)... 
[36m(_WandbLoggingActor pid=34296)[0m Done. 0.2s
