#### Start Local Pokemon Showdown Server
If on the devcontainer, this is done automatically every time the container is started. Otherwise:  
cd into your pokemon-showdown directory  
node pokemon-showdown start --no-security

In [None]:
%load_ext tensorboard

In [None]:
# This needs to be run at the beginning or else testing will fail later.
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
tf1.enable_eager_execution()

import numpy as np

from tqdm import tqdm

import time, copy
from typing import List, Any, Union

from gymnasium.spaces import Box, Space

# from tabulate import tabulate

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.data import GenData
type_chart = GenData(9).type_chart
from poke_env.ps_client.account_configuration import AccountConfiguration, CONFIGURATION_FROM_PLAYER_COUNTER
from poke_env.player import (
    Gen9EnvSinglePlayer,
    MaxBasePowerPlayer,
    ObsType,
    RandomPlayer,
    SimpleHeuristicsPlayer
)

from ray.rllib.env.env_context import EnvContext

In [None]:
# helper functions

def create_account_configuration(username: str, unique_ids: Union[List[Any], None] = None) -> AccountConfiguration:
    # If no unique ids provided, create a unique id using the counter.
    if unique_ids is None:
        # NOTE: This is not thread-safe! Provide unique ids for multithreading / multi workers.
        # The 's' in the resulting username indicates single thread.
        CONFIGURATION_FROM_PLAYER_COUNTER.update([username])
        unique_username = "%s s %d" % (username, CONFIGURATION_FROM_PLAYER_COUNTER[username])
        if len(unique_username) > 18:
            unique_username = "%s s%d" % (
                username[: 18 - len(unique_username)],
                CONFIGURATION_FROM_PLAYER_COUNTER[username],
            )
    else:
        unique_ids_str = ' '.join(map(str, unique_ids))
        unique_username = "%s %s" % (username, unique_ids_str)
        unique_username = unique_username.strip()
        if len(unique_username) > 18:
            unique_username = "%s %s" % (
                username[: 18 - len(unique_username)],
                unique_ids_str,
            )
        
    return AccountConfiguration(unique_username.strip(), None)

In [None]:
class SimpleRLPlayer(Gen9EnvSinglePlayer):
    def __init__(self, env_config: Union[EnvContext, dict]):
        # Get the player config dict for easy access, using a deep copy so we can make changes to it.
        player_config = copy.deepcopy(env_config.get('player_config'))

        # Create list of the unique ids of the worker running this env.
        if type(env_config) is EnvContext:
            # The unique ids from EnvContext are the worker index, and the vector index.
            # Worker index is index of the rollout worker that this env is running on, when there are multiple workers.
            # Vector index is index of this env on this worker, when there are multiple envs per worker.
            # TODO: Add trial index from env_config if it is added in future RLLib update.
            unique_ids = [env_config.worker_index, env_config.vector_index]
        else:
            # Set it to None. This activates create_account_configuration's player counter.
            unique_ids = None

        # Create unique account configuration for this from username to prevent multithreading naming conflicts
        # account_configuration should initially be None, unless this is created as opponent for self-play
        if player_config.get('account_configuration') is None:
            # If no username provided, make one from this class name. 
            username =  env_config.get('username')
            if username is None:
                username = type(self).__name__
            # Create unique account configuration from username, worker process id, and worker env id 
            player_config['account_configuration'] = create_account_configuration(username=username, unique_ids=unique_ids)
        
        # if opponent class provided, instantiate it with its config.
        opponent_class = env_config.get('opponent_class')
        if opponent_class is not None:
            # Get the opponent config dict for easy access, using a deep copy so we can make changes to it. This mirrors player_config.
            opponent_config = copy.deepcopy(env_config.get('opponent_config'))

            # Create unique opponent account configuration if none provided.
            # opponent account_configuration should be None when using parallelization or multiple workers
            if opponent_config.get('account_configuration') is None:
                # If no opponent username provided, make one from its class name. 
                opponent_username =  env_config.get('opponent_username')
                if opponent_username is None:
                    opponent_username = type(opponent_class).__name__
                # Create unique opponent account configuration from username, worker process index, and env index on this worker.
                opponent_config['account_configuration'] = create_account_configuration(username=opponent_username, unique_ids=unique_ids)
            # Instantiate the opponent class with opponent_config.
            player_config['opponent'] = opponent_class(**opponent_config)
        
        super().__init__(**player_config)


    def calc_reward(self, last_battle, current_battle) -> float:
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    def embed_battle(self, battle: AbstractBattle) -> ObsType:
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=type_chart
                )

        # We count how many pokemons have fainted in each team
        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )

