# Some preliminary checks

In [None]:
import torch
import tensorflow as tf
import os

os.environ["RAY_DEDUP_LOGS"] = "0"

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

print("MPS Available:", torch.backends.mps.is_available())

In [None]:
import psutil

# print number of gpus / CPUs
print("Number of CPUs: ", psutil.cpu_count())


num_cpus = 21
num_gpus = 8
num_learner_workers = 2

# Training

## Environement and algorithm configuration

Some of the commented lines are preparation work to use a futur feature of RLLib

In [None]:
from ray.rllib.policy.policy import PolicySpec
#from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
#from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.tune.registry import get_trainable_cls

from custom_env import CustomEnvironment
from config import run_config

# Algorithm used and framework
ALGO = "PPO"        
FRAMEWORK= "torch" # "tf2" or "torch"

# Generate the env with the configuration
env = CustomEnvironment(run_config["env"])

# Generate the algorithm config for the tuner
config = (
    get_trainable_cls(ALGO)
    .get_default_config()
    .environment(CustomEnvironment, env_config=run_config["env"])
    .framework(FRAMEWORK)
    .training(
        _enable_learner_api=False, 
        num_sgd_iter=5, 
        sgd_minibatch_size=256,             # the batch size
        train_batch_size=32768,             # the number of step collected
        model={"fcnet_hiddens": [64, 64, 64]},
        #lr=tune.grid_search([0.01, 0.001, 0.0001])
    )
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
    #rl_module_api
    .rl_module(
        _enable_rl_module_api=True,
#        rl_module_spec=MultiAgentRLModuleSpec(
#            module_specs={
#                "prey": SingleAgentRLModuleSpec(
#                    module_class=PPOTorchRLModule,
#                    observation_space=env.observation_space,
#                    action_space=env.action_space,
#                    model_config_dict={"fcnet_hiddens": [64, 64, 64]},
#                    catalog_class=PPOCatalog
#                ),
#                "predator": SingleAgentRLModuleSpec(
#                    module_class=PPOTorchRLModule,
#                    observation_space=env.observation_space,
#                    action_space=env.action_space,
#                    model_config_dict={"fcnet_hiddens": [64, 64, 64]},
#                    catalog_class=PPOCatalog
#                ),
#            }
#        ),
    )
    .rollouts(
        rollout_fragment_length="auto", # explained here : https://docs.ray.io/en/latest/rllib/rllib-sample-collection.html
        batch_mode= 'truncate_episodes',
        num_rollout_workers=num_cpus-1,
        num_envs_per_worker=1,
        #create_env_on_local_worker=False,
    )
    # learner_api
    .training(
        _enable_learner_api=True, 
    )
    .resources(
        #num_gpus = num_gpus,
        #num_gpus_per_worker=0,
        #num_cpus_per_worker=2,
        # learner workers when using learner api - doesn't work on arm (mac) yet
        num_learner_workers=num_gpus,
        num_gpus_per_learner_worker=1, # always 1 for PPO
        #num_cpus_per_learner_worker=1,
    )
    .checkpointing(export_native_model_files=True)
)

## To load a previously trained policy

Possibility to load a previously saved checkpoint and start the training from there

In [None]:
from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks

