In [1]:
import numpy as np
import random
from gymnasium.spaces import Box, Space
from poke_env.player import Gen9EnvSinglePlayer, RandomPlayer, Player
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.teambuilder import Teambuilder

# from poke_env.player

In [2]:
from poke_env import AccountConfiguration, ServerConfiguration
from poke_env.teambuilder.teambuilder import Teambuilder


class QAgent(Player):
    def __init__(self, account_configuration: AccountConfiguration | None = None, *, avatar: str | None = None, battle_format: str = "gen9randombattle", log_level: int | None = None, max_concurrent_battles: int = 1, accept_open_team_sheet: bool = False, save_replays: bool | str = False, server_configuration: ServerConfiguration | None = None, start_timer_on_battle_start: bool = False, start_listening: bool = True, ping_interval: float | None = 20, ping_timeout: float | None = 20, team: str | Teambuilder | None = None):
        super().__init__(account_configuration, avatar=avatar, battle_format=battle_format, log_level=log_level, max_concurrent_battles=max_concurrent_battles, accept_open_team_sheet=accept_open_team_sheet, save_replays=save_replays, server_configuration=server_configuration, start_timer_on_battle_start=start_timer_on_battle_start, start_listening=start_listening, ping_interval=ping_interval, ping_timeout=ping_timeout, team=team)

        self.q_table = {}
        self.epsilon = 0.5
        self.gamma = 0.95
        self.alpha = 0.1
        self.last_state = None
        self.last_action = None
        self.current_state = None
        self.last_hp = 100

    def embed_moves(self, battle: AbstractBattle): # 2 to 13
        embedding = []
        for move in battle.available_moves:
            embedding += [move.base_power, move.type, move.category]

        return embedding

    def list_to_tuple(self, embedding):
        return tuple(embedding)

    def embed_pokemon(self, battle: AbstractBattle): # 0 and 1
        embedding = []

        # embedding += battle.active_pokemon.base_stats
        embedding += battle.active_pokemon.current_hp
        embedding += battle.active_pokemon.type

        return embedding

    def embed_battle(self, battle: AbstractBattle):
        """Return a list containing info about the game state"""
        embedding = (battle.opponent_active_pokemon.base_species,)
        # embedding += self.embed_pokemon(battle)
        # embedding += self.embed_moves(battle)
        # embedding += [battle.opponent_active_pokemon]

        return embedding

    def choose_move(self, battle):
        # print(self.q_table)
        encoding = self.embed_battle(battle)

        if battle.turn == 1:
            self.last_hp = 100

        self.last_state = self.current_state
        self.current_state = encoding
        # print(self.last_hp, battle.opponent_active_pokemon.current_hp, battle.turn)

        if battle.turn > 1:
            self.update_q_table(battle)

        if encoding in self.q_table:
            if random.random() < self.epsilon:
                best_move = np.argmax(self.q_table[encoding])

                self.last_action = best_move
                return self.select_move(best_move, battle)

        random_move = random.randint(0, 3)
        self.last_action = random_move

        return self.select_move(random_move, battle)

    def select_move(self, move, battle):
        return self.create_order(battle.available_moves[move])

    def change_epsilon(self, new):
        self.epsilon = new

    def new_q_table(self, table):
        self.q_table = table

    def calc_reward(self, battle):
        score = 0

        if battle.active_pokemon.fainted:
            score -= 50

        if battle.active_pokemon.fainted:
            score += 50

        hp_diff = (
            self.last_hp - battle.opponent_active_pokemon.current_hp
        )
        self.last_hp = battle.opponent_active_pokemon.current_hp

        return score + hp_diff

    def update_q_table(self, battle):
        if self.last_state not in self.q_table:
            self.q_table[self.last_state] = [0] * 4

        q_old = self.q_table[self.last_state][self.last_action]
        temporal_difference = self.calc_reward(battle) + (self.gamma * max(self.q_table[self.current_state])) - q_old
        q_update = q_old + (self.alpha * temporal_difference)
        # Now put that into our q-table
        self.q_table[self.last_state][self.last_action] = q_update

In [3]:
team_1 = """
Goodra (M) @ Assault Vest
Ability: Sap Sipper
EVs: 248 HP / 252 SpA / 8 Spe
Modest Nature
IVs: 0 Atk
- Dragon Pulse
- Flamethrower
- Sludge Wave
- Thunderbolt
"""

In [4]:
class RandomTeamFromPool(Teambuilder):
    def __init__(self, teams):
        self.packed_teams = []

        for team in teams:
            parsed_team = self.parse_showdown_team(team)
            packed_team = self.join_team(parsed_team)
            self.packed_teams.append(packed_team)

    def yield_team(self):
        return np.random.choice(self.packed_teams)

In [5]:
team_4 = """
Goodra (M) @ Assault Vest
Ability: Sap Sipper
Tera Type: Steel
EVs: 248 HP / 252 SpA / 8 Spe
Modest Nature
IVs: 0 Atk
- Dragon Pulse
- Flamethrower
- Sludge Wave
- Thunderbolt
"""
team_2 = """
Sylveon (M) @ Leftovers
Ability: Pixilate
EVs: 248 HP / 244 Def / 16 SpD
Calm Nature
IVs: 0 Atk
- Hyper Voice
- Quick Attack
- Protect
- Wish
"""
team_3 = """
Blastoise  
Ability: Torrent  
Tera Type: Water  
EVs: 252 HP / 252 SpA / 4 SpD  
Modest Nature  
IVs: 0 Atk  
- Ice Beam  
- Hydro Pump  
- Aura Sphere  
- Surf
"""

teams = [team_4, team_2, team_3]
custom_builder = RandomTeamFromPool(teams)

In [6]:
train = QAgent(team=team_1, battle_format="gen9ou")
p2 = RandomPlayer(battle_format="gen9ou", team=custom_builder)



In [7]:
for i in range(1000):
    await train.battle_against(p2, n_battles=1)

In [8]:
test = QAgent(team=team_1, battle_format="gen9ou")
test.new_q_table(train.q_table.copy())
test.change_epsilon(1)

In [9]:
for i in range(800):
    await test.battle_against(p2, n_battles=1)

In [10]:
test.n_won_battles, test.n_finished_battles

(695, 800)

In [11]:
test.q_table

{('blastoise',): [762.3854991692689,
  742.6792996519973,
  745.4965384955019,
  1015.2398853218008],
 ('goodra',): [417.7780578741145,
  441.86033945551515,
  417.7083596017996,
  419.00303411192436],
 ('sylveon',): [319.09990505528776,
  330.6946594154646,
  390.1186002961094,
  325.2066795387601]}