In [None]:
## Create config
from ray.rllib.algorithms.dqn import DQNConfig
from ray import tune, train
import os

# This is passed to each environment (SimpleRLPlayer) during training.
# 'player_config' is passed as a kwarg to the super().__init__() of SimpleRLPlayer's Gen9EnvSinglePlayer superclass.
train_env_config = {
    'username': 'tr_SimpleRL',
    'player_config': {
        'battle_format': "gen9randombattle",
        'start_challenging': True,
    },
    'opponent_class': MaxBasePowerPlayer,
    'opponent_username': 'tr_MaxBP',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}
# This is passed to each environment (SimpleRLPlayer) during evaluation
eval_env_config = {
    'username': 'ev_SimpleRL',
    'player_config': {
        'battle_format': "gen9randombattle",
        'start_challenging': True,
    },
    'opponent_class': MaxBasePowerPlayer,
    'opponent_username': 'ev_MaxBP',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

# Guide to RLLib parameters: https://docs.ray.io/en/latest/rllib/rllib-training.html#common-parameters

config = DQNConfig()
config = config.environment(env = SimpleRLPlayer, env_config = train_env_config)
# Set the framework to use. "tf2" for tensorflow, "torch" for PyTorch. Dev container is set up for Tensorflow 2.13.
config = config.framework(framework="tf2")
config = config.resources(
    num_cpus_for_main_process=4,
    num_gpus=0,
)
config = config.learners(
    num_learners=0,
    # num_gpus_per_learner=0
)
config = config.env_runners(
    # Number of cpus assigned to each env_runner. Does not improve sampling speed very much on its own. 
    num_cpus_per_env_runner=1,
    # Number of workers to run environments. 0 forces rollouts onto the local worker. Each uses the above number of cpus.
    num_env_runners=4,
    # Number of environments on each env_runner worker, increasing this drastically improves sampling speed.
    num_envs_per_env_runner=4,
    # Don't cut off episodes before they finish when batching.
    # As a result, the batch size hyperparameter acts as a minimum and batches may vary in size.
    batch_mode="complete_episodes",
    # Validation creates environments and does not close them, causes problems.
    # validate_env_runners_after_construction=False,
    rollout_fragment_length=50,
    # rollout_fragment_length="auto",
    explore=True,
    exploration_config = {
            "type": "EpsilonGreedy",
            "initial_epsilon": 1.0,
            "final_epsilon": 0.0,
            "epsilon_timesteps": 80000,
    }
)
# Set training hyperparameters. 
# For descriptions, see: https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
config = config.training(
    gamma=0.8,
    lr = 0.0002,
    # lr=tune.loguniform(8e-5,8e-3),
    optimizer={
        "weight_decay": 0.02,
        # "betas": [0.9, 0.999] # May need tuning, this is default.
    },
    replay_buffer_config={
        "type": "MultiAgentPrioritizedReplayBuffer",
        "capacity": 100000,
    },
    num_steps_sampled_before_learning_starts=1000,
    v_min=-48, # minimum reward
    v_max=48, # maximum reward
    # n_step=1,
    double_q=False,
    # double_q=tune.grid_search([True, False]),
    num_atoms=1,
    noisy=False,
    # noisy=tune.grid_search([True, False]),
    dueling=False,
    # dueling=tune.grid_search([True, False]),
    train_batch_size=1200,
)
config = config.evaluation(
    evaluation_interval=1,
    evaluation_num_env_runners=4,
    # evaluation_parallel_to_training=True,
    evaluation_duration=30,
    evaluation_config={
        "env_config": eval_env_config,
        "metrics_num_episodes_for_smoothing": 4,
        # Set explore True for policy gradient algorithms
        "explore": False,
    },
)
# These settings allows runs to continue after a worker fails for whatever reason.
config = config.fault_tolerance(recreate_failed_env_runners=True)

In [None]:
## Set stopping criteria for the trials
from ray.tune.stopper import CombinedStopper, MaximumIterationStopper, TrialPlateauStopper

stopper = CombinedStopper(
    MaximumIterationStopper(max_iter=120),
    TrialPlateauStopper(metric="evaluation/env_runners/episode_reward_mean"),
)

In [None]:
# Auto hyperparameter tuning

# Currently, using multiple environments only works when part of the same trial,
# because the usernames created from the unique worker number and env number taken from the EnvContext passed in when
# initializing a SimpleRLPlayer environment are the same across different trials, resulting in duplicate usernames across trials.

analysis = tune.Tuner(
    config.algo_class,
    param_space=config,
    tune_config=tune.TuneConfig(
        num_samples=1,
        max_concurrent_trials=1,
        # reuse_actors=True, # Needs reset_config() to be defined and return True for this algorithm
        # Need to get an external library for search algorithms from https://docs.ray.io/en/master/tune/key-concepts.html#tune-search-algorithms
        # search_alg= NoneProvided # random search, not ideal.
        # scheduler= NoneProvided, # When using concurrent trials, this ends or changes poorly performing trials early.
    ),
    run_config=train.RunConfig(
        name="DQN_SimpleRL_v_MaxBP_1",
        storage_path=os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'results')),
        stop=stopper,
        checkpoint_config=train.CheckpointConfig(
            checkpoint_frequency=1,
            # checkpoint_score_attribute is the metric to use to determine which checkpoints to keep.
            checkpoint_score_attribute="evaluation/env_runners/episode_reward_mean",
            # Only the best num_to_keep checkpoints are saved, using checkpoint_score_attribute as the metric to compare.
            num_to_keep=1,
            # checkpoint_score_order determines whether a higher ("max") or lower ("min") checkpoint_score_attribute is better.
            checkpoint_score_order="max",
            # checkpoint_at_end=True
        ),
    ),

).fit()

