In [5]:
# imports

import asyncio
import json
import os
import matplotlib
import neptune
import nest_asyncio
import numpy as np
import pandas as pd
import time

from collections import defaultdict
from datetime import date
from itertools import product
from matplotlib import pyplot
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player.battle_order import ForfeitBattleOrder
from poke_env.player.player import Player
from scipy.interpolate import griddata
from poke_env.player import Gen8EnvSinglePlayer, RandomPlayer
from poke_env import AccountConfiguration, ShowdownServerConfiguration
from poke_env.player import player as Sampleplayer
# from src.PlayerQLearning import Player as PlayerQLearning
import numpy as np
from stable_baselines3 import A2C
from gymnasium.spaces import Box
from poke_env.data import GenData

from poke_env.player import Gen9EnvSinglePlayer, RandomPlayer
import os

In [6]:
class SimpleRLPlayer(Gen9EnvSinglePlayer):
    def embed_battle(self, battle):
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                    move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=GEN_9_DATA.type_chart
                )

        # We count how many pokemons have not fainted in each team
        remaining_mon_team = (
                len([mon for mon in battle.team.values() if mon.fainted]) / 6
        )
        remaining_mon_opponent = (
                len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        return np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [remaining_mon_team, remaining_mon_opponent],
            ]
        )

    def calc_reward(self, last_state, current_state) -> float:
        return self.reward_computing_helper(
            current_state, fainted_value=2, hp_value=1, victory_value=30
        )

    def describe_embedding(self):
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )


class MaxDamagePlayer(RandomPlayer):
    def choose_move(self, battle):
        # If the player can attack, it will
        if battle.available_moves:
            # Finds the best move among available ones
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)

        # If no attack is available, a random switch will be made
        else:
            return self.choose_random_move(battle)

In [7]:
NB_TRAINING_STEPS = 10000
NB_EVALUATION_EPISODES = 10
MODEL_PATH = "a2c_pokemon_model"

np.random.seed(0)
GEN_9_DATA = GenData.from_gen(9)

In [20]:
opponent = MaxDamagePlayer()
env_player = SimpleRLPlayer(opponent=opponent)

model = A2C("MlpPolicy", env_player, verbose=1)
model = A2C.load(MODEL_PATH, env=env_player)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [21]:
finished_episodes = 0
try:
    env_player.reset_battles()
except:
    pass
obs, _ = env_player.reset()
while True:

    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, _, info = env_player.step(action)

    if done:
        finished_episodes += 1
        obs, _ = env_player.reset()
        if finished_episodes >= NB_EVALUATION_EPISODES:
            break
print("Won", env_player.n_won_battles, "battles against", env_player._opponent)

Won 4 battles against <__main__.MaxDamagePlayer object at 0x00000280AEAC2350>
