# Example from RLLib

# Our Environmeent

In [None]:
import os

import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.algorithms.ppo import PPOConfig

from custom_env import CustomEnvironment
from config import run_config



class Args:
    def __init__(self):
        self.run = "PPO"
        self.framework = "torch" # "tf2" or "torch"
        self.stop_iters = 50
        self.stop_timesteps = 100000
        self.stop_reward = 0.1
        self.as_test = False

args = Args()

ray.init()
env = CustomEnvironment(run_config["env"])

config = (
    PPOConfig()
    .rollouts(rollout_fragment_length="auto", num_rollout_workers=0)
    .environment(CustomEnvironment, env_config=run_config["env"])
    .framework(args.framework)
    .training(num_sgd_iter=10, sgd_minibatch_size=256, train_batch_size=4000)
    .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
    .rl_module(_enable_rl_module_api=True)
    .training(_enable_learner_api=True)
    .resources(num_gpus=0)
)


stop = {
    "training_iteration": args.stop_iters,
    "timesteps_total": args.stop_timesteps,
    "episode_reward_mean": args.stop_reward,
}

tuner = tune.Tuner(
    args.run,
    param_space=config.to_dict(),
    run_config=air.RunConfig(stop=stop, verbose=3),
)
results = tuner.fit()

if args.as_test:
    print("Checking if learning goals were achieved")
    check_learning_achieved(results, args.stop_reward)
ray.shutdown()



# A mockup minimal environment

import os

from ray.rllib.algorithms.ppo import PPOConfig
from env import Environment
from ray.rllib.policy.policy import PolicySpec

env = Environment({})


config = (
    PPOConfig()
    .rollouts(rollout_fragment_length="auto", num_rollout_workers=0)
    .environment(Environment)
    .framework("torch")
    .multi_agent(
        policies= {
            "0": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space,  # infer automatically from env
                action_space=env.action_space,  # infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "1": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space,  # infer automatically from env
                action_space=env.action_space,  # infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
        },
        policy_mapping_fn = lambda agent_id, episode, worker: "0" if agent_id == "0" else "1",
    )
    .resources(num_gpus=0)
)

my_ma_algo = config.build()
my_ma_algo.train()
