#### Required Packages
PyTorch or Tensorflow, preferably CUDA-enabled to utilize the GPU\
"ray[rllib]", requires Python 3.9 to 3.11. Note: ray[rllib] for Windows is currently in beta, and Windows version for Python 3.11 is experimental.\
poke-env ipython ipykernel ipywidgets tqdm tensorboard

Note: RLLib is currently in the process of updating to 3.0, which changes to a new API stack. However, the new API is currently only available for PPO and SAC algorithms. This notebook currently uses the old API stack in order to use the DQN algorithm.

#### Start Local Pokemon Showdown Server
cd into your pokemon-showdown directory  
node pokemon-showdown start --no-security

In [None]:
%load_ext tensorboard

In [None]:
import numpy as np

from tqdm import tqdm

import time
from typing import List, Any, Union

from gymnasium.spaces import Box, Space

# from tabulate import tabulate

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.data import GenData
type_chart = GenData(9).type_chart
from poke_env.ps_client.account_configuration import AccountConfiguration, CONFIGURATION_FROM_PLAYER_COUNTER
from poke_env.player import (
    Gen9EnvSinglePlayer,
    MaxBasePowerPlayer,
    ObsType,
    RandomPlayer,
    SimpleHeuristicsPlayer
)

from ray.rllib.env.env_context import EnvContext

In [None]:
# helper functions

def create_account_configuration(username: str, unique_ids: Union[List[Any], None] = None) -> AccountConfiguration:
    # If no unique ids provided, create a unique id using the counter.
    if unique_ids is None:
        # NOTE: This is not thread-safe! Provide unique ids for multithreading / multi workers.
        # The 's' in the resulting username indicates single thread.
        CONFIGURATION_FROM_PLAYER_COUNTER.update([username])
        unique_username = "%s s %d" % (username, CONFIGURATION_FROM_PLAYER_COUNTER[username])
        if len(unique_username) > 18:
            unique_username = "%s s%d" % (
                username[: 18 - len(unique_username)],
                CONFIGURATION_FROM_PLAYER_COUNTER[username],
            )
    else:
        unique_ids_str = ' '.join(map(str, unique_ids))
        unique_username = "%s %s" % (username, unique_ids_str)
        unique_username = unique_username.strip()
        if len(unique_username) > 18:
            unique_username = "%s %s" % (
                username[: 18 - len(unique_username)],
                unique_ids_str,
            )
        
    return AccountConfiguration(unique_username.strip(), None)

