In [None]:
import torch
import tensorflow as tf
import os

os.environ["RAY_DEDUP_LOGS"] = "0"

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)


In [None]:
import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.algorithms.ppo import PPOConfig

from custom_env import CustomEnvironment
from config import run_config

from ray.air.integrations.wandb import WandbLoggerCallback

## The RLlib configuration
class Args:
    def __init__(self):
        self.run = "PPO"
        self.framework = "torch" # "tf2" or "torch"
        self.stop_iters = 5
        self.stop_timesteps = 20000
        self.stop_reward = 0.1
        self.as_test = False

args = Args()

## Generate the configuration
ray.init()
env = CustomEnvironment(run_config["env"])

config = (
    PPOConfig()
    .environment(CustomEnvironment, env_config=run_config["env"])
    .framework(args.framework)
    .training(num_sgd_iter=10, sgd_minibatch_size=256, train_batch_size=4000)
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
    .rl_module(_enable_rl_module_api=True)
    .training(_enable_learner_api=True)
    .rollouts(
        rollout_fragment_length= 200,
        batch_mode= 'truncate_episodes',
        num_rollout_workers=3
    )
    .resources(
        num_gpus = ray.cluster_resources().get('GPU', 0),
        num_gpus_per_worker=0,
        num_cpus_per_worker=3,
        # learner workers
        num_learner_workers=3,
        num_gpus_per_learner_worker=0,
        num_cpus_per_learner_worker=3,
    )
    .checkpointing(export_native_model_files=True)
)



In [None]:
## Run the experiemnt    
tuner = tune.Tuner(
    args.run,
    param_space=config.to_dict(),
    run_config=air.RunConfig(
        stop={
            "training_iteration": args.stop_iters,
            "timesteps_total": args.stop_timesteps,
            "episode_reward_mean": args.stop_reward,
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(
            project="marl-rllib", 
            api_key="90dc2cefddde123eaac0caae90161981ed969abe"
        )],
        checkpoint_config=air.CheckpointConfig(
            checkpoint_at_end=True,
        ),
    ),
)
results = tuner.fit()

if args.as_test:
    print("Checking if learning goals were achieved")
    check_learning_achieved(results, args.stop_reward)
ray.shutdown()


In [None]:
best_checkpoint = results.get_best_result().checkpoint
best_checkpoint

In [None]:
from ray.rllib.policy.policy import Policy

ray.init()

policy_0_checkpoint = os.path.join(
    best_checkpoint.to_directory(), "policies/policy_0"
)
restored_policy_0 = Policy.from_checkpoint(policy_0_checkpoint)
restored_policy_0_weights = restored_policy_0.get_weights()
print("Starting new tune.Tuner().fit()")

# Start our actual experiment.
stop = {
    "episode_reward_mean": args.stop_reward,
    "timesteps_total": args.stop_timesteps,
    "training_iteration": args.stop_iters,
}

class RestoreWeightsCallback(DefaultCallbacks):
    def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
        algorithm.set_weights({"policy_0": restored_policy_0_weights})

# Make sure, the non-1st policies are not updated anymore.
config.policies_to_train = [pid for pid in policy_ids if pid != "policy_0"]
config.callbacks(RestoreWeightsCallback)

results = tune.run(
    "PPO",
    stop=stop,
    config=config.to_dict(),
    verbose=1,
)

if args.as_test:
    check_learning_achieved(results, args.stop_reward)

ray.shutdown()