# Wait for training battles to finish closing before continuing to testing.
# Without this, showdown may give a nametaken error because the players try to use the same names as in training.
time.sleep(1)

In [None]:
## Get the path of the best saved checkpoint out of all the trials and iterations in this run.
best_result = analysis.get_best_result(
    metric="evaluation/env_runners/episode_reward_mean", 
    mode="max"
    )
print(best_result)
test_checkpoint = best_result.checkpoint

In [None]:
## Manually load checkpoint from path
# If manually loading a checkpoint from a path, you can skip all above cells after SimpleRLPlayer class creation.
# The test_checkpoint path should end with the checkpoint_XXXXXX directory, where X's are the checkpoint number with leading 0s.

# test_checkpoint = "../results/DQN_SimpleRL_v_MaxBP_1/DQN_SimpleRLPlayer_9fc91_00010_10_train_batch_size=900_2024-07-26_13-50-33/checkpoint_000019"

In [None]:
## Restore the model checkpoint in test_checkpoint
from ray.rllib.algorithms.algorithm import Algorithm

test_alg = Algorithm.from_checkpoint(test_checkpoint)

In [None]:
## Test algorithm against baseline players

# Against random player
test_env_config = {
    'username': 'te_SimpleRL',
    'player_config': {
        'battle_format': "gen9randombattle",
        'start_challenging': True,
    },
    'opponent_class': RandomPlayer,
    'opponent_username': 'RandomPlayer',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

n_battles = 100

# Create test environment.
test_env = SimpleRLPlayer(env_config=test_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = test_env.reset()
    while not terminated and not truncated and not (test_env.current_battle is None) and not test_env.current_battle.finished:
        action = test_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward

print("\n\nResults against random player:")
print(
    f"DQN Test: {test_env.n_won_battles} victories out of {test_env.n_finished_battles} episodes"
)
test_env.close()

# Against max base power player
test_env_config = {
    'username': 'te_SimpleRL',
    'player_config': {
        'battle_format': "gen9randombattle",
        'start_challenging': True,
    },
    'opponent_class': MaxBasePowerPlayer,
    'opponent_username': 'MaxBasePower',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

# Create testing environment.
test_env = SimpleRLPlayer(env_config=test_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = test_env.reset()
    while not terminated and not truncated and not (test_env.current_battle is None) and not test_env.current_battle.finished:
        action = test_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward

print("\nResults against Max Base Power player:")
print(
    f"DQN Test: {test_env.n_won_battles} victories out of {test_env.n_finished_battles} episodes"
)
test_env.close()

# Against SimpleHeuristics player
test_env_config = {
    'username': 'te_SimpleRL',
    'player_config': {
        'battle_format': "gen9randombattle",
        'start_challenging': True,
    },
    'opponent_class': SimpleHeuristicsPlayer,
    'opponent_username': 'SimpleHeur',
    'opponent_config': {
        'battle_format': "gen9randombattle",
    },
}

# Create testing environment.
test_env = SimpleRLPlayer(env_config=test_env_config)

for i in tqdm(range(n_battles), leave=False):
    episode_reward = 0
    terminated = truncated = False
    obs, info = test_env.reset()
    while not terminated and not truncated and not (test_env.current_battle is None) and not test_env.current_battle.finished:
        action = test_alg.compute_single_action(obs)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward

print("\nResults against Simple Heuristic player:")
print(
    f"DQN Test: {test_env.n_won_battles} victories out of {test_env.n_finished_battles} episodes"
)
test_env.close()

In [None]:
## Close the test algorithm to free up resources.
test_alg.stop()