In [None]:
# -*- coding: utf-8 -*-
from poke_env.player.env_player import Gen8EnvSinglePlayer

class SimpleRLPlayer(Gen8EnvSinglePlayer):
    def embed_battle(self, battle):
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = move.base_power / 100 # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                )

        # We count how many pokemons have not fainted in each team
        remaining_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        remaining_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        return np.concatenate(
            [moves_base_power, moves_dmg_multiplier, [remaining_mon_team, remaining_mon_opponent]]
        )

    def compute_reward(self, battle) -> float:
        return self.reward_computing_helper(
            battle,
            fainted_value=2,
            hp_value=1,
            status_value=0.2,
            victory_value=30,
        )

In [None]:
import asyncio
import numpy as np
from poke_env.utils import to_id_str
import abc
from tf_agents.environments import py_environment
from tf_agents.specs import array_spec, BoundedArraySpec
class PokeEnv(py_environment.PyEnvironment):
    def __init__(self, player, opponent) -> None:
        #self.poke_env = poke_env_creatobr(battle_format="gen8randombattle")
        self.player = player
        self.opponent = opponent
        print(type(self.player.action_space))
        print(self.player.action_space)
        self._action_spec = BoundedArraySpec(
            shape=(), 
            dtype=np.int32, 
            minimum=0, 
            maximum=len(self.player.action_space)-1, 
            name='action')
        
        self._observation_spec = BoundedArraySpec(
            shape=(10,), 
            dtype=np.float32)

    async def launch_battles(player, opponent):
        battles_coroutine = asyncio.gather(
                player.send_challenges(
                    opponent=to_id_str(opponent.username),
                    n_challenges=1,
                    to_wait=opponent.logged_in,
                ),
                opponent.accept_challenges(
                    opponent=to_id_str(player.username), n_challenges=1
                ),
            )
        await battles_coroutine

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def action_spec(self):
        return self._action_spec
    
    def _reset(self):
        return self.poke_env.reset()
    
    def _step(self, action):
        return self.poke_env.step(action)
   

In [None]:
from tf_agents.environments import validate_py_environment
player = SimpleRLPlayer(battle_format="gen8randombattle")
from poke_env.player.random_player import RandomPlayer

player.play_against(opponent=RandomPlayer(battle_format="gen8randombattle"))

environment = PokeEnv(player)
validate_py_environment(environment, episodes=5)
