# Some preliminary checks

In [None]:
import torch
import tensorflow as tf
import os

os.environ["RAY_DEDUP_LOGS"] = "0"

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)


# Training

In [None]:
import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.algorithms.ppo import PPOConfig

from custom_env import CustomEnvironment
from config import run_config

from ray.air.integrations.wandb import WandbLoggerCallback

## The RLlib configuration
class Args:
    def __init__(self):
        self.run = "PPO"
        self.framework = "torch" # "tf2" or "torch"
        self.stop_iters = 5
        self.stop_timesteps = 20000
        self.stop_reward = 0.1
        self.as_test = False

args = Args()

## Generate the configuration
ray.init()
env = CustomEnvironment(run_config["env"])

config = (
    PPOConfig()
    .environment(CustomEnvironment, env_config=run_config["env"])
    .framework(args.framework)
    .training(num_sgd_iter=10, sgd_minibatch_size=256, train_batch_size=4000)
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
    .rl_module(_enable_rl_module_api=True)
    .training(_enable_learner_api=True)
    .rollouts(
        rollout_fragment_length= 200,
        batch_mode= 'truncate_episodes',
        num_rollout_workers=3
    )
    .resources(
        num_gpus = ray.cluster_resources().get('GPU', 0),
        num_gpus_per_worker=0,
        num_cpus_per_worker=2,
        # learner workers
        num_learner_workers=2,
        num_gpus_per_learner_worker=0,
        num_cpus_per_learner_worker=2,
    )
    .checkpointing(export_native_model_files=True)
)



In [None]:
## Run the experiemnt    
tuner = tune.Tuner(
    args.run,
    param_space=config.to_dict(),
    run_config=air.RunConfig(
        stop={
            "training_iteration": args.stop_iters,
            "timesteps_total": args.stop_timesteps,
            "episode_reward_mean": args.stop_reward,
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(
            project="marl-rllib", 
            api_key="90dc2cefddde123eaac0caae90161981ed969abe"
        )],
        checkpoint_config=air.CheckpointConfig(
            checkpoint_at_end=True,
        ),
    ),
)
results = tuner.fit()

if args.as_test:
    print("Checking if learning goals were achieved")
    check_learning_achieved(results, args.stop_reward)
ray.shutdown()


# Render episode 

### Retrieve checkpoint

In [None]:
best_checkpoint = results.get_best_result().checkpoint
best_checkpoint

In [None]:
from ray.rllib.algorithms.algorithm import Algorithm

algo = Algorithm.from_checkpoint(best_checkpoint)

# After loading the algorithm
local_worker = algo.workers.local_worker()
available_policy_ids = list(local_worker.policy_map.keys())
print("Available Policy IDs:", available_policy_ids)

### Run and plot

In [None]:
import numpy as np

def process_observations(observation, agent_ids, truncation=None):
    loc_x = [observation[key][4] if key in observation else 0 for key in agent_ids]
    loc_y = [observation[key][5] if key in observation else 0 for key in agent_ids]
    if truncation:
        still_in_the_game = [1 if not truncation[key] else 0 for key in agent_ids]
    else:
        still_in_the_game = [1 for _ in agent_ids]
    observations["loc_x"].append(np.array(loc_x))
    observations["loc_y"].append(np.array(loc_y))
    observations["still_in_the_game"].append(np.array(still_in_the_game))
    
    return observations

# Use the first available policy ID
policy_id = available_policy_ids[0]

step_count = 0
observations = {"loc_x": [], "loc_y": [], "still_in_the_game": []}

observation, _ = env.reset()
agent_ids = env._agent_ids
loc_x, loc_y, still_in_the_game = process_observations(observation, agent_ids)


while step_count < 500:
    actions = {
        key: algo.compute_single_action(
            value, policy_id="prey" if env.agents[key].agent_type == 0 else "predator"
        ) for key, value in observation.items()
    }
    
    observation, _, termination, truncation, _ = env.step(actions)
    
    observations = process_observations(observation, agent_ids, truncation)
    
    step_count += 1

stage_size = env.stage_size
observations["loc_x"] = np.array(observations["loc_x"]) * stage_size
observations["loc_y"] = np.array(observations["loc_y"]) * stage_size
observations["still_in_the_game"] = np.array(observations["still_in_the_game"])

env.close()

In [None]:
import importlib

import animation

importlib.reload(animation)
from animation import generate_animation

ani = generate_animation(observations, env)

In [None]:
from IPython.display import HTML

HTML(ani.to_html5_video())

# Retrain

In [None]:
from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks

def restore_policy_and_weights(policy_type):
    checkpoint_path = os.path.join(best_checkpoint.to_directory(), f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

restored_policy_predator_weights = restore_policy_and_weights("predator")
restored_policy_prey_weights = restore_policy_and_weights("prey")

print("Starting new tune.Tuner().fit()")

ray.init()

# Start our actual experiment.
stop = {
    "episode_reward_mean": args.stop_reward,
    "timesteps_total": args.stop_timesteps,
    "training_iteration": args.stop_iters,
}

class RestoreWeightsCallback(DefaultCallbacks):
    def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
        algorithm.set_weights({"predator": restored_policy_predator_weights})
        algorithm.set_weights({"prey": restored_policy_prey_weights})

config.callbacks(RestoreWeightsCallback)

results = tune.run(
    "PPO",
    stop=stop,
    config=config.to_dict(),
    verbose=1,
)

if args.as_test:
    check_learning_achieved(results, args.stop_reward)

ray.shutdown()