In [None]:
class SimpleRLPlayer(Gen9EnvSinglePlayer):
    def __init__(self, env_config: Union[EnvContext, dict]):
        # Create list of the unique ids of the worker running this env.
        if type(env_config) is EnvContext:
            # The unique ids from EnvContext are the worker index, and the vector index.
            # Worker index is index of the rollout worker that this env is running on, when there are multiple workers.
            # Vector index is index of this env on this worker, when there are multiple envs per worker.
            unique_ids = [env_config.worker_index, env_config.vector_index]
        else:
            # Set it to None. This activates create_account_configuration's player counter.
            unique_ids = None

        # Create unique account configuration for this from username to prevent multithreading naming conflicts
        # account_configuration should be None, unless this is created as opponent for self-play
        account_configuration = env_config.get('account_configuration')
        if account_configuration is None:
            # If no username provided, make one from this class name. 
            username =  env_config.get('username')
            if username is None:
                username = type(self).__name__
            # Create unique account configuration from username, worker process id, and worker env id 
            account_configuration = create_account_configuration(username=username, unique_ids=unique_ids)
        
        # if opponent class provided, instantiate it with its config.
        opponent_class = env_config.get('opponent_class')
        if opponent_class is not None:
            # Create unique opponent account configuration if none provided.
            # opponent account_configuration should be None when using parallelization or multiple workers
            opponent_account_configuration = env_config.get('opponent_account_configuration')
            if opponent_account_configuration is None:
                # If no opponent username provided, make one from its class name. 
                opponent_username =  env_config.get('opponent_username')
                if opponent_username is None:
                    opponent_username = type(opponent_class).__name__
                # Create unique opponent account configuration from username, worker process index, and env index on this worker.
                opponent_account_configuration = create_account_configuration(username=opponent_username, unique_ids=unique_ids)

            # If opponent config provided, use it when instantiating opponent.
            opponent_config = env_config.get('opponent_config')
            if opponent_config is not None:
                # Instantiate the opponent class with opponent config.
                opponent = opponent_class(
                    account_configuration = opponent_account_configuration,
                    **opponent_config
                )
            else:
                # If no opponent config provided, set battle format to be same as SimpleRLPlayer.
                opponent = opponent_class(
                    account_configuration = opponent_account_configuration, 
                    battle_format=env_config.get('battle_format')
                )
        else:
            opponent = None
        # TODO: Figure out how to only pass arguments if they are not None
        super().__init__(
            opponent = opponent,
            account_configuration = account_configuration,
            # avatar = env_config.get('avatar'),
            # battle_format = env_config.get('battle_format'),
            # log_level = env_config.get('log_level'),
            # save_replays = env_config.get('save_replays'),
            # server_configuration = env_config.get('server_configuration'),
            # start_listening = env_config.get('start_listening'),
            # start_timer_on_battle_start = env_config.get('start_timer_on_battle_start'),
            # ping_interval = env_config.get('ping_interval'),
            # ping_timeout = env_config.get('ping_timeout'),
            # team = env_config.get('team'),
            # start_challenging = env_config.get('start_challenging')
        )
        # Try adding this somewhere in init.
        # self.reset_battles()


    def calc_reward(self, last_battle, current_battle) -> float:
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    def embed_battle(self, battle: AbstractBattle) -> ObsType:
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=type_chart
                )

        # We count how many pokemons have fainted in each team
        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )

In [None]:
## Create config
from ray.rllib.algorithms.dqn import DQNConfig
from ray import tune, train
import os

# This is passed to each environment (SimpleRLPlayer) during training
train_env_config = {
    'username': 'tr_SimpleRL',
    'battle_format': "gen9randombattle",
    'start_challenging': True,
    'opponent_class': MaxBasePowerPlayer,
    'opponent_username': 'MaxBasePower',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}
# This is passed to each environment (SimpleRLPlayer) during evaluation
eval_env_config = {
    'username': 'ev_SimpleRL',
    'battle_format': "gen9randombattle",
    'start_challenging': True,
    'opponent_class': MaxBasePowerPlayer,
    'opponent_username': 'MaxBasePower',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

# Guide to RLLib parameters: https://docs.ray.io/en/latest/rllib/rllib-training.html#common-parameters 
config = DQNConfig()
config = config.environment(env = SimpleRLPlayer, env_config = train_env_config)
config = config.resources(
    num_gpus=1,
    # num_learner_workers=8,
    # num_cpus_per_learner_worker=2,
    num_cpus_per_worker=1,
    num_cpus_for_local_worker=2
)
config = config.env_runners(
    # Number of workers to run environments. 0 forces rollouts onto the local worker.
    num_env_runners=20,
    num_envs_per_env_runner=1,
    # Don't cut off episodes before they finish when batching.
    # As a result, the batch size hyperparameter acts as a minimum and batches may vary in size.
    # batch_mode="complete_episodes",
    # Validation creates environments and does not close them, causes problems.
    # validate_env_runners_after_construction=False,
)
# Set training hyperparameters. 
# For descriptions, see: https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
config = config.training(
    # gamma=0.5,
    lr=0.00025,
    # lr=tune.loguniform(10**-6, 10**-4),
    replay_buffer_config={
        "type": "MultiAgentPrioritizedReplayBuffer",
        "capacity": 50000,
    },
    num_steps_sampled_before_learning_starts=1000,
    n_step=1,
    double_q=True,
    num_atoms=1,
    noisy=False,
    dueling=False,
    # train_batch_size=tune.grid_search([4, 8, 12, 16, 20, 32, 64])
)
config = config.evaluation(
    evaluation_interval=4,
    # evaluation_num_workers=1,
    # evaluation_parallel_to_training=True,
    # evaluation_duration="auto",
    evaluation_config={
        "env_config": eval_env_config,
        "metrics_num_episodes_for_smoothing": 4,
        # Set explore True for policy gradient algorithms
        "explore": False,
    },
)
# These settings allows runs to continue after a worker fails for whatever reason.
# config.recreate_failed_env_runners = True

stop = {
    "episode_reward_mean": 30,
    "training_iteration": 90,
}

In [None]:
# Auto hyperparameter tuning


analysis = tune.Tuner(
    config.algo_class,
    param_space=config,
    tune_config=tune.TuneConfig(
        metric="episode_reward_mean",
        mode="max",
        num_samples=8,
        # reuse_actors=True, # Needs reset_config() to be defined and return True for this algorithm
    ),
    run_config=train.RunConfig(
        name="DQN_SimpleRL_Test",
        storage_path=os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'results')),
        stop=stop,
        checkpoint_config=train.CheckpointConfig(
            checkpoint_frequency=4,
            checkpoint_score_attribute="episode_reward_mean",
            num_to_keep=3,
            checkpoint_at_end=True
        ),
    ),
).fit()

