In [1]:
import asyncio

In [14]:
import numpy as np
import torch

In [8]:
from gym.spaces import Space, Box

In [2]:
from poke_env.environment import AbstractBattle
from poke_env.player import Player
from poke_env.player import Gen8EnvSinglePlayer

In [16]:
battle_format = 'gen8randombattle'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class MaxDamagePlayer(Player):
    """ Adapted from poke-env documentation.

    References:
    [1] Creating a simple max damage player.
        https://poke-env.readthedocs.io/en/stable/examples/max_damage_player.html
    """
    
    def choose_move(self, battle: AbstractBattle):
        # If the player can attack, it will
        if battle.available_moves:
            # Finds the best move among available ones
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)

        # If no attack is available, a random switch will be made
        else:
            return self.choose_random_move(battle)

In [9]:
class SimpleRLEnv(Gen8EnvSinglePlayer):
    """ Adapted from poke-env documentation.
    
    References:
    [1] Reinforcement learning with the OpenAI Gym wrapper.
        https://poke-env.readthedocs.io/en/stable/examples/rl_with_open_ai_gym_wrapper.html
    [2] Type Chart causing a KeyError only containing a type as the error message.
        https://github.com/hsahovic/poke-env/issues/484
    """

    def calc_reward(self, last_battle, current_battle) -> float:
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    def embed_battle(self, battle: AbstractBattle) -> np.ndarray:
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            # print(move)
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = battle.opponent_active_pokemon.damage_multiplier(move)

        # We count how many pokemons have fainted in each team
        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )
    

In [10]:
opponent = MaxDamagePlayer(battle_format=battle_format)

In [11]:
simple_env = SimpleRLEnv(
        battle_format=battle_format,
        opponent=opponent,
        start_challenging=True,
    )

In [80]:
simple_env.action_space.sample()

11

In [35]:
action = torch.tensor([[simple_env.action_space.sample()]], device=device, dtype=torch.long)
print(action)

tensor([[18]])


In [36]:
action.item()

18

In [37]:
observation, reward, terminated, truncated, _ = simple_env.step(action.item())

In [38]:
observation

array([1.1, 0.9, 1. , 0. , 1. , 0.5, 2. , 0.5, 0. , 0. ], dtype=float32)

In [39]:
reward

-0.375

In [40]:
terminated

False

In [41]:
truncated

False

In [42]:
# step thru with random actions
def step():
    simple_env.action_space.sample()
    action = torch.tensor([[simple_env.action_space.sample()]], device=device, dtype=torch.long)
    observation, reward, terminated, truncated, _ = simple_env.step(action.item())
    print(reward)

In [78]:
step()

RuntimeError: Battle is already finished, call reset

In [81]:
simple_env.reset()

(array([1.3 , 0.75, 0.  , 0.  , 0.25, 0.5 , 0.25, 0.5 , 0.  , 0.  ],
       dtype=float32),
 {})

In [82]:
step()

-0.1215708812260532
