In [None]:
import numpy as np

from tqdm import tqdm

from gymnasium.spaces import Box, Space

from tabulate import tabulate

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.data import GenData
type_chart = GenData(9).type_chart

from poke_env.player import (
    Gen9EnvSinglePlayer,
    MaxBasePowerPlayer,
    ObsType,
    RandomPlayer,
    SimpleHeuristicsPlayer,
    background_cross_evaluate,
    background_evaluate_player,
)

In [None]:
class SimpleRLPlayer(Gen9EnvSinglePlayer):
    def __init__(self, env_config):
        self.env_config = env_config.copy()
        opponent_class = self.env_config.get('opponent')
        # TODO: add full opponent config to env_config, for now:
        # if opponent provided, set it to the same battle format.
        if opponent_class is not None:
            # instantiate the opponent class before passing it to the superclass
            self.env_config['opponent'] = opponent_class(battle_format=env_config.get('battle_format'))

        super().__init__(**self.env_config)


    def calc_reward(self, last_battle, current_battle) -> float:
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    def embed_battle(self, battle: AbstractBattle) -> ObsType:
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=type_chart
                )

        # We count how many pokemons have fainted in each team
        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )

In [None]:
## Create config
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.algorithm import Algorithm


# This is passed to each environment (SimpleRLPlayer) during training
train_env_config = {
    'opponent': RandomPlayer,
    'battle_format': "gen9randombattle",
    'start_challenging': True,
}
# This is passed to each environment (SimpleRLPlayer) during evaluation
eval_env_config = {
    'opponent': RandomPlayer,
    'battle_format': "gen9randombattle",
    'start_challenging': True
}

config = DQNConfig()
config = config.environment(env = SimpleRLPlayer, env_config = train_env_config, disable_env_checking=True)
config = config.training(
    # train_batch_size_per_learner=1,
    replay_buffer_config={
        "type": "MultiAgentPrioritizedReplayBuffer",
        "capacity": 10000,
    },
    num_steps_sampled_before_learning_starts=500,
    n_step=1,
    double_q=True,
    num_atoms=1,
    noisy=False,
    dueling=False,
)
# config = config.learners(
#     num_learners=1, # set to number of GPUs
#     num_gpus_per_learner=1, # set to 1, if have at least 1 gpu
# )
# config = config.env_runners()
# config = config.rollouts()
config = config.evaluation(
    evaluation_interval=5,
    # evaluation_num_workers=1,
    # evaluation_parallel_to_training=True,
    # evaluation_duration="auto",
    evaluation_config={
        "env_config": eval_env_config,
        "metrics_num_episodes_for_smoothing": 4,
        # Set explore True for policy gradient algorithms
        "explore": False,
    },
)

stop = {
    "training_iteration": 50
}

In [None]:
# Auto hyperparameter tuning
from ray import tune, train
import os

analysis = tune.Tuner(
    config.algo_class,
    param_space=config,
    run_config=train.RunConfig(
        name="DQN_SimpleRL_Test",
        storage_path=os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'results')),
        stop=stop,
        checkpoint_config=train.CheckpointConfig(
            checkpoint_frequency=1,
            num_to_keep=10,
            checkpoint_at_end=True
        )
    ),
).fit()

In [None]:
best_result = analysis.get_best_result(metric="env_runners/episode_return_mean", mode="max")
eval_checkpoint = best_result.checkpoint

In [None]:
## Evaluate algorithm
# Against random player
eval_env_config = {
    'opponent': RandomPlayer,
    'battle_format': "gen9randombattle",
    'start_challenging': True
}

n_battles = 100
# Restore the model checkpoint in eval_checkpoint
# TODO: stop .from_checkpoint from running env verification, it creates envs and does not close them properly.
# Alternatively, make a callback that closes them.
eval_alg = Algorithm.from_checkpoint(eval_checkpoint)
# Create evaluation environment.
eval_env = SimpleRLPlayer(env_config=eval_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = eval_env.reset()
    while not terminated and not truncated and not (eval_env.current_battle is None) and not eval_env.current_battle.finished:
        action = eval_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = eval_env.step(action)
        episode_reward += reward

print("\n\nResults against random player:")
print(
    f"DQN Evaluation: {eval_env.n_won_battles} victories out of {eval_env.n_finished_battles} episodes"
)

# Against max base power player
eval_env_config = {
    'opponent': MaxBasePowerPlayer,
    'battle_format': "gen9randombattle",
    'start_challenging': True
}

# Create evaluation environment.
eval_env = SimpleRLPlayer(env_config=eval_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = eval_env.reset()
    while not terminated and not truncated and not (eval_env.current_battle is None) and not eval_env.current_battle.finished:
        action = eval_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = eval_env.step(action)
        episode_reward += reward

print("\nResults against Max Base Power player:")
print(
    f"DQN Evaluation: {eval_env.n_won_battles} victories out of {eval_env.n_finished_battles} episodes"
)
eval_env.reset_env(restart=False)

In [None]:
# Evaluate the player with included util method
# Random player
# eval_env_config = {
#     'opponent': RandomPlayer,
#     'battle_format': "gen9randombattle",
#     'start_challenging': True,
# }
# eval_env = SimpleRLPlayer(eval_env_config)


# n_challenges = 10
# placement_battles = 4
# eval_task = background_evaluate_player(
#     eval_env.agent, n_challenges, placement_battles
# )

# # Make a new config for evaluation from the train config
# eval_config = alg.config.copy(copy_frozen=False)
# # Increase duration of evaluation
# eval_config.evaluation_duration = n_challenges
# eval_config.evaluation_interval = 1
# eval_config.env = eval_env
# alg.reset_config(eval_config) # Unsure if this resets weights, I want to reset everything but that.

# alg.evaluate()
# print("Evaluation with included method:", eval_task.result())
# eval_env.reset_env(restart=False)

In [None]:
# Cross evaluate the player with included util method
# n_challenges = 50
# players = [
#     eval_env.agent,
#     RandomPlayer(battle_format="gen9randombattle"),
#     MaxBasePowerPlayer(battle_format="gen9randombattle"),
#     SimpleHeuristicsPlayer(battle_format="gen9randombattle"),
# ]
# cross_eval_task = background_cross_evaluate(players, n_challenges)
# dqn.test(
#     eval_env,
#     nb_episodes=n_challenges * (len(players) - 1),
#     verbose=False,
#     visualize=False,
# )
# cross_evaluation = cross_eval_task.result()
# table = [["-"] + [p.username for p in players]]
# for p_1, results in cross_evaluation.items():
#     table.append([p_1] + [cross_evaluation[p_1][p_2] for p_2 in results])
# print("Cross evaluation of DQN with baselines:")
# print(tabulate(table))
# eval_env.close()