### Some preliminary checks

In [None]:
import torch
import ray 

#os.environ["RAY_DEDUP_LOGS"] = "0"

print("Ray version :", ray.__version__)
print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("MPS Available:", torch.backends.mps.is_available())

torch._dynamo.list_backends()

### To modify : Number of CPUs and GPUs available

In [None]:
import psutil

print("Number of CPUs: ", psutil.cpu_count())

num_cpus = 12
num_gpus = 0

assert num_cpus <= psutil.cpu_count()

# Environement and algorithm configuration

In [None]:
from ray.rllib.policy.policy import PolicySpec

from ray.tune.registry import get_trainable_cls

from importlib import reload
import particle_2d_env
reload(particle_2d_env)
from particle_2d_env import Particle2dEnvironment
from particle_2d_env import MetricsCallbacks, RenderingCallbacks
from config import run_config

ALGO = "PPO"        
FRAMEWORK= "torch" # always "torch"
env = Particle2dEnvironment(run_config["env"])

def create_callbacks():
    return [RenderingCallbacks(), MetricsCallbacks()]

config = (
    get_trainable_cls(ALGO).get_default_config()
    .environment(Particle2dEnvironment, env_config=run_config["env"])
    .framework(FRAMEWORK,)
    .api_stack(enable_rl_module_and_learner=True,enable_env_runner_and_connector_v2=True,)
#    .callbacks(RenderingCallbacks)
    # Specify the learner's hyperparameters.
    .training(
        num_epochs=10,
        train_batch_size_per_learner=512, 
    )
    .rl_module(
        model_config={
            "fcnet_hiddens": [128, 128, 128], 
            "use_attention": True,
        },
    )
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.particule_agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"],
        count_steps_by="agent_steps",
    )
    .learners(
        num_learners=4,  # or >2
        num_cpus_per_learner=1,  # <- default 1
        num_gpus_per_learner=0,  # <- default 0
    )
    .resources(num_cpus_for_main_process=1)  # <- default  1
    .env_runners(
        rollout_fragment_length="auto", #"auto" for PPO explained here : https://docs.ray.io/en/latest/rllib/rllib-sample-collection.html
        batch_mode= 'truncate_episodes',
        num_env_runners=1, # need 2 for IMPALA, 1 for PPO
        num_envs_per_env_runner=1,
    )
    .checkpointing(export_native_model_files=True)
)


# Training

In [None]:
checkpoint_folder = None # is something like "PPO_2024-12-18_20-12-15"

## Launch training

In [None]:
from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback
import os

ray.init(
    num_cpus=num_cpus, 
    num_gpus=num_gpus,
    ignore_reinit_error = True,
)

# Read the API key from the file to use Wanddb
with open('wandb_api_key.txt', 'r') as file:
    api_key = file.read().strip()
callbacks = [
    WandbLoggerCallback(                   
        project="marl-rllib", 
        group=ALGO,
        api_key=api_key,
        log_config=True,
        upload_checkpoints=True
    ), 
]

# Where to save 
# absolute path + ray_results directory
storage_path=os.getcwd() + "/ray_results"

if checkpoint_folder is None : 
    tuner = tune.Tuner(
        trainable = ALGO,                                     # Defined before
        param_space=config,                                   # Defined before
        run_config=train.RunConfig(    
            storage_path=storage_path,
            stop={"training_iteration": 1500},
            verbose=3,
            callbacks=callbacks,
            checkpoint_config=train.CheckpointConfig(         
                checkpoint_at_end=True,
                checkpoint_frequency=10,
            ),
        ),
    )

# If we start a training that failed
else: 
    path = storage_path + "/" + checkpoint_folder
    # Restore the training
    tuner = tune.Tuner.restore(
        trainable = ALGO,
        path = path, 
        resume_unfinished=True, 
        resume_errored=True,
        restart_errored=False,
    )
    

# Run the experiment 
results = tuner.fit()

ray.shutdown()
