In [None]:
import numpy as np
import random
from gym.spaces import Box
from poke_env.player import Gen9EnvSinglePlayer
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player.env_player import ObservationType

In [None]:
class QPlayer(Gen9EnvSinglePlayer):

    def __init__(self, battle_format="gen9ou", epsilon=1.0, alpha=0.1, gamma=0.9):
        super().__init__(battle_format=battle_format)
        self.epsilon = epsilon  # Exploration rate
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.q_table = {}  # Q-table

    def calc_reward(self, last_battle, current_battle) -> float:
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    # converts the battle state into a numerical vector representation 
    # extracts important battle information and transforms it into a feature vector with move power, move effectiveness, and num of fainted pokemons
    def embed_battle(self, battle: AbstractBattle) -> ObservationType:
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = move.base_power / 100 if move.base_power else -1
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                )

        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6

        final_vector = np.concatenate(
            [moves_base_power, moves_dmg_multiplier, [fainted_mon_team, fainted_mon_opponent]]
        )
        return np.float32(final_vector)

    # defines the observation space??
    # function ensures your AI knows what state values are valid
    def describe_embedding(self) -> Box:
        low = np.array([-1, -1, -1, -1, 0, 0, 0, 0, 0, 0], dtype=np.float32)
        high = np.array([3, 3, 3, 3, 4, 4, 4, 4, 1, 1], dtype=np.float32)
        return Box(low, high, dtype=np.float32)
    
    def select_action(self, battle):
        state = tuple(self.embed_battle(battle))  # Convert state to a hashable type
        if state not in self.q_table:
            self.q_table[state] = np.zeros(4)
        
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(range(4))  # Explore
        else:
            return int(np.argmax(self.q_table[state]))  # Exploit

    def update_q_values(self, state, action, reward, next_state):
        state, next_state = tuple(state), tuple(next_state)  # Convert to tuple
        if next_state not in self.q_table:
            self.q_table[next_state] = np.zeros(4)

        best_next_action = np.max(self.q_table[next_state])
        temporal_difference = reward + (self.gamma * best_next_action) - self.q_table[state][action]
        self.q_table[state][action] += self.alpha * temporal_difference

    # def battle_end(self, battle):
    #     self.epsilon = max(0.01, self.epsilon * 0.995)  # Gradually reduce exploration

    def reset_battle(self):
        return self.reset()  # Reset the battle environment

    def step(self, action):
        return self.act(action)  # Execute action in battle

    def train(self, num_episodes=1000):
        for episode in range(num_episodes):
            battle = self.reset_battle()  # Start a new battle
            state = self.embed_battle(battle)  # Get initial state

            while not battle.finished:
                action = self.select_action(battle)
                
                next_battle = self.step(action)  # Perform action in battle
                next_state = self.embed_battle(next_battle)
                reward = self.calc_reward(battle, next_battle)

                self.update_q_values(state, action, reward, next_state)

                state = next_state  # Move to the next state

            self.battle_end(battle)  # Reduce exploration rate after each battle
