# Some preliminary checks

In [None]:
import torch
import tensorflow as tf
import os
import ray 

os.environ["RAY_DEDUP_LOGS"] = "0"

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

print("MPS Available:", torch.backends.mps.is_available())
print("onnx Available:", torch.onnx.is_onnxrt_backend_supported())
torch._dynamo.list_backends()

In [None]:
import psutil

# print number of gpus / CPUs
print("Number of CPUs: ", psutil.cpu_count())

num_cpus = 34
num_gpus = 0
num_learner = 0

assert num_cpus <= psutil.cpu_count()

# Environement and algorithm configuration

Some of the commented lines are preparation work to use a futur feature of RLLib

Note: In multi-agent environments, `rollout_fragment_lenght` sets the batch size based on (across-agents) environment steps, not the steps of individual agents, which can result in unexpectedly large batches.


In [None]:
from ray.rllib.policy.policy import PolicySpec
#from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
#from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec

from ray.tune.registry import get_trainable_cls

from custom_env import CustomEnvironment
from config import run_config

ALGO = "PPO"        
FRAMEWORK= "torch" # "tf2" or "torch"
env = CustomEnvironment(run_config["env"])

config = (
    get_trainable_cls(ALGO)
    .get_default_config()
    .environment(CustomEnvironment, env_config=run_config["env"])
    .framework(
        FRAMEWORK,
    )
    .training(
        num_sgd_iter=5, 
        sgd_minibatch_size=256,             # the batch size
        train_batch_size=524288,             # the number of step collected
        model={
            "fcnet_hiddens": [64, 64, 64], 
            #"use_attention": True,
            #"use_lstm": False,
            #"max_seq_len": 5,
            #"lstm_cell_size": 16,
        },
        #lr=tune.grid_search([0.01, 0.001, 0.0001])
    )
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"],
        count_steps_by="agent_steps",
    )
    #rl_module_api
    .experimental(_enable_new_api_stack=True)
    .rl_module(
#        rl_module_spec=MultiAgentRLModuleSpec(
#            module_specs={
#                "prey": SingleAgentRLModuleSpec(
#                    module_class=PPOTorchRLModule,
#                    observation_space=env.observation_space,
#                    action_space=env.action_space,
#                    model_config_dict={"fcnet_hiddens": [64, 64, 64]},
#                    catalog_class=PPOCatalog
#                ),
#                "predator": SingleAgentRLModuleSpec(
#                    module_class=PPOTorchRLModule,
#                    observation_space=env.observation_space,
#                    action_space=env.action_space,
#                    model_config_dict={"fcnet_hiddens": [64, 64, 64]},
#                    catalog_class=PPOCatalog
#                ),
#            }
#        ),
    )
    .rollouts(
        rollout_fragment_length="auto", # explained here : https://docs.ray.io/en/latest/rllib/rllib-sample-collection.html
        batch_mode= 'truncate_episodes',
        num_rollout_workers=num_cpus-num_learner-1,
        num_envs_per_worker=2,
        #create_env_on_local_worker=False,
    )
    .resources(
        #num_gpus = num_gpus,
        #num_gpus_per_worker=0,
        #num_cpus_per_worker=2,
        # learner workers when using learner api - doesn't work on arm (mac) yet
        #num_learner_workers=num_learner,
        #num_gpus_per_learner_worker=1, # always 1 for PPO
        #num_cpus_per_learner_worker=1,
    )
    .checkpointing(export_native_model_files=True)
)


# Training

In [None]:
from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks

path_to_checkpoint = None #"/Users/tanguy/ray_results/PPO_2023-12-10_17-58-05/PPO_CustomEnvironment_49b3e_00000_0_2023-12-10_17-58-05/checkpoint_000134"

def restore_weights(path, policy_type):
    checkpoint_path = os.path.join(path, f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

if path_to_checkpoint is not None: 
    class RestoreWeightsCallback(DefaultCallbacks):
        def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
            algorithm.set_weights({"predator": restore_weights(path_to_checkpoint, "predator")})
            algorithm.set_weights({"prey": restore_weights(path_to_checkpoint, "prey")})

    config.callbacks(RestoreWeightsCallback)



## Launch training

In [None]:
from ray import train, tune
from ray.tune import Tuner
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.rllib.utils.test_utils import check_learning_achieved

ray.init(
    num_cpus=num_cpus, 
    num_gpus=num_gpus
)

# Stop criterium
stop = {  
    "training_iteration": 2000,
    #"timesteps_total": 200000000,
}

# To use Wanddb
callbacks = [WandbLoggerCallback(                   
    project="marl-rllib", 
    group="PPO",
    api_key="90dc2cefddde123eaac0caae90161981ed969abe",
    log_config=True,
)]

# When to save the models 
checkpoint_config = train.CheckpointConfig(         
    checkpoint_at_end=True,
    checkpoint_frequency=10,
)

if path_to_checkpoint is None : 
    tuner = tune.Tuner(
        ALGO,                                                 # Defined before
        param_space=config,                                   # Defined before
        run_config=train.RunConfig(                          
            stop=stop,
            verbose=3,
            callbacks=callbacks,
            checkpoint_config=checkpoint_config,
        ),
    )
    # Run the experiment 
    results = tuner.fit()
    
else: 
    callbacks.append(RestoreWeightsCallback)

    results = tune.run(
        ALGO,
        config=config.to_dict(),
        stop=stop,
        verbose=3,
        callbacks=callbacks,
        checkpoint_config=checkpoint_config,
    )


ray.shutdown()
