This notebook's code is from https://github.com/hsahovic/poke-env/blob/master/examples/rl_with_new_open_ai_gym_wrapper.py

In [None]:
import numpy as np

from gymnasium.spaces import Box, Space
from gymnasium.utils.env_checker import check_env
# from gymnasium.envs.registration import register

from openrl.modules.common import DQNNet
from openrl.runners.common import DQNAgent
from openrl.envs.common import make

from tabulate import tabulate

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.data import GenData
type_chart = GenData(8).type_chart

from poke_env.player import (
    Gen8EnvSinglePlayer,
    MaxBasePowerPlayer,
    ObsType,
    RandomPlayer,
    SimpleHeuristicsPlayer,
    background_cross_evaluate,
    background_evaluate_player,
)


In [None]:
class SimpleRLPlayer(Gen8EnvSinglePlayer):
    agent_num = 1 # necessary for OpenRL's environment instantiation
    
    def calc_reward(self, last_battle, current_battle) -> float:
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    def embed_battle(self, battle: AbstractBattle) -> ObsType:
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=type_chart
                )

        # We count how many pokemons have fainted in each team
        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )

In [None]:
## Create the base environments to copy with make later.

# First test the environment to ensure the class is consistent
# with the OpenAI API
# opponent = RandomPlayer(battle_format="gen8randombattle")
# test_env = SimpleRLPlayer(
#     battle_format="gen8randombattle", start_challenging=True, opponent=opponent
# )
# check_env(test_env)
# test_env.close()

# Create one environment for training and one for evaluation
opponent = RandomPlayer(battle_format="gen8randombattle")
train_env_base = SimpleRLPlayer(
    battle_format="gen8randombattle", opponent=opponent, start_challenging=True
)

In [None]:
from openrl.envs.common import build_envs
from typing import Callable, List, Optional, Union
from gymnasium import Env
import copy
from openrl.configs.config import create_config_parser
cfg_parser = create_config_parser()

# register the base train environment in gymnasium.
# register(id="SimpleRLPlayer/train_env", entry_point=SimpleRLPlayer)
# create custom make function to aid in setting up environments.
def make_train_env(id, render_mode, disable_env_checker, **kwargs):
    return train_env_base
def make_train_envs(
    id: str,
    env_num: int = 1,
    render_mode: Optional[Union[str, List[str]]] = None,
    **kwargs,
) -> List[Callable[[], Env]]:
    env_wrappers = copy.copy(kwargs.pop("env_wrappers", []))
    env_fns = build_envs(
        make=make_train_env,
        id=id,
        env_num=env_num,
        render_mode=render_mode,
        wrappers=env_wrappers,
        **kwargs,
    )
    return env_fns

train_envs = make(id="SimpleRLPlayer/train_env", env_num=4, make_custom_envs = make_train_envs)

In [None]:
# Compute dimensions
# n_action = train_envs.action_space.n
# input_shape = (1,) + train_envs.observation_space.shape

# Create model
# model = Sequential()
# model.add(layers.Dense(128, activation="elu", input_shape=input_shape))
# model.add(layers.Flatten())
# model.add(layers.Dense(64, activation="elu"))
# model.add(layers.Dense(n_action, activation="linear"))

# Defining the DQN
# memory = SequentialMemory(limit=10000, window_length=1)

# policy = LinearAnnealedPolicy(
#     EpsGreedyQPolicy(),
#     attr="eps",
#     value_max=1.0,
#     value_min=0.05,
#     value_test=0.0,
#     nb_steps=10000,
# )

# dqn = DQNAgent(
#     model=model,
#     nb_actions=n_action,
#     policy=policy,
#     memory=memory,
#     nb_steps_warmup=1000,
#     gamma=0.5,
#     target_model_update=1,
#     delta_clip=0.01,
#     enable_double_dqn=True,
# )
# dqn.compile(optimizers.Adam(learning_rate=0.00025), metrics=["mae"])

# TODO: Need to set config for dqn, here's the format:
# https://github.com/OpenRL-Lab/openrl/blob/main/examples/cartpole/dqn_cartpole.yaml
# https://github.com/OpenRL-Lab/openrl/blob/main/examples/cartpole/train_dqn_beta.py

dqn_cfg = cfg_parser.parse_args(["--config", "configs/basic_dqn.yaml"])

dqn_cfg.env = train_envs

dqn_net = DQNNet(train_envs, cfg = dqn_cfg)

dqn_agent = DQNAgent(dqn_net, env=train_envs)

# Training the model
dqn_agent.train(total_time_steps=10000)
# dqn.fit(train_env, nb_steps=10000)
# train_envs.close()

In [None]:
# Evaluating the model
opponent = RandomPlayer(battle_format="gen8randombattle")
eval_env = SimpleRLPlayer(
    battle_format="gen8randombattle", opponent=opponent, start_challenging=True
)

print("Results against random player:")
dqn.test(eval_env, nb_episodes=100, verbose=False, visualize=False)
print(
    f"DQN Evaluation: {eval_env.n_won_battles} victories out of {eval_env.n_finished_battles} episodes"
)
second_opponent = MaxBasePowerPlayer(battle_format="gen8randombattle")
eval_env.reset_env(restart=True, opponent=second_opponent)
print("Results against max base power player:")
dqn.test(eval_env, nb_episodes=100, verbose=False, visualize=False)
print(
    f"DQN Evaluation: {eval_env.n_won_battles} victories out of {eval_env.n_finished_battles} episodes"
)
eval_env.reset_env(restart=False)

In [None]:
# Evaluate the player with included util method
n_challenges = 250
placement_battles = 40
eval_task = background_evaluate_player(
    eval_env.agent, n_challenges, placement_battles
)
dqn.test(eval_env, nb_episodes=n_challenges, verbose=False, visualize=False)
print("Evaluation with included method:", eval_task.result())
eval_env.reset_env(restart=False)

In [None]:
# Cross evaluate the player with included util method
n_challenges = 50
players = [
    eval_env.agent,
    RandomPlayer(battle_format="gen8randombattle"),
    MaxBasePowerPlayer(battle_format="gen8randombattle"),
    SimpleHeuristicsPlayer(battle_format="gen8randombattle"),
]
cross_eval_task = background_cross_evaluate(players, n_challenges)
dqn.test(
    eval_env,
    nb_episodes=n_challenges * (len(players) - 1),
    verbose=False,
    visualize=False,
)
cross_evaluation = cross_eval_task.result()
table = [["-"] + [p.username for p in players]]
for p_1, results in cross_evaluation.items():
    table.append([p_1] + [cross_evaluation[p_1][p_2] for p_2 in results])
print("Cross evaluation of DQN with baselines:")
print(tabulate(table))
eval_env.close()