In [None]:
# TODO: free resources used by tune once training finished
best_result = analysis.get_best_result(
#     # metric="env_runners/episode_return_mean", 
#     # mode="max"
    )
test_checkpoint = best_result.checkpoint

# Load checkpoint from path
# test_checkpoint = 

# Wait for training battles to finish closing before continuing to testing.
# Without this, showdown gives a nametaken error because the players try to use the same names as in training.
time.sleep(3)

In [None]:
## Test algorithm against baseline players
from ray.rllib.algorithms.algorithm import Algorithm
# Against random player
test_env_config = {
    'username': 'te_SimpleRL',
    'battle_format': "gen9randombattle",
    'start_challenging': True,
    'opponent_class': RandomPlayer,
    'opponent_username': 'RandomPlayer',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

n_battles = 100
# Restore the model checkpoint in test_checkpoint
# TODO: stop .from_checkpoint from running env verification, it creates envs and does not close them properly.
# Alternatively, make a callback that closes them.
test_alg = Algorithm.from_checkpoint(test_checkpoint)
# Create test environment.
test_env = SimpleRLPlayer(env_config=test_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = test_env.reset()
    while not terminated and not truncated and not (test_env.current_battle is None) and not test_env.current_battle.finished:
        action = test_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward

print("\n\nResults against random player:")
print(
    f"DQN Test: {test_env.n_won_battles} victories out of {test_env.n_finished_battles} episodes"
)
test_env.close()

# Against max base power player
test_env_config = {
    'username': 'te_SimpleRL',
    'battle_format': "gen9randombattle",
    'start_challenging': True,
    'opponent_class': MaxBasePowerPlayer,
    'opponent_username': 'MaxBasePower',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

# Create testing environment.
test_env = SimpleRLPlayer(env_config=test_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = test_env.reset()
    while not terminated and not truncated and not (test_env.current_battle is None) and not test_env.current_battle.finished:
        action = test_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward

print("\nResults against Max Base Power player:")
print(
    f"DQN Test: {test_env.n_won_battles} victories out of {test_env.n_finished_battles} episodes"
)
test_env.close()

# Against SimpleHeuristics player
test_env_config = {
    'username': 'te_SimpleRL',
    'battle_format': "gen9randombattle",
    'start_challenging': True,
    'opponent_class': SimpleHeuristicsPlayer,
    'opponent_username': 'SimpleHeuristic',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

# Create testing environment.
test_env = SimpleRLPlayer(env_config=test_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = test_env.reset()
    while not terminated and not truncated and not (test_env.current_battle is None) and not test_env.current_battle.finished:
        action = test_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward

print("\nResults against Simple Heuristic player:")
print(
    f"DQN Test: {test_env.n_won_battles} victories out of {test_env.n_finished_battles} episodes"
)
test_env.close()