path_to_checkpoint = None
def restore_weights(path_to_checkpoint, policy_type):
    checkpoint_path = os.path.join(path_to_checkpoint, f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

if path_to_checkpoint is not None: 
    class RestoreWeightsCallback(DefaultCallbacks):
        def __init__(self):
            self.restored_policy_predator_weights = restore_weights(path_to_checkpoint, "predator")
            self.restored_policy_prey_weights = restore_weights(path_to_checkpoint,"prey")
    
        def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
            algorithm.set_weights({"predator": self.restored_policy_predator_weights})
            algorithm.set_weights({"prey": self.restored_policy_prey_weights})

    config.callbacks(RestoreWeightsCallback)



## Launch training

In [None]:
import ray 
from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.rllib.utils.test_utils import check_learning_achieved

ray.init(
    num_cpus=num_cpus, 
    num_gpus=num_gpus
)

print("num CPUS rays sees :", ray.cluster_resources().get('CPU', 0))
print("num GPUS rays sees :", ray.cluster_resources().get('GPU', 0))

# Define experiment    
tuner = tune.Tuner(
    ALGO,                                                 # Defined before
    param_space=config,                                   # Defined before
    run_config=train.RunConfig(                           ## RUN CONFIG ##
        stop={                                            # Stop criterium
            "training_iteration": 5000,
            "timesteps_total": 200000,
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(                   # To use Wanddb
            project="marl-rllib", 
            group="PPO",
            api_key="90dc2cefddde123eaac0caae90161981ed969abe",
            log_config=True,
        )],
        checkpoint_config=train.CheckpointConfig(         # When to save the models 
            checkpoint_at_end=True,
            checkpoint_frequency=10
        ),
    ),
)

# Run the experiment 
results = tuner.fit()

ray.shutdown()


### Retrieve checkpoint

In [None]:
best_checkpoint = results.get_best_result().checkpoint
best_checkpoint

In [None]:
import ray
from ray.rllib.algorithms.algorithm import Algorithm

path_to_checkpoint = best_checkpoint

# This does ray.init()
algo = Algorithm.from_checkpoint(path_to_checkpoint)

# After loading the algorithm
available_policy_ids = list(algo.workers.local_worker().policy_map.keys())
print("Available Policy IDs:", available_policy_ids)

## Episode animation

In [None]:
import numpy as np

def process_observations(observation, agent_ids, truncation=None):
    loc_x = [observation[key][2] if key in observation else 0 for key in agent_ids]
    loc_y = [observation[key][3] if key in observation else 0 for key in agent_ids]
    heading = [observation[key][4] if key in observation else 0 for key in agent_ids]
    if truncation:
        still_in_the_game = [1 if not truncation[key] else 0 for key in agent_ids]
    else:
        still_in_the_game = [1 for _ in agent_ids]
    observations["loc_x"].append(np.array(loc_x))
    observations["loc_y"].append(np.array(loc_y))
    observations["heading"].append(np.array(heading))
    observations["still_in_the_game"].append(np.array(still_in_the_game))
    
    return observations

observations = {"loc_x": [], "loc_y": [], "heading": [], "still_in_the_game": []}

observation, _ = env.reset()
agent_ids = env._agent_ids
loc_x, loc_y, heading, still_in_the_game = process_observations(observation, agent_ids)
step_count = 1


while step_count < 500:
    actions = {
        key: algo.compute_single_action(
            value, policy_id="prey" if env.agents[key].agent_type == 0 else "predator"
        ) for key, value in observation.items()
    }
    
    observation, _, termination, truncation, _ = env.step(actions)
    
    observations = process_observations(observation, agent_ids, truncation)
    
    step_count += 1

stage_size = env.stage_size
observations["loc_x"] = np.array(observations["loc_x"]) * stage_size
observations["loc_y"] = np.array(observations["loc_y"]) * stage_size
observations["heading"] = np.array(observations["heading"])
observations["still_in_the_game"] = np.array(observations["still_in_the_game"])

env.close()
ray.shutdown()


In [None]:
import importlib
import animation
importlib.reload(animation)

from animation import generate_animation_3d

ani = generate_animation_3d(observations, env, fps=20)

In [None]:
from IPython.display import HTML

HTML(ani.to_html5_video())

## Network visualization

In [None]:
algo.get_policy(available_policy_ids[0]).get_weights()['_hidden_layers.1._model.0.weight'].shape
# from algo.get_policy(available_policy_ids[0]).get_weights() get the weights which have _value_branch in their key
nn_weights = {}
for key, value in algo.get_policy(available_policy_ids[0]).get_weights().items():
    if "_value_branch" not in key:
        nn_weights[key] = value

nn_weights.keys()

In [None]:
import numpy as np

from graph_tool.all import *

def plot_mlp(neural_network):
    g = Graph(directed=True)
    
    # Create property maps for vertex and edge labels and edge width
    v_label = g.new_vertex_property("string")
    e_width = g.new_edge_property("double")
    pos = g.new_vertex_property("vector<double>")
    
    max_neurons = max(len(neural_network[key]) for key in neural_network if 'weight' in key)
    
    ## VERTEX ##
    # Add vertices for input layer and set their positions
    input_neurons = [g.add_vertex() for _ in neural_network['_hidden_layers.0._model.0.weight'].T]
    starting_y = (max_neurons - len(input_neurons)) / 2
    for i, v in enumerate(input_neurons):
        pos[v] = (0, starting_y + len(input_neurons) - 1 - i)
    
    # Get all the keys that contain "_bias"
    biases_keys = [key for key in neural_network if ".bias" in key and "_hidden" in key]
    # Initial x position
    x_position = 20
    current_neurons = []
    for i, biases_key in enumerate(biases_keys):
        # Add vertices for the current layer and set their positions
        current_neurons.append([g.add_vertex() for _ in neural_network[biases_key]])
        starting_y = (max_neurons - len(current_neurons[i])) / 2
        for j, v in enumerate(current_neurons[i]):
            pos[v] = (x_position, starting_y + len(current_neurons[i]) - 1 - j)
    
        # Increment x position for the next layer
        x_position += 20

    # Add vertex for output layer and set its position
    output_neurons = [g.add_vertex() for _ in neural_network['_logits._model.0.bias']]
    starting_y = (max_neurons - len(output_neurons)) / 2
    for i, v in enumerate(output_neurons):
        pos[v] = (x_position, starting_y + len(output_neurons) - 1 - i)

    ## LABELS ##
    # Set labels and add edges for input-hidden layer
    for i, input_neuron in enumerate(input_neurons):
        for j, hidden_neuron in enumerate(current_neurons[0]):
            e = g.add_edge(input_neuron, hidden_neuron)
            weight = neural_network['_hidden_layers.0._model.0.weight'].T[i][j]
            e_width[e] = weight
            
    # Set labels and add edges for hidden-hidden layer
    weights_keys = [key for key in neural_network if ".weight" in key and "_hidden" in key and not "layers.0" in key]
    for k, weights_key in enumerate(weights_keys):
        for i, hidden_neuron in enumerate(current_neurons[k]):
            for j, next_hidden_neuron in enumerate(current_neurons[k+1]):
                e = g.add_edge(hidden_neuron, next_hidden_neuron)
                weight = neural_network[weights_key].T[i][j]
                e_width[e] = weight

    # Set labels and add edges for hidden-output layer
    for j, output_neuron in enumerate(output_neurons):
        for i, hidden_neuron in enumerate(current_neurons[-1]):
            e = g.add_edge(hidden_neuron, output_neuron)
            weight = neural_network['_logits._model.0.weight'].T[i][j]
            e_width[e] = weight
        
    
    # Set neuron labels (optional, for clarity)
    for v in input_neurons:
        v_label[v] = "I"
    for k, hidden_neurons in enumerate(current_neurons):
        for v in hidden_neurons:
            v_label[v] = "H"
    for v in output_neurons:
        v_label[v] = "O"
    
    # Draw the graph
    graph_draw(g, pos=pos, vertex_text=v_label, edge_text=None, edge_pen_width=e_width, vertex_size=15, vertex_font_size=10, edge_font_size=10, output_size=(800, 800))
    
# Example usage with the same nn_wandb
plot_mlp(nn_weights)